PyTorch 图像分割模型教程
PyTorch 图像分割模型教程
在图像分割任务中,目标是将图像的每个像素归类为某一类,以分割出特定的物体。PyTorch 提供了非常灵活的工具,可以用于构建和训练图像分割模型。我们将使用 PyTorch 的经典网络架构,如 UNet 和 DeepLabV3,并演示如何构建、训练和测试这些模型。
1. 图像分割概述
图像分割的目标是将图像的每个像素进行分类。常见的应用场景有医学图像分割(如肿瘤检测)、自动驾驶(道路、车辆、行人分割)等。
- 语义分割:每个像素被分配给某个类别,例如道路、天空或车辆。
- 实例分割:不仅对物体分类,还要区分物体实例,如区分不同的行人。
PyTorch 中有许多预训练的模型可以直接用于图像分割任务,常用的模型包括 UNet、FCN (Fully Convolutional Network)、DeepLabV3 等。
2. 官方文档链接
- PyTorch 官方文档
- Torchvision 模型
3. 准备工作
在开始训练之前,我们需要安装 torch
, torchvision
和 PIL
等依赖项,并准备图像数据集。您可以使用自己的图像数据集,或者使用 COCO、VOC 等常用数据集。
pip install torch torchvision pillow
4. 使用预训练的 DeepLabV3 模型
DeepLabV3 是一个性能优异的语义分割模型,PyTorch 的 torchvision
提供了预训练的 DeepLabV3 模型。我们将使用 COCO 数据集中的预训练模型,并进行推理和测试。
import torch
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
# 加载预训练的 DeepLabV3 模型
model = models.segmentation.deeplabv3_resnet50(pretrained=True)
model.eval() # 切换到评估模式
# 定义预处理步骤
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 加载图像
input_image = Image.open("test_image.jpg")
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # 创建 batch 维度
# 将输入移到 GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_batch = input_batch.to(device)
# 进行预测
with torch.no_grad():
output = model(input_batch)['out'][0] # DeepLabV3 的输出包含 'out' 字段
# 将输出转换为类别索引(每个像素对应一个类别)
output_predictions = output.argmax(0).cpu().numpy()
# 显示分割结果
plt.imshow(output_predictions)
plt.show()
说明:
models.segmentation.deeplabv3_resnet50(pretrained=True)
:加载使用 ResNet-50 作为主干网络的 DeepLabV3 模型,预训练于 COCO 数据集。preprocess
:对输入图像进行预处理,包括调整大小、裁剪、归一化等。output_predictions
:模型的输出是每个像素的类别索引,经过argmax
操作,获取每个像素的类别。
5. UNet 模型
UNet 是一个广泛用于医学图像分割的经典模型。我们将从头实现 UNet 模型,并在简单的合成数据上进行训练。
5.1 UNet 网络结构
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
# 下采样(编码器部分)
self.encoder1 = self.double_conv(1, 64)
self.encoder2 = self.double_conv(64, 128)
self.encoder3 = self.double_conv(128, 256)
self.encoder4 = self.double_conv(256, 512)
# 中间部分
self.middle = self.double_conv(512, 1024)
# 上采样(解码器部分)
self.upconv4 = self.up_conv(1024, 512)
self.decoder4 = self.double_conv(1024, 512)
self.upconv3 = self.up_conv(512, 256)
self.decoder3 = self.double_conv(512, 256)
self.upconv2 = self.up_conv(256, 128)
self.decoder2 = self.double_conv(256, 128)
self.upconv1 = self.up_conv(128, 64)
self.decoder1 = self.double_conv(128, 64)
# 最后的分类层
self.final = nn.Conv2d(64, 1, kernel_size=1)
def double_conv(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
def up_conv(self, in_channels, out_channels):
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
def forward(self, x):
# 编码器部分
e1 = self.encoder1(x)
e2 = self.encoder2(F.max_pool2d(e1, 2))
e3 = self.encoder3(F.max_pool2d(e2, 2))
e4 = self.encoder4(F.max_pool2d(e3, 2))
# 中间部分
middle = self.middle(F.max_pool2d(e4, 2))
# 解码器部分
d4 = self.upconv4(middle)
d4 = torch.cat((e4, d4), dim=1)
d4 = self.decoder4(d4)
d3 = self.upconv3(d4)
d3 = torch.cat((e3, d3), dim=1)
d3 = self.decoder3(d3)
d2 = self.upconv2(d3)
d2 = torch.cat((e2, d2), dim=1)
d2 = self.decoder2(d2)
d1 = self.upconv1(d2)
d1 = torch.cat((e1, d1), dim=1)
d1 = self.decoder1(d1)
return self.final(d1)
# 创建模型实例
unet_model = UNet()
print(unet_model)
说明:
UNet
是一种编码-解码结构,包含多个下采样(编码器)和上采样(解码器)层。每次下采样都会减少特征图的大小,并增加特征通道数,上采样则恢复原始图像的大小。ConvTranspose2d
用于进行上采样操作。
5.2 训练 UNet 模型
为了训练 UNet 模型,我们需要构建一个数据加载器并定义损失函数和优化器。我们以一个简单的二分类分割任务为例。
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
# 创建合成数据集
class SyntheticSegmentationDataset(Dataset):
def __init__(self, num_samples, image_size):
self.num_samples = num_samples
self.image_size = image_size
self.transform = transforms.Compose([
transforms.ToTensor(),
])
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
image = torch.rand(1, self.image_size, self.image_size)
mask = (image > 0.5).float() # 简单的二分类掩码
return image, mask
# 创建数据集
dataset = SyntheticSegmentationDataset(num_samples=1000, image_size=128)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
# 定义损失函数和优化器
criterion = nn.BCEWithLogitsLoss() # 二分类交叉熵损失
optimizer = torch.optim.Adam(unet_model.parameters(), lr=0.001)
# 训练循环
unet_model.train()
for epoch in range(5): # 简单训练 5 个 epoch
for images, masks in dataloader:
# 前向传播
outputs = unet_model(images)
loss = criterion(outputs, masks)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch [{epoch
+1}/5], Loss: {loss.item():.4f}')
说明:
BCEWithLogitsLoss
是二分类任务的标准损失函数,适合输出为单通道(1 表示目标类,0 表示背景)的分割任务。- 我们创建了一个合成数据集,其中图像为随机值,掩码为图像中值大于 0.5 的部分。
6. 总结
- DeepLabV3 是一种非常强大的图像分割模型,适用于各种复杂场景,PyTorch 提供了预训练模型,适合快速部署。
- UNet 是经典的医学图像分割模型,适用于更细致的分割任务。
通过使用 PyTorch,您可以轻松实现并训练图像分割模型,利用 GPU 加速并扩展到大规模数据集。