当前位置: 首页 > article >正文

相对位置2d矩阵和kron运算的思考

文章目录

  • 1. 相对位置矩阵2d
  • 2. kron运算

1. 相对位置矩阵2d

在swin-transformer中,我们会计算每个patch之间的相对位置,那么我们看到有一连串的拉伸和相减,直接贴代码:

import torch
import torch.nn as nn

torch.set_printoptions(precision=3, sci_mode=False,threshold=torch.inf)

if __name__ == "__main__":
    run_code = 2
    x_len = 5
    y_len = 5
    x_tensor = torch.arange(x_len)
    y_tensor = torch.arange(y_len)
    x_meshgrid, y_meshgrid = torch.meshgrid(x_tensor, y_tensor)
    print(f"x_tensor=\n{x_tensor}")
    print(f"y_tensor=\n{y_tensor}")
    print(f"x_meshgrid=\n{x_meshgrid}")
    print(f"x_meshgrid.shape=\n{x_meshgrid.shape}")
    print(f"y_meshgrid.shape=\n{y_meshgrid.shape}")
    print(f"y_meshgrid=\n{y_meshgrid}")
    stack_meshgrid = torch.stack(torch.meshgrid(x_tensor, y_tensor))
    print(f"stack_meshgrid.shape=\n{stack_meshgrid.shape}")
    print(f"stack_meshgrid=\n{stack_meshgrid}")
    stack_meshgrid_flatten = torch.flatten(stack_meshgrid, 1)
    print(f"stack_meshgrid_flatten.shape=\n{stack_meshgrid_flatten.shape}")
    print(f"stack_meshgrid_flatten=\n{stack_meshgrid_flatten}")
    stack_meshgrid_flatten_1 = stack_meshgrid_flatten[:, None, :]
    stack_meshgrid_flatten_2 = stack_meshgrid_flatten[:, :, None]
    relative_coords_bias = stack_meshgrid_flatten_2 - stack_meshgrid_flatten_1
    print(f"stack_meshgrid_flatten_1=\n{stack_meshgrid_flatten_1}")
    print(f"stack_meshgrid_flatten_2=\n{stack_meshgrid_flatten_2}")
    print(f"relative_coords_bias=\n{relative_coords_bias}")
    relative_coords_bias[0, :, :] += x_len
    relative_coords_bias[1, :, :] += y_len
    print(f"relative_coords_bias=\n{relative_coords_bias}")
  • result:
x_tensor=
tensor([0, 1, 2, 3, 4])
y_tensor=
tensor([0, 1, 2, 3, 4])
x_meshgrid=
tensor([[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4]])
x_meshgrid.shape=
torch.Size([5, 5])
y_meshgrid.shape=
torch.Size([5, 5])
y_meshgrid=
tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])
stack_meshgrid.shape=
torch.Size([2, 5, 5])
stack_meshgrid=
tensor([[[0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3],
         [4, 4, 4, 4, 4]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]]])
stack_meshgrid_flatten.shape=
torch.Size([2, 25])
stack_meshgrid_flatten=
tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,
         4],
        [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3,
         4]])
stack_meshgrid_flatten_1=
tensor([[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4,
          4, 4]],

        [[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2,
          3, 4]]])
stack_meshgrid_flatten_2=
tensor([[[0],
         [0],
         [0],
         [0],
         [0],
         [1],
         [1],
         [1],
         [1],
         [1],
         [2],
         [2],
         [2],
         [2],
         [2],
         [3],
         [3],
         [3],
         [3],
         [3],
         [4],
         [4],
         [4],
         [4],
         [4]],

        [[0],
         [1],
         [2],
         [3],
         [4],
         [0],
         [1],
         [2],
         [3],
         [4],
         [0],
         [1],
         [2],
         [3],
         [4],
         [0],
         [1],
         [2],
         [3],
         [4],
         [0],
         [1],
         [2],
         [3],
         [4]]])
relative_coords_bias=
tensor([[[ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0]],

        [[ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0],
         [ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0],
         [ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0],
         [ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0],
         [ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0]]])
relative_coords_bias=
tensor([[[5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5]],

        [[5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5],
         [5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5],
         [5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5],
         [5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5],
         [5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5]]])

2. kron运算

在结果中,我们发现很多重复的值,这就让我联想到kron运算。

  • step1:形成子矩阵
    在这里插入图片描述
  • step2: kron
    在这里插入图片描述
  • pytorch
import torch
import torch.nn as nn

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == '__main__':
    run_code = 0
    height = 5
    width = 5
    a_vector = torch.arange(width).to(torch.float).reshape(-1, 1)
    a_ones = torch.ones(1, width)
    a_matrix = a_vector @ a_ones
    print(f"a_matrix=\n{a_matrix}")
    b_matrix = a_matrix - a_matrix.T
    print(f"b_matrix=\n{b_matrix}")
    b_matrix_ones = torch.ones_like(b_matrix)
    ab_kron = torch.kron(b_matrix,b_matrix_ones)
    print(f"ab_kron=\n{ab_kron}")
    final_ab = ab_kron+5
    print(f"final_ab=\n{final_ab}")
  • result:
a_matrix=
tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.],
        [4., 4., 4., 4., 4.]])
b_matrix=
tensor([[ 0., -1., -2., -3., -4.],
        [ 1.,  0., -1., -2., -3.],
        [ 2.,  1.,  0., -1., -2.],
        [ 3.,  2.,  1.,  0., -1.],
        [ 4.,  3.,  2.,  1.,  0.]])
ab_kron=
tensor([[ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.]])
final_ab=
tensor([[5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.]])

http://www.kler.cn/a/613461.html

相关文章:

  • MFC中的窗口线程安全性与CWnd类
  • 从 YOLO11 模型格式导出到TF.js 模型格式 ,环境爬坑,依赖关系已经贴出来了
  • 智慧养老时代:老年人慢性病预防与生活方式优化
  • 【今日EDA行业分析】2025年3月28日
  • 基于扩散模型的光照编辑新突破:IC-Light方法解析与优化
  • DeepSeek大模型应用开发新模式
  • 智能舵机:AI融合下的自动化新纪元
  • ADZS-ICE-2000和AD-ICE2000仿真器在线升级固件
  • Error:Flash Download failed
  • AIGC-广告助手创作智能体完整指令(DeepSeek,豆包,千问,Kimi,GPT)
  • Ubuntu与CentOS操作指令的主要区别详解
  • 【力扣hot100题】(004)盛水最多的容器
  • 【go微服务】如何快速掌握grpc开发
  • 计算机二级WPS Office第十二套WPS演示
  • ETL中数据转换的三种处理方式
  • 职场新人面对不懂的问题应该如何寻求帮助?
  • 基于Dockerfile以docker运行java(可快速替换jar包实现工程更新)
  • 2007-2019年各省地方财政一般公共服务支出数据
  • Proxmox配置显卡直通
  • 算法基础_基础算法【快速排序 + 归并排序 + 二分查找】