【代码解读】阿里最新开源视频生成模型 Wan 2.1 实现解析
昨晚阿里巴巴开源了最新视频生成模型的代码和权重,从给出的 demo 效果来看还是非常惊艳的。这个模型目前也是在 VBench 榜单上排到了第一名,超越了 Sora 以及 HunyuanVideo 等一系列知名方法。
从官方给出的方法架构图来说,Wan 2.1 并没有使用 MMDiT 的架构,而是基于普通的 DiT 架构,而文本条件则是通过 Cross Attention 实现注入。在文本编码器方面,Wan 2.1 采用了支持多语言的 UMT5 作为编码器,因此 prompt 部分或许能够原生支持中文输入。图中的 Wan-Encoder 和 Wan-Decoder 实际上就是视频生成模型常用的 3D Causal VAE,根据官方的说法,其支持无损时序信息编解码任意时长1080P视频。在时间编码方面,模型的所有 block 采用了统一的时间步编码器,并采取了类似 AdaLN 的方式将时间步编码进行注入。
Wan 2.1 公布了不同尺寸的多个变体,小型的为 1.3B,想必是为了支持消费级显卡推理推出的一款模型;大型的为 14B,超过了 HunyuanVideo 的尺度,并且支持 720P 分辨率视频的生成。从表中的信息来看,不仅支持文生视频,同时也能够支持图生视频。
模型 | 支持 480P 分辨率 | 支持 720P 分辨率 |
---|---|---|
T2V-14B | 支持 | 支持 |
I2V-14B-720P | 不支持 | 支持 |
I2V-14B-480P | 支持 | 不支持 |
T2V-1.3B | 支持 | 不支持 |
同时官方也已经给出了一些定量指标,目前看到的生成质量指标是由人工评测得到的,所以暂时先不分析。个人感觉比较重要的信息是这张图里的推理成本。从表中可以看出,1.3B 模型的峰值显存占用仅为 8 GB,且在单张消费级显卡上推理约 4 分钟即可的得到一段视频(而且这个结果是将 T5 模型卸载到 CPU 上得到的,所以如果把文本提前做离线 embedding,这个性能应当还有进一步提升),还是很可观的。不过 14B 模型的推理成本就比较高了,在单卡上的显存占用已经接近 80 GB,推理时间也来到了几千秒的数量级。
代码实现分析
首先可以看到的是,和其他的方法一样,Wan 2.1 也使用了 Classifier-Free Guidance(代码链接):
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
对于图生视频任务,模型会使用 CLIP Vision Encoder 将图像进行编码作为 latents 中的第一帧,其余部分填充零,且加入一个 mask 通道(类似 inpainting 的做法,代码链接):
self.clip.model.to(self.device)
clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model:
self.clip.model.cpu()
y = self.vae.encode([
torch.concat([
torch.nn.functional.interpolate(
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
0, 1),
torch.zeros(3, 80, h, w)
],
dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])
除了在 latents 上的调整,图生视频还会将图像的 CLIP 特征再次进行 embedding,并在 Cross Attention 时与文本图像拼接后共同作为条件进行生成(代码链接):
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
进入模型内部,可以看到一个现象是模型的输入并不是 batched tensor,而是一个 tensor 的列表,相当于把同一个批次拆分成了多个单个视频。在推理时也是遍历整个列表,可能因为模型的推理显存比较高,通过把批次拆开来节省显存。以 Patch Embedding 为例(代码链接):
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
对于模型的每个 block,其内部由一组 self attention 与一组 cross attention 组成,并且都按照 DiT 的方式进行了 modulation 操作(代码链接):
# self-attention
y = self.self_attn(
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
freqs)
with amp.autocast(dtype=torch.float32):
x = x + y * e[2]
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
with amp.autocast(dtype=torch.float32):
x = x + y * e[5]
return x
x = cross_attn_ffn(x, context, context_lens, e)
在图生视频的 attention 中,将文本 token 与图像 token 进行了拆分,分别与 latent 计算 cross attention,然后再将两组结果相加得到最后的交叉注意力结果(代码链接):
def forward(self, x, context, context_lens):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
context_img = context[:, :257]
context = context[:, 257:]
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
v_img = self.v_img(context_img).view(b, -1, n, d)
img_x = flash_attention(q, k_img, v_img, k_lens=None)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
# output
x = x.flatten(2)
img_x = img_x.flatten(2)
x = x + img_x
x = self.o(x)
return x
在计算 self-attention 时,也使用了 RoPE(代码链接):
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
目前来说从代码里能够看到的比较有用的信息就是这些,由于具体的 report 还没有放出来,所以关于数据的细节目前不太清楚(据说用了 1.5B 视频数据和 10B 图像数据),也期待一下技术报告早日公布。