当前位置: 首页 > article >正文

pytorch nn.Unflatten 和 nn.Flatten模块介绍

nn.Flatten 和 nn.Unflatten 是 PyTorch 中用于调整张量形状的模块。它们提供了对多维张量的简单变换,常用于神经网络模型的层之间的数据调整。


1. nn.Flatten

功能:

  • 将输入张量展平为二维张量,通常用于将卷积层的输出展平成全连接层的输入。
  • 它会将张量的指定维度范围压缩为单个维度。

构造参数:

  • start_dim: 展平的起始维度(默认值为 1)。
  • end_dim: 展平的结束维度(默认值为 -1)。

用法:

import torch
from torch import nn

# 输入张量: [batch_size, channels, height, width]
x = torch.randn(4, 3, 32, 32)

# 展平操作
flatten = nn.Flatten(start_dim=1)  # 从维度1到最后展平
y = flatten(x)

print(y.shape)  # 输出: [4, 3072] (3*32*32 被展平)

适用场景:

  • 通常用于从卷积层(或其他多维特征)到全连接层的过渡。
  • 例如:[batch_size, channels, height, width] -> [batch_size, features]

2. nn.Unflatten

功能:

  • 将展平的张量还原为多维张量。
  • 它通过指定目标维度和形状信息,反向操作 nn.Flatten

构造参数:

  • dim: 需要展开的维度。
  • unflattened_size: 展开的形状(tuple 类型)。

用法:

import torch
from torch import nn

# 输入张量: [batch_size, features]
x = torch.randn(4, 3072)

# 还原操作
unflatten = nn.Unflatten(dim=1, unflattened_size=(3, 32, 32))
y = unflatten(x)

print(y.shape)  # 输出: [4, 3, 32, 32]

适用场景:

  • 通常用于从全连接层(或展平特征)还原到卷积层或其他多维表示。
  • 例如:[batch_size, features] -> [batch_size, channels, height, width]

对比

特性nn.Flattennn.Unflatten
主要操作将多个维度压缩为一个维度将一个维度展开为多个维度
输入多维张量展平的张量
输出二维张量恢复为多维张量
常用场景用于连接卷积层和全连接层用于从展平的特征恢复到多维结构
参数控制指定展平的起始和结束维度范围指定需要展开的维度和目标形状

实际应用示例

结合使用 Flatten 和 Unflatten:

import torch
from torch import nn

# 初始化 Flatten 和 Unflatten
flatten = nn.Flatten(start_dim=1)
unflatten = nn.Unflatten(dim=1, unflattened_size=(3, 32, 32))

# 模拟数据
x = torch.randn(4, 3, 32, 32)  # [batch_size, channels, height, width]

# 展平
flat_x = flatten(x)
print(flat_x.shape)  # 输出: [4, 3072]

# 恢复
unflat_x = unflatten(flat_x)
print(unflat_x.shape)  # 输出: [4, 3, 32, 32]

这两个模块通过简单的接口提供了灵活的形状调整功能,是构建神经网络过程中不可或缺的工具。


http://www.kler.cn/a/452099.html

相关文章:

  • 【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
  • stm32基础(keil创建、Proteus仿真、点亮LED灯,7段数码管)
  • Java技术专家视角解读:SQL优化与批处理在大数据处理中的应用及原理
  • vue3 Proxy替换vue2 defineProperty的原因
  • 各种网站(学习资源及其他)
  • vscode插件更新特别慢的问题
  • Chrome 浏览器插件获取网页 iframe 中的 window 对象
  • 【ORB-SLAM3:相机针孔模型和相机K8模型】
  • Chapter 03 复合数据类型-1
  • RBF分类-径向基函数神经网络(Radial Basis Function Neural Network)
  • 数据库安全-redisCouchdb
  • 硬件设计-传输线匹配
  • 3D视觉坐标变换(像素坐标转换得到基于相机坐标系的坐标)
  • 以太网通信--读取物理层PHY芯片的状态
  • C++ 特殊类的设计
  • 开发微信小程序的过程与心得
  • RuoYi-ue前端分离版部署流程
  • mac中idea菜单工具栏没有git图标了
  • 【HarmonyOS NEXT】hdc环境变量配置
  • 认识计算机网络
  • CosyVoice安装过程详解
  • Java基础学习资料
  • Visual Studio - API调试与测试工具之HTTP文件
  • 《战神:诸神黄昏》游戏运行时提示找不到emp.dll怎么办?emp.dll丢失如何修复?
  • 前端开发 -- 自定义鼠标指针样式
  • 【pytorch】深度学习计算