当前位置: 首页 > article >正文

剪枝与重参第四课:NVIDIA的2:4剪枝方案

目录

  • NVIDIA的2:4 pattern稀疏方案
    • 前言
    • 1.稀疏性的研究现状
    • 2.图解nvidia2-4稀疏方案
    • 3.训练策略
    • 4.手写复现
      • 4.1 大体框架
      • 4.2 ASP类的实现
      • 4.3 mask的实现
      • 4.4 模型初始化
      • 4.5 Layer嵌入稀疏特性
      • 4.6 优化器初始化
      • 4.7 拓展-dynamic function assignment
      • 4.8 完整示例代码
    • 总结

NVIDIA的2:4 pattern稀疏方案

前言

手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。

本次课程主要讲解NVIDIA的2:4剪枝方案。

reference:

ASP nvidia 2:4 pattern pruning

paper:

  • Accelerating Sparse Deep Neural Networks

code:

  • https://github.com/NVIDIA/apex/tree/master/apex/contrib/sparsity

blog:

  • https://developer.nvidia.com/blog/accelerating-inference-with-sparsity-using-ampere-and-tensorrt/

tensor core:

  • https://developer.nvidia.com/blog/programming-tensor-cores-cuda-9/

课程大纲可看下面的思维导图

在这里插入图片描述

1.稀疏性的研究现状

许多研究集中在两方面:

  • 大量(80-95%)的非结构化、细粒度稀疏
  • 用于简单加速的粗粒度稀疏

这些方法所面临的挑战有:

  • 精度损失

    • 高稀疏度往往会导致准确率损失几个百分点,即使拥有先进的训练技术也是如此
  • 缺少一种适用于不同任务和网络的训练方法

    • 恢复准确性的训练方法因网络而异,通常需要超参数搜索
  • 缺少加速

    • Math:非结构数据难以利用现代向量/矩阵数学指令的优势
    • Memory access:非结构化数据往往不能很好地利用内存总线,由于读操作之间存在依赖关系,导致延迟增加
    • Storage overheads:metadata占用的存储空间比非零权重多消耗2倍,从而抵消了一些压缩的好处。(metadata通常指的是对于权重矩阵的稀疏性描述信息,例如哪些位置是零元素,哪些位置是非零元素)

2.图解nvidia2-4稀疏方案

NVIDIA在处理稀疏矩阵W时,会采用2:4稀疏方案。在这个方案中,稀疏矩阵W首先会被压缩,压缩后的矩阵存储着非零的数据值,而metadata则存储着对应非零元素在原矩阵W中的索引信息。具体来说,metadata会将W中非零元素的行号和列号压缩成两个独立的一维数组,这两个数组就是metadata中存储的索引信息。如下图所示:

在这里插入图片描述

对于大型矩阵相乘时,我们可以采用2:4稀疏方案减少计算量,假设矩阵A和B相乘得到C,正常运算如下图所示:

在这里插入图片描述

我们可以将A矩阵进行剪枝使其变得稀疏,如下图所示:

在这里插入图片描述

而针对于稀疏矩阵,我们可以通过上述的NVIDIA方案将其变为2:4的结构,可以将A矩阵进行压缩,而对矩阵B的稀疏是通过硬件上面的Sparse Tensor Cores进行选择,如下图所示:

在这里插入图片描述

3.训练策略

NVIDIA提供的2:4稀疏训练方案步骤如下:

  • 1)训练网络
  • 2)2:4稀疏剪枝
  • 3)重复原始的训练流程
    • 超参数的选择与步骤1一致
    • 权重的初始化与步骤2一致
    • 保持步骤2中的 0 patter:不需要重新计算mask

图示如下:

在这里插入图片描述

4.手写复现

4.1 大体框架

示例代码如下:

import os
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

model = None
optimizer = None

class ToyDataset(Dataset):
    def __init__(self):
        x = torch.round(torch.rand(1000) * 200) # (1000,)
        x = x.unsqueeze(1) # (1000,1)
        x = torch.cat((x, x * 2, x * 3, x * 4, x * 5, x * 6, x * 7, x * 8), 1) # (1000,8)
        self.X = x
        self.Y = self.X
    
    def __getitem__(self, index):
        return self.X[index], self.Y[index]
    
    def __len__(self):
        return len(self.X)

