当前位置: 首页 > article >正文

【人工智能】Python常用库-PyTorch常用方法教程

PyTorch 是一个强大的开源深度学习框架,以其灵活性和动态计算图而广受欢迎。以下是 PyTorch 的详细教程,涵盖从基础到实际应用的使用方法。


1. 安装与导入

1.1 安装 PyTorch

访问 PyTorch 官方网站,根据系统、Python 版本和 CUDA 支持选择安装命令。

常用安装命令:

pip install torch torchvision torchaudio
1.2 导入库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

2. PyTorch 基础

2.1 张量(Tensor)

张量是 PyTorch 的核心数据结构,可以看作是一个高维数组。

# 创建张量
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])

# 基本运算
c = a + b
print(c)  # 输出 tensor([5., 7., 9.])

# 随机张量
random_tensor = torch.rand((2, 3))  # 2行3列随机数
print(random_tensor)

输出结果

tensor([5., 7., 9.])
tensor([[0.9980, 0.2970, 0.5257],
        [0.8807, 0.0471, 0.7896]])
2.2 自动求导

PyTorch 提供动态计算图支持自动求导。

x = torch.tensor(2.0, requires_grad=True)
y = x**2 + 3*x + 4

y.backward()  # 自动求导
print(x.grad)  # 输出 dy/dx = 2*x + 3 = 7.0

输出结果

tensor(7.)

3. 数据加载

PyTorch 提供强大的数据加载功能。

import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

# 下载并加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

4. 构建神经网络

4.1 使用 nn.Module 构建模型
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # 展平输入
        x = self.relu(self.fc1(x))
        x = self.softmax(self.fc2(x))
        return x


model = SimpleNN()

print(model)

输出结果

SimpleNN(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (softmax): Softmax(dim=1)
)

5. 模型训练

5.1 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)
5.2 训练循环
for epoch in range(5):
    for images, labels in train_loader:
        optimizer.zero_grad()  # 梯度清零
        outputs = model(images)
        loss = criterion(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新权重
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

完整代码

from torch import nn, optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader


class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # 展平输入
        x = self.relu(self.fc1(x))
        x = self.softmax(self.fc2(x))
        return x


model = SimpleNN()

# 下载并加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

criterion = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(5):
    for images, labels in train_loader:
        optimizer.zero_grad()  # 梯度清零
        outputs = model(images)
        loss = criterion(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新权重
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

输出结果

Epoch 1, Loss: 1.482284665107727
Epoch 2, Loss: 1.4968496561050415
Epoch 3, Loss: 1.5289227962493896
Epoch 4, Loss: 1.4832825660705566
Epoch 5, Loss: 1.5070817470550537

6. 模型评估

6.1 在测试集上评估
test_data = MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

correct = 0
total = 0
with torch.no_grad():  # 禁用梯度计算
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {correct / total * 100:.2f}%")

输出结果

Test Accuracy: 10.32%

7. GPU 加速

PyTorch 支持使用 GPU 加速。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# 将数据也移动到 GPU
for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)

8. 保存与加载模型

8.1 保存模型
torch.save(model.state_dict(), 'model.pth')
8.2 加载模型
model = SimpleNN()
model.load_state_dict(torch.load('model.pth'))
model.eval()  # 切换到评估模式

9. 实际案例

9.1 CIFAR-10 图像分类
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms

# CIFAR-10 数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(16 * 16 * 16, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 16 * 16 * 16)
        x = self.fc1(x)
        return x


model = CNN()
# 后续训练步骤类似

10. PyTorch 优势总结

  1. 动态计算图:支持动态构建与修改模型。
  2. 灵活性:适合研究和开发,易于调试。
  3. 强大的社区支持:广泛的教程、示例和扩展工具。

通过实践,PyTorch 能够帮助用户更好地理解和实现深度学习算法!


http://www.kler.cn/a/412659.html

相关文章:

  • UE5 fieldSystemActor类
  • UE5 的DOP简化碰撞的基本概念
  • Unity 中 Application 四种常用目录总结
  • golang 定时器的不同任务
  • 单片机main函数执行结束干嘛?
  • YOLO系列论文综述(从YOLOv1到YOLOv11)【第3篇:YOLOv1——YOLO的开山之作】
  • 【深度学习基础】一篇入门模型评估指标(分类篇)
  • Linux 时间属性
  • SurfaceFlinger学习之一:概览
  • 大模型专栏--大模型开发框架
  • Spring | (七)AOP概念及工作流程
  • 【速通GO】数据类型与变量和常量
  • 丹摩 | 基于PyTorch的CIFAR-10图像分类实现
  • 第三方数据库连接免费使用和安装
  • 白光干涉仪:表面粗糙度形貌台阶高测量解决方案
  • Flutter 共性元素动画
  • 工业网络安全 智能电网,SCADA和其他工业控制系统等关键基础设施的网络安全(总结)
  • 无法通过外网连接访问mysql问题排查
  • 如何通过终端连接无线网
  • echarts使用示例