当前位置: 首页 > article >正文

PyTorch多机训练Loss不一致问题排查指南:基于算子级一致性验证

比较二次训练过程中所有算子的误差,定位存在一致性问题的pytorch算子

    • 一.背景
    • 二.技术方案
      • 1.核心思路
      • 2.关键技术点
    • 三.代码

一.背景

在分布式训练场景中,观察到以下现象:

  1. 相同超参配置下,多次训练的Loss曲线存在显著差异(波动幅度>5%)
  2. 模型收敛稳定性受训练节点数影响,节点越多差异越明显
  3. 梯度检查(Gradient Check)未发现异常,初步排除模型结构问题

二.技术方案

1.核心思路

通过算子级数值一致性验证,定位导致多机训练结果不一致的PyTorch原生算子。关键技术路径:

  1. 算子拦截:利用__torch_dispatch__机制捕获所有ATen算子调用
  2. 双模校验
    • 基准模式:首次运行保存各算子输入/输出的统计特征
    • 验证模式:后续运行实时校验数值一致性
  3. 差异定位:当检测到统计特征偏离时,打印完整的调用栈信息

2.关键技术点

模块实现方案
算子拦截继承TorchDispatchMode重写调度逻辑
特征提取计算张量均值(排除形状/类型等非数值因素)
差异检测使用torch.allclose进行容差对比(默认atol=1e-4)
结果持久化按rank序列化存储基准数据到磁盘
黑名单机制过滤empty_like等非计算类算子,减少误报

三.代码

import torch
from torch.utils._python_dispatch import TorchDispatchMode
from dataclasses import dataclass
from typing import Any
from datetime import datetime
import numpy as np
import torch.nn as nn
import time
import os
import pickle
import inspect
import traceback

@dataclass
class _ProfilerState:
    cls: Any
    object: Any = None

def is_tensor(val):
    return isinstance(val, (torch.Tensor, nn.Parameter))

def do_compare(rank, name, tensor):
    timestamp = time.time()
    seconds = int(timestamp)
    millis = int((timestamp - seconds) * 1000)
    now=f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(seconds))}.{millis:03d}"
    has_nan = torch.isnan(tensor).any()
    if has_nan:
        print(f"has_nan {now} {rank} {name} {tensor.shape} {tensor.dtype} {tensor.device}")
        return False
    current_mean = torch.mean(tensor.float()).cpu().item()
    cache_file = f"/logs/rank_{rank}.pkl"
    os.makedirs(os.path.dirname(cache_file), exist_ok=True)
    stored_data = {}
    if os.path.exists(cache_file):
        try:
            with open(cache_file, 'rb') as f:
                stored_data = pickle.load(f)
        except:
            print("load failed:",cache_file)
            traceback.print_exc()
    if name in stored_data:
        if torch.allclose(torch.tensor(current_mean),
                         torch.tensor(stored_data[name]),
                         atol=1e-4):
            return True
        print("------------------------------In-----------------------------------")
        print(f"MisMatch {now} {rank} {name} {tensor.shape} {tensor.dtype} {tensor.device} {current_mean} {stored_data[name]} min:{torch.min(tensor)} max:{torch.max(tensor)}")
        return False
    else:
        stored_data[name] = current_mean
        with open(cache_file, 'wb') as f:
            pickle.dump(stored_data, f)
        return True

index_counter = 0
def compare_tensor(name, tensor):
    global index_counter
    index_counter += 1
    rank=torch.distributed.get_rank()
    if is_tensor(tensor):
        if not do_compare(rank,f"{name}-{index_counter}",tensor):
            return False
    elif isinstance(tensor, (tuple, list)):
        for idx, t in enumerate(tensor):
            if is_tensor(t):
                if not do_compare(rank,f"{name}-{index_counter}-{idx}",t):
                    return False
    return True

