当前位置: 首页 > article >正文

pytorch nn.Parameter模块介绍

在 PyTorch 中,nn.Parameter 是一个用于定义可训练参数的模块。它通常用于自定义模型时,将张量注册为模型的一部分,使其在训练过程中能够被优化。

nn.Parameter 的作用

  1. 可训练性:将一个普通张量转换为 Parameter 后,它会被自动添加到模型的参数列表中(model.parameters()),并参与梯度计算和优化。
  2. 模块关联Parameter 通常与 nn.Module 配合使用,用于定义模型的权重或偏置。

方法签名

torch.nn.Parameter(data, requires_grad=True)
参数说明
  • data: 初始化 Parameter 的张量。
  • requires_grad: 是否计算梯度。默认为 True,意味着它会参与反向传播。

用法示例

示例 1:将张量定义为可训练参数
import torch
from torch.nn import Parameter

# 创建一个普通张量
tensor = torch.randn(3, 3)

# 转换为 nn.Parameter
param = Parameter(tensor)
print("参数值:\n", param)
print("是否计算梯度:", param.requires_grad)
示例 2:在自定义模型中使用 nn.Parameter
import torch
import torch.nn as nn

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        # 使用 nn.Parameter 定义一个可训练参数
        self.weight = nn.Parameter(torch.randn(5, 5))
        self.bias = nn.Parameter(torch.randn(5))

    def forward(self, x):
        # 使用定义的参数进行计算
        return x @ self.weight + self.bias

# 实例化模型
model = CustomModel()
print("模型参数:")
for name, param in model.named_parameters():
    print(f"{name}: {param.shape}")
示例 3:控制 requires_grad
param = nn.Parameter(torch.randn(4, 4), requires_grad=False)
print("是否计算梯度:", param.requires_grad)

如果 requires_grad=False,则参数不会在反向传播中更新。

注意事项

  1. 与 torch.Tensor 的区别

    • 普通张量不会被自动添加到模型的参数列表中。
    • 使用 nn.Parameter 可以确保张量是模型的一部分,参与优化。
  2. 冻结参数: 如果需要临时冻结 nn.Parameter 的更新,可以手动设置其 requires_grad=False

  3. model.weight.requires_grad = False
    
  4. 自定义参数初始化: 可以在定义 nn.Parameter 时使用自定义初始化:

  5. self.weight = nn.Parameter(torch.zeros(10, 10))
    

常见应用场景

  • 自定义权重和偏置:当模型结构中需要手动定义权重或偏置时,nn.Parameter 是最佳选择。
  • 实现特殊模块:比如需要权重共享或参数固定的模型模块。
  • 控制参数是否参与优化:通过 requires_grad,可以灵活控制某些参数是否更新。

通过 nn.Parameter,开发者可以更加灵活地构造自定义模型,并充分利用 PyTorch 的自动梯度和优化功能。


http://www.kler.cn/a/457153.html

相关文章:

  • Three.js教程004:坐标辅助器与轨道控制器
  • android studio gradle 如何解决下载依赖一直卡住的问题
  • 【Wi-Fi】802.11u、WPA、WPA2/WPA3-ENterprise、Hotspot 、IEEE802.11x的关系
  • VScode 格式化代码空格记录
  • Linux文件描述符
  • C#-使用StbSharp库读写图片
  • Python|Pyppeteer实现自动化获取reCaptcha验证码图片以及提示词(29)
  • Debian-linux运维-ssh配置(兼容Jenkins插件的ssh连接公钥类型)
  • 【JS笔记】快速安装nodejs(九)
  • 雪花算法(Snowflake algorithm)介绍、优缺点及代码示例
  • upload-labs关卡记录17
  • 服务器时间不同步
  • Redis到底支不支持事务啊?
  • Docker安装GPUStack详细教程
  • 知识碎片-环境配置
  • 设计模式通俗解释
  • 基于springboot校园招聘系统源码和论文
  • c++---------------------------string
  • 深入解析JVM中对象的创建过程
  • 用 Python 从零开始创建神经网络(十八):模型对象(Model Object)
  • 隨筆20241226 ExcdlJs 將數據寫入excel
  • C# winform 报错:类型“System.Int32”的对象无法转换为类型“System.Int16”。
  • WPF编程excel表格操作
  • PDB数据库解析:
  • C 语言中 strlen 函数的深入剖析
  • leetcdoe 1670.设计前中后队列