当前位置: 首页 > article >正文

【PyTorch】4.张量拼接操作

       

个人主页:Icomi

在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架,为构建和训练神经网络提供了高效且灵活的平台。神经网络作为人工智能的核心技术,能够处理复杂的数据模式。通过 PyTorch,我们可以轻松搭建各类神经网络模型,实现从基础到高级的人工智能应用。接下来,就让我们一同走进 PyTorch 的世界,探索神经网络与人工智能的奥秘。本系列为PyTorch入门文章,若各位大佬想持续跟进,欢迎与我交流互关。

         前面我们学习了张量和 numpy 数组的相互转换,这是我们在深度学习数据处理中非常实用的技能。

        今天,咱们要讲讲张量的拼接操作,这可是在神经网络搭建过程中极为常用的方法,就好比搭建一座宏伟建筑时不可或缺的连接工艺。想象一下,我们构建神经网络就像搭建一座复杂的大厦,张量就是构成大厦的各种预制构件,而拼接操作就像是把这些构件精准连接在一起的关键技术。

        比如说,在后面将要学习到的残差网络里,张量的拼接起到了至关重要的作用。残差网络能够有效解决深度神经网络训练过程中的梯度消失和梯度爆炸问题,让网络可以更深层次地学习数据特征。这里面,通过巧妙地拼接不同层的张量,就像是把不同功能的建筑模块合理组合,从而构建出了强大的深层网络结构。

        还有注意力机制,这也是深度学习领域的一个重要概念。在注意力机制中,张量的拼接帮助我们对不同的信息进行整合与聚焦,就像在纷繁复杂的信息海洋中,通过拼接操作找到最关键的信息片段并组合起来,让模型能够更加 “聪明” 地处理数据。

        所以,掌握张量的拼接操作,对于我们理解和构建先进的神经网络模型至关重要。接下来,咱们就深入学习一下张量的拼接到底是怎么实现的,以及在不同场景下该如何灵活运用它。

1. torch.cat 函数的使用¶

torch.cat 函数可以将两个张量根据指定的维度拼接起来.
 

import torch


def tensor_concatenation():
    # 创建第一个三维随机整数张量
    tensor_1 = torch.randint(0, 10, [3, 5, 4])
    # 创建第二个三维随机整数张量
    tensor_2 = torch.randint(0, 10, [3, 5, 4])

    print(tensor_1)
    print(tensor_2)
    print('-' * 50)

    # 1. 按 0 维度拼接
    concatenated_tensor_dim0 = torch.cat([tensor_1, tensor_2], dim=0)
    print(concatenated_tensor_dim0.shape)
    print('-' * 50)

    # 2. 按 1 维度拼接
    concatenated_tensor_dim1 = torch.cat([tensor_1, tensor_2], dim=1)
    print(concatenated_tensor_dim1.shape)
    print('-' * 50)

    # 3. 按 2 维度拼接
    concatenated_tensor_dim2 = torch.cat([tensor_1, tensor_2], dim=2)
    print(concatenated_tensor_dim2)


if __name__ == '__main__':
    tensor_concatenation()

程序输出结果:

tensor([[[6, 8, 3, 5],
         [1, 1, 3, 8],
         [9, 0, 4, 4],
         [1, 4, 7, 0],
         [5, 1, 4, 8]],

        [[0, 1, 4, 4],
         [4, 1, 8, 7],
         [5, 2, 6, 6],
         [2, 6, 1, 6],
         [0, 7, 8, 9]],

        [[0, 6, 8, 8],
         [5, 4, 5, 8],
         [3, 5, 5, 9],
         [3, 5, 2, 4],
         [3, 8, 1, 1]]])
tensor([[[4, 6, 8, 1],
         [0, 1, 8, 2],
         [4, 9, 9, 8],
         [5, 1, 5, 9],
         [9, 4, 3, 0]],

        [[7, 6, 3, 3],
         [4, 3, 3, 2],
         [2, 1, 1, 1],
         [3, 0, 8, 2],
         [8, 6, 6, 5]],

        [[0, 7, 2, 4],
         [4, 3, 8, 3],
         [4, 2, 1, 9],
         [4, 2, 8, 9],
         [3, 7, 0, 8]]])
