当前位置: 首页 > article >正文

简化深度学习实验管理:批量训练和自动记录方案

简化深度学习实验管理:批量训练和自动记录方案

在深度学习模型的训练过程中,经常需要多次运行模型,以测试不同参数组合的效果,或确保模型在相同配置下的表现稳定。然而,每次手动记录训练结果不仅耗时,还容易出错。为了提高效率并简化分析流程,我们可以构建一个系统,通过自动执行训练、记录训练结果并生成一张表格来总结不同实验的性能表现。

本文将逐步讲解如何实现这一自动化流程,包括修改训练脚本以记录结果、编写批量运行的 Bash 脚本,以及使用数据分析工具查看和分析最终的训练结果。


1. 修改训练脚本以自动记录训练结果

首先,我们需要确保训练结束后能够自动保存实验的关键参数(如数据集、网络结构、延迟帧数等)和模型的性能指标(如验证精度 accVal)。将这些信息保存到 CSV 文件中,使得每次训练结束后结果都能自动追加到表格文件中,方便后续分析和比较。

实现步骤

在本例中,我们假设需要记录以下参数和结果:

  • nameDataset:数据集名称
  • nameNetwork:网络结构类型(如 ResNet、VGG 等)
  • numFrames:延迟帧数 T
  • accVal:验证精度

我们可以定义一个 save_results 函数,将当前实验的参数和精度追加到一个 CSV 文件中。

代码示例:定义结果保存函数

以下是 save_results 函数的实现示例,该函数可以在训练结束时自动保存训练参数和性能结果。

import csv
import os

# 定义保存结果的函数
def save_results(args, accVal, file_path="training_results.csv"):
    """
    将当前实验的参数和精度追加到 CSV 文件中。
    
    参数:
    - args: 包含实验参数的字典
    - accVal: 验证精度
    - file_path: CSV 文件路径
    """
    # 检查文件是否已存在
    file_exists = os.path.isfile(file_path)
    
    # 定义要保存的参数和结果
    data = {
        "Dataset": args["nameDataset"],
        "Network": args["nameNetwork"],
        "Frames": args["numFrames"],
        "Accuracy": accVal
    }
    
    # 将数据写入 CSV 文件
    with open(file_path, mode="a", newline="") as file:
        writer = csv.DictWriter(file, fieldnames=data.keys())
        
        # 如果文件是新建的,写入表头
        if not file_exists:
            writer.writeheader()
        
        # 写入当前的训练结果
        writer.writerow(data)
解释
  • save_results 函数通过检查 file_path 文件是否存在,决定是否写入表头,以确保文件在首次写入时有清晰的列名。
  • data 字典包含了本次实验的核心参数和精度。每次调用该函数时,都会将当前实验的数据写入 CSV 文件。

示例参数和验证精度

在运行训练脚本时,我们可以定义实验参数 args 并生成验证精度 accVal。实际的验证精度应从模型评估中提取,这里使用一个随机数进行示例:

import random  # 生成示例精度

# 假设这些是实验参数
args = {
    "nameDataset": "CIFAR-10",
    "nameNetwork": "resnet-18",
    "numFrames": 6
}

# 假设训练完成后得到的验证精度
accVal = random.uniform(0.8, 0.9)  # 示例精度,实际应用中从模型评估获取

# 保存结果
save_results(args, accVal)

2. 修改 train.py 以记录训练结果

在实际使用中,我们需要确保 train.py 在训练结束时能够提取并记录最佳验证精度 accVal。如果使用 PyTorch Lightning 或类似的深度学习框架,可以通过 trainer 对象管理训练流程,并从中提取最佳验证精度。

修改步骤

在修改 train.py 之前,确保可以提取验证集上的最佳精度并记录结果:

  1. 提取验证精度:从 trainer 对象中提取最佳验证精度的方法,这通常在模型的回调函数或监控指标中可以找到。具体方法取决于所用框架,如 PyTorch Lightning 或 Keras。
  2. 记录到 CSV 文件:在训练完成后,将 best_accValargs 传递给 save_results 函数,以便将结果保存到 CSV 文件中。
代码示例:提取验证精度并记录

以下是如何在训练结束后提取最佳验证精度并调用 save_results 记录结果的示例代码:

# train.py
import recordResult  # 引入结果记录模块

# 假设使用 argparse 获取训练参数
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--learning_rate', type=float, default=0.001)
args = parser.parse_args()

# 配置 Trainer 对象
trainer = Trainer(
    default_root_dir=args.dirLogs,
    max_epochs=args.numEpoches,
    devices="gpu",
    accelerator="gpu",
    callbacks=[checkpoint_callback],
    log_every_n_steps=50
)

# 开始训练
trainer.fit(model)

# 提取最佳验证精度并记录结果
best_accVal = trainer.callback_metrics.get("val_acc").item()  # 使用回调指标获取验证精度

# 将参数和最佳验证精度记录到 CSV 文件中
recordResult.save_results(args, best_accVal)
说明
  • trainer.fit(model):启动训练过程,通过配置的回调函数自动保存验证精度最高的模型。
  • trainer.callback_metrics.get("val_acc"):从 trainer 的回调指标中提取验证集的最佳精度,适用于 PyTorch Lightning(请根据具体框架调整代码)。
  • recordResult.save_results(args, best_accVal):将训练参数和验证精度传递给 save_results 函数,追加到 CSV 文件中。