training_loader = DataLoader(ToyDataset(), batch_size=100, shuffle=True)

def train():
    criterion = nn.MSELoss()
    for i in range(500):
        for x, y in training_loader:
            loss = criterion(model(x.to("cuda")), y.to("cuda"))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    print("epoch #%d: loss: %f" % (i, loss.item()))

def test():
    x = torch.tensor([2, 4, 6, 8, 10, 12, 14, 16]).float()
    y_hat = model(x.to("cuda"))
    print("input: ", x, "\n", "predict: ", y_hat)

def get_model(path):
    global model, optimizer
    if os.path.exists(path):
        model = torch.load(path).cuda()
        optimizer = optim.Adam(model.parameters(), lr=0.01)
    else:
        model = nn.Sequential(
            nn.Linear(8, 16),
            nn.PReLU(),
            nn.Linear(16, 8)
        ).cuda()

        optimizer = optim.Adam(model.parameters(), lr=0.01)
        train()
        torch.save(model, path)

class ASP():
    ...

if __name__ == "__main__":
    
    # ---------------- train ----------------
    get_model("./model.pt")
    print("-------orig-------")
    test()
    
    # ---------------- prune ----------------
    ASP.prune_trained_model(model, optimizer)
    print("-------pruned-------")
    test()

    # ---------------- finetune ----------------
    train()
    print("-------retrain-------")
    test()
    torch.save(model, "./model_sparse.pt")

上述示例代码演示了2:4稀疏方案的大体框架,包括数据集准备、模型训练、模型剪枝、模型微调和模型保存等步骤。剪枝方案为ASP(Automatic SParsity),主要实现的是前面提到过的2:4稀疏剪枝,其具体实现细节在ASP类中。

4.2 ASP类的实现

ASP类的实现示例代码如下:

class ASP():
    
    @classmethod
    def init_model_for_pruning(model, mask_calculater, whitelist):
        pass

    @classmethod
    def init_optimizer_for_pruning(optimizer):
        pass
    
    @classmethod
    def compute_sparse_masks():
        pass

    @classmethod
    def prune_trained_model(cls, model, optimizer):
        cls.init_model_for_pruning(
            model,
            mask_calculater = "m4n2_1d",
            whitelist = [torch.nn.Linear, torch.nn.Conv2d]
        )
        cls.init_optimizer_for_pruning(optimizer)

        cls.compute_sparse_masks()  # 2:4

在上面的示例代码中,ASP的类方法prune_trained_model会对训练好的模型进行剪枝操作,首先它会去调用init_model_for_pruninginit_optimizer_for_pruning对模型和优化器进行初始化,然后调用compute_sparse_masks生成稀疏掩码(具体首先见4.3),最后使用掩码对模型进行剪枝。

4.3 mask的实现

我们来看下核心部分,mask的实现,2:4的方案就是在一张密集的weights中实现每4个weight取其中两个比较大的,其他两个置0,如下图所示:

在这里插入图片描述

最简单的实现方案就是遍历所有的weights,每4个进行比较,然后将较大的weight所对应的mask置1,其他mask置0,如下图所示:

在这里插入图片描述

而NVIDIA的方案是首先创建一个patterns,如下图所示,由于是2:4的方案,所有总共有6种不同的pattern;然后将weight matrix变换成nx4的格式方便与pattern进行矩阵运算,运算后的结果为nx6的矩阵,在n的维度上进行argmax取得最大的索引(索引对应pattern),然后将索引对应的pattern值填充到mask中即可。

在这里插入图片描述

示例代码如下:

import sys
import torch
import numpy as np
from itertools import permutations

def reshape_1d(matrix, m):
    # If not a nice multiple of m, fill with zeros
    if matrix.shape[1] % m > 0:
        mat = torch.cuda.FloatTensor(
            matrix.shape[0], matrix.shape[1] + (m - matrix.shape[1] % m)
        ).fill_(0)
        mat[:, : matrix.shape[1]] = matrix
        shape = mat.shape
        return mat.view(-1, m), shape
    else:
        return matrix.view(-1, m), matrix.shape

