当前位置: 首页 > article >正文

Lightning初探

portch-lightning是pytorch的抽象和包装,它的好处是可复用性强,易维护,逻辑清晰等。

学习使用portch-lightning可以使我们专注于模型,而不是其他的重复脏活。

一、传统的pytorch训练流程

# 模型
model = Modle(nn.Module)

# datasets
train_data_loader = DataLoader()
val_data_loader = DataLoader()

# 优化器
optimizer = torch.optim.SGD(model.parameters,lr)

for epoch in range(num_epoches):
    # 设置为训练模式
    model.train()
    for batch in train_data_loader:
        # 提取数据
        x,target = batch
        # 前向传播
        y = model.forward(x)
        # 计算损失
        loss = MSELoss(y,target)
        # 梯度清零
        optimizer.zero_grad()
        # 反向传播
        loss.backward()
        # 参数优化
        optimizer.step()
    
    # 设置为评估模式
    model.eval()
    # 停止梯度更新
    with torch.no_grad():
        Y = []
        Y_target = []
        for batch in val_data_loader:
            # 提取数据
            x,target = batch
            # 前向传播
            y = model.forward(x)
            # 存储结果
            Y_target.append(target)
            Y.append(y)

        # 统一计算分数
        score = evaluation(Y_target,Y)

这个流程除了模型需要改之外,其他的部分每次都是相同的,需要做很多重复的工作,且加入可视化的逻辑越多,代码就越臃肿,越繁杂。

二、新的portch-lightning的训练流程

"""
原本定义模型时继承的是nn.Module,但现在换成了L.LightningModule
其相对于nn.Module又把部分训练过程封装成了函数
最后又把整体流程封装成了一个大函数
最后实现了:模型定义好之后,只需要调一个大函数就可以完成训练流程
"""
class model(L.LightningModule):
    def init(self):
        pass
    def forward(self):
        pass
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(parameters,lr)
        return optimizer
    def train_step():
        loss = []
        return loss
    def evaluation_step():
        loss = []
"""
函数替代后的具体过程
"""
# 模型(替换,包含优化器)
model

# datasets(没有变化)
train_data_loader = DataLoader()
val_data_loader = DataLoader()

for epoch in range(num_epoches):
    # 设置为训练模式
    model.train()
    for batch in train_data_loader:
        # 训练步骤
        loss = train_step(self,batch,batch_index)
        # 反向传播
        backward()
        # 参数优化
        optimizers_step()
    
    # 设置为评估模式
    model.eval()
    # 停止梯度更新
    with torch.no_grad():
        val_loss_list = []
        for batch in val_data_loader:
            # 评估步骤
            val_loss = evaluation_step(self,batch,batch_index)
            # 保存结果
            val_loss_list.append(val_loss)

        # 统一计算分数
        score = mean(val_loss_list)
"""
把替代后的流程封装为函数
"""
model = model()
trainer = L.Trainer("GPU")
trainer.fit(model,train_data_loader,val_data_loader)


http://www.kler.cn/a/512303.html

相关文章:

  • 海康工业相机的应用部署不是简简单单!?
  • Javascript 将页面缓存存储到 IndexedDB
  • 用户中心项目教程(五)---MyBatis-Plus完成后端初始化+测试方法
  • Ubuntu 24.04 LTS 安装 tailscale 并访问 SMB共享文件夹
  • 20250119面试鸭特训营第27天
  • Spring Boot项目集成Redisson 原始依赖与 Spring Boot Starter 的流程
  • Go channel关闭方法
  • JAVA-IO模型的理解(BIO、NIO)
  • 在VSCode中使用Jupyter Notebook
  • Centos 8 交换空间管理
  • LeetCodeHOT100:60. n个骰子的点数、4. 寻找两个正序数组的中位数
  • 以“智慧建造”为理念,综合应用云、大、物、移、智等数字化技术的智慧工地云平台源码
  • 愿景是什么?
  • JSON-stringify和parse
  • 48V电气架构全面科普和解析:下一代智能电动汽车核心驱动
  • Android 空包签名(详细版)
  • AI刷题-病毒在封闭空间中的传播时间
  • 企业级流程架构设计思路-基于价值链的流程架构
  • 数据结构(二)栈/队列和二叉树/堆
  • centos虚拟机异常关闭,导致数据出现问题
  • 【2024年度个人生活与博客事业的融合与平衡总结】
  • 深入解析 C++17 中的 u8 字符字面量:提升 Unicode 处理能力
  • 大模型LLM-微调 RAG
  • Java-数据结构-二叉树习题(1)
  • AUTOSAR OS模块详解(三) Alarm
  • 我想通过python语言,学习数据结构和算法该如何入手?