当前位置: 首页 > article >正文

FQ-GAN代码解析

主要看 model 、loss 和 data 部分如何实现和处理的。

  • model—VQ_models
    • VQModel
    • Encoder
    • VectorQuantizer
    • Decoder
  • loss—VQLoss_triple_codebook

model—VQ_models

创建vq_model直接根据传入的模型压缩倍率8/16初始化对应的VQ_8/VQ_16,两者都是初始化一个VQModel的类,只是压缩的倍率ch_mult不同(这个和UNet里的ch_mult是一致的,表示每个Block上/下采样的倍数,所有倍率之积就是压缩倍率)

	# create and load model
    vq_model = VQ_models[args.vq_model](
        codebook_size=args.codebook_size,
        codebook_embed_dim=args.codebook_embed_dim,
        commit_loss_beta=args.commit_loss_beta,
        entropy_loss_ratio=args.entropy_loss_ratio,
        dropout_p=args.dropout_p,
        with_clip_supervision=args.with_clip_supervision,
        with_disentanglement=args.with_disentanglement,
        disentanglement_ratio=args.disentanglement_ratio,
    )
    
def VQ_8(**kwargs):
    return VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs))


def VQ_16(**kwargs):
    return VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))


VQ_models = {'VQ-16': VQ_16, 'VQ-8': VQ_8}

VQModel

包含3个 codebook 的 VQModel 的结构如下:

  • EncoderEncoder(逐步压缩spatial维度到embed_dim维度)
  • VectorQuantizer:3个VectorQuantizer(分别是pixel level的无teacher,mid semantic level的DINO teacher,high semantic level的CLIP teacher),配合3个quant_conv(将z_channels变成codebook_embed_dim
  • Decoder:1个post_quant_conv(将emebdding_dim从3*codebook_embed_dim变成z_channels),一个Decoder(逐步将embed_dim维度还原到spatial维度)
  • FeatPredHead:2个FeatPredHead(分别是将vq feature对齐到CLIP和DINO feature的MLP Head,用于蒸馏监督)
class VQModel(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        self.config = config

        # Two head encoder
        self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)

        # Quantizer for visual detail head
        self.quantize_vis = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
                                            config.commit_loss_beta, config.entropy_loss_ratio,
                                            config.codebook_l2_norm, config.codebook_show_usage)
        self.quant_conv_vis = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)

        # Quantizer for mid-level semantic head
        self.quantize_sem_mid = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
                                                config.commit_loss_beta, config.entropy_loss_ratio,
                                                config.codebook_l2_norm, config.codebook_show_usage)
        self.quant_conv_sem_mid = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)

        # Quantizer for high-level semantic head
        self.quantize_sem_high = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
                                                 config.commit_loss_beta, config.entropy_loss_ratio,
                                                 config.codebook_l2_norm, config.codebook_show_usage)
        self.quant_conv_sem_high = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)

        print("Visual codebook: [{} x {}]".format(config.codebook_size, config.codebook_embed_dim))
        print("Mid Semantic codebook: [{} x {}]".format(config.codebook_size, config.codebook_embed_dim))
        print("High Semantic codebook: [{} x {}]".format(config.codebook_size, config.codebook_embed_dim))

        # Pixel decoder
        input_dim = config.codebook_embed_dim * 3
        self.post_quant_conv = nn.Conv2d(input_dim, config.z_channels, 1)
        self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels,
                               dropout=config.dropout_p)

        # Down-sample factor in encoder channel multiplier
        self.num_resolutions = len(config.encoder_ch_mult)
        if self.num_resolutions == 5:  # encoder_ch_mult=[1, 1, 2, 2, 4]
            down_factor = 16
        elif self.num_resolutions == 4:  # encoder_ch_mult=[1, 2, 2, 4]
            down_factor = 8
        else:
            raise NotImplementedError

        # Semantic feature prediction
        if self.config.with_clip_supervision:
            print("Include feature prediction head for representation supervision")
            self.mid_sem_feat_pred = FeatPredHead(input_dim=config.codebook_embed_dim, out_dim=384, down_factor=down_factor)
            self.high_sem_feat_pred = FeatPredHead(input_dim=config.codebook_embed_dim, out_dim=768, down_factor=down_factor)
        else:
            print("NO representation supervision")

        if self.config.with_disentanglement:
            print("Disentangle Ratio: ", self.config.disentanglement_ratio)
        else:
            print("No Disentangle Regularization")