def compute_valid_1d_patterns(m,n):
    patterns = torch.zeros(m)
    patterns[:n] = 1
    valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))
    return valid_patterns

def mn_1d_best(matrix, m, n):
    # find all possible patterns
    patterns = compute_valid_1d_patterns(m,n).cuda()

    # find the best m:n pattern
    mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m)
    mat, shape = reshape_1d(matrix, m)
    pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1)
    mask[:] = patterns[pmax[:]]
    mask = mask.view(matrix.shape)
    return mask

def m4n2_1d(mat, density):
    return mn_1d_best(mat, 4, 2)

def m4n3_1d(mat, density):
    pass

def create_mask(weight, pattern, density=0.5):
    t = weight.float().contiguous()
    shape = weight.shape
    ttype = weight.type()

    func = getattr(sys.modules[__name__], pattern, None) # automatically find the function you want, and call it
    mask = func(t, density)

    return mask.view(shape).type(ttype)


if __name__ == "__main__":
    weight = torch.randn(8, 16).to("cuda")

    def create_mask_from_pattern(weight):
        return create_mask(weight, "m4n2_1d").bool() # 工厂模式 factory method 不同的情况创建不同的对象
    
    mask = create_mask_from_pattern(weight)
    mask = ~mask # for visualize
    # visualize the weight
    import matplotlib.pyplot as plt
    # Calculate the absolute values
    abs_weight = torch.abs(weight)
    
    # Convert to a numpy array for plotting
    abs_weight_np = abs_weight.cpu().numpy()

    # visualize the mask
    mask = mask.cpu().numpy().astype(float)
    # Plot the matrix
    def annotate_image(image_data, ax=None, text_color='red', fontsize=50):
        if ax is None:
            ax = plt.gca()
        for i in range(image_data.shape[0]):
            for j in range(image_data.shape[1]):
                ax.text(j, i, f"{image_data[i, j]:.2f}", ha="center", va="center", color=text_color, fontsize=fontsize)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(80, 20))

    ax1.imshow(abs_weight_np, cmap="gray", vmin=0, vmax=torch.max(abs_weight).item())
    ax1.axis("off")
    annotate_image(abs_weight_np, ax=ax1)
    ax1.set_title("weight", fontsize=100)

    ax2.imshow(mask, cmap="gray", vmin=0, vmax=np.max(mask).item())
    ax2.axis("off")
    ax2.set_title("mask", fontsize=100)

    plt.savefig("param_and_mask.jpg", bbox_inches='tight', dpi=100)

在上面的示例代码中,mn_1d_best函数实现了在指定大小的矩阵中寻找最佳的m:4的mask矩阵,具体实现可看上述的图示流程,m4n2_1d函数则是对mn_1d_best函数的进一步封装,指定了m=4,n=4,即寻找最佳的2:4的mask矩阵,create_mask函数则根据给定的权重矩阵、mask生成函数名和稀疏度生成相应的mask矩阵。

可视化结果如下,其中mask中白色区域填充的是0,黑色区域代表的是1:

在这里插入图片描述

4.4 模型初始化

示例代码如下:

class ASP():
    model = None
    optimizer = None
    sparse_parameters = []
    calculate_mask = None

    @classmethod
    def init_model_for_pruning(
        cls,
        model,
        mask_calculater,
        whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d],
        custom_layer_dict={}
    ):
        
        assert cls.model is None, "ASP has initialized already"

        cls.model = model
        if isinstance(mask_calculater, str):
            def create_mask_from_pattern(param):
                return create_mask(param, mask_calculater).bool()
        
        cls.calculate_mask = create_mask_from_pattern # dynamic function assignment

        sparse_parameter_list = {
            torch.nn.Linear: ["weight"],
            torch.nn.Conv1d: ["weight"],
            torch.nn.Conv2d: ["weight"]
        }
        if (custom_layer_dict):
            # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune
            sparse_parameter_list.update(custom_layer_dict)
            whitelist += list(custom_layer_dict.keys())

        for module_type in whitelist:
            assert module_type in sparse_parameter_list, (
                "Module %s : Don't know how to sparsify module." % module_type
            )

        # find all sparse modules, extract sparse parameters and decorate
        def add_sparse_attributes(module_name, module):
            ...
        
        def eligible_modules(model, whitelist_layer_types):
            eligible_modules_list = []
            for name, mod in model.named_modules():
                if(isinstance(mod, whitelist_layer_types)):
                    eligible_modules_list.append((name, mod))
            return eligible_modules_list

        for name, sparse_module in eligible_modules(model, tuple(whitelist)):
            add_sparse_attributes(name, sparse_module)

