torch jit 动态获取buffer
前言
截止目前(torch2.5)似乎是一个不支持的功能。
样例
class MyClass(object):
def __init__(self, w1,w2,w3):
self.regist_buffer('w1', w1)
self.regist_buffer('w2', w2)
self.regist_buffer('w3', w3)
def forward(self,x, i):
o = 0
for i in range(1,4):
w_name = f'w{i}'
w = self.get_buffer(w_name)
o += w*x
return o
model = MyClass()
script_model = torch.jit.script(model)
这样会有问题。
在script之后,调用 self.get_buffer()
报错, Unknown type name 'torch.nn.Module'
法2
def forward(self,x, i):
o = 0
for i in range(1,4):
w_name = f'w{i}'
w = self._buffers[w_name]
# or
w = self.__dict__['_buffers'][w_name]
o += w*x
return o
w_name = f'w{i}'
w = self._buffers[w_name]
# or
w = self.__dict__['_buffers'][w_name]
return w*x
script之后报错
Module 'MyClass' has no attribute '_buffers'
Module 'MyClass' has no attribute '__dict__
不存在self._buffers
和 __dict__
法3
def forward(self,x):
o = 0
for i in range(1,4):
w_name = f'w{i}'
w = getattr(self, w_name)
o += w*x
return o
script之后报错
getattr's second argument must be a string literal
getattr只支持静态字面量。
但有时候我们是希望动态获取的。
PS:
torch.compile 暂时也不支持 cuda.stream 相关的操作。