当前位置: 首页 > article >正文

Pytorch如何精准记录函数运行时间

0. 引言

参考Pytorch官方文档对CUDA的描述,GPU的运算是异步执行的。一般来说,异步计算的效果对于调用者来说是不可见的,因为

  1. 每个设备按照排队的顺序执行操作
  2. Pytorch对于CPU和GPU的同步,GPU间的同步是自动执行的,不需要显示写在代码中

异步计算的后果是,没有同步的时间测量是不准确的

1. 解决方案

参考引言中提到的帮助文档,Pytorch官方给出的解决方案是使用torch.cuda.Event记录时间,具体代码如下:

# import torch
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

# Run your code snippet here

end_event.record()
torch.cuda.synchronize()  # Wait for the events to be recorded!
elapsed_time_ms = start_event.elapsed_time(end_event)  # elapsed time (ms)

将你的代码插入start_event.record()end_event.record()中间以测量时间(单位毫秒)。

有能力的读者也可以包装为装饰器或者with语句使用:

先书写一个自定义with类(ContextManager)

class CudaTimer:
    def __init__(self):
        self.start_event = torch.cuda.Event(enable_timing=True)
        self.end_event = torch.cuda.Event(enable_timing=True)

    def __enter__(self):
        self.start_event.record()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.end_event.record()
        torch.cuda.synchronize()
        self.elapsed_time = self.start_event.elapsed_time(self.end_event) / 1000 # ms -> s

再安装如下with语句返回:

with CudaTimer() as timer:
	# run your code here
dt = timer.elapsed_time  # s

这样保证了多个文件调用时语句的简单性。特别提醒:获取timer.elapsed_time操作不要写在with语句内部。在with语句未结束时,是无法获取timer的成员变量的。

2. 补充

对于CPU和GPU混合操作的函数,使用torch.cuda.event可能会使统计时间比实际时间短,此时可以使用time.time()代替,标准的with对象书写如下:

# import time
class Timer:
    def __enter__(self):
        self.start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        torch.cuda.synchronize()
        self.elapsed_time = time.time() - self.start_time

然后只需要将上文的with CudaTimer() as timer替换为with Timer() as timer即可。


http://www.kler.cn/a/390644.html

相关文章:

  • 大数据技术实训:Hadoop完全分布式运行模式配置
  • 2024下半年软考中项考试成绩多久出来?成绩合格标准是多少?
  • MySQL的ibtmp1文件详解及过大处理策略
  • laravel php artisan storage:link 后通过nginx代理访问图片404 not found问题
  • 人工智能--自然语言处理简介
  • 干部调整辅助决策系统:为干部管理注入新活力
  • 三周精通FastAPI:37 包含 WSGI - Flask,Django,Pyramid 以及其它
  • Fortinet Security Fabric安全平台
  • 【GPTs】Get Simpsonized:一键变身趣味辛普森角色
  • 微服务电商平台课程三:搭建后台服务
  • 20241112_高级工程数学作业
  • 观影新境界:nastool自动化管理结合cpolar助力群晖NAS远程影音享受——“cpolar内网穿透”
  • linux基础-完结(详讲补充)
  • elementUI中2个日期组件实现开始时间、结束时间(禁用日期面板、控制开始时间不能超过结束时间的时分秒)实现方案
  • PaddleOCR安装教程
  • 【架构论文-1】面向服务架构(SOA)
  • springboot yml配置信息书写与获取
  • 大数据数据存储层MemSQL, HBase与HDFS
  • SQL相关常见的面试题
  • css的z-index图层使用有什么要求
  • Python小游戏25——黄金矿工