Chapter5.4 Loading and saving model weights in PyTorch
5 Pretraining on Unlabeled Data
5.4 Loading and saving model weights in PyTorch
-
训练LLM的计算成本很高,因此能够保存和加载LLM的权重至关重要。
-
在PyTorch中,推荐的方式是通过将
torch.save
函数应用于.state_dict()
方法来保存模型权重,即所谓的state_dict
:torch.save(model.state_dict(),"model.pth")
我们可以将模型权重加载到新的 GPTModel 模型实例中
model = GPTModel(GPT_CONFIG_124M) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True)) model.eval();
-
自适应优化器(如 AdamW)为每个模型权重存储额外的参数。AdamW 使用历史数据动态调整每个模型参数的学习率。如果没有这些参数,优化器会重置,模型可能会学习效果不佳,甚至无法正确收敛,这意味着模型将失去生成连贯文本的能力。使用
torch.save
,我们可以保存模型和优化器的state_dict
内容,如下所示torch.save({ "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, "model_and_optimizer.pth" )
然后,我们可以通过以下方式恢复模型和优化器状态:首先通过 torch.load 加载保存的数据,然后使用 load_state_dict 方法:
checkpoint = torch.load("model_and_optimizer.pth", weights_only=True) model = GPTModel(GPT_CONFIG_124M) model.load_state_dict(checkpoint["model_state_dict"]) optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.1) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) model.train();