--------------------------------------------------
torch.Size([6, 5, 4])
--------------------------------------------------
torch.Size([3, 10, 4])
tensor([[[6, 8, 3, 5, 4, 6, 8, 1],
         [1, 1, 3, 8, 0, 1, 8, 2],
         [9, 0, 4, 4, 4, 9, 9, 8],
         [1, 4, 7, 0, 5, 1, 5, 9],
         [5, 1, 4, 8, 9, 4, 3, 0]],

        [[0, 1, 4, 4, 7, 6, 3, 3],
         [4, 1, 8, 7, 4, 3, 3, 2],
         [5, 2, 6, 6, 2, 1, 1, 1],
         [2, 6, 1, 6, 3, 0, 8, 2],
         [0, 7, 8, 9, 8, 6, 6, 5]],

        [[0, 6, 8, 8, 0, 7, 2, 4],
         [5, 4, 5, 8, 4, 3, 8, 3],
         [3, 5, 5, 9, 4, 2, 1, 9],
         [3, 5, 2, 4, 4, 2, 8, 9],
         [3, 8, 1, 1, 3, 7, 0, 8]]])

2. torch.stack 函数的使用¶

torch.stack 函数可以将两个张量根据指定的维度叠加起来.

import torch


def test():

    data1= torch.randint(0, 10, [2, 3])
    data2= torch.randint(0, 10, [2, 3])
    print(data1)
    print(data2)

    new_data = torch.stack([data1, data2], dim=0)
    print(new_data.shape)

    new_data = torch.stack([data1, data2], dim=1)
    print(new_data.shape)

    new_data = torch.stack([data1, data2], dim=2)
    print(new_data)


if __name__ == '__main__':
    test()

3. 总结¶

张量的拼接操作也是在后面我们经常使用一种操作。cat 函数可以将张量按照指定的维度拼接起来,stack 函数可以将张量按照指定的维度叠加起来。


http://www.kler.cn/a/520607.html

相关文章:

  • STM32-时钟树
  • 用深度学习优化供应链管理:让算法成为商业决策的引擎
  • RNN实现阿尔茨海默症的诊断识别
  • RabbitMQ 架构分析
  • Mono里运行C#脚本36—加载C#类定义的成员变量和方法的数量
  • 基于微信小程序的网上订餐管理系统
  • linux 内核学习方向以及职位
  • 论文笔记(六十三)Understanding Diffusion Models: A Unified Perspective(四)
  • shiro学习五:使用springboot整合shiro。在前面学习四的基础上,增加shiro的缓存机制,源码讲解:认证缓存、授权缓存。
  • Go语言入门指南(二): 数据类型
  • JAVA:利用 Content Negotiation 实现多样式响应格式的技术指南
  • 深入解析ncnn::Net类——高效部署神经网络的核心组件
  • 文献阅读 250125-Accurate predictions on small data with a tabular foundation model
  • SQL Server 使用SELECT INTO实现表备份
  • JWT 实战:在 Spring Boot 中的使用
  • 网络模型简介:OSI七层模型与TCP/IP模型
  • Learning Vue 读书笔记 Chapter 2
  • 【React+ts】 react项目中引入bootstrap、ts中的接口
  • JavaScript使用toFixed保留一位小数的踩坑记录:TypeError: xxx.toFixed is not a function
  • vue3中customRef的用法以及使用场景
  • LeetCode题练习与总结:两个字符串的删除操作--583
  • 9.4 GPT Action 开发实践:从设计到实现的实战指南
  • PoolingHttpClient试验
  • 独立游戏开发赚钱吗?
  • 从0到1:C++ 开启游戏开发奇幻之旅(一)
  • 重构(1)if-else