当前位置: 首页 > article >正文

深度学习| 快速上手深度学习代码的阅读和改写

前言:本文主要是基于pytorch的讲解,方便新手对深度学习代码结构有个概念。了解后,可以学会如何去阅读模型和修改模型,从而快速复现和修改现有的模型。

深度学习代码结构

  • 介绍
  • 1. 模型
  • 2. 数据集
  • 3. 损失
  • 4. 评估
  • 5. 训练
  • 9. 预测
  • 如何阅读和修改深度学习代码

介绍

深度学习的实现是依赖于神经网络的,所以我们讨论的深度学习代码也就是在讨论神经网络。

简单回顾一下神经网络训练过程,有利于我们更好的理解代码结构。

神经网络模型(model)其实就是一堆运算单元的组合,只不过神经网络是模仿了大脑神经元来组合这些运算单元;前向传播是输入(input)通过模型的一堆运算单元组合得到一个输出(output);损失其实就是计算输出(output)和标签(label)的差异得到值;反向传播是通过这个损失值更新模型里面运算单元的参数。

对于深度学习代码其实分为六个部分:模型 数据集 训练 预测 评估 损失。

  • 模型(model):模型的结构、前向传播、保存和读取。
  • 数据集(dataset):数据集的读取、处理和保存。
  • 训练(train):训练模型,读取数据集后,设置patch、epoch,并输出每轮的损失、进行反向传播和保存模型。
  • 预测(predict):模型预测结果,读取数据集和模型后,前向传播得到output结果,并输出评估结果以及保存结果。
  • 评估(evaluation):对output和label做一些指标上评估计算。
  • 损失(Loss):对output和label做一些指标上损失计算,通常评估和损失的指标会比较通用。

1. 模型

以一个简单的分类模型为例子,只是说明代码结构,不考虑具体使用。

import torch.nn as nn
# 模型通常都会继承nn.Module,复杂模型里面的子模块也是继承nn.Module
class NeuralNetwork(nn.Module):
	# num_class是分类的数量
	# 定义了模型具有那些结构
	def __init__(self,num_class):
		super().__init__()# 继承
		self.flatten=nn.Flatten()# 展平,因为要做全连接
		# 全连接层,做分类
		self.classifer=nn.Sequential(
			nn.Linear(28*28, 512),# 全连接
            nn.ReLU(),# 激活
            nn.Linear(512, num_class)# 分类
		)
	# 前向传播
	# 模型结构是如何排列
	def forward(self,x):
		x = self.flatten(x)
        x = self.classifer(x)
        return x

一般来说模型都会比较复杂,就是好几个模块结合在一起,会有很多个这样继承nn.Module的模块,最后有一个总的(也继承了nn.Module)调用了所有的模块。

2. 数据集

网上公开的数据集,通常可以用Dataset直接获取,例如MNIST手写体数据集。

from torch.utils.data import Dataset
training_data = datasets.MNIST(
    root="data",# 将数据保存在本地什么位置
    train=True,# 是否训练数据
    download=True,# 是否下载
    transform=ToTensor(),# 将图像转为 tensor
)

不是很复杂的数据也可以用transformer进行数据预处理,在用dataset读取。

from torchvision import transforms
from torch.utils.data import Dataset
data_transform = transforms.Compose([
    transforms.Resize(256256),# 图像数据大小调整
    transforms.ToTensor(),# 将图像转为 tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))# 归一化,前面是mean后面是std,执行(图像-mean)/ std
])
data_dataset=datasets.ImageFolder(root='地址',transform=data_transform)

对于自己定义的数据集,常常会需要重写Dataset,来更好管理数据。
这里以图像分割二分类数据为例子,读取input和label文件夹中的图像数据。

from torch.utils.data import Dataset
class MyDataSet(Dataset):
	# 读取input和label文件夹下所有的数据
    def __init__(self, _dir, label_dir):
        self.input_dir =input_dir
        self.label_dir = label_dir
        # 获得指定目录下的所有文件
        self.input_file = glob(self.input_dir + '/*')
        self.label_file = glob(self.label_dir + '/*')
    # 获取单个文件
    def __getitem__(self, i):
        img = cv2.imread(self.input_file[i], 0)
        mask = cv2.imread(self.label_file[i], 0)
        # 二分类,做个二值化,因为有时候放进去的图像不是二值化的
        mask[np.where(mask >0)] = 1
        mask[np.where(mask == 0)] = 0
        return {'input': img, 'label': mask}
    # 获取文件长度
    def __len__(self):
        return len(self.input_file)

3. 损失

损失函数是在训练模型的时候使用的,前向传播结束后调用,计算用来参与反向传播调参。

如果是自带的损失函数可以直接调用,例如CE。

criterion = nn.CrossEntropyLoss()

如果是自定义的,需要重写。
可以发现重写的loss是继承nn.Module并且需要写前向传播,这是因为损失计算的值是要参与反向传播的。反向传播过程中,需要计算损失函数相对于每个参数的梯度。这要求损失函数在前向传播中提供了明确的计算路径,以便在反向传播中能够计算这些梯度。
下面给出CE和Dice的组合Loss的例子。

# Dice loss
class DiceCoeff(nn.Module):
    def __init__(self):
        super(DiceCoeff, self).__init__()

    def forward(self, pred, target):
        inter = torch.dot(pred, target) + 0.0001
        union = torch.sum(pred ** 2) + torch.sum(target ** 2) + 0.0001
        t = 2 * inter.float() / union.float()
        return t