上面示例代码主要实现ASP中的类方法init_model_for_pruning,它的作用是初始化模型,该类方法主要有以下几点说明:

  • 该方法通过传入的参数mask_calculater调用函数create_mask_from_pattern,这个函数的作用是根据传入的参数生成一个稀疏矩阵掩码,也就是4.3小节的内容
  • 该方法会根据传入的参数whitelistcustom_layer_dict找到所有需要进行稀疏化的模块,这些模块的类型必须在whitelist中指定,并且每种模块包含一个或多个需要稀疏化的参数。这些信息都被保存在一个字典sparse_parameter_list
  • 如果这个模块的类型在whitelist中,那么就会调用add_sparse_attributes方法对这个模块进行稀疏化处理(该函数的具体实现可参考4.5小节)

4.5 Layer嵌入稀疏特性

示例代码如下:

# find all sparse modules, extract sparse parameters and decorate
def add_sparse_attributes(module_name, module):
    sparse_parameters = sparse_parameter_list[type(module)]
    for p_name, p in module.named_parameters():
        if p_name in sparse_parameters and p.requires_grad:
            # check for NVIDIA's TC compatibility: we check along the horizontal direction
            if p.dtype == torch.float32 and (
                (p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0
            ):  # User defines FP32 and APEX internally uses FP16 math
                print(
                    "[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"
                    % (module_name, p_name, str(p.size()), str(p.dtype))
                )
                continue
                if p.dtype == torch.float16 and (
                    (p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0
                ):  # For Conv2d dim= K x CRS; we prune along C
                    print(
                        "[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"
                        % (module_name, p_name, str(p.size()), str(p.dtype))
                    )
                    continue

                    mask = torch.ones_like(p).bool()
                    buffname = p_name.split(".")[-1] # buffer name cannot contain "."
                    module.register_buffer("__%s_mma_mask" % buffname, mask)
                    cls.sparse_parameters.append(
                        (module_name, module, p_name, p, mask)
                    )

函数add_sparse_attributes的作用是给模型中的每个可稀疏化的参数添加相应的稀疏度掩码。具体来说,函数首先检查模型中每个模块的参数是否在可稀疏化的参数列表中,并且梯度需要计算。然后,函数会检查参数的尺寸是否满足NVIDIA TC(Tensor Cores)的要求。如果满足,则添加一个与参数形状相同的稀疏度掩码。掩码是一个布尔张量,对应于参数中的每个元素。掩码初始化全1,表示所有参数都被保留。在后续的稀疏化操作中,将根据每个参数的稀疏度掩码来确定哪些参数需要被稀疏化。最后,函数将所有的稀疏化参数(包括稀疏度掩码)的元组添加到类变量sparse_parameters中。

4.6 优化器初始化

示例代码如下:

def init_optimizer_for_pruning(cls, optimizer):
    assert cls.optimizer is None, "ASP has initialized optimizer already."

    assert (
        cls.calculate_mask is not None
    ), "Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning."

    cls.optimizer = optimizer
    cls.optimizer.__step = optimizer.step

    def __step(opt_self, *args, **kwargs): # two pruning part: 1) grad 2) weight
        # p.grad p.data
        with torch.no_grad():
            for (module_name, module, p_name, p, mask) in cls.sparse_parameters:
                if p.grad is not None:
                    p.grad.mul_(mask) # inplace

                    # call original optimizer.step
                    rval = opt_self.__step(*args, **kwargs)

                    # prune parameter after step method
                    with torch.no_grad():
                        for (module_name, module, p_name, p, mask) in cls.sparse_parameters:
                            p.mul_(mask)

                            return rval

上面的示例代码为初始化优化器,主要是为优化器注册一个新的step方法,以便在每次更新权重之前进行剪枝。__step方法先对梯度进行剪枝操作,再调用原优化器对象的__step方法完成权重更新,然后对权重进行裁剪。先对梯度进行裁剪是因为最终的结果会影响权重的裁剪,如果不对梯度进行裁剪而只对权重进行裁剪可能导致权重大的元素被裁剪。

4.7 拓展-dynamic function assignment

动态函数赋值(Dynamic Function Assignment)是指在运行时动态地指定对象的某个方法实现。

在Python中,我们可以使用函数名作为变量名,将函数赋值给变量。这意味着我们可以根据不同的条件,将不同的函数赋值给同一个变量,以便在后续的代码中调用该变量的函数时,根据不同的条件执行不同的函数。这就是动态函数赋值。

下面是使用了DFA的示例代码:

class Pruner:
    def __init__(self, pruning_pattern):
        self.pruning_pattern = pruning_pattern

        if pruning_pattern == 'pattern_A':
            self.prune = self.prune_pattern_A
        elif pruning_pattern == 'pattern_B':
            self.prune = self.prune_pattern_B
        elif pruning_pattern == 'pattern_C':
            self.prune = self.prune_pattern_C

    def prune_pattern_A(self, network):
        # Perform pruning with pattern A logic
        pruned_network = ...
        return pruned_network

    def prune_pattern_B(self, network):
        # Perform pruning with pattern B logic
        pruned_network = ...
        return pruned_network
    
    def prune_pattern_C(self, network):
        # Perform pruning with pattern B logic
        pruned_network = ...
        return pruned_network


pruner_A = Pruner('pattern_A')
pruned_network_A = pruner_A.prune(network)

pruner_B = Pruner('pattern_B')
pruned_network_B = pruner_B.prune(network)

pruner_C = Pruner('pattern_C')
pruned_network_C = pruner_C.prune(network)

下面是没有使用DFA的示例代码:

class Pruner:
    def __init__(self, pruning_pattern):
        self.pruning_pattern = pruning_pattern

    def prune(self, network):
        if self.pruning_pattern == 'pattern_A':
            return self.prune_pattern_A(network)
        elif self.pruning_pattern == 'pattern_B':
            return self.prune_pattern_B(network)

    def prune_pattern_A(self, network):
        # Perform pruning with pattern A logic
        pruned_network = ...
        return pruned_network

    def prune_pattern_B(self, network):
        # Perform pruning with pattern B logic
        pruned_network = ...
        return pruned_network


pruner_A = Pruner('pattern_A')
pruned_network_A = pruner_A.prune(network)

pruner_B = Pruner('pattern_B')
pruned_network_B = pruner_B.prune(network)

从二者的对比可以看出动态函数赋值的优点在于它可以使代码更加灵活、可扩展和可维护。它使我们能够动态地改变函数的行为,从而根据不同的条件来处理数据或执行任务。这使得我们的代码更容易理解和维护,也更具可读性和可重用性。此外,动态函数赋值还可以提高代码的灵活性,使得我们可以更容易地在不同的上下文中使用相同的代码。

4.8 完整示例代码

完整的示例代码如下:

import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from itertools import permutations


model = None
optimizer = None

def reshape_1d(matrix, m):
    # If not a nice multiple of m, fill with zeros
    if matrix.shape[1] % m > 0:
        mat = torch.cuda.FloatTensor(
            matrix.shape[0], matrix.shape[1] + (m - matrix.shape[1] % m)
        ).fill_(0)
        mat[:, : matrix.shape[1]] = matrix
        shape = mat.shape
        return mat.view(-1, m), shape
    else:
        return matrix.view(-1, m), matrix.shape


def compute_valid_1d_patterns(m,n):
    patterns = torch.zeros(m)
    patterns[:n] = 1
    valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))
    return valid_patterns

