当前位置: 首页 > article >正文

LLM-pruner源码解析

1.超参数

模型剪枝的超参数

模型
模型检查点和日志的保存地址
剪枝比例,这里默认0.5
剪枝类型,这里模型L2

 模型生成时的超参数

温度
top_p
最大序列长度

逐通道,逐块,逐层,这个逐层我不记得在论文里面提过啊
layer:保留前n层

注意力模块和线性层的开始层和结束层

剪枝的迭代次数
组的计算策略:这里采用的求和

是否使用全局剪枝

泰勒:包括在向量维度上的,在元素维度上的一阶、二阶(混合参数不知道指的是啥)
向量维度:
更细的元素维度:

这几个生成参数也没啥好说的,加载设备

确定torch的版本

官方给的bloom的配置是7B1的模型

我这里用的3B后面还要根据分析的结果改一下

3B有30层,

bloom7b有24层 这个4,20代表从0开始的排序是20还是从1开始的排序是20呀

按照配置获取的超参数

2.程序运行逻辑

第一步先固定下随机种子

设置日志

获得tokenizer

获得模型,这个类是llm-pruner自己写的,

有个问题:为啥要自己重新写一个加载类呢

自定义加载类

这段代码可以不用看直接看对q,k,v重新排序,这里自己写的加载类和transformers自带的没有区别,猜测应该是大佬防止模块名不一致,自己又重新写了一遍

下面套的类比较多这里为了区别,运行到哪一类提前说一下是继承顺序 

BloomForCausallm:BloomForCausallm类先继承BloomPreTrainedModel类

from transformers.models.bloom.configuration_bloom import BloomConfig

BloomForCausallm继承BloomPreTrainedModel:BloomPreTrainedModel类继承PreTrainedModel类

from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig

BloomForCausallm调用BloomModel:实例化BloomModel类,获取模型结构并初始化权重

BloomForCausallm调用BloomModel,BloomModel继承BloomPreTrainedModel:这个BloomModel类也是继承BloomPreTrainedModel类

BloomForCausallm调用BloomModel,BloomModel调用BloomBlockBloomModel类调用BloomBlock类

BloomForCausallm调用BloomModel,BloomModel调用BloomAttentionBloomAttention

BloomForCausallm调用BloomModel,BloomModel调用BloomMLP:BloomMLP

BloomForCausallm调用BloomModel,BloomModel调用BloomGelu:BloomGelu

BloomForCausallm调用BloomModel,BloomModel继承PreTrainedModel中的self.post_init()实例方法:BloomModel中的post_init,这个方法是PreTrainedModel中的方法

BloomForCausallm调用BloomModel,BloomModel继承PreTrainedModel中的self.post_init()实例方法,self.post_init()调用PreTrainedModel中的self.init_weight()实例方法和self._back_compatibity_gradient_ckeckpointing()实例方法
init_weights:如果需要,修剪并可能初始化权重。如果使用自定义的 `PreTrainedModel`,你需要在 `_init_weights` 中实现任何初始化逻辑。

BloomForCausallm:对于 self.transformer,post_init方法中的判定都是否,所以还是保持原来的参数不变,并没有对self.transformer包含的模块进行重新初始化

BloomForCausallm:self.lm_head 

BloomForCausallm:self.post_init()判断都是否,没有对网络进行重新初始化

BloomForCausallm继承BloomPreTrainedModel: 我还是没明白什么时候调用的_init_weight方法,只对self.lm_head进行标准化

q,k,v重排

从这里开始需要看,将q,k,v的顺序进行重排

torch.view 函数的追踪在某些情况下比较复杂,因此,查询、键、值的索引映射有时会遇到问题。为了避免这些问题,函数通过分离查询、键、值的方式来重新组织权重和偏置。


将模型转化为fp16 

from LLMPruner.templates.prompts import prompts

model.generate也是自己写的,先不看这里直接看pruner

pruner