前向forward 包含encodevqdecode三个主要过程(因为需要KD,额外要一步feature对齐操作):

  1. ① 输入经过encoder得到3个不同的feature(h_vis, h_sem_mid, h_sem_high),再经过3个quant_conv将embed_dim对齐到codebook_embed_dim。
  2. ②将不同level的image feature送入不同的VectorQuantizer,得到三个不同的quant_featureemb_lossemb_loss包含vq_losscommit_lossentropy_loss三部分)
  3. 因为需要知识蒸馏,因此需要额外使用FeatPredHead将quant_feature对齐到CLIP和DINO特征的维度(mid_sem_feat_predhigh_sem_feat_pred)
  4. 因为希望3个codebook相互正交(解耦程度大),因此需要构造1个解耦loss,使3个level的vq feature相互不同(embedding点积之和的L2 loss,即disentangle_loss )。
  5. ③ 将quant_feature经过post_quant_conv和decoder,解码为原始image的pixel_values(dec)。
    def forward(self, input):
        # 1. encode
        h_vis, h_sem_mid, h_sem_high = self.encoder(input)
        h_vis = self.quant_conv_vis(h_vis)
        h_sem_mid = self.quant_conv_sem_mid(h_sem_mid)
        h_sem_high = self.quant_conv_sem_high(h_sem_high)

        # 2. vq
        quant_vis, emb_loss_vis, _ = self.quantize_vis(h_vis)
        quant_sem_mid, emb_loss_sem_mid, _ = self.quantize_sem_mid(h_sem_mid)
        quant_sem_high, emb_loss_sem_high, _ = self.quantize_sem_high(h_sem_high)

        # for konwledge distillation
        if self.config.with_clip_supervision:
            mid_lvl_sem_feat = self.mid_sem_feat_pred(quant_sem_mid)
            high_lvl_sem_feat = self.high_sem_feat_pred(quant_sem_high)
        else:
            mid_lvl_sem_feat = None
            high_lvl_sem_feat = None

        # for disentangle vq feature of 3 codebook
        if self.config.with_disentanglement:
            disentangle_loss = (self.compute_disentangle_loss(quant_vis, quant_sem_mid) +
                                self.compute_disentangle_loss(quant_vis, quant_sem_high) +
                                self.compute_disentangle_loss(quant_sem_mid, quant_sem_high)) / 3.0
        else:
            disentangle_loss = 0

        # 3. decode
        quant = torch.cat([quant_vis, quant_sem_mid, quant_sem_high], dim=1)
        dec = self.decode(quant)

        return dec, \
            emb_loss_vis, emb_loss_sem_mid, emb_loss_sem_high, \
            disentangle_loss, \
            mid_lvl_sem_feat, high_lvl_sem_feat

本文叫FQ的创新点就是在于设计了这个disentangle_loss使得3个codebook相互正交解耦:这个损失函数的设计思想是,如果2个特征是解耦的,那么它们的点积应该接近于零,因为它们应该是正交的。通过最小化这个损失,模型被鼓励学习到解耦的不同level的特征。

    def compute_disentangle_loss(self, quant_vis, quant_sem):
        quant_vis = rearrange(quant_vis, 'b c h w -> (b h w) c')
        quant_sem = rearrange(quant_sem, 'b c h w -> (b h w) c')

        quant_vis = F.normalize(quant_vis, p=2, dim=-1)
        quant_sem = F.normalize(quant_sem, p=2, dim=-1)

        dot_product = torch.sum(quant_vis * quant_sem, dim=1)
        loss = torch.mean(dot_product ** 2) * self.config.disentanglement_ratio

        return loss

Encoder