def mn_1d_best(matrix, m, n):
    # find all possible patterns
    patterns = compute_valid_1d_patterns(m,n).cuda()

    # find the best m:n pattern
    mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m)
    mat, shape = reshape_1d(matrix, m)
    pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1)
    mask[:] = patterns[pmax[:]]
    mask = mask.view(matrix.shape)
    return mask

def m4n2_1d(mat, density):
    return mn_1d_best(mat, 4, 2)

def m4n3_1d(mat, density):
    pass

def create_mask(weight, pattern, density=0.5):
    t = weight.float().contiguous()
    shape = weight.shape
    ttype = weight.type()

    func = getattr(sys.modules[__name__], pattern, None) # automatically find the function you want, and call it
    mask = func(t, density)

    return mask.view(shape).type(ttype)

class ToyDataset(Dataset):
    def __init__(self):
        x = torch.round(torch.rand(1000) * 200) # (1000,)
        x = x.unsqueeze(1) # (1000,1)
        x = torch.cat((x, x * 2, x * 3, x * 4, x * 5, x * 6, x * 7, x * 8), 1) # (1000,8)
        self.X = x
        self.Y = self.X
    
    def __getitem__(self, index):
        return self.X[index], self.Y[index]
    
    def __len__(self):
        return len(self.X)

