当前位置: 首页 > article >正文

【深度学习】DataLoader自定义数据集制作

第一步 导包

import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

第二步 自定义数据集

data_dir = "./flower_data/"
train_dir = data_dir + "/train_filelist"
valid_dir = data_dir + "/val_filelist"
from torch.utils.data import Dataset,DataLoader
class FlowerDataset(Dataset):
    def __init__(self,root_dir,ann_file,transform=None):
        self.ann_file = ann_file
        self.root_dir = root_dir
        self.img_label = self.load_annotations()
        self.img = [os.path.join(self.root_dir,img) for img in list(self.img_label.keys())]
        self.label = [label for label in list(self.img_label.values())]
        self.transform = transform
        
    def __len__(self):
        return len(self.img)
    
    def __getitem__(self,idx):
        image = Image.open(self.img[idx])
        label = self.label[idx]
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(label))
        return image,label
    
    def load_annotations(self):
        data_infos = {}
        with open(self.ann_file) as f:
            samples = [x.strip().split(" ") for x in f.readlines()]
            for filename,gt_label in samples:
                data_infos[filename] = np.array(gt_label,dtype=np.int64)
        return data_infos

注:ann_file内容格式如下

第三步 自定义transform

data_transforms = {
    "train":
        transforms.Compose([
            transforms.Resize(64),
            transforms.RandomRotation(45),
            transforms.CenterCrop(64),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ]),
    "valid":
        transforms.Compose([
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])
}

第四步 根据自定义Dataset实例化DataLoader

①实例化Dataset

train_dataset = FlowerDataset(root_dir=train_dir,ann_file="./flower_data/train.txt",transform=data_transforms["train"])
valid_dataset = FlowerDataset(root_dir=train_dir,ann_file="./flower_data/val.txt",transform=data_transforms["valid"])

②实例化DataLoader

train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True)
val_loader = DataLoader(valid_dataset,batch_size=64,shuffle=True)

③验证图片是否加载正确

image, label = iter(train_loader).next()
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))


第五步 训练

①前置准备

dataloaders = {"train":train_loader,"valid":val_loader}

model_name = "resnet"
feature_extract = True

# 是否用GPU训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 使用模型
model_ft = models.resnet18()
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 102))

# 优化器设置
optimizer_ft = optim.Adam(model_ft.parameters(),lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1)
criterion = nn.CrossEntropyLoss()

②自定义模型

def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False,filename="best.pth"):
    since = time.time()
    best_acc = 0
    model.to(device)
    
    val_acc_history = []
    train_acc_history = []
    train_losses = []
    valid_losses = []
    LRs = [optimizer.param_groups[0]["lr"]]
    
    best_model_wts = copy.deepcopy(model.state_dict())
    
    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch,num_epochs-1))
        print("-"*10)
        
        # 训练和验证
        for phase in ["train","valid"]:
            if phase == "train":
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_corrects = 0
            
            # 遍历所有数据
            for inputs,labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # 清零
                optimizer.zero_grad()
                # 只有训练的时候计算和更新梯度
                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    loss = criterion(outputs,labels)
                    _,preds = torch.max(outputs,1)
                    
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                        
                # 计算损失
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds==labels.data)
                
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            time_elapsed = time.time() - since
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            
            # 得到最好的那次模型
            if phase=="valid" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                    "state_dict":model.state_dict(),
                    "best_acc":best_acc,
                    "optimizer":optimizer.state_dict()
                }
                torch.save(state,filename)
                
            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
                scheduler.step(epoch_loss)#学习率衰减
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)
                
        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    

    # 训练完后用最好的一次当做模型最终的结果,等着一会测试
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 

③训练模型

model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20, filename='best.pth')


http://www.kler.cn/a/536210.html

相关文章:

  • C32.【C++ Cont】静态实现双向链表及STL库的list
  • nuxt3中使用useFetch请求刷新不返回数据或返回html结构问题解决-完整nuxt3useFetchtch请求封装
  • mounted钩子函数里如何操作子组件的DOM?
  • 解析PHP文件路径相关常量
  • 深入理解和使用定时线程池ScheduledThreadPoolExecutor
  • Java常用类
  • 海康威视豆干型网络相机QT的Demo
  • 【学习总结|DAY036】Vue工程化+ElementPlus
  • 华为小艺助手接入DeepSeek,升级鸿蒙HarmonyOS NEXT即可体验
  • Linux中DataX使用第三期
  • Java 8的Stream API
  • 栈和队列的实现(C语言)
  • 解决aspose将Excel转成PDF中文变成方框的乱码问题
  • esp32 udp 客户端 广播
  • 【Elasticsearch】nested聚合
  • Day67:类的继承
  • 树莓派5添加摄像头 在C++下调用opencv
  • Junit5使用教程(6)--高级特性2
  • HTML学习之CSS三种引入方式
  • 基于JavaWeb开发的java Springboot实现教务管理系统
  • 介绍10个比较优秀好用的Qt相关的开源库
  • Linux后台运行进程
  • 网络安全 | 什么是XSS跨站脚本攻击?
  • 如何利用 Python 爬虫按关键字搜索淘宝商品
  • C++基础系列【5】namespace using
  • JAVA异步的TCP 通讯-客户端