Encoder是输入image feature,经过统一的downsampling conv_blocksmid blocks,再分别送入3个不同的adapter输出3个不同的feature

  1. conv_in:输入的image feature首先由conv_inchannel维度转化为128
  2. downsamplingconv_blocks根据ch_mult=(1,1,2,2,4)ch_mult=(1, 2, 2, 4)构建ResnetBlockAttnBlock以及Downsample组成,其中ch_mult用于控制每个conv_block的channel增加倍数。(channel增加,h和w减小)。每个block的下采样后的channel是128*ch_mult[i](例如ch_mult=(1, 2, 2, 4)时,共有4个block,channel的变化是128->128->256->512->2048)。
  3. mid:由ResnetBlock+AttnBlock+ResnetBlock组成其中卷积不改变channel,等效于MLP。
  4. adapter:由3个不同的FactorizedAdapter组成,用于将统一的encoder feature转化为3个不同的feature,用于后面3个codebook的VQ操作。
  5. conv_out:因为前一步将feature转化了3份(h_vis, h_sem_mid, h_sem_high),因此此处从conv_out分别用3个不同的conv2d用于对齐feature的channel维度(转换为z_channels维度)。
class Encoder(nn.Module):
    def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, 
                 norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256):
        super().__init__()
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)

        # downsampling
        in_ch_mult = (1,) + tuple(ch_mult)
        self.conv_blocks = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            conv_block = nn.Module()
            # res & attn
            res_block = nn.ModuleList()
            attn_block = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for _ in range(self.num_res_blocks):
                res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
                block_in = block_out
                if i_level == self.num_resolutions - 1:
                    attn_block.append(AttnBlock(block_in, norm_type))
            conv_block.res = res_block
            conv_block.attn = attn_block
            # downsample
            if i_level != self.num_resolutions-1:
                conv_block.downsample = Downsample(block_in, resamp_with_conv)
            self.conv_blocks.append(conv_block)

        # middle
        self.mid = nn.ModuleList()
        self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
        self.mid.append(AttnBlock(block_in, norm_type=norm_type))
        self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))


        if self.num_resolutions == 5:
            down_factor = 16
        elif self.num_resolutions == 4:
            down_factor = 8
        else:
            raise NotImplementedError

        # semantic head mid-level
        self.semantic_head_mid = nn.ModuleList()
        self.semantic_head_mid.append(FactorizedAdapter(down_factor))

        # semantic head high-level
        self.semantic_head_high = nn.ModuleList()
        self.semantic_head_high.append(FactorizedAdapter(down_factor))

        # visual details head
        self.visual_head = nn.ModuleList()
        self.visual_head.append(FactorizedAdapter(down_factor))

        # end
        self.norm_out_sem_mid = Normalize(block_in, norm_type)
        self.conv_out_sem_mid = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)

        self.norm_out_sem_high = Normalize(block_in, norm_type)
        self.conv_out_sem_high = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)

        self.norm_out_vis = Normalize(block_in, norm_type)
        self.conv_out_vis = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        h = self.conv_in(x)
        # downsampling
        for i_level, block in enumerate(self.conv_blocks):
            for i_block in range(self.num_res_blocks):
                h = block.res[i_block](h)
                if len(block.attn) > 0:
                    h = block.attn[i_block](h)
            if i_level != self.num_resolutions - 1:
                h = block.downsample(h)
        
        # middle
        for mid_block in self.mid:
            h = mid_block(h)

        h_vis = h
        h_sem_mid = h
        h_sem_high = h

        # semantic head mid-level
        for blk in self.semantic_head_mid:
            h_sem_mid = blk(h_sem_mid)
        h_sem_mid = self.norm_out_sem_mid(h_sem_mid)
        h_sem_mid = nonlinearity(h_sem_mid)
        h_sem_mid = self.conv_out_sem_mid(h_sem_mid)

        # semantic head high-level
        for blk in self.semantic_head_high:
            h_sem_high = blk(h_sem_high)
        h_sem_high = self.norm_out_sem_high(h_sem_high)
        h_sem_high = nonlinearity(h_sem_high)
        h_sem_high = self.conv_out_sem_high(h_sem_high)

        # visual head
        for blk in self.visual_head:
            h_vis = blk(h_vis)
        h_vis = self.norm_out_vis(h_vis)
        h_vis = nonlinearity(h_vis)
        h_vis = self.conv_out_vis(h_vis)

        return h_vis, h_sem_mid, h_sem_high

VectorQuantizer

VectorQuantizer的初始化操作主要是创建一个大小[codebook_size, codebook_embed_dim]codebook embeddingembedding)。

