pytorch torch.matmul函数介绍
torch.matmul
是 PyTorch 中用于进行矩阵乘法的函数。它可以执行两维矩阵、向量和更高维张量之间的乘法运算,支持的运算取决于输入张量的维度。
1. 函数签名
torch.matmul(input, other, out=None)
- input: 左乘的张量。
- other: 右乘的张量。
- out: 可选,用于存储输出结果的张量。
2. 不同维度的乘法规则
torch.matmul
根据输入张量的维度执行不同类型的乘法:
- 两个标量(0维):相当于两个数的普通乘法。
- 一个向量(1维)与一个矩阵(2维):
- 若第一个张量是一维向量,第二个是二维矩阵,则执行向量和矩阵的乘法,输出是一个向量。
- 若第一个张量是二维矩阵,第二个是一维向量,则执行矩阵和向量的乘法,输出是一个向量。
- 两个矩阵(2维):执行标准的矩阵乘法(即行×列相乘)。
- 高维张量:高维张量可以被看作是包含多个矩阵,
torch.matmul
将沿着批量维度(batch dimension)执行批量矩阵乘法。