选择pruner类型,将模块参数全都转化为要求梯度,组依赖关系的计算方式选择求和,给定提示prompt

选择taylor方式,一阶,求和,实例化TaylorImportance类

from LLMPruner.pruner import hf_llama_pruner as Pruner

Pruner这个类是自己定义的,如果模型不一样对应的类也不一样,具体怎么根据自己的模型改还得继续向下看

在baichuan中

组依赖关系求和,不进行归一化,一阶泰勒
import LLMPruner.torch_pruning as tp

逐模块计算

获取参数

从开始层到结束层 将q,k,v的pruner比例平分
"ch_sparsity_dict": { model.transformer.h[i].self_attention.query_key_value: args.pruning_ratio / 3 for i in range(args.block_attention_layer_start, args.block_attention_layer_end) },

"root_instances": [model.transformer.h[i].mlp.dense_h_to_4h for i in range(args.block_mlp_layer_start, args.block_mlp_layer_end)] +
                  [model.transformer.h[i].self_attention.query_key_value for i in range(args.block_attention_layer_start, args.block_attention_layer_end)],

开始pruner

import LLMPruner.torch_pruning as tp 

获取MetaPruner类的实例属性

对DependencyGraph类实例化

from ... import ops, dependency

from . import _helpers, utils, ops

这几个函数都是调用的ops中的类,CUSTOMIZED定制为None

已经注册的编辑器,更新编辑器。定制pruner,忽略的层
from .pruner import function

 

刚开始的编辑器

更新之后的编辑器

 按照编辑器中的值获取pruner输入通道的函数,定制里面是空值

 按照编辑器中的值获取pruner输出通道的函数,定制里面是空值

每个pruner类型获取pruner输入/出通道的函数,这里只是举例子

_op_id为0,记录pruner的历史,现在为空 

调用build_dependency方法,输入参数包括模型,前面设置的提示词,刚才参数中的前向函数model(example_inputs),剩下的三个output_transform,unwrapped_parameters,customized_pruners都是None或空值

build_dependency 的实例属性包括:获取模型,命名的模块
customized_pruners,CUSTOMIZED_PRUNERS之前已经知道了两个字典都是空

self._module2name有命名的模块有337个 

检测没有包装的参数

新建已包装模块的列表,获取注册表中的pruner

获得每个模块的类型,如果操作类型在pruner模块类型中且不是逐元素(后面不用看后面是空),把参数加入到已包装参数的类中

已包装的参数有336个

新建unwrapped_detected和_param_to_name

遍历看看哪个是未包装的,加入unwrapped_detected

 

最后的运行结果,没有不被包装的,同样_param_to_name里面也是空的

从 unwrapped_detected 列表中移除所有出现在 unwrapped_parameters 列表中的元素,最终的结果存回 unwrapped_detected 变量,这个unwrapped_parameters是类属性,为空列表

如果 unwrapped_detected不为空的话相关的处理手段,对 unwrapped_detected 列表中的每个元素进行处理,找出其 最后一个大于1的维度,并将该元素与对应的维度信息存入 unwrapped_parameters 列表。最后,将 unwrapped_parameters 保存到 self.unwrapped_parameters

开始追踪计算图了,输入有模型,提示输入,前向,output_transform为None

model.eval()
初始化gradfn2module和visited

获得之前pruner编辑器中的内容 

还是之前的那18个 

  如果模型模块不在忽略的层中且在注册表内,在模型的每一层上注册一个前向钩子(forward_hook),这些钩子将在模型进行前向传播时触发并调用 _record_grad_fn 函数

