当前位置: 首页 > article >正文

使用pytorch进行迁移学习的两个步骤

1. 步骤及代码

迁移学习一般都会使用两个步骤进行训练:

  1. 固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛;
  2. 使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节。
import os
import torch
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms as T
import numpy as np

# 设置均值、方差
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# 还原减均值除以方差之前的数据,用于可视化
def reduction_img_show(tensor, mean, std) -> None:
    to_img = T.ToPILImage()
    reduced_img = to_img(tensor * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1))
    reduced_img.show()


def getResNet(*, class_names: str, loadfile: str = None):
    if loadfile is not None:
        model = torchvision.models.resnet18()
        model.load_state_dict(torch.load('resnet18-f37072fd.pth'))  # 加载权重
    else:
        model = torchvision.models.resnet18(
            weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)  # 模型自动下载到C:\Users\GaryLau\.cache\torch\hub\checkpoints

    # 将所有的参数层冻结,设置模型除最后一层以外都不可以进行训练,使模型只针对最后一层进行微调
    for param in model.parameters():
        param.requires_grad = False
    # 输出全连接层信息
    print(model.fc)
    x = model.fc.in_features  # 获取全连接层输入维度
    model.fc = torch.nn.Linear(in_features=x, out_features=len(class_names))  # 创建新的全连接层
    print(model.fc)  # 输出新的全连接层
    return model


# 定义训练函数
def train(model, device, train_loader, criterion, optimizer, epoch):
    model.train()
    all_loss = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        y_pred = model(data)
        loss = criterion(y_pred, target)
        loss.backward()
        all_loss.append(loss.item())
        optimizer.step()
        if batch_idx % 10 == 0:
            print(
                'Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),
                                                               np.mean(all_loss)))

def val(model, device, val_loader, criterion):
    model.eval()
    test_loss = []
    correct = []
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            y_pred = model(data)
            test_loss.append(criterion(y_pred, target).item())
            pred = y_pred.argmax(dim=1, keepdim=True)
            correct.append(pred.eq(target.view_as(pred)).sum().item()/pred.size(0))
    print('-->Test: Average loss:{:.4f}, Accuracy:({:.0f}%)\n'.format(np.mean(test_loss), 100 * sum(correct) / len(correct)))

# 训练,验证时的预处理
transform = {
    'train': T.Compose([
        T.RandomResizedCrop(224),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)
    ]),
    'val': T.Compose([
        T.Resize((224,224)),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)
    ])}

# 加载训练、验证数据
dataset_train = ImageFolder(r'./train', transform=transform['train'])
dataset_val = ImageFolder(r'./test', transform=transform['val'])

# 类别标签
class_names = dataset_train.classes
print(dataset_train.class_to_idx)
print(dataset_val.class_to_idx)

# 显示一张训练、验证图
# reduction_img_show(dataset_train[0][0], mean, std)
# reduction_img_show(dataset_val[0][0], mean, std)

# 使用DataLoader遍历数据
dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True, sampler=None, num_workers=0,
                              pin_memory=False, drop_last=False)
dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=False, sampler=None, num_workers=0,
                            pin_memory=False, drop_last=False)

# 使用方式一,使用next不断获取一个batch的数据
dataiter_train = iter(dataloader_train)
imgs, labels = next(dataiter_train)
print(imgs.size())
# reduction_img_show(imgs[0], mean, std)
# reduction_img_show(imgs[1], mean, std)
multi_imgs = torchvision.utils.make_grid(imgs, nrow=10)  # 拼接一个batch的图像用于展示
# reduction_img_show(multi_imgs, mean, std)

# 获取ResNet模型,并加载预训练模型权重,将最后一层(输出层)去掉,换成一个新的全连接层,新全连接层输出的节点数是新数据的类别数
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

# 构建模型
model = getResNet(class_names=class_names, loadfile='resnet18-f37072fd.pth')
model.to(device)

# 构建损失函数
criterion = torch.nn.CrossEntropyLoss()
# 指定新加的全连接层为要更新的参数
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)  # 只需要更新最后一层fc的参数

if __name__ == '__main__':
    ### 步骤一,微调最后一层
    first_model = 'resnet18-f37072fd_finetune_fcLayer.pth'
    for epoch in range(1, 6):
        train(model, device, dataloader_train, criterion, optimizer, epoch)
        val(model, device, dataloader_val, criterion)
    # 仅保存了最后新添加的全连接层的参数
    #torch.save(model.fc.state_dict(), first_model)
    torch.save(model.state_dict(), first_model)

    ### 步骤二,小学习率微调所有层
    second_model = 'resnet18-f37072fd_finetune_allLayer.pth'
    optimizer2 = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=3, gamma=0.9)
    # 将所有的参数层设为可训练的
    for param in model.parameters():
        param.requires_grad = True

    if os.path.exists(second_model):
        model.load_state_dict(torch.load(second_model))   # 加载本地模型
    else:
        model.load_state_dict(torch.load(first_model))    # 加载步骤一训练得到的本地模型
    print('Finetune all layers with small learning rate......')
    for epoch in range(1, 101):
        train(model, device, dataloader_train, criterion, optimizer2, epoch)
        if optimizer2.state_dict()['param_groups'][0]['lr'] > 0.00001:
            exp_lr_scheduler.step()
            print(f"learning rate: {optimizer2.state_dict()['param_groups'][0]['lr']}")
        val(model, device, dataloader_val, criterion)
    # 保存整个模型
    torch.save(model.state_dict(), second_model)

print('Done.')

2. 完整资源

https://download.csdn.net/download/liugan528/89833913


http://www.kler.cn/news/334623.html

相关文章:

  • Redis终极入门指南:万字解析帮你从零基础到掌握命令与五大数据结构
  • ARM Assembly 6: Shift 和 Rotate
  • SQL进阶技巧:如何优雅求解指标累计去重问题?
  • SpringBoot在线教育系统:构建与优化
  • react-问卷星项目(6)
  • CMake教程:第一步:一个基本的起点
  • mysql中 and or not的执行顺序
  • 解决Vue应用中遇到路由刷新后出现 404 错误
  • 高等数学 第二讲 数列极限_收敛数列_海涅定理_单调有界准则
  • SkyWalking 高可用
  • Redis SpringBoot项目学习
  • 图文深入理解Oracle Network配置管理(一)
  • Windows系统编程(三)进程与线程二
  • sentinel原理源码分析系列(一)-总述
  • Centos Stream 9备份与恢复、实体小主机安装PVE系统、PVE安装Centos Stream 9
  • C++面试速通宝典——9
  • rabbitMq-----消费者管理模块
  • Perforce静态分析工具2024.2新增功能:Helix QAC全新CI/CD集成支持、Klocwork分析引擎改进和安全增强
  • 使用指标进行量化交易时,有哪些需要注意的风险点呢
  • Spring Data JPA中的锁机制