class VectorQuantizer(nn.Module):
    def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
        super().__init__()
        self.n_e = n_e  # codebook_size
        self.e_dim = e_dim  # codebook_embed_dim
        self.beta = beta  # commitment_loss scale
        self.entropy_loss_ratio = entropy_loss_ratio  # entropy_loss scale
        self.l2_norm = l2_norm  # l2_norm for codebook embeddings
        self.show_usage = show_usage  # show codebook usage

        # create codebook embedding and initialize
        self.embedding = nn.Embedding(self.n_e, self.e_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

        if self.l2_norm:  # normalize embeddings
            self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
        if self.show_usage:  # initialize codebook usage
            self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))     # 1048576

forward的操作和VQGAN一样,就是把image feature z的所有token embeddings查表量化为codebook中argmin(distances)的emebddings得到 quant image feature zq,同时计算3个loss(用于优化codebook embedding)。

  • l2 norm:同时对zcodebook embeddings进行L2归一化,可以将向量的模长缩放到相同的大小,即转换为在单位球面上的向量,这样每个向量在距离度量中的作用是相等的,使得不同向量更具有可比性!向量之间更容易比较和匹配,提高了训练稳定性和重建质量
  • argmin(distances):经典的VQ计算distances的操作,展开为两个平方和一个乘积, ( z − e ) 2 = z 2 + e 2 − 2 e ∗ z (z - e)^2 = z^2 + e^2 - 2 e * z (ze)2=z2+e22ez。然后argmin(distances)得到z中每个embedding在codebook中最近的embedding的index,再从codebook的embeddings中取出组成zq
  • codebook usage:是计算codebook中的embeddings 的利用率。
    def forward(self, z):
        # reshape z -> (batch, height, width, channel) and flatten
        z = torch.einsum('b c h w -> b h w c', z).contiguous()
        z_flattened = z.view(-1, self.e_dim)
        
        if self.l2_norm:  # normalize z and codebook_embedding for mapping vector to euclidean space(单位球上)
            z = F.normalize(z, p=2, dim=-1)
            z_flattened = F.normalize(z_flattened, p=2, dim=-1)
            embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
        else:
            embedding = self.embedding.weight

        # distances from z to embeddings e_j: (z - e)^2 = z^2 + e^2 - 2 e * z
        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(embedding**2, dim=1) - 2 * \
            torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))
        # argmin(distances)
        min_encoding_indices = torch.argmin(d, dim=1)
        # replace each z_i with its closest embedding e_j
        z_q = embedding[min_encoding_indices].view(z.shape)

        perplexity = None
        min_encodings = None
        vq_loss = None
        commit_loss = None
        entropy_loss = None
        codebook_usage = 0

        # compute codebook usage
        if self.show_usage and self.training:
            cur_len = min_encoding_indices.shape[0]
            self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()  # copy last cur_len elements to front
            self.codebook_used[-cur_len:] = min_encoding_indices  # set last cur_len elements as min_encoding_indices
            codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e
  • embedding loss
    • vq_loss是计算量化后的向量 z_q 和原始输入向量 z 之间的均方误差(Mean Squared Error, MSE)。z.detach() 表示 z 是从计算图中分离出来的,这意味着在计算 vq_loss 时,z 不会对其梯度产生影响。这个损失鼓励模型将输入向量 z 量化为与其尽可能接近的嵌入向量 z_q
    • commit_loss也是均方误差,但是这里 z_q 是从计算图中分离出来的。这意味着在计算 commit_loss 时,z_q 不会对其梯度产生影响。这个损失的作用是鼓励模型在量化过程中保持对原始输入向量 z 的承诺,即量化后的向量 z_q 应该尽可能地反映输入向量 z 的信息。参数 self.beta 是一个超参数,用于调节这个损失在总损失中的重要性。
    • entropy_loss用于鼓励码本的均匀使用,从而提高模型的泛化能力。compute_entropy_loss(-d) 计算的是基于码本距离的负值的熵损失,-d 表示我们对每个输入向量 z 计算到所有嵌入的平方距离,然后取负值。熵损失的计算通常涉及到对这些距离的softmax操作,然后计算交叉熵。self.entropy_loss_ratio 是一个超参数,用于调节熵损失在总损失中的重要性。
		# compute 3 loss for embedding
        if self.training:
            vq_loss = torch.mean((z_q - z.detach()) ** 2) 
            commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) 
            entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)

        # preserve gradients
        z_q = z + (z_q - z).detach()

        # reshape back to match original input shape
        z_q = torch.einsum('b h w c -> b c h w', z_q)

        return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices)