visited 是一个字典,键是模块对象,值是该模块被调用的次数。每当某个模块被前向传播执行时,visited[module] 增加 1
当前模块是否是 nn.Linear 层,并且该层的输出张量维度是否为 3(例如,(batch_size, seq_len, hidden_dim))。如果条件成立,设置 self._2d_4d = False,表示模型输出的维度不再是 2D 或 4D,而是 3D。
有些层(如 LSTM, GRU)的输出是一个元组,通常包括输出张量和一些附加信息(如隐藏状态)。这个条件会检查输出是否是一个元组,如果是,则提取元组的第一个元素作为最终的输出
PackedSequence 是 PyTorch 中用于表示 RNN 变长序列的输出格式。此条件用于检查 outputs 是否是一个 PackedSequence 对象,如果是,则将其 .data 提取出来。.data 是实际的张量数据
outputs.grad_fn 是 outputs 张量的梯度计算函数(grad_fn),它记录了张量如何计算出来。这里的 gradfn2module 是一个字典,将梯度计算函数 (grad_fn) 映射到对应的模块 module

之前自定义了前向函数 forward_fn,同时会调用之前注册的hook函数 
前向完成后移除掉hook函数

在前向的过程中填充记录模块调用次数的visited 字典和记录计算梯度函数的gradfn2module列表

针对递归模型或层,找到被调用多次的模块记录到reused列表中

这里没有被条用多次的,是空的列表 

这里output_transform是None,如果有的话对模型的输出结果进行转换

from . import _helpers, utils, ops

对于utils.flatten_as_list()

如果 obj 是一个张量,则将其包装成一个列表并返回

检查 obj 是否是一个列表(list)或元组(tuple)。如果是,创建一个空的 flattened_list 用于存储展开后的元素。然后,递归地调用 flatten_as_list 来展开列表或元组中的每个元素(sub_obj)。使用 extend 方法将每次递归得到的展开结果合并到 flattened_list 中。最终返回这个完全展开的列表

检查 obj 是否是一个字典(dict)。如果是,创建一个空的 flattened_list 用于存储展开后的元素。
然后,递归地调用 flatten_as_list 来展开字典中每个键对应的值(sub_obj)。使用 extend 将每次递归得到的展开结果合并到 flattened_list 中。最终返回这个完全展开的列表

如果 obj 既不是 torch.Tensor、列表、元组或字典,那么直接返回 obj 本身。这部分适用于基本类型(如整数、浮点数、字符串等)。

调用是实力属性_trace_computational_graph追踪计算图
输入包括module2node字典,模块梯度计算函数,记录模块计算函数的gradfn2module字典,和纪律被多次调用的模块的字典

 非递归计算图构建

processing_stack: 用于存储待处理的梯度函数(grad_fn)节点,类似栈(stack)的数据结构。
visited: 用于跟踪已经处理过的 grad_fn,避免重复处理。
visited_as_output_node: 用于追踪作为输出节点的计算图节点。

在每次循环中,弹出栈顶的梯度函数(grad_fn)并开始处理。
如果当前 grad_fn 已经处理过,则跳过(防止重复计算)

 调用create_node_if_not_exists