3. 编写 Bash 脚本批量运行训练任务

为了简化多次运行 train.py 的过程,可以编写一个 Bash 脚本,自动循环执行训练并记录结果。该脚本会按指定次数循环运行训练脚本,每次运行结束后将结果追加到 CSV 文件中。

Bash 脚本示例:run_training.sh

#!/bin/bash

# 设置默认执行次数为 5
NUM_RUNS=${1:-5}

# 循环执行指定次数
for ((i=1; i<=NUM_RUNS; i++))
do
  echo "开始第 $i 次训练..."
  
  # 执行训练脚本
  python train.py
  
  echo "第 $i 次训练完成。"
done

echo "所有训练任务已完成,总计运行 $NUM_RUNS 次。"
解释
  • NUM_RUNS=${1:-5}:设置默认执行次数为 5,用户可以在运行脚本时通过参数指定执行次数。
  • 每次运行 python train.py 后,训练结果会自动追加到 training_results.csv 文件中,实现批量记录。
使用方法
  1. 确保脚本具有执行权限:首次运行前,需要为脚本添加可执行权限。

    chmod +x run_training.sh
    
  2. 直接运行脚本(默认执行 5 次):

    ./run_training.sh
    
  3. 自定义运行次数:可以在运行时指定执行次数。例如,执行 10 次:

    ./run_training.sh 10
    
解释
  • chmod +x run_training.sh:为脚本添加执行权限,使其可以被直接运行。
  • ./run_training.sh:执行脚本,若不指定参数,默认运行 5 次。
  • ./run_training.sh 10:指定执行次数为 10 次。

4. 分析训练结果并选择最佳模型

当所有训练任务完成后,可以使用 Pandas 等数据分析工具来加载和查看 training_results.csv 文件,快速分析不同参数组合下的模型性能,进而确定最佳模型配置。

使用 Pandas 查看结果并选取最佳模型

以下是使用 Pandas 加载和分析 CSV 文件的示例代码:

import pandas as pd

# 读取 CSV 文件
df = pd.read_csv("training_results.csv")

# 查看完整结果
print("所有训练结果:")
print(df)

# 获取验证精度最高的配置
best_result = df.loc[df['Accuracy'].idxmax()]
print("\n最佳配置:")
print(best_result)
输出示例
所有训练结果:
    Dataset     Network  Frames  Accuracy
0  CIFAR-10   resnet-18       6     0.850
1  CIFAR-10   resnet-18       6     0.870
2  CIFAR-10   resnet-18       6     0.880
...

最佳配置:
Dataset    CIFAR

-10
Network    resnet-18
Frames             6
Accuracy         0.88
Name: 2, dtype: object

解释

  • df['Accuracy'].idxmax():找到验证精度最高的实验配置。
  • df.loc[...]:通过索引提取该配置对应的所有参数,便于进一步分析或复现实验。

总结

通过上述方法,我们构建了一个自动化的批量训练和记录系统,具体流程如下:

  1. 修改训练脚本:使 train.py 在每次训练结束后自动将实验参数和性能指标记录到 CSV 文件中。
  2. 编写批量执行脚本:通过 Bash 脚本 run_training.sh,自动执行训练多次,并将每次结果追加到 CSV 文件中。
  3. 数据分析和模型选择:使用 Pandas 加载 CSV 文件,以表格形式查看不同实验的参数和精度,进而选择最佳实验结果。

这种自动化流程不仅减少了手动记录的工作量,还有效提升了实验管理的效率,使我们可以轻松对比不同参数组合的效果并快速选出最佳模型。


http://www.kler.cn/news/367788.html

相关文章:

  • C#实现简单的文件夹对比程序
  • JAVA篇之类和对象
  • 我准备写一份Stable Diffusion入门指南-part1
  • 高级SQL技巧掌握
  • 前端处理API接口故障:多接口自动切换的实现方案
  • 【大数据学习 | kafka】kafuka的基础架构
  • 暴力匹配算法 (BF):字符串匹配算法的演进之路
  • springboot 网上影院订票系统-计算机毕业设计源码06993
  • 小程序视频SDK解决方案,提供个性化开发和特效定制设计
  • 笔记整理—linux驱动开发部分(1)驱动梗概
  • 第五十二章 安全元素的详细信息 - EncryptedData 详情
  • 【含开题报告+文档+PPT+源码】基于SpringBoot爱之屋摄影预约管理系统的设计与实现
  • Depcheck——专门用于检测 JavaScript 和 Node.js 项目中未使用依赖项的工具
  • 安全知识见闻-通信协议安全
  • uniapp 报错Invalid Host header
  • 求个数不确定的整数的最大公约数(java)
  • WSL(Ubuntu20.04)编译和安装DPDK
  • PHP const 和 define主要区别
  • 关闭钉钉AI助理
  • 【WiFi7】 支持wifi7的手机
  • 机器视觉运动控制一体机在DELTA并联机械手视觉上下料应用
  • 5550 取数(max)
  • Qt:窗口风格设置
  • SQL实战训练之,力扣:1532最近的三笔订单
  • Python | Leetcode Python题解之第503题下一个更大元素II
  • console.log(“res.data = “ + JSON.stringify(res.data));