def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
    flat_affinity = affinity.reshape(-1, affinity.shape[-1])
    flat_affinity /= temperature
    probs = F.softmax(flat_affinity, dim=-1)
    log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
    if loss_type == "softmax":
        target_probs = probs
    else:
        raise ValueError("Entropy loss {} not supported".format(loss_type))
    avg_probs = torch.mean(target_probs, dim=0)
    avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
    sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
    loss = sample_entropy - avg_entropy
    return loss

get_codebook_entry用于Transformer自回归的预测一个index序列后,用于在codebook查表转化为对应embeddings

    def get_codebook_entry(self, indices, shape=None, channel_first=True):
        # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
        if self.l2_norm:
            embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
        else:
            embedding = self.embedding.weight
        z_q = embedding[indices]  # (b*h*w, c)

        if shape is not None:
            if channel_first:
                z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
                # reshape back to match original input shape
                z_q = z_q.permute(0, 3, 1, 2).contiguous()
            else:
                z_q = z_q.view(shape)
        return z_q

Decoder

整个VQ操作从zzq不改变image feature的shape,channel维度还是z_channels=256。因此Decoder将zq解码为image的pixel values的过程如下:

  1. conv_in:使用Conv2d将zq的channel维度从z_channels变换到block_in(由ch=128和ch_mult决定的)
  2. middle block:和Encoder一样由ResnetBlockAttnBlockResnetBlock组成,不改变channel维度。
  3. upsampling conv_blocks:和Encoder刚好相反,根据ch_mult构造多个Block,逐步上采样,增大spatial维度,减小channel维度。
  4. conv_out:最终的conv_out用于将channel维度从block_in转化为out_channels=3,得到图像pixel valuse。
class Decoder(nn.Module):
    def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group",
                 dropout=0.0, resamp_with_conv=True, out_channels=3):
        super().__init__()
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks

        block_in = ch*ch_mult[self.num_resolutions-1]
        # z to block_in
        self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)

        # middle
        self.mid = nn.ModuleList()
        self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
        self.mid.append(AttnBlock(block_in, norm_type=norm_type))
        self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))

        # upsampling
        self.conv_blocks = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            conv_block = nn.Module()
            # res & attn
            res_block = nn.ModuleList()
            attn_block = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            for _ in range(self.num_res_blocks + 1):
                res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
                block_in = block_out
                if i_level == self.num_resolutions - 1:
                    attn_block.append(AttnBlock(block_in, norm_type))
            conv_block.res = res_block
            conv_block.attn = attn_block
            # downsample
            if i_level != 0:
                conv_block.upsample = Upsample(block_in, resamp_with_conv)
            self.conv_blocks.append(conv_block)

        # end
        self.norm_out = Normalize(block_in, norm_type)
        self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)

    @property
    def last_layer(self):
        return self.conv_out.weight
    
    def forward(self, z):
        # z to block_in
        h = self.conv_in(z)

        # middle
        for mid_block in self.mid:
            h = mid_block(h)
        
        # upsampling
        for i_level, block in enumerate(self.conv_blocks):
            for i_block in range(self.num_res_blocks + 1):
                h = block.res[i_block](h)
                if len(block.attn) > 0:
                    h = block.attn[i_block](h)
            if i_level != self.num_resolutions - 1:
                h = block.upsample(h)

        # end
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)

        return h

loss—VQLoss_triple_codebook

前面的VQ_Model进行forward的时候会得到3个embed_loss和1个disentangle_loss:

  • codebook embedding loss:因为有3个codebook,所有3个VQ操作回得到3个embed_loss(emb_loss_vis, emb_loss_sem_mid, emb_loss_sem_high),每个emb_loss都是由3个loss组成(vq_loss, commit_loss, entropy_loss),用于优化codebook。
  • disentangle loss:本文的创新点之一,将不同的codebook的zq之间计算点积的L2距离之和作为disentangle_loss,希望不同codebook之间相互正交。