如果 module 已经存在,并且该 module 已经在 module2node 字典中关联了一个节点(即已存在对应的计算图节点),并且该 module 不在 reused 中(表示该节点没有被标记为“已重用”),那么直接返回现有的节点 module2node[module

如果 module 为空(表示这是一个新模块,之前没有创建过),则会根据 grad_fn 创建一个新的模块并与其关联
如果 grad_fn 没有 name 属性(说明它是一个不常见的或自定义的操作),则将其视为一个 逐元素操作(如加法、减法等),并使用 ops._ElementWiseOp 创建一个新的操作模块 module,并给这个模块分配一个唯一的 op_id。self._op_id 会在每次创建模块后自增。
如果 verbose 为 True,则发出警告,提示遇到了一个未知操作,默认将其视为逐元素操作。
如果 grad_fn.name() 包含特定的字符串(如 "catbackward"、"split"、"view" 等),则根据操作类型创建对应的模块(例如 ops._ConcatOp 表示拼接操作,ops._SplitOp 表示拆分操作,ops._ReshapeOp 表示形状变化操作等)。
如果没有匹配到特定类型的操作,则默认将其视为 逐元素操作。
创建好模块后,会将 grad_fn 与新创建的模块存储到 gradfn2module 字典中,以便以后查找。

 如果 module 还没有在 module2node 字典中找到对应的节点,则创建一个新的节点 Node 对象。
该节点包含以下信息:
module: 关联的操作模块
grad_fn: 关联的梯度计算函数
name: 从 _module2name 字典中获取模块的名称,如果没有,则为 None。
如果该模块是自定义的修剪器(CUSTOMIZED_PRUNERS),则将节点类型设置为 CUSTOMIZED。
将新节点添加到 module2node 字典中,以便后续访问。
如果 module 已经有对应的节点,则直接使用已存在的节点

hasattr() 是 Python 内置的一个函数,用来检查一个对象是否具有指定的属性。
检查当前的 grad_fn(计算图中的节点)是否有 next_functions 属性。
grad_fn.next_functions 是一个可迭代对象,每个元素表示当前梯度函数(操作)依赖的输入(上游节点)。遍历 next_functions 列表中的每个元素,来处理每个输入。

如果 f[0] 为 None,表示该输入没有有效的梯度函数,因此跳过这个输入

这行代码检查 f[0](即当前输入的 grad_fn)是否有 name 属性,并且其名称是否包含 "accumulategrad"(表示该输入是一个叶子变量)。这种叶子变量通常对应于模型参数(如权重或偏置),它们不是由其他操作计算得到的,而是计算图中的输入

如果 f[0] 是叶子变量,进一步检查它是否属于未包装的参数(即 unwrapped_parameters)。
如果找到了匹配的参数,gradfn2module[f[0]] = p 将 grad_fn 映射到该参数(p)。同时,使用 self._module2name 为该参数生成一个名称 "UnwrappedParameter_j (shape)",并将其赋值为 grad_fn 的名称。
如果没有找到匹配的参数,跳过当前输入 
调用 create_node_if_not_exists(f[0]) 为输入 f[0] 创建一个节点

node.add_input(input_node, allow_dumplicated=False) 将当前的 input_node作为输入添加到 node中。allow_dumplicated=False 表示不允许重复连接相同的输入。
input_node.add_output(node, allow_dumplicated=False) 将 ndoe\作为输出添加到 input_node中。

f[0] 被添加到 processing_stack 中,表示该输入已经被处理

visited.add(grad_fn) 将当前的 grad_fn 标记为已访问,表示该节点已经被处理过。
visited_as_output_node.add(node) 将当前的 node 标记为已访问的输出节点,防止后续重复处理

对于没有包装的节点

最后返回模块和节点之间的关系

打个节点,下次再看


http://www.kler.cn/a/413064.html

相关文章:

  • VS2022的MFC的ReadString的问题
  • 熔断限流:业务实现自我保护
  • C++ ADL参数依赖查找
  • scala统计词频
  • 嵌入式工程师面试笔试总结——day2
  • TorchMoji使用教程/环境配置(2024)
  • 记录下在html文件中如何直接使用npm依赖,以threejs为例
  • sentry前端接入 报错403
  • 2022 年 3 月青少年软编等考 C 语言三级真题解析
  • YourPHPCMS Register_checkEmail存在sql注入漏洞
  • uniapp中的事件:v-on
  • Spring Boot 3 集成 Spring Security(3)数据管理
  • 同时多平台git配置:GitHub和Gitee生成不同的SSH Key
  • WPF——自定义ToolTip
  • Git远程仓库过大导致clone失败的解决方法
  • pytorch 和tensorflow loss.item()` 只能用于只有一个元素的张量. 防止显存爆炸
  • 什么是缓存击穿?如何避免之布隆过滤器
  • 07 初始 Oracle 优化器
  • Java设计模式笔记(一)
  • 14、保存与加载PyTorch训练的模型和超参数