当前位置: 首页 > article >正文

详细分析Pytorch中的register_buffer基本知识(附Demo)

目录

  • 1. 基本知识
  • 2. Demo
  • 3. 与自动注册的差异
    • 3.1 torch.nn.Parameter
    • 3.2 自动注册子模块
    • 3.3 总结

1. 基本知识

register_buffer 是 PyTorch 中 torch.nn.Module 提供的一个方法,允许用户将某些张量注册为模块的一部分,但不会被视为可训练参数。这些张量会随模型保存和加载,但在反向传播过程中不会更新

register_buffer 的作用

  • 将张量注册为模型的缓冲区(buffer),意味着这些张量会与模型一起保存和加载
  • 与参数不同,缓冲区不会参与梯度计算,因此不会在训练时更新
  • 常用于存储像均值、方差、掩码或其他状态信息

与模型参数的区别

  • 模型参数(register_parameterself.param = nn.Parameter(tensor)):这些张量会被认为是可学习的权重,在训练过程中会被优化器更新
  • 缓冲区(register_buffer):这些张量不会被优化器更新,适合用于保存模型的常量或中间状态

使用的场景有如下:

  1. 存储一些与训练无关但随模型保存的常量
  2. 存储一些在 eval() 模式下需要使用的统计数据(如 BatchNorm 层中的均值和方差)
  3. 缓存计算中间的状态或掩码,不希望它们在训练中被更新

2. Demo

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 注册一个缓冲区, 不参与训练
        self.register_buffer('constant_tensor', torch.tensor([1.0, 2.0, 3.0]))
    
    def forward(self, x):
        # 使用缓冲区中的张量
        return x + self.constant_tensor

model = MyModel()
print(model.constant_tensor)  # 输出缓冲区内容 tensor([1., 2., 3.])

在训练模式和评估模式中的差异

model.train()  # 训练模式
print(model(torch.tensor([1.0, 1.0, 1.0])))  # tensor([2.0, 3.0, 4.0])

model.eval()  # 评估模式
print(model(torch.tensor([1.0, 1.0, 1.0])))  # tensor([2.0, 3.0, 4.0])

缓冲区内容不会因为模型模式的切换(train() 或 eval())而改变,因为缓冲区不是可训练参数,它仅存储数据

示例 1: 存储中间计算的掩码

class MaskedModel(nn.Module):
    def __init__(self):
        super(MaskedModel, self).__init__()
        self.register_buffer('mask', torch.tensor([1.0, 0.0, 1.0]))
    
    def forward(self, x):
        # 应用掩码到输入上
        return x * self.mask

model = MaskedModel()
input_tensor = torch.tensor([4.0, 3.0, 2.0])
output = model(input_tensor)
print(output)  # tensor([4.0, 0.0, 2.0])

示例 2: 保存 BatchNorm 层的均值和方差
在批归一化层(BatchNorm)中,均值和方差是通过 register_buffer 来存储的
这些值在训练时会动态更新,但在推理时会固定使用

class MyBatchNorm(nn.Module):
    def __init__(self, num_features):
        super(MyBatchNorm, self).__init__()
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
    
    def forward(self, x):
        # 这里简单模拟了批归一化的逻辑
        return (x - self.running_mean) / torch.sqrt(self.running_var + 1e-5)

bn_layer = MyBatchNorm(3)
input_tensor = torch.tensor([2.0, 3.0, 4.0])
output = bn_layer(input_tensor)
print(output)

示例 3: 用于固定的均值和方差
不希望在训练时更新某些统计数据(比如均值和方差),而希望使用固定的值

class FixedNorm(nn.Module):
    def __init__(self):
        super(FixedNorm, self).__init__()
        # 注册固定的均值和方差
        self.register_buffer('mean', torch.tensor([0.5]))
        self.register_buffer('std', torch.tensor([0.25]))
    
    def forward(self, x):
        return (x - self.mean) / self.std

