当前位置: 首页 > article >正文

如何快速看懂并修改神经网络


前言:个人之见,一个神经网络网络源码出现,你先看数据集的输入和输出,而这数据集肯定要包括数据增加和制作数据集,第二 看模型的输入和输出(至于模型内部可以自己看论文 无非就是加了几个组件),然后根据输出选择的损失函数。至于学习率和优化器 差不多都是余弦退火和admw的优化器


1.数据集

直接实战,首先你看它的readme,它一般由标注文件的格式(一般都是 文件路径 + 对应的标签数字)(要求自己制作)
输入一般都是这个标注文件,输出一般都是元组或者字典。
数据增强一般包含在数据集的制作当中

actionclip

数据增强(空间剪裁)

数据增强源码

from datasets.transforms_ss import *
from RandAugment import RandAugment

class GroupTransform(object):
    def __init__(self, transform):
        self.worker = transform

    def __call__(self, img_group):
        return [self.worker(img) for img in img_group]

def get_augmentation(training, config):
    input_mean = [0.48145466, 0.4578275, 0.40821073]
    input_std = [0.26862954, 0.26130258, 0.27577711]
    scale_size = config.data.input_size * 256 // 224
    if training:

        unique = torchvision.transforms.Compose([GroupMultiScaleCrop(config.data.input_size, [1, .875, .75, .66]),
                                                 GroupRandomHorizontalFlip(is_sth='some' in config.data.dataset),
                                                 GroupRandomColorJitter(p=0.8, brightness=0.4, contrast=0.4,
                                                                        saturation=0.2, hue=0.1),
                                                 GroupRandomGrayscale(p=0.2),
                                                 GroupGaussianBlur(p=0.0),
                                                 GroupSolarization(p=0.0)]
                                                )
    else:
        unique = torchvision.transforms.Compose([GroupScale(scale_size),
                                                 GroupCenterCrop(config.data.input_size)])

    common = torchvision.transforms.Compose([Stack(roll=False),
                                             ToTorchFormatTensor(div=True),
                                             GroupNormalize(input_mean,
                                                            input_std)])
    return torchvision.transforms.Compose([unique, common])

def randAugment(transform_train,config):
    print('Using RandAugment!')
    transform_train.transforms.insert(0, GroupTransform(RandAugment(config.data.randaug.N, config.data.randaug.M)))
    return transform_train

这个数据增强 你可以直接 参考()
一般直接蕴含在数据集

 def __init__(self, list_file, labels_file,
                 num_segments=1, new_length=1,
                 image_tmpl='img_{:05d}.jpg', transform=None,
                 random_shift=True, test_mode=False, index_bias=1):
 
    def get(self, record, indices):
        images = list()
        for i, seg_ind in enumerate(indices):
            p = int(seg_ind)
            try:
                seg_imgs = self._load_image(record.path, p)
            except OSError:
                print('ERROR: Could not read image "{}"'.format(record.path))
                print('invalid indices: {}'.format(indices))
                raise
            images.extend(seg_imgs)
        process_data = self.transform(images)
        return process_data, record.label
  • 空间剪裁 无疑就是进行多少词crop 你得了解一手 ranaugment函数
数据集的制作(时间剪裁以及帧数实现)
  • 输入
    actionclip的标注文件为:
/public/datasets/kinetics400/data2/extracted_train_frames/bowling/HfI4vN2vbHU_000000_000010 289 31
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/B8FXlmO5zk4_000079_000089 240 29
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/XsEw1vd32l8_000052_000062 300 0
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/r61D2lDCHsM_000268_000278 240 18
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/4sCQ-EX6cIg_000021_000031 300 0
/public/datasets/kinetics400/data2/extracted_train_frames/bowling/N9mQC7MeZCk_000008_000018 300 31
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/fzVhIrMnY-E_000322_000332 250 1
/public/datasets/kinetics400/data2/extracted_train_frames/blasting_sand/6dLNI2BPTY0_000057_000067 250 23
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/othYtMhFdOU_000020_000030 250 29
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/JVSxlojnBYk_000047_000057 300 18
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/8jO9DeYLruU_000003_000013 300 1
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/pU12_c-XvU_000045_000055 300 18
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/x6rP9b1V7sQ_000060_000070 250 18
/public/datasets/kinetics400/data2/extracted_train_frames/blasting_sand/jqC2SnFAvoM_000092_000102 300 23
/public/datasets/kinetics400/data2/extracted_train_frames/bowling/ri6AwOp59yA_000009_000019 250 31
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/wRaacvxMoc8_000014_000024 150 1
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/7kbO0v4hag_000107_000117 300 0
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/GjtR9KZbV3Y_000494_000504 300 29
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/hwUQqFadvE_000048_000058 250 0
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/vXmgE41UnBk_000844_000854 300 29
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/dglCzcubsw_000246_000256 159 1
/public/datasets/kinetics400/data2/extracted_train_frames/bowling/ri1H0ygN3Us_000768_000778 300 31
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/n24zV9OtorU_000257_000267 300 18
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/nKoqxSJcZn8_000071_000081 250 0
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/pT2byS0qiZM_000001_000011 150 1
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/CMo6AJhtZo_000075_000085 250 29

视频提起帧 视频总帧数 对应的标签数字

  • 输出
    一般看__getitem_
    def __getitem__(self, index):
        record = self.video_list[index]
        segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
        return self.get(record, segment_indices)

    def __call__(self, img_group):
        return [self.worker(img) for img in img_group]


    def get(self, record, indices):
        images = list()
        for i, seg_ind in enumerate(indices):
            p = int(seg_ind)
            try:
                seg_imgs = self._load_image(record.path, p)
            except OSError:
                print('ERROR: Could not read image "{}"'.format(record.path))
                print('invalid indices: {}'.format(indices))
                raise
            images.extend(seg_imgs)
        process_data = self.transform(images)
        return process_data, record.label

返回元组 (images,labes)

  • 帧数 一般num_segment由这个决定 为什么?
    因为我看顶刊 基本上 一个片段抽一政数,这个无疑由片段决定
  • 时间剪裁
    时间剪裁指的是从视频的时间维度上选取特定的帧(验证数据集)
  def _get_val_indices(self, record):
        if self.num_segments == 1:
            return np.array([record.num_frames //2], dtype=np.int) + self.index_bias
        
        if record.num_frames <= self.total_length:
            if self.loop:
                return np.mod(np.arange(self.total_length), record.num_frames) + self.index_bias
            return np.array([i * record.num_frames // self.total_length
                             for i in range(self.total_length)], dtype=np.int) + self.index_bias
        offset = (record.num_frames / self.num_segments - self.seg_length) / 2.0
        return np.array([i * record.num_frames / self.num_segments + offset + j
                         for i in range(self.num_segments)
                         for j in range(self.seg_length)], dtype=np.int) + self.index_bias

帧数不足时
当 self.loop 为 True 时,通过 np.mod(np.arange(self.total_length), record.num_frames) 循环选取视频帧,确保选取的帧数达到 self.total_length,这是一种时间剪裁方式,通过循环利用现有帧来满足所需的帧数。
当 self.loop 为 False 时,使用 i * record.num_frames // self.total_length 均匀地从视频中选取 self.total_length 帧,同样实现了时间维度上的剪裁。

在视频帧数充足的情况下,先根据 self.num_segments 划分片段,然后在每个片段内选取连续的 self.seg_length 帧。offset 确保每个片段内选取的帧在片段中处于相对居中的位置,通过这种方式实现了在每个片段内的时间剪裁。

x-clip

数据集

1.参考一下这一篇 关于数据集的输入输出
2 讲一下时间剪裁

val_pipeline = [
        dict(type='DecordInit'),
        dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=config.DATA.NUM_FRAMES, test_mode=True),
        dict(type='DecordDecode'),
        dict(type='Resize', scale=(-1, scale_resize)),
        dict(type='CenterCrop', crop_size=config.DATA.INPUT_SIZE),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='FormatShape', input_format='NCHW'),
        dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
        dict(type='ToTensor', keys=['imgs'])
    ]
    if config.TEST.NUM_CROP == 3:
        val_pipeline[3] = dict(type='Resize', scale=(-1, config.DATA.INPUT_SIZE))
        val_pipeline[4] = dict(type='ThreeCrop', crop_size=config.DATA.INPUT_SIZE)
    if config.TEST.NUM_CLIP > 1:
        val_pipeline[1] = dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=config.DATA.NUM_FRAMES, multiview=config.TEST.NUM_CLIP)
    

