torch.nn.Linear(p_input, p_output,bias)
文章目录
- 介绍
- 实例
介绍
在 PyTorch 中,nn.Linear 是一个用于实现全连接层(线性层)的模块。它的作用是对输入数据进行线性变换,公式如下:
y = x W T + b y=xW^T+b y=xWT+b
其中:
- x x x 是输入张量
- W W W 是权重矩阵
- b b b 是偏置向量(如果 bias=True)
torch.nn.Linear(p_input, p_output,bias)
- p_input: 输入数据的变量个数
- p_output: 输出数据的变量个数
- bias: 是否使用偏置
实例
import torch
import torch.nn as nn
# 定义一个线性层
linear = nn.Linear(4, 3, bias=True)
# 查看权重和偏置的形状
print(linear.weight.shape) # torch.Size([3, 4]) -> 输出特征数 x 输入特征数
print(linear.bias.shape) # torch.Size([3]) -> 输出特征数
# 输入一个张量
x = torch.rand(2, 4) # 输入形状为 (batch_size=2, input_features=4)
output = linear(x) # 输出形状为 (batch_size=2, output_features=3)
print(output.shape) # torch.Size([2, 3])