当前位置: 首页 > article >正文

torch jit 动态获取buffer

前言

截止目前(torch2.5)似乎是一个不支持的功能。

样例

class MyClass(object):
    def __init__(self, w1,w2,w3):
        self.regist_buffer('w1', w1)
        self.regist_buffer('w2', w2)
        self.regist_buffer('w3', w3)
    def forward(self,x, i):
   		o = 0
		for i in range(1,4):
		   	w_name = f'w{i}'
		    w = self.get_buffer(w_name)
		    o += w*x
	    return o

model = MyClass()
script_model = torch.jit.script(model)

这样会有问题。
在script之后,调用 self.get_buffer() 报错, Unknown type name 'torch.nn.Module'

法2

def forward(self,x, i):
	o = 0
	for i in range(1,4):
	   	w_name = f'w{i}'
	    w = self._buffers[w_name]
	    # or
	    w = self.__dict__['_buffers'][w_name]
	    o += w*x
    return o
   	w_name = f'w{i}'
    w = self._buffers[w_name]
    # or
    w = self.__dict__['_buffers'][w_name]
    return w*x

script之后报错
Module 'MyClass' has no attribute '_buffers'
Module 'MyClass' has no attribute '__dict__
不存在self._buffers__dict__

法3

def forward(self,x):
	o = 0
	for i in range(1,4):
	   	w_name = f'w{i}'
	    w = getattr(self, w_name)
	    o += w*x
    return o

script之后报错
getattr's second argument must be a string literal

getattr只支持静态字面量。
但有时候我们是希望动态获取的。

PS:
torch.compile 暂时也不支持 cuda.stream 相关的操作。


http://www.kler.cn/a/392241.html

相关文章:

  • 01-Ajax入门与axios使用、URL知识
  • 一种基于深度学习的反无人机无人值守系统及方法
  • 闯关leetcode——3174. Clear Digits
  • Mit6.S081-实验环境搭建
  • python实战(八)——情感识别(多分类)
  • 黄色校正电容102j100
  • upload-labs通关练习
  • 闲鱼监控助手货源获取技巧(轻松找到优质货源的方法)
  • 【大数据测试spark+kafka-详细教程(附带实例)】
  • Unity3D设置3D物体不超出相机视角范围(物体一直保持在相机视角范围内)
  • Android S长按文件或视频或编辑中文字或输入框中文字不会弹出分享菜单
  • 零基础入门转录组下游分析——预后模型之多因素cox模型
  • 小西作业1_third order plant(SPM)
  • Linux也有百度云喔~
  • 在Java中使用ModelMapper简化Shapefile属性转JavaBean实战
  • 信令服务器设计之websocket基础
  • BERT配置详解1:构建强大的自然语言处理模型
  • 【Echarts图轮播显示label】
  • PHP动物收容所管理系统-计算机设计毕业源码94164
  • 初阶C++之C++入门基础
  • OKG Research:用户意图驱动的Web3应用变革
  • 系统架构设计师论文:论湖仓一体架构及其应用
  • ECharts实现数据可视化入门详解
  • 技术专家之路:深耕高门槛领域的策略
  • Tangram利用深度学习完成空间与单细胞数据的整合
  • 电脑浏览器打不开网页怎么办 浏览器无法访问网页解决方法