class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.dice_loss = DiceLoss()

    def forward(self, pred, target):
        target = target.type(torch.LongTensor).cuda()
        pred_soft = F.softmax(pred,dim=1)# 归一化 维度为1进行归一化,one-hot编码变成预测标签
        y1 = torch.mean(self.cross_entropy_loss.forward(pred, target))
        y2 = torch.mean(self.dice_loss(pred_soft, target))
        y=0.5*y1+0.5*y2
        return y

4. 评估

损失是用来让模型收敛的,评估是用来反应模型的效果,通常会在训练的时候验证集上评估效果,以及在测试上评估效果。

通常会有很多个评估的指标,例如DSC、HD等,说白就是一些函数输入output和label能够得到某些值。
损失和评估很多指标是相同的,例如DSC和Dice的计算逻辑是相同的,代码思路和损失函数中前向传播是基本一样的。只不过损失函数是会参与的优化的,需要继承nn.Module并且写前向传播,而评估只是函数。
损失指标很多的时候,会同一写在一个文件里面,方便管理。

5. 训练

训练,就是读取数据,然后规定好batch和epoch,进行前向传播,计算loss,然后反向传播。

os.environ['CUDA_VISIBLE_DEVICES'] = '1'# 使用那个gpu进行训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 没有gpu就用cpu训练

# DataLoader加载数据Dataset
batch_size = 64# 每次训练拿取多少个
train_loader = DataLoader(training_dataset, batch_size=batch_size)
model = NeuralNetwork().to(device)

loss_fn = nn.CrossEntropyLoss()# 损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)# 就是采样梯度更新模型的可学习参数,使得损失减小。

num_epochs=200# 所有数据训练多少轮
for epoch in range(num_epochs):
	model.train()
	runing_loss=0.0
	for i, (x, y) in enumerate(train_loader):
		# i表示第i个batch,x和y是输入和输出
		optimizer.zero_grad()# 梯度清理
        x, y = x.to(device), y.to(device)# 转到数据gpu上
        pred = model(x)# 前向传播预测
        loss = loss_fn(pred, y)# 计算损失
        loss.backward()# 反向传播
        optimizer.step()# 更新参数
        runing_loss+=loss
	epoch_loss=running_loss/len(train_loader.dataset)# 计算该轮损失
	print(f"epoch{epoch+1},train_loss{epoch_loss:.4f}")# 输出该轮的损失
	if (epoch+1)%20==0:
		torch.save(model.state_dict(), f'seg_epoch{epoch + 1}_deeplab.pth')# 保存每轮模型

9. 预测

预测就是读取数据,放入模型前向传播,只不过不需要计算梯度来反向传播罢了,一般会带有评价指标。

# 读取模型
model=NeuralNetwork().to(device)
model.load_state_dict(torch.load('模型路径'))
# DataLoader加载数据Dataset
batch_size = 64# 每次训练拿取多少个
test_loader = DataLoader(testing_dataset, batch_size=batch_size)# 读取dataset数据
size = len(dataloader.dataset)
total_dice=0.0
with torch.no_grad():# 不需要计算梯度,因为不需要反向传播
	for x,y in test_loader:
		x, y = z.to(device), y.to(device)# 放到gpu设备中
		pred = model(X)# 前向传播预测
		dice=diceCoeff(pred,y)# dsc评估指标
		total_dice+=dice
total_dice/=size
print(f"DiceCoeff:{total_dice}")# 输出DSC值

如何阅读和修改深度学习代码

阅读深度学习代码无非就是为了复现使用和改进模型,在知道深度学习代码组成后,其实上手思路就很明确了。

复现通常关注在于数据的输入和输出,可以从模型的dataset以及训练和预测数据输入和输出入手。

需要修改模型结构的时候,可以先去搜模型相关的论文了解模型结构的思路,然后再从模型的前向传播开始看起,因为前向传播能反应结构的顺序,然后跟着反向传播调用结构顺序慢慢看到其它的小结构。

要修改损失和评价那些也是一样的,去找到对应的位置即可。


http://www.kler.cn/a/305802.html

相关文章:

  • 2024-11-13 学习人工智能的Day26 sklearn(2)
  • python魔术方法的学习
  • WebSocket和HTTP协议的性能比较与选择
  • 界面控件Kendo UI for Angular中文教程:如何构建带图表的仪表板?(一)
  • CentOS 服务
  • 类别变量分析——卡方独立性检验卡方拟合优度检验
  • 6.1 溪降技术:绳结
  • 小阿轩yx-Zabbix企业级分布式监控环境部署
  • 期望极大算法(Expectation Maximization Algorithm,EM)
  • 基于SpringBoot的校园新闻网站设计与实现
  • 视觉SLAM ch5——相机与图像
  • AIGC-初体验
  • python 大模型验证码识别
  • C++11的部分新特性
  • Vue路由:Vue router
  • 使用ESP8266和OLED屏幕实现一个小型电脑性能监控
  • 优化深度学习模型训练过程:提升PASCAL VOC 2012数据集上Deeplabv3+模型训练效率的策略
  • 【leetcode-python】最接近的三数之和
  • Acrobat 9 安装教程
  • Redis入门2
  • 驾校预约学习系统的设计与实现
  • 关于决策树的一些介绍
  • 让孩子们动手又动脑,用学优马电子积木,探索电路的奥秘
  • 计算机毕业设计Python深度学习垃圾邮件分类检测系统 朴素贝叶斯算法 机器学习 人工智能 数据可视化 大数据毕业设计 Python爬虫 知识图谱 文本分类
  • Visual Studio安装教程
  • 如何使用ssm实现流浪动物救助站+vue