当前位置: 首页 > article >正文

8、深入剖析PyTorch的state_dict、parameters、modules源码

文章目录

  • 1. 重要类
  • 2. 保存模型
  • 3. 代码测试

1. 重要类

  • container.py
  • nn.sequential
  • nn.modulelist
  • save_state_dict

2. 保存模型

pytorch官网教程

3. 代码测试

比较急,后续完善

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :ToTest01.py
# @Time      :2024/11/24 10:37
# @Author    :Jason Zhang
import torch
from torch import nn
from torch.nn import Module


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear1 = nn.Linear(2, 3)
        self.linear2 = nn.Linear(3, 4)
        self.batch_norm4 = nn.BatchNorm2d(4)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


if __name__ == "__main__":
    run_code = 0
    input_x = torch.randn((1, 2))
    test_model = MyModel()
    y = test_model(input_x)
    model_modules = test_model._modules
    print(f"*"*50)
    print(f"model_modules=\n{model_modules}")
    print(f"*"*50)
    linear1 = model_modules['linear1']
    print(f"*"*50)
    print(f"linear1={linear1}")
    print(f"*"*50)
    print(f"linear1.weight=\n{linear1.weight}")
    print(f"*"*50)
    print(f"linear1.weight.dtype={linear1.weight.dtype}")
    print(f"*"*50)
    test_model.to(torch.double)
    print(f"linear1.weight.dtype={linear1.weight.dtype}")
    print(f"*"*50)
    test_model.to(torch.float32)
    print(f"linear1.weight.dtype={linear1.weight.dtype}")
    print(f"*"*50)
    model_parameters = test_model._parameters
    print(f"model_parameters={model_parameters}")
    print(f"*"*50)
    model_buffers = test_model._buffers
    print(f"model_buffer={model_buffers}")
    print(f"*"*50)
    model_state_dict = test_model.state_dict()
    print(f"model_state_dict=\n{model_state_dict}")
    print(f"*"*50)
    model_state_dict_linear2 = test_model.state_dict()['linear2.weight']
    print(f"model_state_dict_linear2=\n{model_state_dict_linear2}")
    print(f"*"*50)
    model_named_para =list(test_model.named_parameters())
    print(f"model_named_para=\n{model_named_para}")
    print(f"*"*50)
    model_named_modules =list(test_model.named_modules())
    print(f"model_named_modules=\n{model_named_modules}")
    print(f"*"*50)
    model_named_buffers =list(test_model.named_buffers())
    print(f"model_named_buffers=\n{model_named_buffers}")
    print(f"*"*50)
    model_named_children =list(test_model.named_children())
    print(f"model_named_children=\n{model_named_children}")


  • 结果:
**************************************************
model_modules=
OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])
**************************************************
**************************************************
linear1=Linear(in_features=2, out_features=3, bias=True)
**************************************************
linear1.weight=
Parameter containing:
tensor([[-0.5518,  0.0687],
        [-0.7013,  0.4869],
        [-0.1157, -0.1287]], requires_grad=True)
**************************************************
linear1.weight.dtype=torch.float32
**************************************************
linear1.weight.dtype=torch.float64
**************************************************
linear1.weight.dtype=torch.float32
**************************************************
model_parameters=OrderedDict()
**************************************************
model_buffer=OrderedDict()
**************************************************
model_state_dict=
OrderedDict([('linear1.weight', tensor([[-0.5518,  0.0687],
        [-0.7013,  0.4869],
        [-0.1157, -0.1287]])), ('linear1.bias', tensor([-0.2915, -0.4807,  0.0071])), ('linear2.weight', tensor([[ 0.4185,  0.1556,  0.1371],
        [ 0.4751,  0.2029, -0.0679],
        [ 0.1264, -0.0288, -0.3661],
        [ 0.4423, -0.5370,  0.3930]])), ('linear2.bias', tensor([ 0.2746, -0.1798,  0.0218,  0.5465])), ('batch_norm4.weight', tensor([1., 1., 1., 1.])), ('batch_norm4.bias', tensor([0., 0., 0., 0.])), ('batch_norm4.running_mean', tensor([0., 0., 0., 0.])), ('batch_norm4.running_var', tensor([1., 1., 1., 1.])), ('batch_norm4.num_batches_tracked', tensor(0))])
**************************************************
model_state_dict_linear2=
tensor([[ 0.4185,  0.1556,  0.1371],
        [ 0.4751,  0.2029, -0.0679],
        [ 0.1264, -0.0288, -0.3661],
        [ 0.4423, -0.5370,  0.3930]])
**************************************************
model_named_para=
[('linear1.weight', Parameter containing:
tensor([[-0.5518,  0.0687],
        [-0.7013,  0.4869],
        [-0.1157, -0.1287]], requires_grad=True)), ('linear1.bias', Parameter containing:
tensor([-0.2915, -0.4807,  0.0071], requires_grad=True)), ('linear2.weight', Parameter containing:
tensor([[ 0.4185,  0.1556,  0.1371],
        [ 0.4751,  0.2029, -0.0679],
        [ 0.1264, -0.0288, -0.3661],
        [ 0.4423, -0.5370,  0.3930]], requires_grad=True)), ('linear2.bias', Parameter containing:
tensor([ 0.2746, -0.1798,  0.0218,  0.5465], requires_grad=True)), ('batch_norm4.weight', Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)), ('batch_norm4.bias', Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True))]
**************************************************
model_named_modules=
[('', MyModel(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (linear2): Linear(in_features=3, out_features=4, bias=True)
  (batch_norm4): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)), ('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))]
**************************************************
model_named_buffers=
[('batch_norm4.running_mean', tensor([0., 0., 0., 0.])), ('batch_norm4.running_var', tensor([1., 1., 1., 1.])), ('batch_norm4.num_batches_tracked', tensor(0))]
**************************************************
model_named_children=
[('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))]

http://www.kler.cn/a/408668.html

相关文章:

  • 基于STM32的火灾报警装置的Proteus仿真
  • 基于之前的秒杀功能的优化(包括Sentinel在SpringBoot中的简单应用)
  • Typora+PicGo+云服务器搭建博客图床
  • neo4j图数据库community-5.50创建多个数据库————————————————
  • key-value存储实现
  • 基于物联网设计的人工淡水湖养殖系统(华为云IOT)_253
  • GCC编译过程(预处理,编译,汇编,链接)及GCC命令
  • 如果在docker 容器中安装ros遇到的问题
  • 《MySQL 事务隔离级别详解》
  • 学习Servlet(Servlet实现方式3)
  • Knife4j快速入门
  • 【redis】哈希类型详解
  • 【pip install报SSL类错误】
  • 【Anaconda】Pycharm如何配置conda虚拟环境
  • 深入理解 JVM 中的 G1 垃圾收集器原理、算法、过程和参数配置
  • YOLOv11融合[ECCV 2018]RCAN中的RCAB模块及相关改进思路
  • _computed _destinations() 为什么模板不写()
  • 渗透测试---shell(6)if条件判断与for循环结构
  • Vue小项目(开发一个购物车)
  • realme gt neo6官方刷机包 全量升级包下载
  • jar包解压和重新打包
  • 微信小程序 表单验证(async-validator)
  • 基于Gradle搭建Spring6.2.x版本源码阅读环境
  • Alluxio在小红书的实践:加速云端机器学习
  • HarmonyOS Next 浅谈 发布-订阅模式
  • 【热门主题】000062 云原生后端:开启高效开发新时代