FQ-GAN代码解析
主要看 model 、loss 和 data 部分如何实现和处理的。
- model—VQ_models
- VQModel
- Encoder
- VectorQuantizer
- Decoder
- loss—VQLoss_triple_codebook
model—VQ_models
创建vq_model直接根据传入的模型压缩倍率8/16初始化对应的VQ_8/VQ_16,两者都是初始化一个VQModel的类,只是压缩的倍率ch_mult
不同(这个和UNet里的ch_mult
是一致的,表示每个Block上/下采样的倍数,所有倍率之积就是压缩倍率)
# create and load model
vq_model = VQ_models[args.vq_model](
codebook_size=args.codebook_size,
codebook_embed_dim=args.codebook_embed_dim,
commit_loss_beta=args.commit_loss_beta,
entropy_loss_ratio=args.entropy_loss_ratio,
dropout_p=args.dropout_p,
with_clip_supervision=args.with_clip_supervision,
with_disentanglement=args.with_disentanglement,
disentanglement_ratio=args.disentanglement_ratio,
)
def VQ_8(**kwargs):
return VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs))
def VQ_16(**kwargs):
return VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))
VQ_models = {'VQ-16': VQ_16, 'VQ-8': VQ_8}
VQModel
包含3个 codebook 的 VQModel 的结构如下:
- Encoder:
Encoder
(逐步压缩spatial维度到embed_dim维度) - VectorQuantizer:3个
VectorQuantizer
(分别是pixel level的无teacher,mid semantic level的DINO teacher,high semantic level的CLIP teacher),配合3个quant_conv
(将z_channels
变成codebook_embed_dim
) - Decoder:1个
post_quant_conv
(将emebdding_dim从3*codebook_embed_dim
变成z_channels
),一个Decoder
(逐步将embed_dim维度还原到spatial维度) - FeatPredHead:2个
FeatPredHead
(分别是将vq feature对齐到CLIP和DINO feature的MLP Head,用于蒸馏监督)
class VQModel(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
# Two head encoder
self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
# Quantizer for visual detail head
self.quantize_vis = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
config.commit_loss_beta, config.entropy_loss_ratio,
config.codebook_l2_norm, config.codebook_show_usage)
self.quant_conv_vis = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
# Quantizer for mid-level semantic head
self.quantize_sem_mid = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
config.commit_loss_beta, config.entropy_loss_ratio,
config.codebook_l2_norm, config.codebook_show_usage)
self.quant_conv_sem_mid = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
# Quantizer for high-level semantic head
self.quantize_sem_high = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
config.commit_loss_beta, config.entropy_loss_ratio,
config.codebook_l2_norm, config.codebook_show_usage)
self.quant_conv_sem_high = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
print("Visual codebook: [{} x {}]".format(config.codebook_size, config.codebook_embed_dim))
print("Mid Semantic codebook: [{} x {}]".format(config.codebook_size, config.codebook_embed_dim))
print("High Semantic codebook: [{} x {}]".format(config.codebook_size, config.codebook_embed_dim))
# Pixel decoder
input_dim = config.codebook_embed_dim * 3
self.post_quant_conv = nn.Conv2d(input_dim, config.z_channels, 1)
self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels,
dropout=config.dropout_p)
# Down-sample factor in encoder channel multiplier
self.num_resolutions = len(config.encoder_ch_mult)
if self.num_resolutions == 5: # encoder_ch_mult=[1, 1, 2, 2, 4]
down_factor = 16
elif self.num_resolutions == 4: # encoder_ch_mult=[1, 2, 2, 4]
down_factor = 8
else:
raise NotImplementedError
# Semantic feature prediction
if self.config.with_clip_supervision:
print("Include feature prediction head for representation supervision")
self.mid_sem_feat_pred = FeatPredHead(input_dim=config.codebook_embed_dim, out_dim=384, down_factor=down_factor)
self.high_sem_feat_pred = FeatPredHead(input_dim=config.codebook_embed_dim, out_dim=768, down_factor=down_factor)
else:
print("NO representation supervision")
if self.config.with_disentanglement:
print("Disentangle Ratio: ", self.config.disentanglement_ratio)
else:
print("No Disentangle Regularization")
前向forward 包含encode
、vq
、decode
三个主要过程(因为需要KD,额外要一步feature对齐
操作):
- ① 输入经过encoder得到3个不同的feature(
h_vis
,h_sem_mid
,h_sem_high
),再经过3个quant_conv将embed_dim对齐到codebook_embed_dim。 - ②将不同level的image feature送入不同的VectorQuantizer,得到三个不同的
quant_feature
和emb_loss
(emb_loss包含vq_loss
、commit_loss
、entropy_loss
三部分) - 因为需要知识蒸馏,因此需要额外使用FeatPredHead将quant_feature对齐到CLIP和DINO特征的维度(
mid_sem_feat_pred
和high_sem_feat_pred
) - 因为希望3个codebook相互正交(解耦程度大),因此需要构造1个解耦loss,使3个level的vq feature相互不同(
embedding点积之和的L2 loss
,即disentangle_loss
)。 - ③ 将quant_feature经过post_quant_conv和decoder,解码为原始image的pixel_values(
dec
)。
def forward(self, input):
# 1. encode
h_vis, h_sem_mid, h_sem_high = self.encoder(input)
h_vis = self.quant_conv_vis(h_vis)
h_sem_mid = self.quant_conv_sem_mid(h_sem_mid)
h_sem_high = self.quant_conv_sem_high(h_sem_high)
# 2. vq
quant_vis, emb_loss_vis, _ = self.quantize_vis(h_vis)
quant_sem_mid, emb_loss_sem_mid, _ = self.quantize_sem_mid(h_sem_mid)
quant_sem_high, emb_loss_sem_high, _ = self.quantize_sem_high(h_sem_high)
# for konwledge distillation
if self.config.with_clip_supervision:
mid_lvl_sem_feat = self.mid_sem_feat_pred(quant_sem_mid)
high_lvl_sem_feat = self.high_sem_feat_pred(quant_sem_high)
else:
mid_lvl_sem_feat = None
high_lvl_sem_feat = None
# for disentangle vq feature of 3 codebook
if self.config.with_disentanglement:
disentangle_loss = (self.compute_disentangle_loss(quant_vis, quant_sem_mid) +
self.compute_disentangle_loss(quant_vis, quant_sem_high) +
self.compute_disentangle_loss(quant_sem_mid, quant_sem_high)) / 3.0
else:
disentangle_loss = 0
# 3. decode
quant = torch.cat([quant_vis, quant_sem_mid, quant_sem_high], dim=1)
dec = self.decode(quant)
return dec, \
emb_loss_vis, emb_loss_sem_mid, emb_loss_sem_high, \
disentangle_loss, \
mid_lvl_sem_feat, high_lvl_sem_feat
本文叫FQ的创新点就是在于设计了这个disentangle_loss
使得3个codebook相互正交解耦:这个损失函数的设计思想是,如果2个特征是解耦的,那么它们的点积应该接近于零,因为它们应该是正交的。通过最小化这个损失,模型被鼓励学习到解耦的不同level的特征。
def compute_disentangle_loss(self, quant_vis, quant_sem):
quant_vis = rearrange(quant_vis, 'b c h w -> (b h w) c')
quant_sem = rearrange(quant_sem, 'b c h w -> (b h w) c')
quant_vis = F.normalize(quant_vis, p=2, dim=-1)
quant_sem = F.normalize(quant_sem, p=2, dim=-1)
dot_product = torch.sum(quant_vis * quant_sem, dim=1)
loss = torch.mean(dot_product ** 2) * self.config.disentanglement_ratio
return loss
Encoder
Encoder是输入image feature,经过统一的downsampling conv_blocks
和mid blocks
,再分别送入3个不同的adapter
,输出3个不同的feature
。
- conv_in:输入的image feature首先由
conv_in
将channel
维度转化为128
。 - downsampling:conv_blocks根据
ch_mult=(1,1,2,2,4)
或ch_mult=(1, 2, 2, 4)
构建ResnetBlock
和AttnBlock
以及Downsample
组成,其中ch_mult
用于控制每个conv_block的channel增加倍数。(channel增加,h和w减小)。每个block的下采样后的channel是128*ch_mult[i]
(例如ch_mult=(1, 2, 2, 4)时,共有4
个block,channel的变化是128->128->256->512->2048
)。 - mid:由
ResnetBlock+AttnBlock+ResnetBlock
组成其中卷积不改变channel,等效于MLP。 - adapter:由3个不同的
FactorizedAdapter
组成,用于将统一的encoder feature转化为3个不同的feature,用于后面3个codebook的VQ操作。 - conv_out:因为前一步将feature转化了3份(
h_vis
,h_sem_mid
,h_sem_high
),因此此处从conv_out分别用3个不同的conv2d
用于对齐feature的channel维度(转换为z_channels
维度)。
class Encoder(nn.Module):
def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2,
norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256):
super().__init__()
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
# downsampling
in_ch_mult = (1,) + tuple(ch_mult)
self.conv_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
conv_block = nn.Module()
# res & attn
res_block = nn.ModuleList()
attn_block = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for _ in range(self.num_res_blocks):
res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
block_in = block_out
if i_level == self.num_resolutions - 1:
attn_block.append(AttnBlock(block_in, norm_type))
conv_block.res = res_block
conv_block.attn = attn_block
# downsample
if i_level != self.num_resolutions-1:
conv_block.downsample = Downsample(block_in, resamp_with_conv)
self.conv_blocks.append(conv_block)
# middle
self.mid = nn.ModuleList()
self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
self.mid.append(AttnBlock(block_in, norm_type=norm_type))
self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
if self.num_resolutions == 5:
down_factor = 16
elif self.num_resolutions == 4:
down_factor = 8
else:
raise NotImplementedError
# semantic head mid-level
self.semantic_head_mid = nn.ModuleList()
self.semantic_head_mid.append(FactorizedAdapter(down_factor))
# semantic head high-level
self.semantic_head_high = nn.ModuleList()
self.semantic_head_high.append(FactorizedAdapter(down_factor))
# visual details head
self.visual_head = nn.ModuleList()
self.visual_head.append(FactorizedAdapter(down_factor))
# end
self.norm_out_sem_mid = Normalize(block_in, norm_type)
self.conv_out_sem_mid = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
self.norm_out_sem_high = Normalize(block_in, norm_type)
self.conv_out_sem_high = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
self.norm_out_vis = Normalize(block_in, norm_type)
self.conv_out_vis = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
h = self.conv_in(x)
# downsampling
for i_level, block in enumerate(self.conv_blocks):
for i_block in range(self.num_res_blocks):
h = block.res[i_block](h)
if len(block.attn) > 0:
h = block.attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = block.downsample(h)
# middle
for mid_block in self.mid:
h = mid_block(h)
h_vis = h
h_sem_mid = h
h_sem_high = h
# semantic head mid-level
for blk in self.semantic_head_mid:
h_sem_mid = blk(h_sem_mid)
h_sem_mid = self.norm_out_sem_mid(h_sem_mid)
h_sem_mid = nonlinearity(h_sem_mid)
h_sem_mid = self.conv_out_sem_mid(h_sem_mid)
# semantic head high-level
for blk in self.semantic_head_high:
h_sem_high = blk(h_sem_high)
h_sem_high = self.norm_out_sem_high(h_sem_high)
h_sem_high = nonlinearity(h_sem_high)
h_sem_high = self.conv_out_sem_high(h_sem_high)
# visual head
for blk in self.visual_head:
h_vis = blk(h_vis)
h_vis = self.norm_out_vis(h_vis)
h_vis = nonlinearity(h_vis)
h_vis = self.conv_out_vis(h_vis)
return h_vis, h_sem_mid, h_sem_high
VectorQuantizer
VectorQuantizer的初始化操作主要是创建一个大小[codebook_size, codebook_embed_dim]
为codebook embedding
(embedding
)。
class VectorQuantizer(nn.Module):
def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
super().__init__()
self.n_e = n_e # codebook_size
self.e_dim = e_dim # codebook_embed_dim
self.beta = beta # commitment_loss scale
self.entropy_loss_ratio = entropy_loss_ratio # entropy_loss scale
self.l2_norm = l2_norm # l2_norm for codebook embeddings
self.show_usage = show_usage # show codebook usage
# create codebook embedding and initialize
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
if self.l2_norm: # normalize embeddings
self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
if self.show_usage: # initialize codebook usage
self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536))) # 1048576
forward的操作和VQGAN一样,就是把image feature z
的所有token embeddings查表量化为codebook中argmin(distances)
的emebddings得到 quant image feature zq
,同时计算3个loss(用于优化codebook embedding)。
- l2 norm:同时对
z
和codebook embeddings
进行L2归一化,可以将向量的模长缩放到相同的大小,即转换为在单位球面上的向量,这样每个向量在距离度量中的作用是相等的,使得不同向量更具有可比性!向量之间更容易比较和匹配,提高了训练稳定性和重建质量。 - argmin(distances):经典的VQ计算distances的操作,展开为两个平方和一个乘积,
(
z
−
e
)
2
=
z
2
+
e
2
−
2
e
∗
z
(z - e)^2 = z^2 + e^2 - 2 e * z
(z−e)2=z2+e2−2e∗z。然后
argmin(distances)
得到z
中每个embedding在codebook中最近的embedding的index,再从codebook的embeddings中取出组成zq
。 - codebook usage:是计算codebook中的embeddings 的利用率。
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = torch.einsum('b c h w -> b h w c', z).contiguous()
z_flattened = z.view(-1, self.e_dim)
if self.l2_norm: # normalize z and codebook_embedding for mapping vector to euclidean space(单位球上)
z = F.normalize(z, p=2, dim=-1)
z_flattened = F.normalize(z_flattened, p=2, dim=-1)
embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
else:
embedding = self.embedding.weight
# distances from z to embeddings e_j: (z - e)^2 = z^2 + e^2 - 2 e * z
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(embedding**2, dim=1) - 2 * \
torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))
# argmin(distances)
min_encoding_indices = torch.argmin(d, dim=1)
# replace each z_i with its closest embedding e_j
z_q = embedding[min_encoding_indices].view(z.shape)
perplexity = None
min_encodings = None
vq_loss = None
commit_loss = None
entropy_loss = None
codebook_usage = 0
# compute codebook usage
if self.show_usage and self.training:
cur_len = min_encoding_indices.shape[0]
self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone() # copy last cur_len elements to front
self.codebook_used[-cur_len:] = min_encoding_indices # set last cur_len elements as min_encoding_indices
codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e
- embedding loss:
vq_loss
是计算量化后的向量 z_q 和原始输入向量 z 之间的均方误差(Mean Squared Error, MSE)。z.detach() 表示 z 是从计算图中分离出来的,这意味着在计算 vq_loss 时,z 不会对其梯度产生影响。这个损失鼓励模型将输入向量 z 量化为与其尽可能接近的嵌入向量 z_q。commit_loss
也是均方误差,但是这里 z_q 是从计算图中分离出来的。这意味着在计算 commit_loss 时,z_q 不会对其梯度产生影响。这个损失的作用是鼓励模型在量化过程中保持对原始输入向量 z 的承诺,即量化后的向量 z_q 应该尽可能地反映输入向量 z 的信息。参数 self.beta 是一个超参数,用于调节这个损失在总损失中的重要性。entropy_loss
用于鼓励码本的均匀使用,从而提高模型的泛化能力。compute_entropy_loss(-d)
计算的是基于码本距离的负值的熵损失,-d 表示我们对每个输入向量 z 计算到所有嵌入的平方距离,然后取负值。熵损失的计算通常涉及到对这些距离的softmax操作,然后计算交叉熵。self.entropy_loss_ratio 是一个超参数,用于调节熵损失在总损失中的重要性。
# compute 3 loss for embedding
if self.training:
vq_loss = torch.mean((z_q - z.detach()) ** 2)
commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = torch.einsum('b h w c -> b c h w', z_q)
return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices)
def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
flat_affinity /= temperature
probs = F.softmax(flat_affinity, dim=-1)
log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
if loss_type == "softmax":
target_probs = probs
else:
raise ValueError("Entropy loss {} not supported".format(loss_type))
avg_probs = torch.mean(target_probs, dim=0)
avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
loss = sample_entropy - avg_entropy
return loss
get_codebook_entry
用于Transformer自回归的预测一个index序列后,用于在codebook查表转化为对应embeddings。
def get_codebook_entry(self, indices, shape=None, channel_first=True):
# shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
if self.l2_norm:
embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
else:
embedding = self.embedding.weight
z_q = embedding[indices] # (b*h*w, c)
if shape is not None:
if channel_first:
z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
else:
z_q = z_q.view(shape)
return z_q
Decoder
整个VQ操作从z
到zq
不改变image feature的shape,channel维度还是z_channels=256
。因此Decoder将zq解码为image的pixel values的过程如下:
- conv_in:使用Conv2d将zq的channel维度从
z_channels
变换到block_in(由ch=128和ch_mult决定的)
。 - middle block:和Encoder一样由
ResnetBlock
,AttnBlock
,ResnetBlock
组成,不改变channel维度。 - upsampling conv_blocks:和Encoder刚好相反,根据
ch_mult
构造多个Block,逐步上采样,增大spatial维度,减小channel维度。 - conv_out:最终的conv_out用于将channel维度从
block_in
转化为out_channels=3
,得到图像pixel valuse。
class Decoder(nn.Module):
def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group",
dropout=0.0, resamp_with_conv=True, out_channels=3):
super().__init__()
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = ch*ch_mult[self.num_resolutions-1]
# z to block_in
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.ModuleList()
self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
self.mid.append(AttnBlock(block_in, norm_type=norm_type))
self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
# upsampling
self.conv_blocks = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
conv_block = nn.Module()
# res & attn
res_block = nn.ModuleList()
attn_block = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
block_in = block_out
if i_level == self.num_resolutions - 1:
attn_block.append(AttnBlock(block_in, norm_type))
conv_block.res = res_block
conv_block.attn = attn_block
# downsample
if i_level != 0:
conv_block.upsample = Upsample(block_in, resamp_with_conv)
self.conv_blocks.append(conv_block)
# end
self.norm_out = Normalize(block_in, norm_type)
self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
@property
def last_layer(self):
return self.conv_out.weight
def forward(self, z):
# z to block_in
h = self.conv_in(z)
# middle
for mid_block in self.mid:
h = mid_block(h)
# upsampling
for i_level, block in enumerate(self.conv_blocks):
for i_block in range(self.num_res_blocks + 1):
h = block.res[i_block](h)
if len(block.attn) > 0:
h = block.attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = block.upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
loss—VQLoss_triple_codebook
前面的VQ_Model
进行forward的时候会得到3个embed_loss和1个disentangle_loss:
- codebook embedding loss:因为有3个codebook,所有3个VQ操作回得到3个embed_loss(
emb_loss_vis
,emb_loss_sem_mid
,emb_loss_sem_high
),每个emb_loss都是由3个loss组成(vq_loss
,commit_loss
,entropy_loss
),用于优化codebook。 - disentangle loss:本文的创新点之一,将不同的codebook的zq之间计算点积的L2距离之和作为
disentangle_loss
,希望不同codebook之间相互正交。
除此之外,在训练时还可以使用VQLoss_triple_codebook
也可以另外计算reconstruction_loss,perceptual_loss,kd_teacher_loss:
-
pixel loss:
- reconstruction_loss:计算VQ_Model重建前后input和output的pixel values的
l1_loss
或者l2_loss
。 - perceptual_loss:使用
vgg-based LPIPS
计算input和output的lpips
值作为loss。
- reconstruction_loss:计算VQ_Model重建前后input和output的pixel values的
-
discriminator loss:用于优化鉴别器
discriminator
,discriminator可以是PatchGAN
或StyleGAN
(输入真实的image或重建的image,输出预测真假的概率分布logits)。discriminator_loss类型可以是hinge
、vanilla
、non-saturating
三类。 -
gen_adv_loss:用于优化生成器,生成器的目标是生成尽可能接近真实数据的假数据,以欺骗判别器。分为
hinge
和non_saturating
两种,都是希望重建后图像的概率分布logits_fake更倾向于重建后的image是真实的。 -
semantic loss(kd_teacher_loss):使用2个不同的FeatureHead输出了2个image vq feature分别与CLIP和DINO的feature计算loss,用来蒸馏通用的理解表征。
VQLoss_triple_codebook的初始化就是为计算上述loss准备一些参数和模型:
class VQLoss_triple_codebook(nn.Module):
def __init__(self, disc_start, disc_loss="hinge", disc_dim=64, disc_type='patchgan', image_size=256,
disc_num_layers=3, disc_in_channels=3, disc_weight=1.0, disc_adaptive_weight=False,
gen_adv_loss='hinge', reconstruction_loss='l2', reconstruction_weight=1.0,
codebook_weight=1.0, perceptual_weight=1.0,
with_clip_supervision=False, semantic_weight=0.5,
):
super().__init__()
# 1. discriminator loss
assert disc_type in ["patchgan", "stylegan"]
assert disc_loss in ["hinge", "vanilla", "non-saturating"]
# discriminator
if disc_type == "patchgan":
print("Using patchgan D")
self.discriminator = PatchGANDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers,
ndf=disc_dim,
)
elif disc_type == "stylegan":
print("Using stylegan D")
self.discriminator = StyleGANDiscriminator(
input_nc=disc_in_channels,
image_size=image_size,
)
else:
raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.")
# disc_loss type
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
elif disc_loss == "non-saturating":
self.disc_loss = non_saturating_d_loss
else:
raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.")
self.discriminator_iter_start = disc_start
self.disc_weight = disc_weight
self.disc_adaptive_weight = disc_adaptive_weight
assert gen_adv_loss in ["hinge", "non-saturating"]
# 2. gen_adv_loss
if gen_adv_loss == "hinge":
self.gen_adv_loss = hinge_gen_loss
elif gen_adv_loss == "non-saturating":
self.gen_adv_loss = non_saturating_gen_loss
else:
raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.")
# 3. perceptual loss
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# 4. semantic loss
self.with_clip_supervision = with_clip_supervision
if with_clip_supervision:
self.clip_model = CLIPVisionTower("/mnt/workspace/Project/UnderGenTokenizer/FQGAN/models/clip-vit-base-patch16").eval()
self.dinov2_model = DinoVisionTower("/mnt/workspace/Project/UnderGenTokenizer/FQGAN/models/dinov2-small").eval()
self.clip_model.requires_grad_(False)
self.dinov2_model.requires_grad_(False)
self.semantic_weight = semantic_weight
else:
self.clip_model = None
self.dinov2_model = None
self.semantic_weight = None
# 5. reconstruction loss
if reconstruction_loss == "l1":
self.rec_loss = F.l1_loss
elif reconstruction_loss == "l2":
self.rec_loss = F.mse_loss
else:
raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.")
self.rec_weight = reconstruction_weight
# 6. codebook loss
self.codebook_weight = codebook_weight
VQLoss_triple_codebook类的forward过程根据optimizer_idx
的值分为2个模式,两个模式在同一个batch的先后执行,也就是在训练时,要进行2次的vq_loss类的forward,一次计算generator的loss,一次计算discriminator的loss。且generator和discriminator分别使用2个不同的优化器(optimizer
和optimizer_disc
):
- optimizer_idx == 0时,优化generator,计算
reconstruction loss
、perceptual loss
、semantic loss
、gen_adv_loss
,并将其与之前VQModel推理时计算的codebook_embed_loss
和disentangle_loss
线性加权起来组成总loss,用于优化VQ_Model。 - optimizer_idx == 1时,优化discriminator,计算
discriminator loss
用于优化Discriminator。
def forward(self,
codebook_loss_vis, codebook_loss_sem_mid, codebook_loss_sem_high,
inputs, reconstructions,
disentangle_loss,
semantic_feat_mid, semantic_feat_high,
optimizer_idx, global_step, last_layer=None,
logger=None, log_every=100):
# generator update
if optimizer_idx == 0:
# reconstruction loss
rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous())
# semantic loss
if semantic_feat_mid is not None:
assert semantic_feat_high is not None
semantic_loss_mid = self.dinov2_model(inputs.contiguous(), semantic_feat_mid) # how to compute semantic loss?
semantic_loss_mid = torch.mean(semantic_loss_mid)
semantic_loss_high = self.clip_model(inputs.contiguous(), semantic_feat_high)
semantic_loss_high = torch.mean(semantic_loss_high)
else:
assert self.with_clip_supervision == False
semantic_loss_mid = torch.mean(torch.zeros_like(rec_loss))
semantic_loss_high = torch.mean(torch.ones_like(rec_loss))
# perceptual loss
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
p_loss = torch.mean(p_loss)
# discriminator loss
logits_fake = self.discriminator(reconstructions.contiguous())
generator_adv_loss = self.gen_adv_loss(logits_fake)
if self.disc_adaptive_weight:
null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss # pixel loss
disc_adaptive_weight = self.calculate_adaptive_weight(null_loss, generator_adv_loss,
last_layer=last_layer)
else:
disc_adaptive_weight = 1
disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
loss = self.rec_weight * rec_loss + \
self.perceptual_weight * p_loss + \
disc_adaptive_weight * disc_weight * generator_adv_loss + \
codebook_loss_vis[0] + codebook_loss_vis[1] + codebook_loss_vis[2] + \
codebook_loss_sem_mid[0] + codebook_loss_sem_mid[1] + codebook_loss_sem_mid[2] + \
codebook_loss_sem_high[0] + codebook_loss_sem_high[1] + codebook_loss_sem_high[2] + \
self.semantic_weight * semantic_loss_mid + self.semantic_weight * semantic_loss_high + disentangle_loss
if global_step % log_every == 0:
rec_loss = self.rec_weight * rec_loss
p_loss = self.perceptual_weight * p_loss
generator_adv_loss = disc_adaptive_weight * disc_weight * generator_adv_loss
logger.info(f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, "
f"vq_loss_sem_mid: {codebook_loss_sem_mid[0]:.4f}, "
f"commit_loss_sem_mid: {codebook_loss_sem_mid[1]:.4f}, "
f"entropy_loss_sem_mid: {codebook_loss_sem_mid[2]:.4f}, "
f"codebook_usage_sem_mid: {codebook_loss_sem_mid[3]:.4f}, "
f"vq_loss_sem_high: {codebook_loss_sem_high[0]:.4f}, "
f"commit_loss_sem_high: {codebook_loss_sem_high[1]:.4f}, "
f"entropy_loss_sem_high: {codebook_loss_sem_high[2]:.4f}, "
f"codebook_usage_sem_high: {codebook_loss_sem_high[3]:.4f}, "
f"vq_loss_vis: {codebook_loss_vis[0]:.4f}, "
f"commit_loss_vis: {codebook_loss_vis[1]:.4f}, "
f"entropy_loss_vis: {codebook_loss_vis[2]:.4f}, "
f"codebook_usage_vis: {codebook_loss_vis[3]:.4f}, "
f"disentangle_loss: {disentangle_loss: .4f}"
f"generator_adv_loss: {generator_adv_loss:.4f}, "
f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}, "
f"semantic_loss_mid: {semantic_loss_mid:.4f}, semantic_loss_high: {semantic_loss_high:.4f}")
if dist.get_rank() == 0:
wandb.log({
"rec_loss": rec_loss,
"perceptual_loss": p_loss,
"disentangle_loss": disentangle_loss,
"codebook_loss_sem_mid": codebook_loss_sem_mid[0],
"commit_loss_sem_mid": codebook_loss_sem_mid[1],
"entropy_loss_sem_mid": codebook_loss_sem_mid[2],
"codebook_usage_sem_mid": codebook_loss_sem_mid[3],
"codebook_loss_sem_high": codebook_loss_sem_high[0],
"commit_loss_sem_high": codebook_loss_sem_high[1],
"entropy_loss_sem_high": codebook_loss_sem_high[2],
"codebook_usage_sem_high": codebook_loss_sem_high[3],
"codebook_loss_vis": codebook_loss_vis[0],
"commit_loss_vis": codebook_loss_vis[1],
"entropy_loss_vis": codebook_loss_vis[2],
"codebook_usage_vis": codebook_loss_vis[3],
"generator_adv_loss": generator_adv_loss,
"disc_adaptive_weight": disc_adaptive_weight,
"disc_weight": disc_weight,
"semantic_loss_mid": semantic_loss_mid,
"semantic_loss_high": semantic_loss_high,
})
return loss
# discriminator update
if optimizer_idx == 1:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
d_adversarial_loss = disc_weight * self.disc_loss(logits_real, logits_fake)
if global_step % log_every == 0:
logits_real = logits_real.detach().mean()
logits_fake = logits_fake.detach().mean()
logger.info(f"(Discriminator) "
f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, "
f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}")
if dist.get_rank() == 0:
wandb.log({
"discriminator_adv_loss": d_adversarial_loss,
"disc_weight": disc_weight,
"logits_real": logits_real,
"logits_fake": logits_fake,
})
return d_adversarial_loss