pytorch 模型的参数查看函数介绍
在 PyTorch 中,查看和访问模型的参数是非常常见的操作。以下是一些常用的函数和方法,用于查看和操作 PyTorch 模型的参数。
1. model.parameters()
该方法返回一个生成器,它生成模型的所有可训练参数。这些参数通常是模型中的权重和偏置。
示例:
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = self.fc1(x)
return self.fc2(x)
# 创建模型实例
model = SimpleModel()
# 查看模型的所有参数
for param in model.parameters():
print(param)
输出的将是模型中所有层的权重和偏置。例如 fc1.weight
和 fc1.bias
等。
2. model.named_parameters()
该方法与 model.parameters()
类似,但它返回的是一个生成器,生成 (name, parameter)
元组,其中 name
是每个参数的名字,parameter
是对应的张量。