pytorch torch.linalg模块介绍
torch.linalg
是 PyTorch 的 线性代数 (Linear Algebra) 子模块,它提供了许多 高效的矩阵操作和分解方法,类似于 NumPy 的 numpy.linalg
或 SciPy 的 scipy.linalg
,但针对 GPU 加速和自动微分 进行了优化。
1. 矩阵基本运算
矩阵乘法
torch.linalg.matmul
:执行两个张量的矩阵乘法。它可以处理多种情况,包括批量矩阵乘法。
import torch
# 创建两个矩阵
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = torch.linalg.matmul(A, B)
print(result)
矩阵转置
torch.linalg.t
:用于计算矩阵的转置。
import torch
A = torch.tensor([[1, 2], [3, 4]])
transposed_A = torch.linalg.t(A)
print(transposed_A)
2. 矩阵分解
奇异值分解(SVD)
import torch
A = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
U, S, Vh = torch.linalg.svd(A)
print("U:", U)
print("S:", S)
print("Vh:", Vh)
QR 分解
torch.linalg.qr
:将矩阵分解为一个正交矩阵 和一个上三角矩阵 的乘积,即 A= QR
import torch
A = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
Q, R = torch.linalg.qr(A)
print("Q:", Q)
print("R:", R)
3. 矩阵的特征值和特征向量
计算特征值和特征向量
torch.linalg.eig
:用于计算方阵的特征值和特征向量。对于实对称矩阵,可以使用torch.linalg.eigh
以获得更高的效率。
import torch
A = torch.tensor([[1, 2], [2, 1]], dtype=torch.float32)
eigenvalues, eigenvectors = torch.linalg.eig(A)
print("Eigenvalues:", eigenvalues)
print("Eigenvectors:", eigenvectors)
4. 求解线性方程组
import torch
# 系数矩阵 A
A = torch.tensor([[3, 1], [1, 2]], dtype=torch.float32)
# 常数向量 b
b = torch.tensor([9, 8], dtype=torch.float32)
x = torch.linalg.solve(A, b)
print("Solution x:", x)
5. 矩阵的范数计算
计算矩阵的范数
torch.linalg.norm
:用于计算矩阵或向量的范数,支持多种范数类型,如 1 - 范数、2 - 范数、无穷范数等。
import torch
A = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
norm_2 = torch.linalg.norm(A, ord=2)
print("2 - norm of A:", norm_2)
torch.linalg
模块为 PyTorch 用户提供了一套完整的线性代数工具,这些功能在机器学习、深度学习、计算机图形学等领域都有广泛的应用,例如在优化算法、降维技术、图像和信号处理等方面。