文章目录
- 1. description
- 2. pytorch code
1. description
- 功能:按一定比例的随机部分样本,简单来说就是按照一定的比例将行向量从小到大的顺序提取出来。
- 思考1: 用了均匀分布,并且按照一定比例,取前prob概率来表示
- 思考2:用了torch.argsort 来生成idx_shuffle 来的到快速从小到大排序的
- 思考3:用了torch.gather 来配合idx_shuffle 来找到最小的部分数据
- 思考4:用了torch.gather+idx_restore+ones_like matrix 组合的方式生成mask矩阵
- 小结:主要配合torch.argsort+torch.gather的方式生成相关的mask矩阵和最小矩阵和恢复矩阵,代码的巧妙运用很具备参考意义。
2. pytorch code
import torch
import torch.nn as nn
torch.set_printoptions(precision=3, sci_mode=False)
torch.manual_seed(2324)
def random_masking(x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
if __name__ == "__main__":
run_code = 0
bs = 2
seq_len = 8
seq_dim = 10
mx_total = bs * seq_dim * seq_len
a_matrix = torch.arange(mx_total).reshape((bs, seq_len, seq_dim))
a_x_masked, a_mask, a_ids_restore = random_masking(a_matrix, mask_ratio=0.4)
print(f"a_matrix=\n{a_matrix}")
print(f"a_x_masked=\n{a_x_masked}")
print(f"a_mask=\n{a_mask}")
print(f"a_ids_restore=\n{a_ids_restore}")
a_matrix=
tensor([[[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[ 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
[ 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
[ 30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
[ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
[ 50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
[ 60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
[ 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]],
[[ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89],
[ 90, 91, 92, 93, 94, 95, 96, 97, 98, 99],
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
[110, 111, 112, 113, 114, 115, 116, 117, 118, 119],
[120, 121, 122, 123, 124, 125, 126, 127, 128, 129],
[130, 131, 132, 133, 134, 135, 136, 137, 138, 139],
[140, 141, 142, 143, 144, 145, 146, 147, 148, 149],
[150, 151, 152, 153, 154, 155, 156, 157, 158, 159]]])
a_x_masked=
tensor([[[ 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[ 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
[ 60, 61, 62, 63, 64, 65, 66, 67, 68, 69]],
[[ 90, 91, 92, 93, 94, 95, 96, 97, 98, 99],
[120, 121, 122, 123, 124, 125, 126, 127, 128, 129],
[140, 141, 142, 143, 144, 145, 146, 147, 148, 149],
[ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89]]])
a_mask=
tensor([[0., 0., 0., 1., 1., 1., 0., 1.],
[0., 0., 1., 1., 0., 1., 0., 1.]])
a_ids_restore=
tensor([[1, 0, 2, 4, 7, 6, 3, 5],
[3, 0, 6, 4, 1, 5, 2, 7]])