除此之外,在训练时还可以使用VQLoss_triple_codebook也可以另外计算reconstruction_loss,perceptual_loss,kd_teacher_loss:

  • pixel loss

    • reconstruction_loss:计算VQ_Model重建前后input和output的pixel values的l1_loss或者l2_loss
    • perceptual_loss:使用vgg-based LPIPS计算input和output的lpips值作为loss。
  • discriminator loss:用于优化鉴别器discriminator,discriminator可以是PatchGANStyleGAN(输入真实的image或重建的image,输出预测真假的概率分布logits)。discriminator_loss类型可以是hingevanillanon-saturating三类。

  • gen_adv_loss:用于优化生成器,生成器的目标是生成尽可能接近真实数据的假数据,以欺骗判别器。分为hingenon_saturating两种,都是希望重建后图像的概率分布logits_fake更倾向于重建后的image是真实的。

  • semantic loss(kd_teacher_loss):使用2个不同的FeatureHead输出了2个image vq feature分别与CLIP和DINO的feature计算loss,用来蒸馏通用的理解表征。

VQLoss_triple_codebook的初始化就是为计算上述loss准备一些参数和模型:

class VQLoss_triple_codebook(nn.Module):
    def __init__(self, disc_start, disc_loss="hinge", disc_dim=64, disc_type='patchgan', image_size=256,
                 disc_num_layers=3, disc_in_channels=3, disc_weight=1.0, disc_adaptive_weight=False,
                 gen_adv_loss='hinge', reconstruction_loss='l2', reconstruction_weight=1.0,
                 codebook_weight=1.0, perceptual_weight=1.0,
                 with_clip_supervision=False, semantic_weight=0.5,
                 ):
        super().__init__()
        # 1. discriminator loss
        assert disc_type in ["patchgan", "stylegan"]
        assert disc_loss in ["hinge", "vanilla", "non-saturating"]
        # discriminator
        if disc_type == "patchgan":
            print("Using patchgan D")
            self.discriminator = PatchGANDiscriminator(
                input_nc=disc_in_channels,
                n_layers=disc_num_layers,
                ndf=disc_dim,
            )
        elif disc_type == "stylegan":
            print("Using stylegan D")
            self.discriminator = StyleGANDiscriminator(
                input_nc=disc_in_channels,
                image_size=image_size,
            )
        else:
            raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.")
        # disc_loss type
        if disc_loss == "hinge":
            self.disc_loss = hinge_d_loss
        elif disc_loss == "vanilla":
            self.disc_loss = vanilla_d_loss
        elif disc_loss == "non-saturating":
            self.disc_loss = non_saturating_d_loss
        else:
            raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.")
        self.discriminator_iter_start = disc_start
        self.disc_weight = disc_weight
        self.disc_adaptive_weight = disc_adaptive_weight


        assert gen_adv_loss in ["hinge", "non-saturating"]
        # 2. gen_adv_loss
        if gen_adv_loss == "hinge":
            self.gen_adv_loss = hinge_gen_loss
        elif gen_adv_loss == "non-saturating":
            self.gen_adv_loss = non_saturating_gen_loss
        else:
            raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.")

        # 3. perceptual loss
        self.perceptual_loss = LPIPS().eval()
        self.perceptual_weight = perceptual_weight

        # 4. semantic loss
        self.with_clip_supervision = with_clip_supervision
        if with_clip_supervision:
            self.clip_model = CLIPVisionTower("/mnt/workspace/Project/UnderGenTokenizer/FQGAN/models/clip-vit-base-patch16").eval()
            self.dinov2_model = DinoVisionTower("/mnt/workspace/Project/UnderGenTokenizer/FQGAN/models/dinov2-small").eval()
            self.clip_model.requires_grad_(False)
            self.dinov2_model.requires_grad_(False)
            self.semantic_weight = semantic_weight
        else:
            self.clip_model = None
            self.dinov2_model = None
            self.semantic_weight = None

        # 5. reconstruction loss
        if reconstruction_loss == "l1":
            self.rec_loss = F.l1_loss
        elif reconstruction_loss == "l2":
            self.rec_loss = F.mse_loss
        else:
            raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.")
        self.rec_weight = reconstruction_weight

        # 6. codebook loss
        self.codebook_weight = codebook_weight

