使用亚马逊针对 PyTorch 和 MinIO 的 S3 连接器实现可迭代式数据集
2023 年 11 月,Amazon 宣布推出适用于 PyTorch 的 S3 连接器。适用于 PyTorch 的 Amazon S3 连接器提供了专为 S3 对象存储构建的 PyTorch 数据集基元(数据集和数据加载器)的实现。它支持用于随机数据访问模式的地图样式数据集和用于流式处理顺序数据访问模式的可迭代样式数据集。在上一篇文章中,我介绍了适用于 Pytorch 的 S3 连接器,并详细描述了它打算解决的问题。我还介绍了过去即将弃用的库,以支持 S3 连接器。具体而言,请勿使用适用于 PyTorch 的 Amazon S3 插件和基于 CPP 的 S3 IO DataPipe。最后,我介绍了地图样式的数据集。我不会在这里回顾所有这些介绍性信息,所以如果你还没有读过我之前的文章,请尽早查看。在这篇文章中,我将重点介绍可迭代样式的数据集。此连接器的文档仅显示了从 Amazon S3 加载数据的示例 – 在这里,我将向您展示如何对 MinIO 使用它。适用于 PyTorch 的 S3 连接器还包括一个检查点接口,用于将检查点直接保存和加载到 S3 存储桶中,而无需先保存到本地存储。如果您还没有准备好采用正式的 MLOps 工具,而只需要一种简单的方法来保存模型,那么这是一个非常好的选择。我将在以后的文章中介绍此功能。
手动构建 Iterable Style 数据集
可迭代样式的数据集是通过实现一个类来创建的,该类覆盖了 PyTorch 的 IterableDataset 基类中的 iter() 方法。与地图样式数据集不同,没有 len() 方法,也没有 getitem() 方法。如果使用 Python 的 len() 函数查询可交互数据集,则会收到错误,因为 len() 方法不存在。
在训练循环期间调用可迭代样式的数据集时,您可以返回多个样本。具体来说,您将返回一个迭代器对象,数据加载器将迭代该对象以创建所需的批处理。让我们构建一个非常简单的自定义可迭代样式数据集,以更好地了解它们的工作原理。下面的代码显示了如何覆盖 iter() 方法。完整的代码下载可以在这里找到。
class MyIterableDataset(IterableDataset):
def __init__(self, start: int, end: int, transpose):
self.start = start
self.end = end
self.transpose = transpose
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = self.start
iter_end = self.end
worker_id = -1
else: # in a worker process
worker_id = worker_info.id
# split workload
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
samples = []
for sample in range(iter_start, iter_end):
samples.append(sample)
return map(self.transpose, samples)
def my_transpose(x):
return torch.tensor(x)
请注意,您必须自己跟踪分片。当您创建具有多个具有此数据集的工作程序的数据加载程序时,您需要确保确定每个工作程序要处理的数据份额(每个分片)。这是使用 worker_info 对象和一些简单的数学运算来完成的。我们可以创建这个数据集并使用下面的代码循环它,这类似于训练循环。
batch_size = 2
ds = MyIterableDataset3(start=0, end=10, transpose=my_transpose)
dl = DataLoader(ds, batch_size=batch_size, num_workers=2)
for sample in dl:
print(sample)
输出将为:
tensor([0, 1])
tensor([5, 6])
tensor([2, 3])
tensor([7, 8])
tensor([4])
tensor([9])
现在我们已经了解了可迭代数据集,让我们使用 S3 连接器的可迭代数据集。但在执行此操作之前,让我们看看如何让 S3 连接器连接到 MinIO。
将 S3 连接器连接到 MinIO
将 S3 连接器连接到 MinIO 就像设置环境变量一样简单。之后,一切都会顺利进行。诀窍是以正确的方式设置正确的环境变量。本文的代码下载使用 .env 文件来设置环境变量,如下所示。此文件还显示了我用于使用 MinIO Python SDK 直接连接到 MinIO 的环境变量。请注意,AWS_ENDPOINT_URL 需要 protocol,而 MinIO 变量不需要。此外,你可能会注意到 AWS_REGION 变量的一些奇怪行为。从技术上讲,访问 MinIO 时不需要它,但如果为此变量选择错误的值,则 S3 连接器中的内部检查可能会失败。如果您收到这些错误之一,请仔细阅读该消息并指定它请求的值。
AWS_ACCESS_KEY_ID=admin
AWS_ENDPOINT_URL=http://172.31.128.1:9000
AWS_REGION=us-east-1
AWS_SECRET_ACCESS_KEY=password
MINIO_ENDPOINT=172.31.128.1:9000
MINIO_ACCESS_KEY=admin
MINIO_SECRET_KEY=password
MINIO_SECURE=false
使用 S3 连接器创建可迭代样式的数据集
要使用 S3 连接器创建可迭代样式的数据集,您无需像以前那样编写和创建类。S3IterableDataset.from_prefix() 函数将为您完成所有工作。此函数假定您已设置环境变量以连接到 S3 对象存储,如上一节所述。它还要求可以通过 S3 前缀找到您的对象。下面是一个演示如何使用此函数的代码段。
from s3torchconnector import S3IterableDataset
uri = f's3://{bucket_name}/{split}'
aws_region = os.environ['AWS_REGION']
dataset = S3IterableDataset.from_prefix(uri, region=aws_region,
enable_sharding=True,
transform=S3IterTransform(transform))
loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
return loader, (time.perf_counter()-start_time)
请注意,URI 是 S3 路径。在路径 mnist/train 下可以递归找到的每个对象都应该是属于训练集的对象。如果要使用数据加载器中的多个工作程序(num_workers 参数),请务必在数据集上设置 enable_sharding 参数。上述函数还需要一个 transform 来将对象转换为张量并确定标签。这是通过如下所示的可调用类的实例完成的。
from s3torchconnector import S3Reader
class S3IterTransform:
def __init__(self, transform):
self.transform = transform
def __call__(self, object: S3Reader) -> torch.Tensor:
content = object.read()
image_pil = Image.open(BytesIO(content))
image_tensor = self.transform(image_pil)
label = int(object.key.split('/')[1])
return (image_tensor, label)
这就是使用 S3 Connector for PyTorch 创建地图样式数据集所需要做的全部工作。
结论
适用于 PyTorch 的 S3 连接器易于使用,工程师在使用时编写的数据访问代码更少。在本文中,我展示了如何将其配置为使用环境变量连接到 MinIO。配置完成后,三行代码创建了一个可迭代的数据集对象,该对象使用简单的可调用类进行转换。
后续步骤
如果您的网络是训练管道中最薄弱的环节,请考虑创建包含多个样本的对象,您甚至可以对其进行 tar 或 zip 处理。遗憾的是,S3 连接器无法对地图样式或可迭代样式的数据集执行此操作。在以后的博文中,我将展示如何使用自定义构建的可迭代样式数据集来完成此操作。