当前位置: 首页 > article >正文

5.深度学习计算

5.2 参数管理

每个网络都由各层组成,一个网络模块中的层可由索引访问

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
print(net[2])

输出:

Linear(in_features=8, out_features=1, bias=True)

5.2.1 参数访问

网络中的参数一般是指各层权重和偏置

若想访问某层的参数,用层来调用state_dict()函数

print(net[2].state_dict())

输出:

OrderedDict([('weight', tensor([[-0.3343,  0.3289, -0.0063,  0.0594, -0.1051, -0.3419,  0.2796,  0.0557]])), ('bias', tensor([0.3026]))])

可直接对各层参数进行调用

由于需要目标函数对参数求梯度进行优化,所以需要记录梯度

所以各层的参数也具有属性grad,梯度初始化为None

net[2].weight.grad == None

5.2.1.2 一次性访问所有参数

使用named_parameters()访问所有参数

各层或者整个模型都可以调用

print(*[(name, param.shape) for name, param in net[0].named_parameters()])
print([(name, param.shape) for name, param in net.named_parameters()])

输出:

('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))
[('0.weight', torch.Size([8, 4])), ('0.bias', torch.Size([8])), ('2.weight', torch.Size([1, 8])), ('2.bias', torch.Size([1]))]
注:

print 函数的 * 操作符用于将列表中的每个元组作为独立的参数传递给 print,这样 print 函数会直接打印列表中的元组


 


http://www.kler.cn/news/355731.html

相关文章:

  • 【MR开发】在Pico设备上接入MRTK3(二)——在Unity中配置Pico SDK
  • 利用Arcgis进行沟道形态分析
  • maven工程怎么将除工程源码外的三方依赖一起打包到jar中
  • Vue3 集成Monaco Editor编辑器
  • AI学习指南深度学习篇-自编码器的基本原理
  • MobaXterm 中文乱码
  • WordPress添加meta标签做seo优化
  • node.js下载安装以及环境配置超详细教程【Windows版本】
  • SAP 为 Copilot Joule 增添协作功能
  • 【算法系列-字符串】反转字符串中的单词
  • 深入探讨ASP.NET Core中间件及其请求处理管道特性
  • 1.1 C++语言基础面试问题
  • 基于Arduino做的“鱿鱼游戏”BOSS面具,支持动作检测
  • 软件设计——数据流图
  • arm-none-linux-gnueabi-strip的作用
  • Springboot接入Elastic
  • AWTK fscript 中的 widget 扩展函数
  • C++ 学习笔记 十二 结构体
  • 深度学习框架-Keras的常用内置数据集总结
  • nacos组件介绍