当前位置: 首页 > article >正文

【代码解读】阿里最新开源视频生成模型 Wan 2.1 实现解析

昨晚阿里巴巴开源了最新视频生成模型的代码和权重,从给出的 demo 效果来看还是非常惊艳的。这个模型目前也是在 VBench 榜单上排到了第一名,超越了 Sora 以及 HunyuanVideo 等一系列知名方法。
截至写文章时的 VBench 榜单
从官方给出的方法架构图来说,Wan 2.1 并没有使用 MMDiT 的架构,而是基于普通的 DiT 架构,而文本条件则是通过 Cross Attention 实现注入。在文本编码器方面,Wan 2.1 采用了支持多语言的 UMT5 作为编码器,因此 prompt 部分或许能够原生支持中文输入。图中的 Wan-Encoder 和 Wan-Decoder 实际上就是视频生成模型常用的 3D Causal VAE,根据官方的说法,其支持无损时序信息编解码任意时长1080P视频。在时间编码方面,模型的所有 block 采用了统一的时间步编码器,并采取了类似 AdaLN 的方式将时间步编码进行注入。
Wan 2.1 模型架构
Wan 2.1 公布了不同尺寸的多个变体,小型的为 1.3B,想必是为了支持消费级显卡推理推出的一款模型;大型的为 14B,超过了 HunyuanVideo 的尺度,并且支持 720P 分辨率视频的生成。从表中的信息来看,不仅支持文生视频,同时也能够支持图生视频。

模型支持 480P 分辨率支持 720P 分辨率
T2V-14B支持支持
I2V-14B-720P不支持支持
I2V-14B-480P支持不支持
T2V-1.3B支持不支持

同时官方也已经给出了一些定量指标,目前看到的生成质量指标是由人工评测得到的,所以暂时先不分析。个人感觉比较重要的信息是这张图里的推理成本。从表中可以看出,1.3B 模型的峰值显存占用仅为 8 GB,且在单张消费级显卡上推理约 4 分钟即可的得到一段视频(而且这个结果是将 T5 模型卸载到 CPU 上得到的,所以如果把文本提前做离线 embedding,这个性能应当还有进一步提升),还是很可观的。不过 14B 模型的推理成本就比较高了,在单卡上的显存占用已经接近 80 GB,推理时间也来到了几千秒的数量级。
Wan 2.1 的推理成本

代码实现分析

首先可以看到的是,和其他的方法一样,Wan 2.1 也使用了 Classifier-Free Guidance(代码链接):

noise_pred_cond = self.model(
    latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(
    latent_model_input, t=timestep, **arg_null)[0]

noise_pred = noise_pred_uncond + guide_scale * (
    noise_pred_cond - noise_pred_uncond)

对于图生视频任务,模型会使用 CLIP Vision Encoder 将图像进行编码作为 latents 中的第一帧,其余部分填充零,且加入一个 mask 通道(类似 inpainting 的做法,代码链接):

self.clip.model.to(self.device)
clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model:
    self.clip.model.cpu()

y = self.vae.encode([
    torch.concat([
        torch.nn.functional.interpolate(
            img[None].cpu(), size=(h, w), mode='bicubic').transpose(
                0, 1),
        torch.zeros(3, 80, h, w)
    ],
                    dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])

除了在 latents 上的调整,图生视频还会将图像的 CLIP 特征再次进行 embedding,并在 Cross Attention 时与文本图像拼接后共同作为条件进行生成(代码链接):

if clip_fea is not None:
    context_clip = self.img_emb(clip_fea)  # bs x 257 x dim
    context = torch.concat([context_clip, context], dim=1)

进入模型内部,可以看到一个现象是模型的输入并不是 batched tensor,而是一个 tensor 的列表,相当于把同一个批次拆分成了多个单个视频。在推理时也是遍历整个列表,可能因为模型的推理显存比较高,通过把批次拆开来节省显存。以 Patch Embedding 为例(代码链接):

x = [self.patch_embedding(u.unsqueeze(0)) for u in x]

对于模型的每个 block,其内部由一组 self attention 与一组 cross attention 组成,并且都按照 DiT 的方式进行了 modulation 操作(代码链接):

# self-attention
y = self.self_attn(
    self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
    freqs)
with amp.autocast(dtype=torch.float32):
    x = x + y * e[2]

# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
    x = x + self.cross_attn(self.norm3(x), context, context_lens)
    y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
    with amp.autocast(dtype=torch.float32):
        x = x + y * e[5]
    return x

x = cross_attn_ffn(x, context, context_lens, e)

在图生视频的 attention 中,将文本 token 与图像 token 进行了拆分,分别与 latent 计算 cross attention,然后再将两组结果相加得到最后的交叉注意力结果(代码链接):

def forward(self, x, context, context_lens):
    r"""
    Args:
        x(Tensor): Shape [B, L1, C]
        context(Tensor): Shape [B, L2, C]
        context_lens(Tensor): Shape [B]
    """
    context_img = context[:, :257]
    context = context[:, 257:]
    b, n, d = x.size(0), self.num_heads, self.head_dim

    # compute query, key, value
    q = self.norm_q(self.q(x)).view(b, -1, n, d)
    k = self.norm_k(self.k(context)).view(b, -1, n, d)
    v = self.v(context).view(b, -1, n, d)
    k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
    v_img = self.v_img(context_img).view(b, -1, n, d)
    img_x = flash_attention(q, k_img, v_img, k_lens=None)
    # compute attention
    x = flash_attention(q, k, v, k_lens=context_lens)

    # output
    x = x.flatten(2)
    img_x = img_x.flatten(2)
    x = x + img_x
    x = self.o(x)
    return x

在计算 self-attention 时,也使用了 RoPE(代码链接):

x = flash_attention(
    q=rope_apply(q, grid_sizes, freqs),
    k=rope_apply(k, grid_sizes, freqs),
    v=v,
    k_lens=seq_lens,
    window_size=self.window_size)

目前来说从代码里能够看到的比较有用的信息就是这些,由于具体的 report 还没有放出来,所以关于数据的细节目前不太清楚(据说用了 1.5B 视频数据和 10B 图像数据),也期待一下技术报告早日公布。


http://www.kler.cn/a/563282.html

相关文章:

  • 锂电池保护板测试仪:电池安全的守护者与创新驱动力
  • JUC并发—14.Future模式和异步编程分析二
  • go-zero中定时任务的用法
  • 神经网络参数量计算
  • 云图库平台(五)——后端图片模块开发
  • 2025/2/25,字节跳动后端开发一面面经
  • 3D格式转换工具HOOPS Exchange在PMI处理中的关键作用与优势解析
  • 互联网核心技术概念笔记
  • NLP学习记录十:多头注意力
  • 【react】react Native
  • [免单统计]
  • 使用前端 html css 和js 开发一个AI智能平台官网模板-前端静态页面项目
  • 机器学习数学基础:34.克隆巴赫α系数
  • 热更新-arthas + jenkins/Bshell实现
  • 免费PDF工具
  • 【iOS】小蓝书学习(四)
  • Qwen2.5-VL技术报告:多模态大模型的新SOTA!视觉理解能力全面超越GPT-4o
  • 公共数据授权运营模式研究(总体框架、主要模式及发展趋势)
  • 解锁C# XML编程:从新手到实战高手的蜕变之路
  • 【学习方法】学习软件专业课程的思考方式