class TorchDumpDispatchMode(TorchDispatchMode):
    def __init__(self,parent):
        super().__init__()
        self.parent=parent

    def is_allow_dump(self,name):
        black_list=["empty","like","zero","detach","has","view",
                    "copy","arange","fill","ones","lift_fresh","alias",
                    "scalar_tensor","clone","stack","slice","source","barrier",
                    "select","random","unsqueeze","expand","normal_"]
        for i in black_list:
            if name.find(i)>=0:
                return False
        return True

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        func_packet = func._overloadpacket
        op_name=f"{func}"
        enable_dump= self.is_allow_dump(op_name)
        if kwargs is None:
            kwargs = {}
        if enable_dump:
            torch.cuda.synchronize()
            if not compare_tensor(f"{op_name}-[input]", args):
                stack = inspect.stack()
                i=0
                for frame_info in reversed(stack):
                    msg=f"{i}:{frame_info.filename}:{frame_info.lineno}"
                    print(msg)
                    i+=1
                print("------------------------------Out-----------------------------------")
        ret= func(*args, **kwargs)
        if enable_dump:
            torch.cuda.synchronize()
            if not compare_tensor(f"{op_name}-[output0]", ret):
                stack = inspect.stack()
                i=0
                for frame_info in reversed(stack):
                    msg=f"{i} {frame_info.filename}:{frame_info.lineno}"
                    print(msg)
                    i+=1
                print("------------------------------Out0-----------------------------------")
            if not compare_tensor(f"{op_name}-[output1]", args):
                stack = inspect.stack()
                i=0
                for frame_info in reversed(stack):
                    msg=f"{i} {frame_info.filename}:{frame_info.lineno}"
                    print(msg)
                    i+=1
                print("------------------------------Out1-----------------------------------")
        return ret

class TorchDebugDumper:
    _CURRENT_Dumper = None
    def __init__(self):
        self.p= _ProfilerState(TorchDumpDispatchMode)

    def __enter__(self):
        assert TorchDebugDumper._CURRENT_Dumper is None
        TorchDebugDumper._CURRENT_Dumper = self
        if self.p.object is None:
            o = self.p.cls(self)
            o.__enter__()
            self.p.object = o
        else:
            self.p.object.step()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        TorchDebugDumper._CURRENT_Dumper = None
        if self.p.object is not None:
            self.p.object.__exit__(exc_type, exc_val, exc_tb)
            del self.p.object

def main():
    pretrain(
        train_valid_test_datasets_provider,
        model_provider,
        forward_step,
        extra_args_provider=llama_argument_handler,
        args_defaults={"tokenizer_type": "GPT2BPETokenizer"},
    )
if __name__ == "__main__":
    with TorchDebugDumper():
		main()

http://www.kler.cn/a/586315.html

相关文章:

  • TGARS2024 | LGP | 面向目标检测的通用且可控攻击
  • Deepseek-R1 VS QwQ-32B 评测对比:文本理解与生成(2)
  • 计算机网络OSI七层模型
  • 专题|Python梯度提升实例合集:GBM、XGBoost、SMOTE重采样、贝叶斯、逻辑回归、随机森林分析信贷、破产数据...
  • 若依(RuoYi)前后端分离项目前端部署宝塔访问不到接口
  • 鸿蒙 @ohos.arkui.componentSnapshot (组件截图)
  • OpnenHarmony 开源鸿蒙北向开发——1.开发环境搭建(DevEco Studio 5.03)
  • Flutter嵌套问题解决方案
  • 专业的IP干净度检测工具
  • 【 <一> 炼丹初探:JavaWeb 的起源与基础】之 JavaWeb 项目的部署:从开发环境到生产环境
  • 涨薪技术|Kubernetes(k8s)之Pod生命周期(下)
  • 面向神经机器翻译的多语言去噪预训练
  • 力扣 Hot 100 刷题记录 - 对称二叉树
  • C++前缀和
  • env.development.local 和 env.development 的区别
  • MySQL EXPLAIN 详解
  • python 提取视频中的音频
  • 9月论文学习
  • Houdini学习笔记
  • 蓝桥杯15届省C