当前位置: 首页 > article >正文

TransVG 代码配置及一些小细节

TransVG代码配置

  1. File “/home/wyq/TransVG/utils/misc.py”, line 22, in <module>
    from torchvision.ops import _new_empty_tensor
    ImportError: cannot import name ‘_new_empty_tensor’

    if float(torchvision.__version__[:3]) &lt; 0.7: # torchvision.__version__[:3]=0.1
        from torchvision.ops import _new_empty_tensor
        from torchvision.ops.misc import _output_size
    

    解决:注释掉这一部分代码

  2. 计算参数量

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
  3. dataloader

    m = re.match(r"^(.*) \|\|\| (.*)$", line)
    text_a = m.group(1)
    text_b = m.group(2)
    
  4. DIOR_RSVG(2022)构造在TransVG之上(2021)

  5. numpy.array() 函数期望其参数是一个序列(如列表、元组等),或者是一个单独的数组式对象。所以box = np.array(x1, y1,x2, y2, dtype=np.float32)不对

  6. unbind() 是 PyTorch 张量(torch.Tensor)的一个方法,用于移除最后一个维度的尺寸并返回一个新的张量列表,其中包含了沿着该维度解包的所有元素

  7. missing_keys, unexpected_keys = model_without_ddp.visumodel.load_state_dict(checkpoint['model'], strict=False)

    • 将一个检查点(checkpoint)中的模型状态字典(state dict)加载到一个没有分布式数据并行(DDP)封装的模型(model_without_ddp)的 visumodel属性上
    • missing_keys:这些是检查点中存在但模型状态字典中不存在的键。这通常发生在检查点包含了一些模型当前版本不再需要的额外权重(例如,由于模型架构的更改)。
    • unexpected_keys:这些是模型状态字典中存在但检查点中不存在的键。这通常发生在模型包含了一些新的权重,而这些权重在检查点中没有对应的值(例如,由于模型架构的扩展)。
  8. torch.tensor(images) & torch.stack(images)

    • 当 images 是 NumPy 数组(ndarray)的列表时,torch.tensor(images) 会创建一个新的 PyTorch 张量,其形状为 (batch_size, …),其中 batch_size 是列表中数组的数量。
    • 当 images 是 PyTorch 张量(tensor)的列表时,更推荐使用 torch.stack(images) 而不是 torch.tensor(images),以确保张量沿着一个新的维度被正确堆叠。
  9. 重定向:2>&1 | tee ./models/refcoco/output

    • 2>&1:这是一个重定向操作,它将标准错误(stderr,文件描述符为 2)重定向到标准输出(stdout,文件描述符为 1)。这通常用于将错误消息和正常输出一起发送到同一个地方。
    • |:这是管道符号,它将一个命令的输出作为另一个命令的输入。
    • tee:tee 命令从标准输入读取数据,并将其内容写入一个或多个文件,同时还将数据复制到标准输出。
  10. 报错:

    Traceback (most recent call last):
      File "train.py", line 15, in <module>
        import datasets
      File "/home/wyq/TransVG/datasets/__init__.py", line 8, in <module>
        from .DIOR_data_loader import RSVGDataset
      File "/home/wyq/TransVG/datasets/DIOR_data_loader.py", line 16, in <module>
        import cv2
    ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.26' not found
    (required by /root/anaconda3/envs/py36/lib/python3.6/site-packages/cv2.cpython-36m-x86_64-linux-gnu.so)
    

    解决:解决问题链接

    strings /usr/lib/x86_64-linux-gnu/libstdc++.so.6 | grep GLIBCXX # 查看当前有哪些版本
    sudo find / -name "libstdc++.so.6*" #查看其它的版本,选择合适版本进行替换
    
    # strings 命令主要用于提取文件中的可打印字符串,
    # 而 GLIBCXX 通常是与 GCC 的标准 C++ 库(libstdc++)相关的版本标识符
    strings  /root/anaconda3/envs/py36/lib/libstdc++.so.6.0.33 | grep GLIBCXX # 查看新版本是否有需要的文件
    cp  /root/anaconda3/envs/py36/lib/libstdc++.so.6.0.33 /usr/lib/x86_64-linux-gnu/ # 复制新版本
    rm  /usr/lib/x86_64-linux-gnu/libstdc++.so.6  # 删除旧版本
    ln -s /usr/lib/x86_64-linux-gnu/libstdc++.so.6.0.33 /usr/lib/x86_64-linux-gnu/libstdc++.so.6 #建立新连接
    # 命令 ln -s /usr/lib/x86_64-linux-gnu/libstdc++.so.6.0.33 /usr/lib/x86_64-linux-gnu/libstdc++.so.6 是在
    # Linux 系统中创建一个符号链接(也称为软链接)。这个命令的目的是将 libstdc++.so.6.0.33 这个具体的库文件链接
    # 到一个更通用的名称 libstdc++.so.6,这样其他程序或库在请求 libstdc++.so.6 时,
    # 实际上会加载到 libstdc++.so.6.0.33    -s:这个选项告诉 ln 命令创建一个符号链接,而不是硬链接
    
  11. 开始时结果很不好,通过调整学习率(变为之前的1/8)提高了效果。原因:之前采用的是8卡训练,现在是单卡训练,所以考虑学习率应该降低为原来的1/8

  12. GIOU IoU、GIoU、DIoU、CIoU、EIoU 5大评价指标
    在这里插入图片描述

    def generalized_box_iou(boxes1, boxes2):
        """
        Generalized IoU from https://giou.stanford.edu/
    
        The boxes should be in [x0, y0, x1, y1] format
    
        Returns a [N, M] pairwise matrix, where N = len(boxes1)
        and M = len(boxes2)
        """
        # degenerate boxes gives inf / nan results
        # so do an early check
        assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
        assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
        iou, union = box_iou(boxes1, boxes2)
    
        lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
        rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
    
        wh = (rb - lt).clamp(min=0)  # [N,M,2]
        area = wh[:, :, 0] * wh[:, :, 1]
    
        return iou - (area - union) / area
    
  13. 正弦位置编码 Transformer架构:位置编码(sin/cos编码) 正弦、余弦三角函数位置编码讲解、代码实现
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

  14. bert处理文本的结果

    tokens,#  分词结果
    input_ids, # 分词转换成词向量
    input_mask, # 输入向量的掩码,1表示该位置有输入,0表示该位置没有输入
    input_type_ids # 输入向量的类型,0表示text_a,1表示text_b
    
  15. TransVG (1+L+N)xBxC

    class TransVG(nn.Module):
        def __init__(self, args):
            super(TransVG, self).__init__()
            hidden_dim = args.vl_hidden_dim
            divisor = 16 if args.dilation else 32
            self.num_visu_token = int((args.imsize / divisor) ** 2)
            self.num_text_token = args.max_query_len
    
            self.visumodel = build_detr(args)
            self.textmodel = build_bert(args)
    
            num_total = self.num_visu_token + self.num_text_token + 1
            self.vl_pos_embed = nn.Embedding(num_total, hidden_dim)
            self.reg_token = nn.Embedding(1, hidden_dim)
    
            self.visu_proj = nn.Linear(self.visumodel.num_channels, hidden_dim)
            self.text_proj = nn.Linear(self.textmodel.num_channels, hidden_dim)
    
            self.vl_transformer = build_vl_transformer(args)
            self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
    
    
        def forward(self, img_data, text_data):
            bs = img_data.tensors.shape[0]
    
            # visual backbone
            visu_mask, visu_src = self.visumodel(img_data) 
            visu_src = self.visu_proj(visu_src) # (N*B)xC
    
            # language bert
            text_fea = self.textmodel(text_data)
            text_src, text_mask = text_fea.decompose()
            assert text_mask is not None
            text_src = self.text_proj(text_src)
            # permute BxLenxC to LenxBxC
            text_src = text_src.permute(1, 0, 2)
            text_mask = text_mask.flatten(1)
    
            # target regression token
            tgt_src = self.reg_token.weight.unsqueeze(1).repeat(1, bs, 1)
            tgt_mask = torch.zeros((bs, 1)).to(tgt_src.device).to(torch.bool)
            
            vl_src = torch.cat([tgt_src, text_src, visu_src], dim=0)
            vl_mask = torch.cat([tgt_mask, text_mask, visu_mask], dim=1)
            vl_pos = self.vl_pos_embed.weight.unsqueeze(1).repeat(1, bs, 1)
    
            vg_hs = self.vl_transformer(vl_src, vl_mask, vl_pos) # (1+L+N)xBxC
            vg_hs = vg_hs[0]
    
            pred_box = self.bbox_embed(vg_hs).sigmoid()
    
            return pred_box
    
  16. nn.Embedding参数:num_embeddings:表示词汇表或者类别集合的大小,也就是总共有多少个不同的类别需要进行嵌入操作 (例如上面词汇表中有 5 个单词,那么 num_embeddings 就设为 5)。embedding_dim:指定每个嵌入向量的维度。比如想把每个单词映射成维度为 3 的向量,那么 embedding_dim 就设为 3。

    import torch
    import torch.nn as nn
    
    embedding_layer = nn.Embedding(num_embeddings=5, embedding_dim=3)
    
  17. MLP self.layers= nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    class MLP(nn.Module):
        """ Very simple multi-layer perceptron (also called FFN)"""
    
        def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
            super().__init__()
            self.num_layers = num_layers
            h = [hidden_dim] * (num_layers - 1)
            self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
    
        def forward(self, x):
            for i, layer in enumerate(self.layers):
                x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
            return x
    
  18. DETR

    class DETR(nn.Module):
        """ This is the DETR module that performs object detection """
        def __init__(self, backbone, transformer, num_queries, train_backbone, train_transformer, aux_loss=False):
            """ Initializes the model.
            Parameters:
                backbone: torch module of the backbone to be used. See backbone.py
                transformer: torch module of the transformer architecture. See transformer.py
                num_classes: number of object classes
                num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                             DETR can detect in a single image. For COCO, we recommend 100 queries.
                aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
            """
            super().__init__()
            self.num_queries = num_queries
            self.transformer = transformer
            self.backbone = backbone
    
            if self.transformer is not None:
                hidden_dim = transformer.d_model
                self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
            else:
                hidden_dim = backbone.num_channels
    
            if not train_backbone:
                for p in self.backbone.parameters():
                    p.requires_grad_(False)
            
            if self.transformer is not None and not train_transformer:
                for m in [self.transformer, self.input_proj]:
                    for p in m.parameters():
                        p.requires_grad_(False)
    
            self.num_channels = hidden_dim
    
        def forward(self, samples: NestedTensor):
            """The forward expects a NestedTensor, which consists of:
                   - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
                   - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
    
                It returns a dict with the following elements:
                   - "pred_logits": the classification logits (including no-object) for all queries.
                                    Shape= [batch_size x num_queries x (num_classes + 1)]
                   - "pred_boxes": The normalized boxes coordinates for all queries, represented as
                                   (center_x, center_y, height, width). These values are normalized in [0, 1],
                                   relative to the size of each individual image (disregarding possible padding).
                                   See PostProcess for information on how to retrieve the unnormalized bounding box.
                   - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
                                    dictionnaries containing the two above keys for each decoder layer.
            """
            if isinstance(samples, (list, torch.Tensor)):
                samples = nested_tensor_from_tensor_list(samples)
            features, pos = self.backbone(samples)
    
            src, mask = features[-1].decompose()
            assert mask is not None
    
            if self.transformer is not None:
                out = self.transformer(self.input_proj(src), mask, pos[-1], query_embed=None)
            else:
                out = [mask.flatten(1), src.flatten(2).permute(2, 0, 1)]
                 
            return out
    
  19. IntermediateLayerGetter是一个实用工具,它允许你从预训练的模型中提取中间层的特征,这在特征提取或迁移学习中非常有用。

    self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
    
  20. 导入BertModel的两种方式

    from pytorch_pretrained_bert.modeling import BertModel
    from transformers import BertModel
    

