简易了解Pytorch中的@ 和 * 运算符(附Demo)
目录
- 1. 基本知识
- 2. @
- 3. *
1. 基本知识
在 PyTorch 中,@ 和 * 运算符用于不同类型的数学运算,具体是矩阵乘法和逐元素乘法
基本知识
运算符 | 功能 | 适用场景 | 示例 |
---|---|---|---|
@ | 矩阵乘法(或点乘) | 用于执行线性代数中的矩阵乘法 | C = A @ B,其中 A 和 B 为矩阵 |
* | 逐元素乘法 | 用于对同一形状的张量进行逐元素相乘 | C = A * B,其中 A 和 B 为同形状张量 |
两者的差异总结如下:
特点 | 矩阵乘法 (@) | 逐元素乘法 (*) |
---|---|---|
运算类型 | 矩阵乘法(线性代数) | 逐元素运算 |
适用条件 | 列数等于行数 | 形状相同 |
返回结果形状 | (m, p) | 与输入张量相同 |
使用示例 | C = A @ B | C = A * B |
适用场景 | 线性变换、深度学习中的权重计算 | 图像处理、逐元素操作等 |
- 使用 @ 运算符进行矩阵乘法适合线性代数操作,常用于深度学习中的层与权重的运算
- 使用 * 运算符进行逐元素乘法适合需要对张量进行元素级操作的场景,如数据处理和图像增强等
2. @
@ 运算符用于执行矩阵乘法或向量点乘
对于两个矩阵 A 和 B,其结果 C 是一个新矩阵,其中 C[i][j] 是 A 的第 i 行与 B 的第 j 列的点积
适用条件: A 的列数必须等于 B 的行数,即 A 的形状为 (m, n),B 的形状为 (n, p),则结果 C 的形状为 (m, p)
import torch
# 创建两个矩阵
A = torch.tensor([[1, 2], [3, 4]]) # 2x2 矩阵
B = torch.tensor([[5, 6], [7, 8]]) # 2x2 矩阵
# 使用 @ 运算符进行矩阵乘法
C = A @ B
print("矩阵乘法结果:\n", C)
截图如下:
3. *
*运算符用于对两个相同形状的张量进行逐元素相乘
结果张量的每个元素是操作数张量中对应元素的乘积
适用条件: A 和 B 必须具有相同的形状(或能够通过广播规则兼容)
import torch
# 创建两个相同形状的张量
A = torch.tensor([[1, 2], [3, 4]]) # 2x2 矩阵
B = torch.tensor([[5, 6], [7, 8]]) # 2x2 矩阵
# 使用 * 运算符进行逐元素乘法
C = A * B
print("逐元素乘法结果:\n", C)
截图如下: