当前位置: 首页 > article >正文

PyTorch训练Celeba

CelebFaces属性数据集(CelebA)是一个大规模的人脸属性数据集,包含超过20万张名人图像,每张图像都有40个属性标注。该数据集中的图像涵盖了大范围的姿态变化和复杂的背景。CelebA具有高度的多样性、大量的数据和丰富的标注信息,包括:

  • 10,177个身份,
  • 202,599张人脸图像,
  • 每张图像有5个标志点位置和40个二进制属性标注。

该数据集可用于以下计算机视觉任务的训练和测试集:人脸属性识别、人脸识别、人脸检测、标志点(或面部部位)定位以及人脸编辑与合成。

1. 安装必要的依赖

如果还没有安装PyTorch和Torchvision,可以通过pip安装:

%pip install torch torchvision

2. 加载CelebA数据集

我们可以通过torchvision.datasets模块轻松获取它。
如果没有数据集且无法通过google drive下载,可以通过这篇notebook关联的数据集下载

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 图像的预处理:缩放到 64x64 大小,并归一化
# 这个变换管道通常用于图像分类任务中的数据预处理步骤。通过调整图像大小、转换为张量和标准化,可以确保输入数据的一致性和模型训练的稳定性。
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    #将图像从 PIL 图像或 numpy 数组转换为 PyTorch 张量,并且会自动将图像的像素值从 [0, 255] 范围缩放到 [0, 1] 范围。
    transforms.ToTensor(),
    #使用给定的均值和标准差对图像进行标准化。这里的均值和标准差分别是 [0.5, 0.5, 0.5],表示对每个通道(RGB)进行标准化。标准化公式为: [ text{output} = frac{text{input} - text{mean}}{text{std}} ] 由于输入的像素值已经被 transforms.ToTensor 缩放到 [0, 1] 范围,标准化后的像素值将被调整到 [-1, 1] 范围。
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 加载 CelebA 数据集,如果google drive可以用,download=True,可以直接下载,也可以自己下载后解压到/root/celeba目录下
root = '/root'
train_dataset = torchvision.datasets.CelebA(root=root, split='train', download=False, transform=transform)
test_dataset = torchvision.datasets.CelebA(root=root, split='test', download=False, transform=transform)

# 创建数据加载器,改大batch_size,在GPU T4下显存使用率没增加?
batch_size = 128
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 检查数据加载是否成功
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of testing samples: {len(test_dataset)}")

3. 定义CNN模型

接下来,我们定义一个简单的卷积神经网络模型,结构可以根据任务的复杂性调整。这里的例子是一个基础的CNN。

3.1 初始化

卷积层:
  • self.conv1:第一个卷积层,输入通道数为 3(RGB 图像),输出通道数为 64,卷积核大小为 3x3,步幅为 1,填充为 1。
  • self.conv2:第二个卷积层,输入通道数为 64,输出通道数为 128,卷积核大小为 3x3,步幅为 1,填充为 1。
  • self.conv3:第三个卷积层,输入通道数为 128,输出通道数为 256,卷积核大小为 3x3,步幅为 1,填充为 1。
全连接层:
  • self.fc1:第一个全连接层,输入大小为 25688,输出大小为 512。
  • self.fc2:第二个全连接层,输入大小为 512,输出大小为 40(CelebA 数据集有 40 个属性标签)。
池化层和激活函数:
  • self.pool:最大池化层,池化核大小为 2x2,步幅为 2。
  • self.relu:ReLU 激活函数。
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(256*8*8, 512)
        self.fc2 = nn.Linear(512, 40)  # CelebA 有 40 个属性标签

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.relu(self.conv3(x))
        x = self.pool(x)
        
        x = x.view(x.size(0), -1)  # 展平操作
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x
# 实例化模型
model = SimpleCNN()

4. 定义损失函数和优化器

对于多标签分类任务,我们使用 BCEWithLogitsLoss 损失函数。优化器可以选择 Adam 或 SGD。

