解决 TypeError: Expected state_dict to be dict-like, got <class ‘*‘>.
这是一个简洁的错误复现和解决文章
文章目录
- 错误原因
- 错误重现
- 正确加载演示
- 拓展阅读
错误原因
一般是因为混合使用不同的保存和加载方式,问题出在你用 load_state_dict()
去加载别人使用torch.save(model)
保存的整个模型。
错误重现
下面我们来复现它,看是不是和你的操作一致:
- 错误地保存整个
model
而不是其state_dict
:import torch import torch.nn as nn # 定义一个线性模型进行演示 class LinearModel(nn.Module): def __init__(self, input_size, output_size): super(LinearModel, self).__init__() self.linear = nn.Linear(input_size, output_size) def forward(self, x): return self.linear(x) # 创建模型实例 model = LinearModel(input_size=10, output_size=1) # 打印模型结构 print("Model:", model) # 保存模型的 state_dict torch.save(model.state_dict(), './linear_model_state_dict.pth')
- 加载时传入
model
对象:
输出:# 创建一个新的模型实例 new_model = LinearModel(input_size=10, output_size=1) # 加载 state_dict 到新模型 new_model.load_state_dict(torch.load('./linear_model_state_dict.pth')) # 打印加载后的新模型结构 print("Model loaded with state_dict:", new_model)
Error: Expected state_dict to be dict-like, got <class '__main__.LinearModel'>.
正确加载演示
下面是两种保存和加载的方法,任选其一即可。
import torch
import torch.nn as nn
# 定义一个线性模型
class LinearModel(nn.Module):
def __init__(self, input_size, output_size):
super(LinearModel, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
return self.linear(x)
# 创建模型实例
model = LinearModel(input_size=10, output_size=1)
print("Model:", model)
# 方法 1:保存和加载 state_dict
# 保存模型的 state_dict
torch.save(model.state_dict(), './linear_model_state_dict.pth')
# 创建一个新的模型实例
new_model = LinearModel(input_size=10, output_size=1)
# 加载 state_dict 到新模型
new_model.load_state_dict(torch.load('./linear_model_state_dict.pth'))
# 方法 2:保存和加载整个模型
# 保存整个模型
torch.save(model, './linear_model.pth')
# 加载整个模型
loaded_model = torch.load('./linear_model.pth')
拓展阅读
PyTorch 模型保存与加载的三种常用方式