当前位置: 首页 > article >正文

源码分析之blip2的ITC和ITM的具体实现

引言:很久之前读blip2,对ITC和ITM大致有个印象,一个对比学习,一个图文匹配的二分类,咋一听好像没什么难理解的,最近好好看了一下源码,觉得实现上很巧妙,值得与诸君共享

这里小编没有一句一句分析,直接源码+注释,觉得这样看比较方便,因为只分析ITC和ITM,所以这里只放了blip2里面的Blip2Qformer的forward函数内容,如有出入,还请各位小伙伴留言斧正!

Image-text Contrastive

###============== Image-text Contrastive ===================###
    
    """
    因为在多张卡上训练,所以这里需要将所有卡上的图像特征收集起来,维度为[batch_size*num_gpu, num_query_tokens, embed_dim],
    其中,num_query_tokens是视觉tokens数量,embed_dim是维度
    """
    image_feats_all = concat_all_gather(
        image_feats
    )  # [batch_size*num_gpu, num_query_tokens, embed_dim]

    # 文本这一步操作与上述同理
    text_feat_all = concat_all_gather(text_feat)  # [batch_size*num_gpu, embed_dim]

    """
    求图像与所有文本的相似度
    这里image_feats.unsqueeze(1)之后的维度是[batch_size,1, num_query_tokens, embed_dim]
    text_feat_all.unsqueeze(-1)之后的维度是[batch_size*num_gpu, embed_dim,1]
    为了求每个图像跟所有文本的相似度,图像特征[batch_size,1, num_query_tokens, embed_dim]第2个维度会被广播到batch_size*num_gpu变成[batch_size*,batch_size*num_gpu, num_query_tokens, embed_dim]
    然后矩阵乘法会沿着image_feats和text_feat_all最后两个维度进行相乘,embed_dim维度相乘消失,所以得到的结果为[batch_size,batch_size*num_gpu, num_query_tokens,1]
    相乘之后的结果再squeeze()就得到了[batch_size,batch_size*num_gpu, num_query_tokens]
    """
    sim_q2t = torch.matmul(
        image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)
    ).squeeze()  # [batch_size, batch_size*num_gpu, num_query_tokens]

    """
    max(-1)表示在最后一个维度上,寻找最大值
    也就是说,对每个图像到文本的相似度,选取所有num_query_tokens中的最大值,sim_i2t最终的维度为[batch_size, batch_size*num_gpu]
    """
    sim_i2t, _ = sim_q2t.max(-1)

    # 通过温度参数self.temp进行相似度的缩放控制
    sim_i2t = sim_i2t / self.temp

    """
    求文本与所有图像的相似度
    text_feat.unsqueeze(1).unsqueeze(1)之后的维度为[batch_size,1,1,embed_dim]
    image_feats_all.permute(0, 2, 1)交换后面两个维度之后的特征维度为[batch_size*num_gpu, embed_dim, num_query_token]
    同理,文本特征[batch_size,1,1,embed_dim]会广播第2个维度到batch_size*num_gpu,变成[batch_size,batch_size*num_gpu,1,embed_dim]
    然后最后两个维度做矩阵乘法得到[batch_size,batch_size*num_gpu,1,num_query_token]
    squeeze()之后的特征为[batch_size,batch_size*num_gpu,num_query_token]
    """
    sim_t2q = torch.matmul(
        text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)
    ).squeeze()

    # 对每个文本到图像的相似度,选取所有num_query_tokens中的最大值,sim_i2t最终的维度为[batch_size, batch_size*num_gpu]
    sim_t2i, _ = sim_t2q.max(-1)
    sim_t2i = sim_t2i / self.temp

    rank = dist.get_rank()
    bs = image.size(0)

    """
    torch.linspace(start, end, steps, dtype=int)的作用是生成从 start 到 end 之间的 steps 个数值,并返回一个 1D 张量
    这里用来生成多 GPU 训练中的标签(targets)索引,targets维度维[batch_size]
    每个 GPU 进程(或 rank)负责处理自己的 batch,并为它分配唯一的索引序列
    """
    targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(
        image.device
    )

    if "image_id" in samples.keys():  # coco retrieval finetuning
        # 对于包含图像 ID 的样本,使用基于相似度的目标分布计算损失
        image_ids = samples["image_id"].view(-1, 1)
        image_ids_all = concat_all_gather(image_ids)
        pos_idx = torch.eq(image_ids, image_ids_all.t()).float()
        sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)
        sim_targets = 0.9 * sim_targets + 0.1 * torch.ones_like(sim_targets) / sim_targets.size(1)

        loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean()
        loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean()
        loss_itc = (loss_t2i + loss_i2t) / 2
    else:
        """
        否则,使用交叉熵计算损失
        sim_i2t维度为[batch_size, batch_size*num_gpu],targets维度维[batch_size]
        对于sim_i2t每个batch,targets都有唯一一个在 0 到 batch_size * num_gpu - 1 之间真实值,因此可以计算交叉熵,从而达到让正例更接近,负例更远的效果
        """
        loss_itc = (
                           F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
                           + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
                   ) / 2

Image-text Matching

