当前位置: 首页 > article >正文

pytorch--流水线并行

流水线并行(pipelining )部署实施起来非常困难,因为这需要根据模型的weights把模型分块(通常涉及到对源码的修改),此外,分布式的调度和数据流的依赖也是要考虑的点;
pipelining 库可以让部署变得更加简单;
这个库包含两个部分:
splitting frontend:此部分用于把模型分块,并且捕捉到数据流之间的关系;
distributed runtime:并行地执行pipeline stage在不同的设备上,同时处理好batch的划分、调度、通信和梯度回传;
所以这个库支持以下操作:
1.对于模型的简单划分;
2.丰富的流水线调度策略,包括GPipe, 1F1B, Interleaved 1F1B and Looped BFS;
3.支持跨主机的并行;
4.支持一些常规的并行操作,比如data parallel (DDP, FSDP) or tensor parallel;
关于模型的splitting:
为了构建PipelineStage,需要提供包含了nn.Parameters and nn.Buffers的nn.Module,同时定义了能够执行对应stage的forward函数

class Transformer(nn.Module):
    def __init__(self, model_args: ModelArgs):
        super().__init__()

        self.tok_embeddings = nn.Embedding(...)

        # Using a ModuleDict lets us delete layers witout affecting names,
        # ensuring checkpoints will correctly save and load.
        self.layers = torch.nn.ModuleDict()
        for layer_id in range(model_args.n_layers):
            self.layers[str(layer_id)] = TransformerBlock(...)

        self.output = nn.Linear(...)

    def forward(self, tokens: torch.Tensor):
        # Handling layers being 'None' at runtime enables easy pipeline splitting
        h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

        for layer in self.layers.values():
            h = layer(h, self.freqs_cis)

        h = self.norm(h) if self.norm else h
        output = self.output(h).float() if self.output else h
        return output

用这种方式定义的模型可以很容易配置stage和初始化,(为了防止OMM error使用meta device),删除对应stage不需要的层,然后构造PipelineStage 来wrap model;

with torch.device("meta"):
    assert num_stages == 2, "This is a simple 2-stage example"

    # we construct the entire model, then delete the parts we do not need for this stage
    # in practice, this can be done using a helper function that automatically divides up layers across stages.
    model = Transformer()

    if stage_index == 0:
        # prepare the first stage model
        del model.layers["1"]
        model.norm = None
        model.output = None

    elif stage_index == 1:
        # prepare the second stage model
        model.tok_embeddings = None
        del model.layers["0"]

    from torch.distributed.pipelining import PipelineStage
    stage = PipelineStage(
        model,
        stage_index,
        num_stages,
        device,
        input_args=example_input_microbatch,
    )

这里还提供自动切分模型的接口函数,这里不做细致赘述;
其中input_args 代表执行时候的input data samples,这个要拿去经过forward去确定输入输出的shapes;当同时使用其他并行trick的时候,output_args 也需要的,因为模型输出大小可能会受到影响;
第一步:构建一个执行的PipelineStage
PipelineStage用于分配通信内存,创造发送、接受操作去通信;它用来存储还未被consume的forward的缓存,同时为stage model执行backward;
PipelineStage需要知道输入输出的shape大小,方便创建通信缓存,shapes必须是固定大小的,也就是训练执行的时候它不能是变化的;
每一个stage model必须是nn.Module的格式;(所以第一步要做的事情就是手动分割模型);
当然也有其他替代方式,可以用图分割去把你的模型自动分割为一系列的nn.Module,这个要求模型必须是torch.Export traceable ;所以能手动更改模型代码是最方便的;
第二步:用PipelineSchedule 去执行
以下是执行的示例代码:

from torch.distributed.pipelining import ScheduleGPipe

# Create a schedule
schedule = ScheduleGPipe(stage, n_microbatches)

# Input data (whole batch)
x = torch.randn(batch_size, in_dim, device=device)

# Run the pipeline with input `x`
# `x` will be divided into microbatches automatically
if rank == 0:
    schedule.step(x)
