一种统计torch内部计算过程算子输入输出信息的方法
最近有个需求,相对比网络中不同设备之间的运行情况,出现差异的时候,需要知道差异的地方,不同的设备上发生了什么,设计如下过程,可以看到中间算子的执行信息,用来辅助算子开发。
dumper.py
import torch
from dataclasses import dataclass, field
from torch.utils._python_dispatch import TorchDispatchMode
from typing import Any
@dataclass
class _ProfilerState:
cls: Any
object: Any = None
global_index = 0
class TorchDumpDispatchMode(TorchDispatchMode):
def __init__(self,parent):
super().__init__()
self.parent=parent
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
global global_index
func_packet = func._overloadpacket
if kwargs is None:
kwargs = {}
if (func_packet.__name__ != "_to_copy"):
print(f"Profiling {func_packet.__name__, args[0].dtype if isinstance(args[0], torch.Tensor) else None}")
print(func_packet.__name__, "->in")
names = f"./golden_dir/input_{global_index}_"
for m in args:
if isinstance(m, torch.Tensor):
names+=str(m.shape)
names+="_"
names+=str(m.dtype)
names+="_"
else:
names+=str(m)
names+="_"
names += ".pt"
print("kwargs is:",kwargs,names)
index = 0
for m in args:
if isinstance(m, torch.Tensor) and m.numel() != 0:
#torch.save(m.cpu(), names)
print("args:", index, m.dtype)
print("max:", m.cpu().reshape(-1).max(), "min:", m.cpu().reshape(-1).min())
break
ret= func(*args, **kwargs)
if isinstance(ret, torch.Tensor) and ret.numel() != 0:
#torch.save(ret.cpu(), names)
print(names, "->out")
print("max:", ret.cpu().reshape(-1).max(), "min:", ret.cpu().reshape(-1).max())
global_index+=1
elif isinstance(ret, tuple):
for result in ret:
shape = None if result is None else result.shape
dtype = None if result is None else result.dtype
names = f"./golden_dir/result_{global_index}_{shape}_{dtype}_{func_packet.__name__}.pt"
print("out op:", names)
if isinstance(result, torch.Tensor) and result.numel() != 0:
#torch.save(result.cpu(), names)
print("max:", result.cpu().reshape(-1).max(), "min:", result.cpu().reshape(-1).max())
global_index+=1
torch.cuda.synchronize()
return ret
class TorchDumper:
_CURRENT_Dumper = None
def __init__(self,schedule: Any):
self.p= _ProfilerState(schedule)
def __enter__(self):
assert TorchDumper._CURRENT_Dumper is None
TorchDumper._CURRENT_Dumper = self
if self.p.object is None:
o = self.p.cls(self)
o.__enter__()
self.p.object = o
else:
self.p.object.step()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
TorchDumper._CURRENT_Dumper = None
if self.p.object is not None:
self.p.object.__exit__(exc_type, exc_val, exc_tb)
使用方法:
from dumper import *
import torch
def layernorm_test(dtype=torch.float16):
activations = torch.randn(1, 4096, 320).cuda().to(dtype)
layernorm = torch.nn.LayerNorm(320).cuda().to(dtype)
result = layernorm(activations)
sum_out = result.sum() * 0.0001
sum_out.backward()
with TorchDumper(TorchDumpDispatchMode):
layernorm_test(torch.bfloat16)
效果:
Profiling ('randn', None)
randn ->in
kwargs is: {'device': device(type='cpu'), 'pin_memory': False} ./golden_dir/input_0_[1, 4096, 320]_.pt
./golden_dir/input_0_[1, 4096, 320]_.pt ->out
max: tensor(4.8994) min: tensor(4.8994)
_to_copy ->in
kwargs is: {'dtype': torch.float32, 'layout': torch.strided, 'device': device(type='cuda', index=0)} ./golden_dir/input_1_torch.Size([1, 4096, 320])_torch.float32_.pt
args: 0 torch.float32
max: tensor(4.8994) min: tensor(-4.8003)
./golden_dir/input_1_torch.Size([1, 4096, 320])_torch.float32_.pt ->out
max: tensor(4.8994) min: tensor(4.8994)
_to_copy ->in
kwargs is: {'dtype': torch.bfloat16} ./golden_dir/input_2_torch.Size([1, 4096, 320])_torch.float32_.pt
args: 0 torch.float32
max: tensor(4.8994) min: tensor(-4.8003)
./golden_dir/input_2_torch.Size([1, 4096, 320])_torch.float32_.pt ->out
max: tensor(4.9062, dtype=torch.bfloat16) min: tensor(4.9062, dtype=torch.bfloat16)
Profiling ('empty', None)
empty ->in
kwargs is: {'device': device(type='cpu'), 'pin_memory': False} ./golden_dir/input_3_[320]_.pt
./golden_dir/input_3_[320]_.pt ->out
max: tensor(nan) min: tensor(nan)
Profiling ('empty', None)
empty ->in
kwargs is: {'device': device(type='cpu'), 'pin_memory': False} ./golden_dir/input_4_[320]_.pt
./golden_dir/input_4_[320]_.pt ->out
max: tensor(2.5819) min: tensor(2.5819)
Profiling ('fill_', torch.float32)
fill_ ->in
kwargs is: {} ./golden_dir/input_5_torch.Size([320])_torch.float32_1.0_.pt
args: 0 torch.float32
max: tensor(nan) min: tensor(nan)
./golden_dir/input_5_torch.Size([320])_torch.float32_1.0_.pt ->out
max: tensor(1.) min: tensor(1.)
Profiling ('zero_', torch.float32)
zero_ ->in
kwargs is: {} ./golden_dir/input_6_torch.Size([320])_torch.float32_.pt
args: 0 torch.float32
max: tensor(2.5819) min: tensor(-2.5820)
./golden_dir/input_6_torch.Size([320])_torch.float32_.pt ->out
max: tensor(0.) min: tensor(0.)
_to_copy ->in
kwargs is: {'dtype': torch.float32, 'layout': torch.strided, 'device': device(type='cuda', index=0)} ./golden_dir/input_7_torch.Size([320])_torch.float32_.pt
args: 0 torch.float32
max: tensor(1.) min: tensor(1.)
./golden_dir/input_7_torch.Size([320])_torch.float32_.pt ->out
max: tensor(1.) min: tensor(1.)
_to_copy ->in
kwargs is: {'dtype': torch.float32, 'layout': torch.strided, 'device': device(type='cuda', index=0)} ./golden_dir/input_8_torch.Size([320])_torch.float32_.pt
args: 0 torch.float32
max: tensor(0.) min: tensor(0.)
./golden_dir/input_8_torch.Size([320])_torch.float32_.pt ->out
max: tensor(0.) min: tensor(0.)
_to_copy ->in
kwargs is: {'dtype': torch.bfloat16} ./golden_dir/input_9_torch.Size([320])_torch.float32_.pt
args: 0 torch.float32
max: tensor(1.) min: tensor(1.)
./golden_dir/input_9_torch.Size([320])_torch.float32_.pt ->out
max: tensor(1., dtype=torch.bfloat16) min: tensor(1., dtype=torch.bfloat16)
_to_copy ->in
kwargs is: {'dtype': torch.bfloat16} ./golden_dir/input_10_torch.Size([320])_torch.float32_.pt
args: 0 torch.float32
max: tensor(0.) min: tensor(0.)
./golden_dir/input_10_torch.Size([320])_torch.float32_.pt ->out
max: tensor(0., dtype=torch.bfloat16) min: tensor(0., dtype=torch.bfloat16)
Profiling ('native_layer_norm', torch.bfloat16)
native_layer_norm ->in
kwargs is: {} ./golden_dir/input_11_torch.Size([1, 4096, 320])_torch.bfloat16_[320]_torch.Size([320])_torch.bfloat16_torch.Size([320])_torch.bfloat16_1e-05_.pt
args: 0 torch.bfloat16
max: tensor(4.9062, dtype=torch.bfloat16) min: tensor(-4.8125, dtype=torch.bfloat16)
out op: ./golden_dir/result_11_torch.Size([1, 4096, 320])_torch.bfloat16_native_layer_norm.pt
max: tensor(4.8125, dtype=torch.bfloat16) min: tensor(4.8125, dtype=torch.bfloat16)
out op: ./golden_dir/result_12_torch.Size([1, 4096, 1])_torch.float32_native_layer_norm.pt
max: tensor(0.2048) min: tensor(0.2048)
out op: ./golden_dir/result_13_torch.Size([1, 4096, 1])_torch.float32_native_layer_norm.pt
max: tensor(1.1668) min: tensor(1.1668)
Profiling ('sum', torch.bfloat16)
sum ->in
kwargs is: {} ./golden_dir/input_14_torch.Size([1, 4096, 320])_torch.bfloat16_.pt
args: 0 torch.bfloat16
max: tensor(4.8125, dtype=torch.bfloat16) min: tensor(-4.7812, dtype=torch.bfloat16)
./golden_dir/input_14_torch.Size([1, 4096, 320])_torch.bfloat16_.pt ->out
max: tensor(1.8438, dtype=torch.bfloat16) min: tensor(1.8438, dtype=torch.bfloat16)
Profiling ('mul', torch.bfloat16)
mul ->in
kwargs is: {} ./golden_dir/input_15_torch.Size([])_torch.bfloat16_0.0001_.pt
args: 0 torch.bfloat16
max: tensor(1.8438, dtype=torch.bfloat16) min: tensor(1.8438, dtype=torch.bfloat16)
./golden_dir/input_15_torch.Size([])_torch.bfloat16_0.0001_.pt ->out
max: tensor(0.0002, dtype=torch.bfloat16) min: tensor(0.0002, dtype=torch.bfloat16)
Profiling ('ones_like', torch.bfloat16)
ones_like ->in
kwargs is: {'pin_memory': False, 'memory_format': torch.preserve_format} ./golden_dir/input_16_torch.Size([])_torch.bfloat16_.pt
args: 0 torch.bfloat16
max: tensor(0.0002, dtype=torch.bfloat16) min: tensor(0.0002, dtype=torch.bfloat16)
./golden_dir/input_16_torch.Size([])_torch.bfloat16_.pt ->out
max: tensor(1., dtype=torch.bfloat16) min: tensor(1., dtype=torch.bfloat16)
Profiling ('mul', torch.bfloat16)
mul ->in
kwargs is: {} ./golden_dir/input_17_torch.Size([])_torch.bfloat16_0.0001_.pt
args: 0 torch.bfloat16
max: tensor(1., dtype=torch.bfloat16) min: tensor(1., dtype=torch.bfloat16)
./golden_dir/input_17_torch.Size([])_torch.bfloat16_0.0001_.pt ->out
max: tensor(0.0001, dtype=torch.bfloat16) min: tensor(0.0001, dtype=torch.bfloat16)
Profiling ('expand', torch.bfloat16)
expand ->in
kwargs is: {} ./golden_dir/input_18_torch.Size([])_torch.bfloat16_[1, 4096, 320]_.pt
args: 0 torch.bfloat16
max: tensor(0.0001, dtype=torch.bfloat16) min: tensor(0.0001, dtype=torch.bfloat16)
./golden_dir/input_18_torch.Size([])_torch.bfloat16_[1, 4096, 320]_.pt ->out
max: tensor(0.0001, dtype=torch.bfloat16) min: tensor(0.0001, dtype=torch.bfloat16)
Profiling ('native_layer_norm_backward', torch.bfloat16)
native_layer_norm_backward ->in
kwargs is: {} ./golden_dir/input_19_torch.Size([1, 4096, 320])_torch.bfloat16_torch.Size([1, 4096, 320])_torch.bfloat16_[320]_torch.Size([1, 4096, 1])_torch.float32_torch.Size([1, 4096, 1])_torch.float32_torch.Size([320])_torch.bfloat16_torch.Size([320])_torch.bfloat16_[False, True, True]_.pt
args: 0 torch.bfloat16
max: tensor(0.0001, dtype=torch.bfloat16) min: tensor(0.0001, dtype=torch.bfloat16)
out op: ./golden_dir/result_19_None_None_native_layer_norm_backward.pt
out op: ./golden_dir/result_20_torch.Size([320])_torch.bfloat16_native_layer_norm_backward.pt
max: tensor(0.0220, dtype=torch.bfloat16) min: tensor(0.0220, dtype=torch.bfloat16)
out op: ./golden_dir/result_21_torch.Size([320])_torch.bfloat16_native_layer_norm_backward.pt
max: tensor(0.4102, dtype=torch.bfloat16) min: tensor(0.4102, dtype=torch.bfloat16)
Profiling ('detach', torch.bfloat16)
detach ->in
kwargs is: {} ./golden_dir/input_22_torch.Size([320])_torch.bfloat16_.pt
args: 0 torch.bfloat16
max: tensor(0.0220, dtype=torch.bfloat16) min: tensor(-0.0164, dtype=torch.bfloat16)
./golden_dir/input_22_torch.Size([320])_torch.bfloat16_.pt ->out
max: tensor(0.0220, dtype=torch.bfloat16) min: tensor(0.0220, dtype=torch.bfloat16)
Profiling ('detach', torch.bfloat16)
detach ->in
kwargs is: {} ./golden_dir/input_23_torch.Size([320])_torch.bfloat16_.pt
args: 0 torch.bfloat16
max: tensor(0.0220, dtype=torch.bfloat16) min: tensor(-0.0164, dtype=torch.bfloat16)
./golden_dir/input_23_torch.Size([320])_torch.bfloat16_.pt ->out
max: tensor(0.0220, dtype=torch.bfloat16) min: tensor(0.0220, dtype=torch.bfloat16)
Profiling ('detach', torch.bfloat16)
detach ->in
kwargs is: {} ./golden_dir/input_24_torch.Size([320])_torch.bfloat16_.pt
args: 0 torch.bfloat16
max: tensor(0.4102, dtype=torch.bfloat16) min: tensor(0.4102, dtype=torch.bfloat16)
./golden_dir/input_24_torch.Size([320])_torch.bfloat16_.pt ->out
max: tensor(0.4102, dtype=torch.bfloat16) min: tensor(0.4102, dtype=torch.bfloat16)
Profiling ('detach', torch.bfloat16)
detach ->in
kwargs is: {} ./golden_dir/input_25_torch.Size([320])_torch.bfloat16_.pt
args: 0 torch.bfloat16
max: tensor(0.4102, dtype=torch.bfloat16) min: tensor(0.4102, dtype=torch.bfloat16)
./golden_dir/input_25_torch.Size([320])_torch.bfloat16_.pt ->out
max: tensor(0.4102, dtype=torch.bfloat16) min: tensor(0.4102, dtype=torch.bfloat16)