http://www.kler.cn/a/415502.html

相关文章:

  • 深度学习3:数据预处理使用Pandas与PyTorch的实践
  • C++类的自动转换和强制类型转换
  • C# 中的接口:定义行为契约与实现多态性
  • uniapp定义new plus.nativeObj.View实现APP端全局弹窗
  • 【Git】Git 完全指南:从入门到精通
  • 利用Python爬虫获取1688商品类目:技术解析
  • 《 C++ 点滴漫谈: 二 》编程语言之争:从 C 到 C++,两代语言的技术传承与演化,谁更适合你的项目?
  • 青训营-豆包MarsCode技术训练营试题解析九
  • 软件设计模式复习
  • 火语言RPA流程组件介绍--键盘按键
  • Scala学习记录,统计成绩
  • ADAM优化算法与学习率调度器:深度学习中的关键工具
  • 深入学习MapReduce:原理解析与基础实战
  • 认识redis 及 Ubuntu安装redis
  • Figma入门-约束与对齐
  • 【前端开发】小程序无感登录验证
  • windows下使用WSL
  • AI智算-正式上架GPU资源监控概览 Grafana Dashboard
  • 小程序-基于java+SpringBoot+Vue的戏曲文化苑小程序设计与实现
  • tomcat 8.5.35安装及配置
  • 【Leetcode Top 100】206. 反转链表
  • 消息传递神经网络(Message Passing Neural Networks, MPNN)
  • Unity类银河战士恶魔城学习总结(P150 End Screen结束重启按钮)
  • 学习threejs,使用specularMap设置高光贴图
  • 实习冲刺第三十四天
  • 基于单片机的仓库环境无线监测系统(论文+源码)