当前位置: 首页 > article >正文

详解 @符号在 PyTorch 中的矩阵乘法规则

详解 @ 符号在 PyTorch 中的矩阵乘法规则

在 PyTorch 和 NumPy 中,@ 符号被用作矩阵乘法运算符,它本质上等价于 torch.matmul()numpy.matmul(),用于执行张量之间的矩阵乘法。

在本篇博客中,我们将深入探讨:

  • @ 运算符的基本概念
  • @ 在不同维度张量上的计算规则
  • @(d, k) @ (d, 1) 这种情况下的运算细节
  • PyTorch 自动广播机制
  • 代码示例与直观理解

1. 什么是 @

在 Python 3.5 之后,@ 被引入作为 矩阵乘法运算符,它在 NumPyPyTorch 中与 matmul() 等价。例如:

import numpy as np

A = np.array([[1, 2], [3, 4]])
B = np.array([[5], [6]])

C = A @ B  # 矩阵乘法
print(C)

输出:

[[17]
 [39]]

等价于

C = np.matmul(A, B)

PyTorch 中,@ 也适用于张量计算:

import torch
A = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
B = torch.tensor([[5], [6]], dtype=torch.float32)

C = A @ B  # PyTorch 版本的矩阵乘法
print(C)

2. @ 在不同维度张量上的计算规则

2.1 规则概述

@ 的运算规则依赖于输入张量的维度:

  1. 两个标量(0D):返回标量
  2. 标量和张量:标量与张量的元素逐个相乘
  3. 一维向量(1D)
    • (N,) @ (N,) → 标量(点积)
    • (N,) @ (N, M) → (M,)(左向量 × 矩阵)
    • (N, M) @ (M,) → (N,)(矩阵 × 右向量)
  4. 二维矩阵(2D)
    • (N, M) @ (M, K) → (N, K)(标准矩阵乘法)
  5. 高维张量(≥3D)
    • (A, B, C) @ (C, D) → (A, B, D)(批量矩阵乘法)

3. 重点解析 (d, k) @ (d, 1)

PyTorch 中,如果 A.shape = (d, k)B.shape = (d, 1)A @ B非法操作,因为矩阵乘法要求 A 的列数(k)等于 B 的行数(d),但这里 B 的形状 (d, 1) 无法与 (d, k) 匹配。

3.1 (d, k) @ (d, 1) 为什么不合法?

假设:

import torch
d, k = 4, 3

A = torch.randn(d, k)  # (4, 3)
B = torch.randn(d, 1)  # (4, 1)

C = A @ B  # ❌ 错误:形状不匹配

会报错:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x3 and 4x1)

原因:

  • 矩阵乘法规则: A 的列数(k)必须等于 B 的行数(d)。
  • (d, k) @ (d, 1) 不符合这个规则,因为 d ≠ k

3.2 如何让 (d, k) @ (d, 1) 变成合法操作?

我们需要 调整矩阵的形状,使其满足矩阵乘法的规则。

方法 1:交换操作数顺序

如果计算 B.T @ A

C = B.T @ A  # shape (1, d) @ (d, k) → (1, k)

就变成了合法操作。

方法 2:转置 A

如果我们计算:

C = A.T @ B  # shape (k, d) @ (d, 1) → (k, 1)

这个计算是 合法的,因为 A.T.shape = (k, d)B.shape = (d, 1),满足矩阵乘法规则。

示例:

C = A.T @ B  # (k, d) @ (d, 1) → (k, 1)

现在 A.T 变成 (k, d)B 仍然是 (d, 1),最终 C 的形状是 (k, 1)


3.3 PyTorch 如何正确处理 (d, k) @ (d,)

在 PyTorch 代码中,我们常见这样的计算:

q = P_q @ x  # (h, d, k) @ (d,)

为什么这里不需要转置 P_q

  • x.shape = (d,),PyTorch 自动扩展为 (d, 1) 使其成为列向量
  • 计算 (d, k) @ (d, 1)非法的,PyTorch 自动调整计算规则
  • PyTorch 实际执行的是 P_q.T @ x,确保计算正确
  • 最终返回 (h, k),去掉了多余的维度

因此 PyTorch 不需要我们手动转置 P_q,它会自动处理 x 为列向量进行计算!


4. 代码示例

import torch

d, k = 4, 3
torch.manual_seed(42)

A = torch.randn(d, k)  # (4, 3)
x = torch.randn(d)     # (4,)

# PyTorch 自动扩展 x,使其符合矩阵乘法规则
C = A.T @ x  # (k, d) @ (d,) → (k,)

print("A shape:", A.shape)  # (4, 3)
print("x shape:", x.shape)  # (4,)
print("C shape:", C.shape)  # (3,)

5. 结论

  • @矩阵乘法运算符,等价于 torch.matmul(A, B)
  • (d, k) @ (d, 1) 是不合法的矩阵乘法
  • PyTorch 会自动扩展 (d,) → (d, 1) 并进行正确的矩阵计算
  • (d, k) @ (d,) 实际等价于 (k, d) @ (d, 1),避免了显式转置

🚀 PyTorch 的 @ 计算规则很智能,能够自动扩展维度,让矩阵乘法符合数学规则! 🎯

q = P_q @ x 计算中,P_q.T 转置的是哪个维度?如何判断?

在 PyTorch 代码:

q = P_q @ x  # (h, d, k) @ (d,)

核心问题

  • P_q.shape = (h, d, k)
  • x.shape = (d,)

为什么 不需要手动转置 P_q?以及 PyTorch 在计算 P_q @ x 时转置了哪个维度