VQLoss_triple_codebook类的forward过程根据optimizer_idx的值分为2个模式,两个模式在同一个batch的先后执行,也就是在训练时,要进行2次的vq_loss类的forward,一次计算generator的loss,一次计算discriminator的loss。且generator和discriminator分别使用2个不同的优化器(optimizeroptimizer_disc):

  • optimizer_idx == 0时,优化generator,计算reconstruction lossperceptual losssemantic lossgen_adv_loss,并将其与之前VQModel推理时计算的codebook_embed_lossdisentangle_loss线性加权起来组成总loss,用于优化VQ_Model
  • optimizer_idx == 1时,优化discriminator,计算discriminator loss用于优化Discriminator
    def forward(self,
                codebook_loss_vis, codebook_loss_sem_mid, codebook_loss_sem_high,
                inputs, reconstructions,
                disentangle_loss,
                semantic_feat_mid, semantic_feat_high,
                optimizer_idx, global_step, last_layer=None,
                logger=None, log_every=100):
        # generator update
        if optimizer_idx == 0:
            # reconstruction loss
            rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous())

            # semantic loss
            if semantic_feat_mid is not None:
                assert semantic_feat_high is not None
                semantic_loss_mid = self.dinov2_model(inputs.contiguous(), semantic_feat_mid)  # how to compute semantic loss?
                semantic_loss_mid = torch.mean(semantic_loss_mid)

                semantic_loss_high = self.clip_model(inputs.contiguous(), semantic_feat_high)
                semantic_loss_high = torch.mean(semantic_loss_high)
            else:
                assert self.with_clip_supervision == False
                semantic_loss_mid = torch.mean(torch.zeros_like(rec_loss))
                semantic_loss_high = torch.mean(torch.ones_like(rec_loss))

            # perceptual loss
            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
            p_loss = torch.mean(p_loss)

            # discriminator loss
            logits_fake = self.discriminator(reconstructions.contiguous())
            generator_adv_loss = self.gen_adv_loss(logits_fake)

            if self.disc_adaptive_weight:
                null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss  # pixel loss
                disc_adaptive_weight = self.calculate_adaptive_weight(null_loss, generator_adv_loss,
                                                                      last_layer=last_layer)
            else:
                disc_adaptive_weight = 1
            disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)

            loss = self.rec_weight * rec_loss + \
                   self.perceptual_weight * p_loss + \
                   disc_adaptive_weight * disc_weight * generator_adv_loss + \
                   codebook_loss_vis[0] + codebook_loss_vis[1] + codebook_loss_vis[2] + \
                   codebook_loss_sem_mid[0] + codebook_loss_sem_mid[1] + codebook_loss_sem_mid[2] + \
                   codebook_loss_sem_high[0] + codebook_loss_sem_high[1] + codebook_loss_sem_high[2] + \
                   self.semantic_weight * semantic_loss_mid + self.semantic_weight * semantic_loss_high + disentangle_loss

            if global_step % log_every == 0:
                rec_loss = self.rec_weight * rec_loss
                p_loss = self.perceptual_weight * p_loss
                generator_adv_loss = disc_adaptive_weight * disc_weight * generator_adv_loss
                logger.info(f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, "
                            
                            f"vq_loss_sem_mid: {codebook_loss_sem_mid[0]:.4f}, "
                            f"commit_loss_sem_mid: {codebook_loss_sem_mid[1]:.4f}, "
                            f"entropy_loss_sem_mid: {codebook_loss_sem_mid[2]:.4f}, "
                            f"codebook_usage_sem_mid: {codebook_loss_sem_mid[3]:.4f}, "
                            
                            f"vq_loss_sem_high: {codebook_loss_sem_high[0]:.4f}, "
                            f"commit_loss_sem_high: {codebook_loss_sem_high[1]:.4f}, "
                            f"entropy_loss_sem_high: {codebook_loss_sem_high[2]:.4f}, "
                            f"codebook_usage_sem_high: {codebook_loss_sem_high[3]:.4f}, "
                            
                            f"vq_loss_vis: {codebook_loss_vis[0]:.4f}, "
                            f"commit_loss_vis: {codebook_loss_vis[1]:.4f}, "
                            f"entropy_loss_vis: {codebook_loss_vis[2]:.4f}, "
                            f"codebook_usage_vis: {codebook_loss_vis[3]:.4f}, "
                            
                            f"disentangle_loss: {disentangle_loss: .4f}"
                            
                            f"generator_adv_loss: {generator_adv_loss:.4f}, "
                            f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}, "
                            f"semantic_loss_mid: {semantic_loss_mid:.4f}, semantic_loss_high: {semantic_loss_high:.4f}")
                if dist.get_rank() == 0:
                    wandb.log({
                        "rec_loss": rec_loss,
                        "perceptual_loss": p_loss,

                        "disentangle_loss": disentangle_loss,

                        "codebook_loss_sem_mid": codebook_loss_sem_mid[0],
                        "commit_loss_sem_mid": codebook_loss_sem_mid[1],
                        "entropy_loss_sem_mid": codebook_loss_sem_mid[2],
                        "codebook_usage_sem_mid": codebook_loss_sem_mid[3],

                        "codebook_loss_sem_high": codebook_loss_sem_high[0],
                        "commit_loss_sem_high": codebook_loss_sem_high[1],
                        "entropy_loss_sem_high": codebook_loss_sem_high[2],
                        "codebook_usage_sem_high": codebook_loss_sem_high[3],

                        "codebook_loss_vis": codebook_loss_vis[0],
                        "commit_loss_vis": codebook_loss_vis[1],
                        "entropy_loss_vis": codebook_loss_vis[2],
                        "codebook_usage_vis": codebook_loss_vis[3],
                        "generator_adv_loss": generator_adv_loss,
                        "disc_adaptive_weight": disc_adaptive_weight,
                        "disc_weight": disc_weight,
                        "semantic_loss_mid": semantic_loss_mid,
                        "semantic_loss_high": semantic_loss_high,
                    })
            return loss

        # discriminator update
        if optimizer_idx == 1:
            logits_real = self.discriminator(inputs.contiguous().detach())
            logits_fake = self.discriminator(reconstructions.contiguous().detach())

            disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
            d_adversarial_loss = disc_weight * self.disc_loss(logits_real, logits_fake)

            if global_step % log_every == 0:
                logits_real = logits_real.detach().mean()
                logits_fake = logits_fake.detach().mean()
                logger.info(f"(Discriminator) "
                            f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, "
                            f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}")
                if dist.get_rank() == 0:
                    wandb.log({
                        "discriminator_adv_loss": d_adversarial_loss,
                        "disc_weight": disc_weight,
                        "logits_real": logits_real,
                        "logits_fake": logits_fake,
                    })

            return d_adversarial_loss

