8、深入剖析PyTorch的state_dict、parameters、modules源码
文章目录
- 1. 重要类
- 2. 保存模型
- 3. 代码测试
1. 重要类
- container.py
- nn.sequential
- nn.modulelist
- save_state_dict
2. 保存模型
pytorch官网教程
3. 代码测试
比较急,后续完善
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName :ToTest01.py
# @Time :2024/11/24 10:37
# @Author :Jason Zhang
import torch
from torch import nn
from torch.nn import Module
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear1 = nn.Linear(2, 3)
self.linear2 = nn.Linear(3, 4)
self.batch_norm4 = nn.BatchNorm2d(4)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
if __name__ == "__main__":
run_code = 0
input_x = torch.randn((1, 2))
test_model = MyModel()
y = test_model(input_x)
model_modules = test_model._modules
print(f"*"*50)
print(f"model_modules=\n{model_modules}")
print(f"*"*50)
linear1 = model_modules['linear1']
print(f"*"*50)
print(f"linear1={linear1}")
print(f"*"*50)
print(f"linear1.weight=\n{linear1.weight}")
print(f"*"*50)
print(f"linear1.weight.dtype={linear1.weight.dtype}")
print(f"*"*50)
test_model.to(torch.double)
print(f"linear1.weight.dtype={linear1.weight.dtype}")
print(f"*"*50)
test_model.to(torch.float32)
print(f"linear1.weight.dtype={linear1.weight.dtype}")
print(f"*"*50)
model_parameters = test_model._parameters
print(f"model_parameters={model_parameters}")
print(f"*"*50)
model_buffers = test_model._buffers
print(f"model_buffer={model_buffers}")
print(f"*"*50)
model_state_dict = test_model.state_dict()
print(f"model_state_dict=\n{model_state_dict}")
print(f"*"*50)
model_state_dict_linear2 = test_model.state_dict()['linear2.weight']
print(f"model_state_dict_linear2=\n{model_state_dict_linear2}")
print(f"*"*50)
model_named_para =list(test_model.named_parameters())
print(f"model_named_para=\n{model_named_para}")
print(f"*"*50)
model_named_modules =list(test_model.named_modules())
print(f"model_named_modules=\n{model_named_modules}")
print(f"*"*50)
model_named_buffers =list(test_model.named_buffers())
print(f"model_named_buffers=\n{model_named_buffers}")
print(f"*"*50)
model_named_children =list(test_model.named_children())
print(f"model_named_children=\n{model_named_children}")
- 结果:
**************************************************
model_modules=
OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])
**************************************************
**************************************************
linear1=Linear(in_features=2, out_features=3, bias=True)
**************************************************
linear1.weight=
Parameter containing:
tensor([[-0.5518, 0.0687],
[-0.7013, 0.4869],
[-0.1157, -0.1287]], requires_grad=True)
**************************************************
linear1.weight.dtype=torch.float32
**************************************************
linear1.weight.dtype=torch.float64
**************************************************
linear1.weight.dtype=torch.float32
**************************************************
model_parameters=OrderedDict()
**************************************************
model_buffer=OrderedDict()
**************************************************
model_state_dict=
OrderedDict([('linear1.weight', tensor([[-0.5518, 0.0687],
[-0.7013, 0.4869],
[-0.1157, -0.1287]])), ('linear1.bias', tensor([-0.2915, -0.4807, 0.0071])), ('linear2.weight', tensor([[ 0.4185, 0.1556, 0.1371],
[ 0.4751, 0.2029, -0.0679],
[ 0.1264, -0.0288, -0.3661],
[ 0.4423, -0.5370, 0.3930]])), ('linear2.bias', tensor([ 0.2746, -0.1798, 0.0218, 0.5465])), ('batch_norm4.weight', tensor([1., 1., 1., 1.])), ('batch_norm4.bias', tensor([0., 0., 0., 0.])), ('batch_norm4.running_mean', tensor([0., 0., 0., 0.])), ('batch_norm4.running_var', tensor([1., 1., 1., 1.])), ('batch_norm4.num_batches_tracked', tensor(0))])
**************************************************
model_state_dict_linear2=
tensor([[ 0.4185, 0.1556, 0.1371],
[ 0.4751, 0.2029, -0.0679],
[ 0.1264, -0.0288, -0.3661],
[ 0.4423, -0.5370, 0.3930]])
**************************************************
model_named_para=
[('linear1.weight', Parameter containing:
tensor([[-0.5518, 0.0687],
[-0.7013, 0.4869],
[-0.1157, -0.1287]], requires_grad=True)), ('linear1.bias', Parameter containing:
tensor([-0.2915, -0.4807, 0.0071], requires_grad=True)), ('linear2.weight', Parameter containing:
tensor([[ 0.4185, 0.1556, 0.1371],
[ 0.4751, 0.2029, -0.0679],
[ 0.1264, -0.0288, -0.3661],
[ 0.4423, -0.5370, 0.3930]], requires_grad=True)), ('linear2.bias', Parameter containing:
tensor([ 0.2746, -0.1798, 0.0218, 0.5465], requires_grad=True)), ('batch_norm4.weight', Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)), ('batch_norm4.bias', Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True))]
**************************************************
model_named_modules=
[('', MyModel(
(linear1): Linear(in_features=2, out_features=3, bias=True)
(linear2): Linear(in_features=3, out_features=4, bias=True)
(batch_norm4): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)), ('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))]
**************************************************
model_named_buffers=
[('batch_norm4.running_mean', tensor([0., 0., 0., 0.])), ('batch_norm4.running_var', tensor([1., 1., 1., 1.])), ('batch_norm4.num_batches_tracked', tensor(0))]
**************************************************
model_named_children=
[('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))]