当前位置: 首页 > article >正文

使用亚马逊针对 PyTorch 和 MinIO 的 S3 连接器进行模型检查点处理

2023 年 11 月,Amazon 宣布推出适用于 PyTorch 的 S3 连接器。适用于 PyTorch 的 Amazon S3 连接器提供了专为 S3 对象存储构建的 PyTorch 数据集基元(数据集和数据加载器)的实现。它支持用于随机数据访问模式的地图样式数据集和用于流式处理顺序数据访问模式的可迭代样式数据集。适用于 PyTorch 的 S3 连接器还包括一个检查点接口,用于将检查点直接保存和加载到 S3 存储桶,而无需先保存到本地存储。如果您还没有准备好采用正式的 MLOps 工具,而只需要一种简单的方法来保存模型,那么这是一个非常好的选择。这就是我将在这篇文章中介绍的内容。S3 连接器的文档仅展示了如何将其与 Amazon S3 一起使用 - 我将在此处向您展示如何将其用于 MinIO。让我们先执行此作 - 让我们设置 S3 连接器,以便它从 MinIO 写入和读取检查点。

将 S3 连接器连接到 MinIO

将 S3 连接器连接到 MinIO 就像设置环境变量一样简单。之后,一切都会顺利进行。诀窍是以正确的方式设置正确的环境变量。

本文的代码下载使用 .env 文件来设置环境变量,如下所示。此文件还显示了我用于使用 MinIO Python SDK 直接连接到 MinIO 的环境变量。请注意,AWS_ENDPOINT_URL 需要 protocol,而 MinIO 变量不需要。

AWS_ACCESS_KEY_ID=admin
AWS_ENDPOINT_URL=http://172.31.128.1:9000
AWS_REGION=us-east-1
AWS_SECRET_ACCESS_KEY=password
MINIO_ENDPOINT=172.31.128.1:9000
MINIO_ACCESS_KEY=admin
MINIO_SECRET_KEY=password
MINIO_SECURE=false

写入和读取 Checkpoint

我从一个简单的例子开始。下面的代码段创建了一个 S3Checkpointing 对象,并使用其 writer() 方法将模型的状态字典发送到 MinIO。我还使用 Torchvision 创建了一个 ResNet-18(18 层)模型,用于演示目的。

import os

from dotenv import load_dotenv
from s3torchconnector import S3Checkpoint
import torchvision
import torch

# Load the credentials and connection information.
load_dotenv()

model = torchvision.models.resnet18()
model_name = 'resnet18.pth'
bucket_name = 'checkpoints'

checkpoint_uri = f's3://{bucket_name}/{model_name}'
s3_checkpoint = S3Checkpoint(os.environ['AWS_REGION'])

# Save checkpoint to S3
with s3_checkpoint.writer(checkpoint_uri) as writer:
   torch.save(model.state_dict(), writer)

请注意,该区域有一个强制参数。从技术上讲,访问 MinIO 时没有必要,但如果为此变量选择错误的值,内部检查可能会失败。此外,您的存储桶必须存在,上述代码才能正常工作。如果 writer() 方法不存在,它将引发错误。不幸的是,无论出了什么问题,writer() 方法都会引发相同的错误。例如,如果您的存储桶不存在,您将收到如下所示的错误。如果 writer() 方法不喜欢您指定的区域,您也会收到相同的错误。希望未来的版本将提供更具描述性的错误消息。

S3Exception: Client error: Request canceled

将以前保存的模型读取到内存中的代码类似于写入 MinIO。使用 reader() 方法,而不是 writer() 方法。下面的代码显示了如何执行此作。

import os

from dotenv import load_dotenv
from s3torchconnector import S3Checkpoint
import torchvision
import torch

# Load the credentials and connection information.
load_dotenv()

model_name = 'resnet18.pth'
bucket_name = 'checkpoints'

checkpoint_uri = f's3://{bucket_name}/{model_name}'
s3_checkpoint = S3Checkpoint(os.environ['AWS_REGION'])

# Load checkpoint from S3
with s3_checkpoint.reader(checkpoint_uri) as reader:
   state_dict = torch.load(reader, weights_only=True)

model.load_state_dict(state_dict)

接下来,让我们看看模型训练期间检查点的一些实际注意事项。

在模型训练期间编写检查点

如果您使用大型数据集训练大型模型,请考虑在每个 epoch 后设置检查点。这些训练运行可能需要数小时甚至数天才能完成,因此在发生故障时能够从上次中断的地方继续非常重要。此外,我们假设您必须使用共享存储桶来保存来自多个团队的多个模型的模型检查点。MLOps 约定是按试验组织训练运行。例如,如果您正在研究具有四个隐藏层的架构,那么在寻找各种超参数的最佳值时,您将使用此架构进行多次运行。如果同事使用五层体系结构运行实验,则需要一种方法来防止名称冲突。这可以通过模拟如下所示的层次结构的对象路径来解决。