multiview=config.TEST.NUM_CLIP)无疑是控制为时间剪裁的数量
3 空间剪裁
val_pipeline[4] = dict(type='ThreeCrop', crop_size=config.DATA.INPUT_SIZE)
这个更加直观了直接剪了三次 所以为3

2 模型

action-clip

从输入而言:

  • 文本
    classes, num_text_aug, text_dict = text_prompt(train_data)
    class为( num_text_augxnum_class,context)
    text_dict为(num_class,context)
    num_text_aug为填充内容长度
text_id = numpy.random.randint(num_text_aug,size=len(list_id))
            texts = torch.stack([text_dict[j][i,:] for i,j in zip(list_id,text_id)])

分为了(B,context)

  • 图片
images = images.view((-1,config.data.num_segments,3)+images.size()[-2:])
           b,t,c,h,w = images.size()
  images= images.to(device).view(-1,c,h,w ) 

这个论文严格意义上 是借用 clip的编码器 所以它压缩了

输出也简单

  • 文件
    text_embedding = model_text(texts)(b,d)
  • 图片
 image_embedding = model_image(images)
            image_embedding = image_embedding.view(b,t,-1)
            image_embedding = fusion_model(image_embedding)
            

关于这个fusion输出x.mean(dim=1, keepdim=False)
会把t压缩 x 变成了 (b,d)

x-clip

  • 文本
    text_labels = generate_text(train_data) 这个为(num_class(k),77)
    (和上面同理),但是它没有转为样本数
  • 图片
    images = images.view((-1, config.DATA.NUM_FRAMES, 3) + images.size()[-2:])
    它内部实现了一个编码器
    def encode_video(self, image):
        b,t,c,h,w = image.size()
        image = image.reshape(-1,c,h,w)

        cls_features, img_features = self.encode_image(image)
        img_features = self.prompts_visual_ln(img_features)
        img_features = img_features @ self.prompts_visual_proj
        
        cls_features = cls_features.view(b, t, -1)
        img_features = img_features.view(b,t,-1,cls_features.shape[-1])
        
        video_features = self.mit(cls_features)

        return video_features, img_features

image = image.reshape(-1,c,h,w) 内部化了

输出:

logit_scale = self.logit_scale.exp()
 logits = torch.einsum("bd,bkd->bk", video_features, logit_scale * text_features)
        
  return logits

返回了一个b k 相似度得分


http://www.kler.cn/a/613314.html

相关文章:

  • MySQL数据库和表的操作
  • 【开源宝藏】用 JavaScript 手写一个丝滑的打字机动画效果
  • Netty——零拷贝
  • Java 大视界 -- 基于 Java 的大数据隐私计算在医疗影像数据共享中的实践探索(158)
  • 批量将多个 XPS 文档转换为 PDF 格式
  • 洛谷题单1-B2005 字符三角形-python-流程图重构
  • 安全性测试(Security Testing)
  • Manus AI 与多语言手写识别技术解析
  • 科技与人文的交融——当代科技对文化、艺术与社会伦理的深度影响
  • 提示词工程 — 科研论文笔记
  • 计算机视觉算法实战——半监督学习:技术与应用全景
  • 菜鸡前端计算机强基计划之CS50 第七课 python 入门—— Python 中文件操作专题学习
  • 配置基于接口的二层协议透明传输
  • Skynet 中 snlua 服务 init 细节
  • NX二次开发刻字功能——布尔运算
  • Matlab进阶绘图第73期-双组堆叠图
  • Python入门学习笔记 - 从环境搭建到基础语法
  • uni-app:自定义键盘
  • Leetcode 二叉树剪枝
  • 开源测试用例管理平台