当前位置: 首页 > article >正文

random_masking 函数测试

文章目录

  • 1. description
  • 2. pytorch code

1. description

  • 功能:按一定比例的随机部分样本,简单来说就是按照一定的比例将行向量从小到大的顺序提取出来。
  • 思考1: 用了均匀分布,并且按照一定比例,取前prob概率来表示
  • 思考2:用了torch.argsort 来生成idx_shuffle 来的到快速从小到大排序的
  • 思考3:用了torch.gather 来配合idx_shuffle 来找到最小的部分数据
  • 思考4:用了torch.gather+idx_restore+ones_like matrix 组合的方式生成mask矩阵
  • 小结:主要配合torch.argsort+torch.gather的方式生成相关的mask矩阵和最小矩阵和恢复矩阵,代码的巧妙运用很具备参考意义。

2. pytorch code

  • pytorch
import torch
import torch.nn as nn

torch.set_printoptions(precision=3, sci_mode=False)
torch.manual_seed(2324)


def random_masking(x, mask_ratio):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence
    """
    N, L, D = x.shape  # batch, length, dim
    len_keep = int(L * (1 - mask_ratio))

    noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

    # sort noise for each sample
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=x.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return x_masked, mask, ids_restore


if __name__ == "__main__":
    run_code = 0
    bs = 2
    seq_len = 8
    seq_dim = 10
    mx_total = bs * seq_dim * seq_len
    a_matrix = torch.arange(mx_total).reshape((bs, seq_len, seq_dim))
    a_x_masked, a_mask, a_ids_restore = random_masking(a_matrix, mask_ratio=0.4)
    print(f"a_matrix=\n{a_matrix}")
    print(f"a_x_masked=\n{a_x_masked}")
    print(f"a_mask=\n{a_mask}")
    print(f"a_ids_restore=\n{a_ids_restore}")
  • result
a_matrix=
tensor([[[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
         [ 10,  11,  12,  13,  14,  15,  16,  17,  18,  19],
         [ 20,  21,  22,  23,  24,  25,  26,  27,  28,  29],
         [ 30,  31,  32,  33,  34,  35,  36,  37,  38,  39],
         [ 40,  41,  42,  43,  44,  45,  46,  47,  48,  49],
         [ 50,  51,  52,  53,  54,  55,  56,  57,  58,  59],
         [ 60,  61,  62,  63,  64,  65,  66,  67,  68,  69],
         [ 70,  71,  72,  73,  74,  75,  76,  77,  78,  79]],

        [[ 80,  81,  82,  83,  84,  85,  86,  87,  88,  89],
         [ 90,  91,  92,  93,  94,  95,  96,  97,  98,  99],
         [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
         [110, 111, 112, 113, 114, 115, 116, 117, 118, 119],
         [120, 121, 122, 123, 124, 125, 126, 127, 128, 129],
         [130, 131, 132, 133, 134, 135, 136, 137, 138, 139],
         [140, 141, 142, 143, 144, 145, 146, 147, 148, 149],
         [150, 151, 152, 153, 154, 155, 156, 157, 158, 159]]])
a_x_masked=
tensor([[[ 10,  11,  12,  13,  14,  15,  16,  17,  18,  19],
         [  0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
         [ 20,  21,  22,  23,  24,  25,  26,  27,  28,  29],
         [ 60,  61,  62,  63,  64,  65,  66,  67,  68,  69]],

        [[ 90,  91,  92,  93,  94,  95,  96,  97,  98,  99],
         [120, 121, 122, 123, 124, 125, 126, 127, 128, 129],
         [140, 141, 142, 143, 144, 145, 146, 147, 148, 149],
         [ 80,  81,  82,  83,  84,  85,  86,  87,  88,  89]]])
a_mask=
tensor([[0., 0., 0., 1., 1., 1., 0., 1.],
        [0., 0., 1., 1., 0., 1., 0., 1.]])
a_ids_restore=
tensor([[1, 0, 2, 4, 7, 6, 3, 5],
        [3, 0, 6, 4, 1, 5, 2, 7]])

http://www.kler.cn/a/592361.html

相关文章:

  • 【达梦数据库】快速加列参数ALTER_TABLE_OPT使用
  • Qt Creator入门
  • 《UNIX网络编程卷1:套接字联网API》第2章 传输层:TCP、UDP和SCTP
  • 使用 PIC 微控制器和 Adafruit IO 的基于 IoT 的 Web 控制家庭自动化
  • IvorySQL 增量备份与合并增量备份功能解析
  • 开源模型应用落地-shieldgemma-2-4b-it模型小试-多模态内容安全检测(一)
  • C++ 各种map对比
  • Fragment与React.StrictMode一起使用时有什么需要注意的?
  • 【蓝桥杯】每天一题,理解逻辑(4/90)【Leetcode 二进制求和】
  • 【HarmonyOS Next之旅】DevEco Studio使用指南(五) -> 添加/删除Module
  • Pandas完全指南:数据处理与分析从入门到实战
  • Netty源码—1.服务端启动流程一
  • 武汉临空港开发区第七批区级非物质文化遗产代表性项目和第四批非遗传承人申报条件流程和材料时间
  • 大模型如何赋能安全防御?威胁检测与漏洞挖掘的“AI革命”
  • 基于java的ssm+JSP+MYSQL的高校四六级报名管理系统(含LW+PPT+源码+系统演示视频+安装说明)
  • 【UE5 PuerTS笔记】PuerTS安装
  • 深度解析 | Android 13 Launcher3分页指示器改造:横线变圆点实战指南
  • 在 Ubuntu 下通过 Docker 部署 Nginx+PHP-FPM 服务器
  • 华为鲲鹏ARM服务器安装Docker
  • MySql面试总结(三)