model = FixedNorm()
input_tensor = torch.tensor([1.0, 0.5, 0.0])
output = model(input_tensor)
print(output)  # tensor([ 2.0, 0.0, -2.0])

3. 与自动注册的差异

补充与上述不同的知识点,上述为手动注册,下面为自动注册

3.1 torch.nn.Parameter

这些 Parameter 会参与反向传播和梯度更新。
通过 model.parameters() 可以获得所有自动注册的参数

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # torch.nn.Parameter 会自动注册为模型参数
        self.weight = nn.Parameter(torch.randn(3, 3))
    
    def forward(self, x):
        return x @ self.weight

model = MyModel()
# weight 自动注册为参数
print(list(model.parameters()))  # 输出: [Parameter containing: tensor(...)]

3.2 自动注册子模块

在模型的 __init__ 函数中定义了 torch.nn.Module 子模块(例如卷积层、线性层等),PyTorch 会自动将这些子模块注册到模型中,并且它们的参数也会一并注册

  • 通过 model.children() 或 model.named_children() 获取所有子模块
  • 子模块的所有参数也会自动注册,可以通过 model.parameters() 或 model.named_parameters() 获取
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # nn.Module 子模块会自动注册
        self.conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
    
    def forward(self, x):
        return self.conv(x)

model = MyModel()
# conv 层自动注册为模型的子模块
print(list(model.children()))  # 输出: [Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))]

3.3 总结

  • register_buffer:用于注册不需要梯度的张量,比如存储模型状态的变量,不参与反向传播和优化过程
  • register_parameter:用于注册模型的可训练参数,会参与梯度计算和优化过程

两者结合的Demo

import torch
from torch import nn


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 自动注册的子模块
        self.fc = nn.Linear(5, 5)

        # 手动注册一个可训练的参数
        self.register_parameter('my_weight', nn.Parameter(torch.randn(5, 5)))

        # 手动注册一个不可训练的缓冲区
        self.register_buffer('my_buffer', torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]))

    def forward(self, x):
        x = self.fc(x)
        return x @ self.my_weight + self.my_buffer


model = MyModel()
# 输出所有注册的参数和缓冲区
print("Parameters:", list(model.named_parameters()))
print("Buffers:", list(model.named_buffers()))

在这里插入图片描述


http://www.kler.cn/news/311729.html

相关文章:

  • 9.19工作笔记
  • fmql之驱动程序编写(首次)
  • 浏览器插件利器--allWebPluginV2.0.0.20-beta版发布
  • 安科瑞智能塑壳断路器适用于物联网配电电网中-安科瑞黄安南
  • 算法打卡:第十一章 图论part01
  • 每天五分钟深度学习框架pytorch:pytorch中已经定义好的损失函数
  • Redis的Key的过期策略是怎样实现的?
  • 【WPF】02 按钮控件圆角配置及状态切换
  • 华为昇腾智算中心-智算中心测试方案与标准
  • JavaEE:网络编程(UDP)
  • pg入门3—详解tablespaces2
  • Elasticsearch 8.+ 版本查询方式
  • JVM 内存结构?
  • 卖家必看:利用亚马逊自养号测评精选热门产品,增强店铺权重
  • 运维工程师面试整理-监控与报警监控系统
  • uni-app-通过vue-cli命令行快速上手
  • 数据结构(Day16)
  • 虎先锋,你也喜欢线程控制嘛
  • UAC2.0 麦克风——音量控制
  • etcd之etcd简介和安装(一)
  • 全面整理的Oracel 数据库性能调优方案
  • 关系运算符
  • vue选项式写法项目案例(购物车)
  • 制作网上3D展馆需要什么技术并投入多少费用?
  • JSP分页功能实现案例:从基础到应用的全面解析
  • python SQLAlchemy 数据库连接池
  • 《拿下奇怪的前端报错》序章:报错输出个数值数组Buffer(475) [Uint8Array],我来教它说人话!
  • 【Unity】检测鼠标点击位置是否有2D对象
  • Modbus_tcp
  • 数据结构-3.2.栈的顺序存储实现