UNet-二维全景X射线图像牙齿分割(代码和模型修改)
1.背景
使用了之前用于眼底分割的模型,刚好查到有牙齿分割的数据,修改一下是不是可以自己也做一个牙齿分割的模型,这里主要使用的是太阳的小花博主的代码,自己还添加了resnet34作为backbone的模型文件,关于unet的相关的知识点可以参考链接中的github地址,有相关视频和博文,我也是借鉴然后自己发挥一下。
2.数据集下载
STS2D2023数据集:训练集提供2000张牙齿全景图像、测试集500张。训练集包括原图以及对应的mask,测试集仅提供原图。
(1)原始的数据下载地址:DRIVE.zip
链接: https://pan.baidu.com/s/1SnGiiK-R-s8RwzkVxDDngg
提取码: 5fdz
(2)修改后的数据下载地址:DRIVE-change.zip
链接: https://pan.baidu.com/s/1M8-Fu3kidlWVbOlMEl6vEA
提取码: a2y3
DRIVE-change的文件目录结构:
├── test
│ ├── image
├── images
└── mask
└── train
├── images
└── mask
修改后的数据只是文件夹的区别,将2000张全景的图片和mask的内容分成train和test,然后test中的image就是测试集的500张,但是训练中没有用到,供测试使用
3.代码展示
(1) train.py文件
1.需要计算图片的均值和方差,使用项目中的compute_mean_std.py文件,具体见文章末尾的项目链接
2.修改了模型的保存路径,可以根据自己的实际路径修改或者用默认的路径
3.解释一下arg参数中的'--resume',训练中断的话,在训练时指定'--resume model_pth(已经训练的模型地址)'可以接着训练,保存了训练的epoch,ir等参数,具体看代码内容
4.修改模型的话只需要导入相应的模型结构文件
step1:from src import UNet,VGG16UNet,MobileV3Unet, ResNetUnet
step2:def create_model(num_classes):
# model = UNet(in_channels=3, num_classes=num_classes, base_c=32)
# model = MobileV3Unet(num_classes=num_classes)
# model = VGG16UNet(num_classes=num_classes)
model = ResNetUnet(num_classes=num_classes)
return model
import os
import time
import datetime
import torch
from src import UNet,VGG16UNet,MobileV3Unet, ResNetUnet
from train_utils import train_one_epoch, evaluate, create_lr_scheduler
from my_dataset import DriveDataset
import transforms as T
class SegmentationPresetTrain:
# def __init__(self, base_size, crop_size, hflip_prob=0.5, vflip_prob=0.5,
# mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
def __init__(self, base_size, crop_size, hflip_prob=0.5, vflip_prob=0.5,
mean=(0.50281333, 0.50281333, 0.50281333), std=(0.13062168, 0.13062168, 0.13062168)):
min_size = int(0.5 * base_size)
max_size = int(1.2 * base_size)
trans = [T.RandomResize(min_size, max_size)]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))
if vflip_prob > 0:
trans.append(T.RandomVerticalFlip(vflip_prob))
trans.extend([
T.RandomCrop(crop_size),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
self.transforms = T.Compose(trans)
def __call__(self, img, target):
return self.transforms(img, target)
class SegmentationPresetEval:
def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose([
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
def __call__(self, img, target):
return self.transforms(img, target)
def get_transform(train, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
base_size = 565
crop_size = 480
if train:
return SegmentationPresetTrain(base_size, crop_size, mean=mean, std=std)
else:
return SegmentationPresetEval(mean=mean, std=std)
def create_model(num_classes):
# model = UNet(in_channels=3, num_classes=num_classes, base_c=32)
# model = VGG16UNet(num_classes=num_classes)
model = ResNetUnet(num_classes=num_classes)
return model
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# device = torch.device(args.device if torch.backends.mps.is_available() else "cpu")
print(device)
batch_size = args.batch_size
# segmentation nun_classes + background
num_classes = args.num_classes + 1
print(num_classes)
# using compute_mean_std.py
# mean = (0.709, 0.381, 0.224)
# std = (0.127, 0.079, 0.043)
mean = (0.50281333, 0.50281333, 0.50281333)
std = (0.13062168, 0.13062168, 0.13062168)
# 用来保存训练以及验证过程中信息
results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
train_dataset = DriveDataset(args.data_path,
train=True,
transforms=get_transform(train=True, mean=mean, std=std))
val_dataset = DriveDataset(args.data_path,
train=False,
transforms=get_transform(train=False, mean=mean, std=std))
num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
pin_memory=True,
collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=1,
num_workers=num_workers,
pin_memory=True,
collate_fn=val_dataset.collate_fn)
model = create_model(num_classes=num_classes)
model.to(device)
params_to_optimize = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
params_to_optimize,
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
)
scaler = torch.cuda.amp.GradScaler() if args.amp else None
# 创建学习率更新策略,这里是每个step更新一次(不是每个epoch)
lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, warmup=True)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])
best_dice = 0.
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
mean_loss, lr = train_one_epoch(model, optimizer, train_loader, device, epoch, num_classes,
lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)
confmat, dice = evaluate(model, val_loader, device=device, num_classes=num_classes)
val_info = str(confmat)
print(val_info)
print(f"dice coefficient: {dice:.3f}")
# write into txt
with open(results_file, "a") as f:
# 记录每个epoch对应的train_loss、lr以及验证集各指标
train_info = f"[epoch: {epoch}]\n" \
f"train_loss: {mean_loss:.4f}\n" \
f"lr: {lr:.6f}\n" \
f"dice coefficient: {dice:.3f}\n"
f.write(train_info + val_info + "\n\n")
if args.save_best is True:
if best_dice < dice:
best_dice = dice
else:
continue
save_file = {"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args}
if args.amp:
save_file["scaler"] = scaler.state_dict()
if args.save_best is True:
torch.save(save_file, "save_weights_resnet_unet/best_model.pth")
else:
torch.save(save_file, "save_weights_resnet_unet/model_{}.pth".format(epoch))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("training time {}".format(total_time_str))
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="pytorch unet training")
parser.add_argument("--data-path", default="../", help="DRIVE root")
# exclude background
parser.add_argument("--num-classes", default=1, type=int)
parser.add_argument("--device", default="cuda", help="training device")
parser.add_argument("-b", "--batch-size", default=8, type=int)
parser.add_argument("--epochs", default=200, type=int, metavar="N",
help="number of total epochs to train")
parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--print-freq', default=1, type=int, help='print frequency')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--save-best', default=True, type=bool, help='only save best dice weights')
# Mixed precision training parameters
parser.add_argument("--amp", default=False, type=bool,
help="Use torch.cuda.amp for mixed precision training")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if not os.path.exists("./save_weights_resnet_unet"):
os.mkdir("./save_weights_resnet_unet")
main(args)
(2) src目录中的模型定义文件
这里我只展示resnet34作为backbone的模型文件(可以自己定义resnet50或者101都可以,主要是匹配unet的输出通道,比如[64, 64, 128, 256, 512]),这里根据模型的不同,取出来的特征层名字也不同,可以打印出backbone的模型层选取模型对应的output就行,下面的展示一下pytorch打印的模型结果
self.stage_out_channels = [64, 64, 128, 256, 512]
return_layers = {"relu": "stage0", "layer1": "stage1", "layer2": "stage2","layer3": "stage3", "layer4": "stage4"}
stage0:查看激活函数不会改变图像的shape,所以取rule后的输出特征就是64
stage1:layer1最后的bn层输出也是64,对应stage_out_channels中的第二个64
stage2:layer2最后的bn层输出也是64,对应stage_out_channels中的128
stage3:layer3最后的bn层输出也是64,对应stage_out_channels中的256
stage4:layer4最后的bn层输出也是64,对应stage_out_channels中的512
然后就是上采样了,没什么特别的
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(4): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(5): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(up1): Up(
(up): Upsample(scale_factor=2.0, mode=bilinear)
(conv): DoubleConv(
(0): Conv2d(768, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(up2): Up(
(up): Upsample(scale_factor=2.0, mode=bilinear)
(conv): DoubleConv(
(0): Conv2d(384, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(up3): Up(
(up): Upsample(scale_factor=2.0, mode=bilinear)
(conv): DoubleConv(
(0): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(up4): Up(
(up): Upsample(scale_factor=2.0, mode=bilinear)
(conv): DoubleConv(
(0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(conv): OutConv(
(0): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
)
resnet模型文件
from collections import OrderedDict
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchvision.models import resnet34
# from .unet import Up, OutConv
class Down(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(Down, self).__init__(
nn.MaxPool2d(2, stride=2),
DoubleConv(in_channels, out_channels)
)
class DoubleConv(nn.Sequential):
def __init__(self, in_channels, out_channels, mid_channels=None):
if mid_channels is None:
mid_channels = out_channels
super(DoubleConv, self).__init__(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
class Up(nn.Module):
def __init__(self, in_channels, out_channels, bilinear=True):
super(Up, self).__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1)
# [N, C, H, W]
diff_y = x2.size()[2] - x1.size()[2]
diff_x = x2.size()[3] - x1.size()[3]
# padding_left, padding_right, padding_top, padding_bottom
x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
diff_y // 2, diff_y - diff_y // 2])
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class OutConv(nn.Sequential):
def __init__(self, in_channels, num_classes):
super(OutConv, self).__init__(
nn.Conv2d(in_channels, num_classes, kernel_size=1)
)
class IntermediateLayerGetter(nn.ModuleDict):
def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers
return_layers = {str(k): str(v) for k, v in return_layers.items()}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layers:
del return_layers[name]
if not return_layers:
break
super(IntermediateLayerGetter, self).__init__(layers)
self.return_layers = orig_return_layers
def forward(self, x: Tensor) -> Dict[str, Tensor]:
out = OrderedDict()
for name, module in self.items():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
out[out_name] = x
return out
class ResNetUnet(nn.Module):
def __init__(self, num_classes, pretrain_backbone: bool = False):
super(ResNetUnet, self).__init__()
backbone = resnet34(pretrained=pretrain_backbone)
self.stage_out_channels = [64, 64, 128, 256, 512]
return_layers = {"relu": "stage0", "layer1": "stage1", "layer2": "stage2", "layer3": "stage3", "layer4": "stage4"}
self.backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
c = self.stage_out_channels[4] + self.stage_out_channels[3]
self.up1 = Up(c, self.stage_out_channels[3])
c = self.stage_out_channels[3] + self.stage_out_channels[2]
self.up2 = Up(c, self.stage_out_channels[2])
c = self.stage_out_channels[2] + self.stage_out_channels[1]
self.up3 = Up(c, self.stage_out_channels[1])
c = self.stage_out_channels[1] + self.stage_out_channels[0]
self.up4 = Up(c, self.stage_out_channels[0])
self.conv = OutConv(self.stage_out_channels[0], num_classes=num_classes)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
input_shape = x.shape[-2:]
backbone_out = self.backbone(x)
x = self.up1(backbone_out['stage4'], backbone_out['stage3'])
x = self.up2(x, backbone_out['stage2'])
x = self.up3(x, backbone_out['stage1'])
x = self.up4(x, backbone_out['stage0'])
x = self.conv(x)
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
return {"out": x}
if __name__ == '__main__':
model = ResNetUnet(num_classes=2)
print(model)
(3) 预测代码
import os
import time
import torch
from torchvision import transforms
import numpy as np
from PIL import Image
from src import UNet, MobileV3Unet,ResNetUnet
def time_synchronized():
torch.cuda.synchronize() if torch.cuda.is_available() else None
return time.time()
def main():
classes = 1 # exclude background
weights_path = "./save_weights_resnet_unet/best_model.pth"
img_path = "../DRIVE/test/image/13.png"
# roi_mask_path = "./DRIVE/test/mask/01_test_mask.gif"
assert os.path.exists(weights_path), f"weights {weights_path} not found."
assert os.path.exists(img_path), f"image {img_path} not found."
# assert os.path.exists(roi_mask_path), f"image {roi_mask_path} not found."
# mean = (0.709, 0.381, 0.224)
# std = (0.127, 0.079, 0.043)
mean = (0.50281333, 0.50281333, 0.50281333)
std = (0.13062168, 0.13062168, 0.13062168)
# get devices
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
# create model
# model = UNet(in_channels=3, num_classes=classes+1, base_c=32)
model = ResNetUnet(num_classes=classes+1)
# model = ResNetUnet(num_classes=classes + 1)
# load weights
model.load_state_dict(torch.load(weights_path, map_location='cpu')['model'])
model.to(device)
# load roi mask
# roi_img = Image.open(roi_mask_path).convert('L')
# roi_img = np.array(roi_img)
# load image
original_img = Image.open(img_path).convert('RGB')
# from pil image to tensor and normalize
data_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)])
img = data_transform(original_img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
model.eval() # 进入验证模式
dummy_input = torch.randn(1, 3, 320, 640).to(device)
torch.onnx.export(model, dummy_input, 'mobilenetv3.onnx', verbose=True, opset_version=11)
with torch.no_grad():
# init model
img_height, img_width = img.shape[-2:]
init_img = torch.zeros((1, 3, img_height, img_width), device=device)
print(init_img.shape)
model(init_img)
# t_start = time_synchronized()
output = model(img.to(device))
# t_end = time_synchronized()
# print("inference time: {}".format(t_end - t_start))
prediction = output['out'].argmax(1).squeeze(0)
prediction = prediction.to("cpu").numpy().astype(np.uint8)
print(prediction)
np.save("./13.npy", prediction)
# 将前景对应的像素值改成255(白色)
prediction[prediction == 1] = 255
# 将不敢兴趣的区域像素设置成0(黑色)
# prediction[roi_img == 0] = 0
mask = Image.fromarray(prediction)
mask.save("test_result-res.png")
if __name__ == '__main__':
main()
4.测试结果
![]() |
![]() |
结果还不错吧
5.整体项目链接
通过网盘分享的文件:unet-2d.zip
链接: https://pan.baidu.com/s/1xxn2rRZ60bCrqMuL32ODBw
提取码: qf54