当前位置: 首页 > article >正文

Pytorch学习笔记(十六)Image and Video - Transfer Learning for Computer Vision Tutorial

这篇博客瞄准的是 pytorch 官方教程中 Image and Video 章节的 Transfer Learning for Computer Vision Tutorial 部分。

  • 官网链接:https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
完整网盘链接: https://pan.baidu.com/s/1L9PVZ-KRDGVER-AJnXOvlQ?pwd=aa2m 提取码: aa2m 

Transfer Learning for Computer Vision Tutorial

这个示例中将介绍如何使用迁移学习训练卷积神经网络进行图像分类。

实际上,很少有人从头开始训练整个卷积网络,通常在非常大的数据集上预训练 ConvNet(例如 ImageNet,其中包含 120 万张图像和 1000 个类别),然后将 ConvNet 用作初始化或固定特征提取器来完成感兴趣的任务。

两个主要的迁移学习场景如下所示:

  • 微调 ConvNet:使用预训练网络来初始化网络,其余训练与常规一致;
  • ConvNet 作为固定特征提取器,冻结除最终全连接层之外的所有网络的权重,最后一个全连接层将被替换为具有随机权重的新层,并且只训练这一层;

导入依赖包:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import os, time
from PIL import Image
from tempfile import TemporaryDirectory

cudnn.benchmark = True
plt.ion()

Load Data

使用 torchvisiontorch.utils.data 来加载数据。目标是训练一个模型来对蚂蚁和蜜蜂进行分类,为蚂蚁和蜜蜂各准备了大约 120 张训练图像,每个类别有 75 张验证图像,该数据集是 imagenet 的一个非常小的子集。从这个 链接 中下载并解压数据。

定义一个数据增强函数

data_transforms = {
    'train': transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
    ),
    'val': transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
    )  
}

定义数据加载器

data_dir = 'data/hymenoptera_data'

image_datasets = {
    x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']
}
data_loaders = {
    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']
}
dataset_size = {
    x: len(image_datasets[x]) for x in ['train', 'val']
}
class_names = image_datasets['train'].classes

检查可用设备

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'

抽查几个数据

def imshow(inp, title=None):
    """Display image for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)

inputs, classes = next(iter(data_loaders['train']))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])

Training the model

定义训练函数

def train_model(model, criterion, optmizer, scheduler, num_epochs=25):
    since = time.time()
    
    with TemporyDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')
        print(f"Best model save as {best_model_params_path}")
        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0
        
        for epoch in range(num_epochs):
            print('-' * 30)
            print(f"Epoch {epoch+1}/{num_epochs}")
            
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()
                
                running_loss = 0.0
                running_corrects = 0
                
                for inputs, labels in data_loaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    optimizer.zero_grad()
                    
                    # train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(preds, labels)
                        
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
                    
                    # staistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                
                if phase == 'train':
                    scheduler.step()
                    
                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc  = running_corrects.double() / dataset_sizes[phase]
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)
            print()

        time_elpased = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')
            
        model.load_static_dict(torch.load(best_model_params_path, weights_only=True))
    return model

Visualizing the model predictions

定义模型可视化工具

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig.plt.figure()
    
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(data_loaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {class_names[preds[j]]}')
                imshow(inputs.cpu().data[j])
                
                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return 
        model.train(mode=was_training)

Finetuning the ConvNet

拉取预训练模型

model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features

第一优化器与损失函数

model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

训练模型

model_ft = train_model(model_ft, criterion, optimizer, exp_lr_scheduler, num_epochs=25)

抽查可视化

visualize_model(model_ft)

在这里插入图片描述


ConvNet as fixed feature extractor

冻结除最后一层之外的所有参数,设置 require_grad = False 来冻结参数,这样梯度就不会在 Backward() 中计算。

加载预训练模型

model_conv = torchvision.models.resnet18(weights="IMAGENET1K_V1")

for param in model_conv.parameters():
    param.requires_grad = False

替换掉模型的最后一层

num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

定义优化器与损失函数

model_conv = model_conv.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_conv.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

训练模型

model_conv = train_model(model_conv, criterion, optimizer, exp_lr_scheduler, num_epochs=20)

抽查可视化

visualize_model(model_conv)

plt.ioff()
plt.show()

Inference on custom images

使用指定路径文件进行推理

def visualize_model_predictions(model, img_path):
    was_training = model.training
    model.eval()
    
    img = Image.open(img_path)
    img = data_transforms['val'](img)
    img = img.unsqueeze(0)
    img = img.to(device)
    
    with torch.no_grad():
        outputs = model(img)
        _, preds = torch.max(outputs, 1)
        
        ax = plt.subplot(2,2,1)
        ax.axis('off')
        ax.set_title(f"Predicted: {class_names[preds[0]]}")
        imshow(img.cpu().data[0])
        
        model.train(mode=was_training)

绘制图像

visualize_model_predictions(
    model_conv,
    img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg'
)

plt.ioff()
plt.show()

http://www.kler.cn/a/612955.html

相关文章:

  • Mysql-DML
  • Linux命令大全:从入门到高效运维
  • Mac: 运行python读取CSV出现 permissionError
  • 【LeetCode 题解】数据库:180. 连续出现的数字
  • 提示词应用:IT模拟面试
  • CSS学习笔记5——渐变属性+盒子模型阶段案例
  • 构建高可用性西门子Camstar服务守护者:异常监控与自愈实践
  • k近邻算法K-Nearest Neighbors(KNN)
  • office_word中使用宏以及DeepSeek
  • 如何让DeepSeek-R1在内网稳定运行并实现随时随地远程在线调用
  • Redis原理:setnx
  • 基于深度学习的图像超分辨率技术研究与实现
  • 解决 Apache Kylin 加载 Hive 表失败的问题:深入分析与解决方案
  • 逗万DareWorks|创意重构书写美学,引领新潮无界的文创革命
  • 从物理学到机器学习:用技术手段量化分析职场被动攻击行为
  • 配置完nfs后vmware虚拟机下ubuntu/无法联网问题
  • 生成信息提取的大型语言模型综述
  • 看懂roslunch输出
  • Neo4j【环境部署 03】插件APOC和ALGO配置使用实例分享(网盘分享3.5.5兼容版本插件)
  • Python 爬虫案例