最后,为了确保您在每个 epoch 中获得新版本的模型,请确保在用于保存检查点的存储桶上启用版本控制。下面的训练函数使用上述路径结构在每个 epoch 后对模型进行检查点作。(可以在本文的代码下载中找到此训练函数的更强大版本。

def train_model(model: nn.Module, loader: DataLoader, 
                training_parameters: Dict[str, Any]) -> List[float]:

   if training_parameters['checkpoint']:
       checkpoint_uri = f's3://{training_parameters["checkpoint_bucket"]} \
                          /{training_parameters["project_name"]} \
                          /{training_parameters["experiment_name"]} \
                          /{training_parameters["run_id"]} \
                          /{training_parameters["model_name"]}'
       s3_checkpoint = S3Checkpoint(region=os.environ['AWS_REGION'])

   loss_func = nn.NLLLoss()
   optimizer = optim.SGD(model.parameters(), lr=training_parameters['lr'], 
                         momentum=training_parameters['momentum'])

   # Epoch loop
   compute_time_by_epoch = []
   for epoch in range(training_parameters['epochs']):
       # Batch loop
       for images, labels in loader:

           # Flatten MNIST images into a 784 long vector.
           # shape = [32, 784]
           images = images.view(images.shape[0], -1)

           # Training pass
           optimizer.zero_grad()
           output = model(images)
           loss = loss_func(output, labels)
           loss.backward()
           optimizer.step()

       # Save checkpoint to S3
       if training_parameters['checkpoint']:
           with s3_checkpoint.writer(checkpoint_uri) as writer:
               torch.save(model.state_dict(), writer)

请注意,模型名称不包含指示纪元的子字符串。如前所述,我使用了启用了版本控制的存储桶 - 换句话说,版本号表示纪元。这种方法的优点在于,您无需知道引用最新模型的 epoch 数。在上述训练代码运行了 10 个 epoch 后,我的检查点存储桶如下面的屏幕截图所示。

此培训演示可被视为 DIY MLOps 解决方案的开始。

结论

适用于 PyTorch 的 S3 连接器易于使用,工程师在使用时编写的数据访问代码行数会更少。在本文中,我展示了如何将其配置为使用环境变量连接到 MinIO。配置完成后,工程师可以分别使用 writer() 和 reader() 方法将检查点写入和读取 MinIO。在本文中,我展示了如何配置 S3 Connect 以连接到 MinIO。我还演示了 S3Checkpoint 类及其 reader() 和 writer() 方法的基本用法。最后,我展示了一种在实际训练函数中针对启用了版本的检查点存储桶使用这些检查点功能的方法。在这篇文章中,我没有介绍在分布式训练期间检查点所需的技术和工具,这可能有点棘手。分布式训练期间的检查点设置会有所不同,具体取决于您使用的框架(PyTorch、Ray 或 DeepSpeed 等)和您正在进行的分布式训练类型:数据并行(每个工作程序都有模型的完整副本)或模型并行(每个工作程序只有一个模型分片)。在以后的文章中,我将介绍其中的一些技术。


http://www.kler.cn/a/540685.html

相关文章:

  • 【C++八股】 前置 ++i vs. 后置 i++ 的区别
  • CSDN 博客之星 2024:肖哥弹架构的社区耕耘总结
  • 防御综合实验
  • 循环神经网络学习01——transformer:输入部分-嵌入层位置编码
  • Centos Ollama + Deepseek-r1+Chatbox运行环境搭建
  • 工业相机在工业生产制造过程中的视觉检测技术应用
  • php 实现 deepSeek聊天对话
  • MacOS安装Milvus向量数据库
  • 【AIGC】在VSCode中集成 DeepSeek(OPEN AI同理)
  • 蓝桥杯算法日记|贪心、双指针
  • 石英表与机械表的世纪之争(Quartz vs. Mechanical Watches):瑞士钟表业的危机与重生(中英双语)
  • 如何在Kickstart自动化安装完成后ISO内拷贝文件到新系统或者执行命令
  • 目标检测数据集合集(持续更新中)
  • centos docker安装
  • 【C#零基础从入门到精通】(八)——C#String字符串详解
  • 【华为OD-E卷 - 120 分割数组的最大差值 100分(python、java、c++、js、c)】
  • ABP框架9——自定义拦截器的实现与使用
  • 如何使用Socket编程在Python中实现实时聊天应用
  • 笔试-字符串2
  • Web前端开发--HTML
  • java后端开发day10--综合练习(一)
  • 基于“感知–规划–行动”的闭环系统架构
  • DeepSeek+3D视觉机器人应用场景、前景和简单设计思路
  • 深入理解TCP/IP协议栈:从原理到实践
  • Linux: ASoC 声卡硬件参数的设置过程简析
  • 协议-ACLLite-ffmpeg