###============== Image-text Matching ===================###
    # 同上述
    text_input_ids_world = concat_all_gather(text_tokens.input_ids)
    text_attention_mask_world = concat_all_gather(text_tokens.attention_mask)
    image_embeds_world = all_gather_with_grad(image_embeds)

    with torch.no_grad():
        # 当有image_id时,作者把相似度矩阵里面image_ids相匹配的都mask掉了,即在后面计算的时候忽略样本自身的匹配
        if "image_id" in samples.keys():
            mask = torch.eq(image_ids, image_ids_all.t())
            sim_t2i.masked_fill_(mask, -10000)
            sim_i2t.masked_fill_(mask, -10000)
        else:
            # 与上面同理,将当前 GPU 进程处理的样本的索引范围填充为 -10000,即在后面计算的时候忽略样本自身的匹配
            sim_t2i[:, rank * bs: rank * bs + bs].fill_diagonal_(-10000)
            sim_i2t[:, rank * bs: rank * bs + bs].fill_diagonal_(-10000)

        # 被masked的值和被fill_diagonal_(-10000),经过softmax之后都会接近于0
        weights_t2i = F.softmax(sim_t2i, dim=1)
        weights_i2t = F.softmax(sim_i2t, dim=1)

    # 为每个文本选择一个负样本图像
    image_embeds_neg = []
    for b in range(bs):
        """
        对每个batch的数据随机选择一个负样本
        torch.multinomial从给定的概率分布中进行多项式分布抽样
        weights_t2i[b]中值大的数,被采样的概率就大,上述对sim_t2i自身样本进行mask就是为了这里自身样本作为正样本不会被选择
        """
        neg_idx = torch.multinomial(weights_t2i[b], 1).item()
        image_embeds_neg.append(image_embeds_world[neg_idx])
    image_embeds_neg = torch.stack(image_embeds_neg, dim=0)

    # 为每个图像选择一个负样本文本
    text_ids_neg = []
    text_atts_neg = []
    for b in range(bs):
        neg_idx = torch.multinomial(weights_i2t[b], 1).item()
        text_ids_neg.append(text_input_ids_world[neg_idx])
        text_atts_neg.append(text_attention_mask_world[neg_idx])

    text_ids_neg = torch.stack(text_ids_neg, dim=0)
    text_atts_neg = torch.stack(text_atts_neg, dim=0)

    """
    这一步很妙!
    将文本的两个正样本一个负样本进行拼接,为后续二分类做准备
    至于为什么这么拼接,后面你就知道了
    """
    text_ids_all = torch.cat(
        [text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
    )  # pos, pos, neg
    text_atts_all = torch.cat(
        [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
        dim=0,
    )

    # 这一步是对query_tokens进行一些处理
    query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
    query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(
        image.device
    )
    attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)

    """
    将图像的两个正样本一个负样本进行拼接,为后续二分类做准备
    注意:文本拼接的顺序是:正样本,正样本,负样本
    图像拼接的顺序是:正样本,负样本,正样本
    它们只有第一个位置都是正样本,也即第一个位置是一对匹配的正例,后面两个位置都是一正一负是不匹配的,这样我们就可以通过判断它们匹不匹配来进行二分类学习,妙哉!
    """
    image_embeds_all = torch.cat(
        [image_embeds, image_embeds_neg, image_embeds], dim=0
    )  # pos, neg, pos
    image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(
        image.device
    )

    # 将拼接后的文本特征,图像特征以及相应的query_tokens输入到bert中进行分类预测
    output_itm = self.Qformer.bert(
        text_ids_all,
        query_embeds=query_tokens_itm,
        attention_mask=attention_mask_all,
        encoder_hidden_states=image_embeds_all,
        encoder_attention_mask=image_atts_all,
        return_dict=True,
    )

    # 取分类预测的结果
    vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
    vl_output = self.itm_head(vl_embeddings)
    logits = vl_output.mean(dim=1)

    # 生成对应的真实标签,只有第一个batch文本对是匹配的,所以第一个batch的标签设置为1,其他都是0
    itm_labels = torch.cat(
        [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
        dim=0,
    ).to(image.device)

    # 将预测结果和真实值输入到交叉熵损失函数中,进行二分类损失计算
    loss_itm = F.cross_entropy(logits, itm_labels)


http://www.kler.cn/news/336792.html

相关文章:

  • 需求管理工具Jama Connect:与Jira/Slack/GitHub无缝集成,一站式解决复杂产品开发中的协作难题
  • 单调队列应用介绍
  • 2024四非保研回忆录(天大、华师、东南、华科)
  • 10.7每日作业
  • 数据工程师岗位常见面试问题-2(附回答)
  • 力扣 简单 100.相同的树
  • Linux数据备份
  • GSLAM——一个通用的SLAM架构和基准
  • 【强训笔记】day27
  • 【Qt笔记】QFrame控件详解
  • AtCoder ABC373 A-D题解
  • YOLO11改进|上采样篇|引入CARAFE上采样模块
  • Leecode热题100-560.和为k的子数组
  • Golang | Leetcode Golang题解之第449题序列化和反序列化二叉搜索树
  • 如何查看NVIDIA Container Toolkit是否配置成功
  • 《数据结构》学习系列——树
  • ssh连接阿里云长连接
  • Django学习笔记十四:系统框架总结
  • 日常记账:解锁生活财务管理的秘密钥匙
  • 大模型应用新领域:探寻终端侧 AI 竞争核心|智于终端