http://www.kler.cn/a/460441.html

相关文章:

  • HarmonyNext 鸿蒙开发中,在H5 页面如何下载图片保存到媒体库。
  • ACL的注意事项
  • 什么是Redis哨兵机制?
  • tcpdump的常见方法
  • AI知识库与用户行为分析:优化用户体验的深度洞察
  • 概率论与数理统计
  • HarmonyOS-面试整理
  • Day2 微服务 网关路由转发、网关登录校验、配置管理
  • 小程序基础 —— 07 创建小程序项目
  • 基于Flask后端框架的均值填充
  • 计算机毕业设计Python+Spark考研预测系统 考研推荐系统 考研数据分析 考研大数据 大数据毕业设计 大数据毕设
  • Maven的依赖Scope详细解释
  • UE4_用户控件_9_用按钮来控制播放动画
  • 评估可视化大屏效果除了震撼外,还有哪些衡量标准。
  • 20道Redis面试题核心技术知识点
  • 如何利用PEST分析法提升企业在行业竞争中的战略地位?
  • 【AR的手势识别算法有哪些】
  • 用户界面的UML建模07
  • C# 中 `new` 关键字的用法
  • 【超好用远程工具】跨平台SSH工具WindTerm免费开源
  • 25考研总结
  • Apache Commons Pool :介绍与使用
  • 再见24你好25
  • 计算机网络500题2024-2025学年度第一学期复习题库(选择、判断、填空)
  • C高级:思维导图
  • Kotlin Multiplatform 新纪元:klibs.io 与鸿蒙支持解锁跨平台开发新潜力