else:
    output = schedule.step()

以上代码的rank应该指的是进程号,也就是在0进程中,执行stage1,在1进程中执行stage 2;

以下是官方给的关于llama流水线并行的示例代码,会更加清晰明了;

# $ torchrun --nproc-per-node 4 pippy_llama.py
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed.pipelining import SplitPoint, pipeline, ScheduleGPipe

# Grab the model
llama = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True
)
print(llama)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token = tokenizer.eos_token
mb_prompts = (
    "How do you", "I like to",
)  # microbatch size = 2

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.distributed.init_process_group(rank=rank, world_size=world_size)

llama.to(device).eval()

# Cut model by equal number of layers per rank
layers_per_rank = llama.config.num_hidden_layers // world_size
print(f"layers_per_rank = {layers_per_rank}")
split_spec = {
    f"model.layers.{i * layers_per_rank}": SplitPoint.BEGINNING
    for i in range(1, world_size)
}

# Create a pipeline representation from the model
mb_inputs = tokenizer(mb_prompts, return_tensors="pt", padding=True).to(device)
pipe = pipeline(llama, mb_args=(mb_inputs["input_ids"],))

# Create pipeline stage for each rank
stage = pipe.build_stage(rank, device=device)

# Run time inputs
full_batch_prompts = (
    "How do you", "I like to", "Can I help", "You need to",
    "The weather is", "I found a", "What is your", "You are so",
)  # full batch size = 8
inputs = tokenizer(full_batch_prompts, return_tensors="pt", padding=True).to(device)

# Attach to a schedule
# number of microbatches = 8 // 2 = 4
num_mbs = 4
schedule = ScheduleGPipe(stage, num_mbs)

# Run
if rank == 0:
    args = inputs["input_ids"]
else:
    args = None

output = schedule.step(args)

# Decode
if output is not None:
    next_token_logits = output[0][:, -1, :]
    next_token = torch.argmax(next_token_logits, dim=-1)
    print(tokenizer.batch_decode(next_token))

torchrun --nproc-per-node 8 pippy_llama.py


http://www.kler.cn/a/317730.html

相关文章:

  • 代码随想录算法训练营第六十天|KM94.城市间货物运输Ⅰ|KM95.城市间货物运输Ⅱ|KM96.城市间货物运输Ⅲ
  • kubeneters-循序渐进Cilium网络(二)
  • 前端 图片上鼠标画矩形框,标注文字,任意删除
  • 在iStoreOS上安装Tailscale
  • 【Ubuntu】 Ubuntu22.04搭建NFS服务
  • maven的简单介绍
  • pandas外文文档快速入门
  • UNet 眼底血管分割实战教程
  • Python Flask网页开发基本框架
  • 大数据新视界 --大数据大厂之 Vue.js 与大数据可视化:打造惊艳的数据界面
  • 【Java面向对象高级06】static的应用知识:代码块
  • java开发jmeter采样器
  • 【AI写作】介绍 Docker 的基本概念和优势,以及在应用程序开发中的实际应用
  • 目标检测任务中xml标签文件修改
  • 【学习笔记】Transformer架构探讨
  • [ IDE ] SEGGER Embedded Studio for RISC-V
  • C++初阶学习——探索STL奥秘——反向迭代器
  • [Leetcode 543][Easy]-二叉树的直径-递归
  • ubuntu安装StarQuant
  • 【Verilog学习日常】—牛客网刷题—Verilog快速入门—VL22
  • 【Linux】生产者消费者模型:基于阻塞队列,使用互斥锁和条件变量维护互斥与同步关系
  • 高级java每日一道面试题-2024年9月20日-分布式篇-什么是CAP理论?
  • 【Java】Java开发全攻略:从环境搭建到高效编程
  • vulnhub-prime1
  • Android 检测图片抓拍, 聚焦图片后自动完成拍照,未对准图片的提示请将摄像头对准要拍照的图片
  • 红书 API 接口:笔记详情数据接口的接入与使用