当前位置: 首页 > article >正文

浅谈AI落地之-加速训练

前言

曾在游戏世界挥洒创意,也曾在前端和后端的浪潮间穿梭,如今,而立的我仰望AI的璀璨星空,心潮澎湃,步履不停!愿你我皆乘风破浪,逐梦星辰!

混合精度:

FL32是目前模型存储数据一个比较普遍的格式。有时候过于浪费,根本用不着那么多。所以如果有一种方法能动态调整存储数据的大小,就能节省不少显存占用,从来提高批次大小,加速学习。

混合精度简单简单的来说,就是用FL16+FL32来代替原来的清一色FL32数据。具体实现是,在初始化好scaler以后,在遍历批次中的数据的时候,使用autocast自动对前向传播和损失处理混合精度。然后再使用梯度缩放器来缩放损失,并反向传播。

from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()
# 使用autocast自动处理混合精度
with autocast():
    out = model(**data)  # 前向传播
    loss = out['loss']  # 获取损失
# 使用梯度缩放器来缩放损失,并反向传播
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()


http://www.kler.cn/a/586949.html

相关文章:

  • 模型蒸馏系列——开源项目
  • Mininet树形拓扑解析
  • 条件运算符
  • 洛谷 P1357 花园
  • c语言zixue
  • Java基础编程练习第31题-String类和StringBuffer类
  • 【八股文】ArrayList和LinkedList的区别
  • 【Python 语法】排序算法
  • 个人博客系统测试报告
  • C++程序设计语言笔记——抽象机制:模板
  • eclipse-mosquitt之docker部署安装与使用
  • 现在有分段、句子数量可能不一致的中英文文本,如何用python实现中英文对照翻译(即每行英文对应相应的中文)
  • MySQL事务及索引复习笔记
  • Qt从入门到入土(十) -数据库操作--SQLITE
  • JAVA EE(10)——线程安全——synchronized JUC(java.util.concurrent) 的常见类 线程安全的集合类
  • 机器学习编译器(二)
  • Java中的访问修饰符有哪些
  • Swagger 从 .NET 9 中删除:有哪些替代方案
  • 洛谷 P4933 大师
  • LRU(最近最少使用)算法实现