criterion = nn.BCEWithLogitsLoss()  # 二分类交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)

5. 训练模型

下面是训练模型的代码,每个 epoch 后对模型进行验证。

5.1. train 函数

train 函数用于训练模型。

参数解释
  • model:要训练的神经网络模型。
  • train_loader:训练数据的 DataLoader 对象。
  • criterion:损失函数,用于计算预测值和真实值之间的误差。
  • optimizer:优化器,用于更新模型的参数。
  • num_epochs:训练的轮数。
  • device:计算设备(如 CPU 或 GPU)。

5.2. evaluate 函数

evaluate 函数用于评估模型的性能。

参数解释
  • model:要评估的神经网络模型。
  • test_loader:测试数据的 DataLoader 对象。
  • criterion:损失函数,用于计算预测值和真实值之间的误差。
  • device:计算设备(如 CPU 或 GPU)。
def train(model, train_loader, criterion, optimizer, num_epochs, device):
    model.train() #设置模型为训练模式
    for epoch in range(num_epochs):
        for images, labels in train_loader:
            # 将数据移动到选定的设备
            images, labels = images.to(device), labels.to(device).float()
            
            # 前向传播
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # 反向传播和优化
            optimizer.zero_grad()#清零梯度
            loss.backward()#计算梯度
            optimizer.step()#更新参数
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
                
def evaluate(model, test_loader, criterion, device):
    model.eval() # 设置模型为评估模式
    with torch.no_grad(): # 禁用梯度计算
        total_loss = 0
        correct = 0
        total = 0
        for images, labels in test_loader:
            # 将数据移动到选定的设备
            images, labels = images.to(device), labels.to(device).float()
            
            # 前向传播
            outputs = model(images)
            # 计算损失,并累加到 total_loss
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            # 计算准确率
            predicted = (outputs > 0.5).float()
            total += labels.size(0) * labels.size(1)  # 总标签数
            correct += (predicted == labels).sum().item()
        # 打印平均损失和准确率
        print(f'Average loss: {total_loss / len(test_loader):.4f}, Accuracy: {100 * correct / total:.2f}%')
  1. 开始训练

通过调用train函数进行训练,完成后用evaluate函数进行验证。

# 检查 CUDA 是否可用,并选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 将模型移动到选定的设备
model = model.to(device)

# 训练模型
train(model, train_loader, criterion, optimizer, num_epochs=10, device=device)

# 验证模型
evaluate(model, test_loader, criterion, device=device)
  1. 保存和加载模型

训练完成后,可以保存模型以便之后使用。

# 保存模型
torch.save(model.state_dict(), 'celeba_cnn.pth')

# 加载模型
model.load_state_dict(torch.load('celeba_cnn.pth'))

http://www.kler.cn/news/332887.html

相关文章:

  • 论文笔记:基于细粒度融合网络和跨模态一致性学习的多模态假新闻检测
  • 基于SpringBoot+Vue+MySQL的智能垃圾分类系统
  • RabbitMQ篇(基本介绍)
  • 高炉计算笔记
  • 【网络安全】IP切换绕过2FA身份验证
  • 19款奔驰E300升级新款触摸屏人机交互系统
  • 使用MTVerseXR SDK实现VR串流
  • 使用微服务Spring Cloud集成Kafka实现异步通信(消费者)
  • C++11--智能指针
  • 【Linux】进程替换、命令行参数及环境变量(超详解)
  • 长期提供APX515/B原装二手APX525/B音频分析仪
  • Linux学习笔记(五):shell脚本,强大的文本处理工具awk,sed
  • Pyhton爬虫使用Selenium实现浏览器自动化操作抓取网页
  • 1516-函数指针
  • 微信小程序操作蓝牙
  • Navicat Premium 12 for Mac中文永久版
  • 初识Django
  • 基于大数据技术的共享单车数据分析与辅助管理系统
  • 免费 Oracle 各版本 离线帮助使用和介绍
  • 基于GitLab 的持续集成环境