1. @ 运算规则

PyTorch 处理 torch.matmul(A, B) 时,遵循 广播机制矩阵乘法规则

  1. 最后两个维度 参与矩阵乘法
  2. 如果 B 是 1D 张量(即 B.shape = (d,)),PyTorch 会自动扩展为 (d, 1) 但不会影响计算逻辑

2. q = P_q @ x 具体计算

2.1 P_q.shape = (h, d, k), x.shape = (d,)

按照 PyTorch 规则:

  1. 扩展 x 形状
    • x.shape = (d,) 自动扩展为 (d, 1),使其符合矩阵乘法规则:
    x = x.unsqueeze(-1)  # (d,) → (d, 1)
    
  2. 选择 P_q 参与矩阵乘法的维度
    • P_q.shape = (h, d, k),表示:
      • h:注意力头数(不参与矩阵计算)
      • d:输入维度(x 匹配
      • k:查询维度(计算目标)
    • P_q @ x 的计算目标是:
      ( h , d , k ) @ ( d , 1 ) (h, d, k) @ (d, 1) (h,d,k)@(d,1)
      需要 P_q d 维度与 xd 维度对齐,才能进行矩阵乘法。

2.2 PyTorch 自动调整 P_q 计算方式

PyTorch 不会转置完整的 P_q,但会 调整最后两个维度 (d, k) 进行计算

  • 等价于
    q = ( h , k , d ) @ ( d , 1 ) = ( h , k , 1 ) q = (h, k, d) @ (d, 1) = (h, k, 1) q=(h,k,d)@(d,1)=(h,k,1)
  • 等价于
    q = torch.matmul(P_q.transpose(-2, -1), x.unsqueeze(-1))  # shape (h, k, 1)
    
    其中 P_q.transpose(-2, -1) 交换 (d, k)(k, d)

最终 PyTorch 计算:

q = (h, d, k) @ (d,) = (h, k)

其中 PyTorch 自动去除了 1 维度,返回 (h, k),而不是 (h, k, 1)


3. 如何判断 PyTorch 进行了哪些维度调整?

我们可以用 transpose()matmul() 手动验证

import torch

h, d, k = 2, 4, 3  # 2 个注意力头, 输入维度 4, 投影到 3 维
torch.manual_seed(42)

P_q = torch.randn(h, d, k)  # shape (h, d, k)
x = torch.randn(d)  # shape (d,)

# PyTorch 计算
q1 = P_q @ x  # (h, d, k) @ (d,) → (h, k)

# 手动转置 + matmul
q2 = torch.matmul(P_q.transpose(-2, -1), x.unsqueeze(-1)).squeeze(-1)  # (h, k)

print("q1 shape:", q1.shape)  # (h, k)
print("q2 shape:", q2.shape)  # (h, k)
print(torch.allclose(q1, q2))  # True

结果:

q1 shape: torch.Size([2, 3])
q2 shape: torch.Size([2, 3])
True

说明 PyTorch 自动进行了 P_q.transpose(-2, -1),使 d 维度匹配 xd 维度


4. 结论

💡 PyTorch 只会转置 P_qd, k 维度,确保矩阵乘法合法,但不会改变 h 维度

判断 PyTorch 何时自动调整维度

操作等效 PyTorch 计算
(d, k) @ (d,)自动转置 (d, k)(k, d), 计算 (k, d) @ (d, 1)
(h, d, k) @ (d,)自动调整 (d, k)(k, d), 计算 (h, k, d) @ (d, 1)
(d, k) @ (k, 1)直接符合矩阵乘法规则,正常计算
(h, d, k) @ (k, 1)符合矩阵乘法规则,正常计算

5. 关键点总结

P_qd, k 维度会被 PyTorch 自动调整,以匹配 x.shape = (d,)
PyTorch 计算 (h, d, k) @ (d,),本质等价于 P_q.transpose(-2, -1) @ x.unsqueeze(-1)
最终 q.shape = (h, k),符合多头注意力计算要求

🚀 PyTorch 的 @ 操作非常智能,会自动调整张量的形状,使矩阵乘法符合数学规则! 🎯

后记

2025年2月23日07点49分于上海,在GPT4o大模型辅助下完成。


http://www.kler.cn/a/559667.html

相关文章:

  • 二:前端发送POST请求,后端获取数据
  • 基于python实现机器学习的心脏病预测系统
  • 23贪心算法
  • 在 Mac mini M2 上使用Docker快速部署MaxKB:打造本地知识库问答系统
  • C语言多人聊天室 ---chat(客户端聊天)
  • 记录一下VScode可以使用nvcc编译,但VS不行的解决方案
  • DeepSeek技术演进史:从MoE到当前架构
  • 彻底卸载kubeadm安装的k8s集群
  • 深入理解P2P网络架构与实现
  • ubuntu离线安装ollama
  • 【在 Debian Linux下安装 privoxy 将 Socks5 转换为 HTTP 代理与privoxy的过滤配置】
  • 《深度学习实战》第2集:卷积神经网络(CNN)与图像分类
  • pytorch入门级项目--基于卷积神经网络的数字识别
  • 【Python爬虫(45)】Python爬虫新境界:分布式与大数据框架的融合之旅
  • Java List 自定义对象排序 Java 8 及以上版本使用 Stream API
  • 打破常规:用 Python Enum 管理常量的趣味之旅
  • 【计算机网络】传输层TCP协议
  • 详解 为什么 tcp 会出现 粘包 拆包 问题
  • MySQL数据库习题(选择题)
  • 蓝思科技赋能灵伴科技:AI眼镜产能与供应链双升级