当前位置: 首页 > article >正文

Chapter5.4 Loading and saving model weights in PyTorch

5 Pretraining on Unlabeled Data

5.4 Loading and saving model weights in PyTorch

  • 训练LLM的计算成本很高,因此能够保存和加载LLM的权重至关重要。

  • 在PyTorch中,推荐的方式是通过将torch.save函数应用于.state_dict()方法来保存模型权重,即所谓的state_dict

    torch.save(model.state_dict(),"model.pth")
    

    我们可以将模型权重加载到新的 GPTModel 模型实例中

    model = GPTModel(GPT_CONFIG_124M)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True))
    model.eval();
    
  • 自适应优化器(如 AdamW)为每个模型权重存储额外的参数。AdamW 使用历史数据动态调整每个模型参数的学习率。如果没有这些参数,优化器会重置,模型可能会学习效果不佳,甚至无法正确收敛,这意味着模型将失去生成连贯文本的能力。使用 torch.save,我们可以保存模型和优化器的 state_dict 内容,如下所示

    torch.save({
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        }, 
        "model_and_optimizer.pth"
    )
    

    然后,我们可以通过以下方式恢复模型和优化器状态:首先通过 torch.load 加载保存的数据,然后使用 load_state_dict 方法:

    checkpoint = torch.load("model_and_optimizer.pth", weights_only=True)
    
    model = GPTModel(GPT_CONFIG_124M)
    model.load_state_dict(checkpoint["model_state_dict"])
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.1)
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    model.train();
    


http://www.kler.cn/a/505990.html

相关文章:

  • Android BottomNavigationView不加icon使text垂直居中,完美解决。
  • vue3使用vue-native-websocket-vue3通讯
  • SK海力士(SK Hynix)是全球领先的半导体制造商之一,其在无锡的工厂主要生产DRAM和NAND闪存等存储器产品。
  • Spring Boot + MyBatis-Flex 配置 ProxySQL 的完整指南
  • 前端常见的设计模式之【单例模式】
  • 微信小程序在使用页面栈保存页面信息时,如何避免数据丢失?
  • 【机器学习实战入门项目】基于机器学习的鸢尾花分类项目
  • C++:工具VSCode的编译和调试文件内容:
  • Python爬虫:从入门到实践
  • 路由环路的产生原因与解决方法(1)
  • 在Android 15的设备上关闭edge-to-edge功能
  • uniapp 页面铺满屏幕
  • STM32 FreeRTOS 信号量
  • 使用docker-compose安装ELK(elasticsearch,logstash,kibana)并简单使用
  • Web基础-分层解耦-IOC与DI入门(具体的是实现步骤)
  • 遥感原理及图像处理
  • 向量数据库Milvus详解
  • day_2_排序算法和树
  • IOS工程师
  • 隧道网络:为数据传输开辟安全通道
  • HttpClient和HttpGet实现音频数据的高效爬取与分析
  • Unity中实现倒计时结束后干一些事情
  • Leetcode 72. 编辑距离 动态规划
  • ASP.NET Core - 配置系统之自定义配置提供程序
  • pytorch小记(六):pytorch中的clone和detach操作:克隆/复制数据 vs 共享相同数据但 与计算图断开联系
  • 【Vue】父组件向子组件传递参数;子组件向父组件触发自定义事件