当前位置: 首页 > article >正文

训练模型过程中优雅的指定GPU

目前训练模型大部分在单机多卡的环境下,我们通常会指定一个GPU来训练模型。在不指定GPU情况下,默认使用GPU0来训练,但是很不巧GPU0被别人占了一半显存,导致OOM错误。每次跑模型都要去看下哪张卡显存最大,然后再来修改代码,指定GPU,是不是超级烦人呢!😶‍🌫️,今天就介绍一个每次都由程序自动选择剩余最大的显存的GPU来训练。

1. Quick Start

  • step1: 安装依赖包
    安装管理NVIDIA显卡的python依赖包pynvml
    pip install nvidia-ml-py
    
  • 使用pynvml监控GPU
    import psutil
    import pynvml #导包
    
    
    UNIT = 1024 * 1024
    
    
    pynvml.nvmlInit() #初始化
    gpuDeriveInfo = pynvml.nvmlSystemGetDriverVersion()
    print("Drive版本: ", str(gpuDeriveInfo, encoding='utf-8')) #显示驱动信息
    
    
    gpuDeviceCount = pynvml.nvmlDeviceGetCount()#获取Nvidia GPU块数
    print("GPU个数:", gpuDeviceCount )
    
    
    for i in range(gpuDeviceCount):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)#获取GPU i的handle,后续通过handle来处理
    
        memoryInfo = pynvml.nvmlDeviceGetMemoryInfo(handle)#通过handle获取GPU i的信息
    
        gpuName = str(pynvml.nvmlDeviceGetName(handle), encoding='utf-8')
    
        gpuTemperature = pynvml.nvmlDeviceGetTemperature(handle, 0)
    
        gpuFanSpeed = pynvml.nvmlDeviceGetFanSpeed(handle)
    
        gpuPowerState = pynvml.nvmlDeviceGetPowerState(handle)
    
        gpuUtilRate = pynvml.nvmlDeviceGetUtilizationRates(handle).gpu
        gpuMemoryRate = pynvml.nvmlDeviceGetUtilizationRates(handle).memory
    
        print("第 %d 张卡:"%i, "-"*30)
        print("显卡名:", gpuName)
        print("内存总容量:", memoryInfo.total/UNIT, "MB")
        print("使用容量:", memoryInfo.used/UNIT, "MB")
        print("剩余容量:", memoryInfo.free/UNIT, "MB")
        print("显存空闲率:", memoryInfo.free/memoryInfo.total)
        print("温度:", gpuTemperature, "摄氏度")
        print("风扇速率:", gpuFanSpeed)
        print("供电水平:", gpuPowerState)
        print("gpu计算核心满速使用率:", gpuUtilRate)
        print("gpu内存读写满速使用率:", gpuMemoryRate)
        print("内存占用率:", memoryInfo.used/memoryInfo.total)
    
        """
        # 设置显卡工作模式
        # 设置完显卡驱动模式后,需要重启才能生效
        # 0 为 WDDM模式,1为TCC 模式
        gpuMode = 0     # WDDM
        gpuMode = 1     # TCC
        pynvml.nvmlDeviceSetDriverModel(handle, gpuMode)
        # 很多显卡不支持设置模式,会报错
        # pynvml.nvml.NVMLError_NotSupported: Not Supported
        """
    
        # 对pid的gpu消耗进行统计
        pidAllInfo = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)#获取所有GPU上正在运行的进程信息
        for pidInfo in pidAllInfo:
            pidUser = psutil.Process(pidInfo.pid).username()
            print("进程pid:", pidInfo.pid, "用户名:", pidUser, 
                "显存占有:", pidInfo.usedGpuMemory/UNIT, "Mb") # 统计某pid使用的显存
    
    
    pynvml.nvmlShutdown() #最后关闭管理工具
    

2. Advanced Tutorial

使用 pynvml 写一个自动化脚本,使其在程序开始时自动选择显存最大的GPU

def select_best_gpu():
    import pynvml
    pynvml.nvmlInit()  # 初始化
    gpu_count = pynvml.nvmlDeviceGetCount()
    if gpu_count == 0:
        device = "cpu"
    else:
        gpu_id, max_free_mem = 0, 0.
        for i in range(gpu_count):
            handle = pynvml.nvmlDeviceGetHandleByIndex(i)
            memory_free = round(pynvml.nvmlDeviceGetMemoryInfo(handle).free/(1024*1024*1024), 3)  # 单位GB
            if memory_free > max_free_mem:
                gpu_id = i
                max_free_mem = memory_free
        device = f"cuda:{gpu_id}"
        print(f"total have {gpu_count} gpus, max gpu free memory is {max_free_mem}, which gpu id is {gpu_id}")
    return device


available_device = select_best_gpu()

# 方法1:直接通过os全局设置GPU
import os
if available_device.startswith("cuda"):
    os.environ['CUDA_VISIBLE_DEVICES'] = available_device.split(":")[1]
    
# 方法2:在模型处指定
model = Model()   # 初始化模型
model.to(available_device)

注意:以上方法一定放到程序最开始处,否则指定GPU可能会失败,通常在import torch后,通过os指定GPU就会失败

3. REFERENCE

python查看gpu信息


http://www.kler.cn/a/226643.html

相关文章:

  • 【Java程序设计】【C00239】基于Springboot的漫画之家管理系统(有论文)
  • CentOS 8 下载
  • django+flask警务案件信息管理系统python-5dg53-vue
  • go中的WaitGroups
  • 正点原子--STM32定时器学习笔记(1)
  • LVGL部件8
  • npm ERR! reason: certificate has expired(淘宝镜像过期)
  • 【备战蓝桥杯】——循环结构终篇
  • 【自动化测试教程】Java+Selenium自动化测试环境搭建
  • 上班族学习方法系列文章目录
  • Modbus协议学习第七篇之libmodbus库API介绍(modbus_write_bits等)
  • Linux下find命令详解
  • PriorityBlockingQueue的tryGrow方法
  • 【Spring连载】使用Spring Data访问Redis(三)----连接模式
  • 【misc | CTF】攻防世界 2017_Dating_in_Singapore
  • CSS实现文字大小自适应
  • 【Java程序设计】【C00232】基于Springboot的抗疫物资管理系统(有论文)
  • ubuntu22.04@laptop 常用基础环境安装
  • 二层设备与三层设备的区别--总结
  • 微服务框架go-zero集成swagger在线接口文档