training_loader = DataLoader(ToyDataset(), batch_size=100, shuffle=True)

def train():
    criterion = nn.MSELoss()
    for i in range(500):
        for x, y in training_loader:
            loss = criterion(model(x.to("cuda")), y.to("cuda"))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    print("epoch #%d: loss: %f" % (i, loss.item()))

def test():
    x = torch.tensor([2, 4, 6, 8, 10, 12, 14, 16]).float()
    y_hat = model(x.to("cuda"))
    print("input: ", x, "\n", "predict: ", y_hat)

def get_model(path):
    global model, optimizer
    if os.path.exists(path):
        model = torch.load(path).cuda()
        optimizer = optim.Adam(model.parameters(), lr=0.01)
    else:
        model = nn.Sequential(
            nn.Linear(8, 16),
            nn.PReLU(),
            nn.Linear(16, 8)
        ).cuda()

        optimizer = optim.Adam(model.parameters(), lr=0.01)
        train()
        torch.save(model, path)

class ASP():
    model = None
    optimizer = None
    sparse_parameters = []
    calculate_mask = None

    @classmethod
    def init_model_for_pruning(
        cls,
        model,
        mask_calculater,
        whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d],
        custom_layer_dict={}
    ):
        
        assert cls.model is None, "ASP has initialized already"

        cls.model = model
        if isinstance(mask_calculater, str):
            def create_mask_from_pattern(param):
                return create_mask(param, mask_calculater).bool()
        
        cls.calculate_mask = create_mask_from_pattern # dynamic function assignment

        sparse_parameter_list = {
            torch.nn.Linear: ["weight"],
            torch.nn.Conv1d: ["weight"],
            torch.nn.Conv2d: ["weight"]
        }
        if (custom_layer_dict):
            # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune
            sparse_parameter_list.update(custom_layer_dict)
            whitelist += list(custom_layer_dict.keys())

        for module_type in whitelist:
            assert module_type in sparse_parameter_list, (
                "Module %s : Don't know how to sparsify module." % module_type
            )

        # find all sparse modules, extract sparse parameters and decorate
        def add_sparse_attributes(module_name, module):
            sparse_parameters = sparse_parameter_list[type(module)]
            for p_name, p in module.named_parameters():
                if p_name in sparse_parameters and p.requires_grad:
                    # check for NVIDIA's TC compatibility: we check along the horizontal direction
                    if p.dtype == torch.float32 and (
                        (p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0
                    ):  # User defines FP32 and APEX internally uses FP16 math
                        print(
                            "[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"
                            % (module_name, p_name, str(p.size()), str(p.dtype))
                        )
                        continue
                    if p.dtype == torch.float16 and (
                        (p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0
                    ):  # For Conv2d dim= K x CRS; we prune along C
                        print(
                            "[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"
                            % (module_name, p_name, str(p.size()), str(p.dtype))
                        )
                        continue
                    
                    mask = torch.ones_like(p).bool()
                    buffname = p_name.split(".")[-1] # buffer name cannot contain "."
                    module.register_buffer("__%s_mma_mask" % buffname, mask)
                    cls.sparse_parameters.append(
                        (module_name, module, p_name, p, mask)
                    )
            
        
        def eligible_modules(model, whitelist_layer_types):
            eligible_modules_list = []
            for name, mod in model.named_modules():
                if(isinstance(mod, whitelist_layer_types)):
                    eligible_modules_list.append((name, mod))
            return eligible_modules_list

        for name, sparse_module in eligible_modules(model, tuple(whitelist)):
            add_sparse_attributes(name, sparse_module)

    @classmethod
    def init_optimizer_for_pruning(cls, optimizer):
        assert cls.optimizer is None, "ASP has initialized optimizer already."
        
        assert (
            cls.calculate_mask is not None
        ), "Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning."

        cls.optimizer = optimizer
        cls.optimizer.__step = optimizer.step

        def __step(opt_self, *args, **kwargs): # two pruning part: 1) grad 2) weight
            # p.grad p.data
            with torch.no_grad():
                for (module_name, module, p_name, p, mask) in cls.sparse_parameters:
                    if p.grad is not None:
                        p.grad.mul_(mask) # inplace

            # call original optimizer.step
            rval = opt_self.__step(*args, **kwargs)

            # prune parameter after step method
            with torch.no_grad():
                for (module_name, module, p_name, p, mask) in cls.sparse_parameters:
                    p.mul_(mask)
            
            return rval
    
    @classmethod
    def compute_sparse_masks():
        pass

    @classmethod
    def prune_trained_model(cls, model, optimizer):
        cls.init_model_for_pruning(
            model,
            mask_calculater = "m4n2_1d",
            whitelist = [torch.nn.Linear, torch.nn.Conv2d]
        )
        cls.init_optimizer_for_pruning(optimizer)

        cls.compute_sparse_masks()  # 2:4

if __name__ == "__main__":
    
    # ---------------- train ----------------
    get_model("./model.pt")
    print("-------orig-------")
    test()
    
    # ---------------- prune ----------------
    ASP.prune_trained_model(model, optimizer)
    print("-------pruned-------")
    test()

    # ---------------- finetune ----------------
    train()
    print("-------retrain-------")
    test()
    torch.save(model, "./model_sparse.pt")

总结

本次课程主要学习了NVIDIA的2:4 pattern稀疏方案,并手写复现了一部分重要的功能。


http://www.kler.cn/a/7424.html

相关文章:

  • Python学习笔记(2)正则表达式
  • Word_小问题解决_1
  • 高级数据结构——hash表与布隆过滤器
  • MySQL:表设计
  • JavaSE常用API-日期(计算两个日期时间差-高考倒计时)
  • MuMu模拟器安卓12安装Xposed 框架
  • 做了个springboot接口参数解密的工具,我给它命名为万能钥匙(已上传maven中央仓库,附详细使用说明)
  • 4.5--计算机网络之基础篇--1.模型分层--(复习+深入)---好好沉淀,加油呀
  • Elasticsearch:索引状态是红色还是黄色?为什么?
  • C++ 数组与字符串详解
  • 51单片机-LED篇
  • erpnext--指令
  • 多个硬盘挂载到同一个目录
  • 重新理解一个类中的forward()和__init__()函数
  • MyBatisPlus-DML编程控制
  • Muduo库源码剖析(八)——TcpServer类
  • 腾讯云轻量应用服务器价格表(2023版)
  • 前端学习:HTML基本标签
  • cgroups是linux内核中限制、记录、隔离进程组(process groups)所使用的物理资源的机制
  • 【C++从0到1】22、C++中switch语句
  • 「SQL面试题库」 No_25 统计各专业学生人数
  • 【ChatGPT】ChatGPT 能否取代程序员?
  • 英语——不定词(二)
  • 对象的比较(数据结构系列12)
  • 2023中国程序员薪酬报告出炉,你拖后腿了吗?
  • ViewBinding用法