当前位置: 首页 > article >正文

【代码实现】torch实现F.pixel_shuffle和F.pixel_unshuffle

原理

pixel_shuffle 和 pixel_unshuffle 常用于神经网络减少特征图尺寸以减少计算量,由于有些硬件不支持这两个算子,可以根据原理使用torch实现。

代码实现

import torch.nn.functional as F
import torch
def pixelshuffle_inv(tensor, scale=2):
    N, ch, height, width = tensor.shape
    new_ch = ch * (scale * scale)
    new_height = height // scale
    new_width = width // scale
    
    tensor = tensor.view(N, ch, new_height, scale, new_width, scale)
    tensor = tensor.permute(0, 1, 3, 5, 2, 4).contiguous()
    tensor = tensor.view(N, ch * (scale ** 2), new_height, new_width)

    return tensor

def pixelshuffle(tensor, scale=2):
    N, ch, height, width = tensor.shape
    new_ch = ch // (scale * scale)
    new_height = height * scale
    new_width = width * scale
    # 重新排列张量
    output_tensor = tensor.view(N, new_ch, scale, scale, height, width)
    output_tensor = output_tensor.permute(0, 1, 4, 2, 5, 3).contiguous()
    output_tensor = output_tensor.view(N, new_ch, new_height, new_width)
    return output_tensor


if __name__ == '__main__':
    input = torch.randn(1, 3, 256, 256)
    scale = 2
    unshuffle_ = pixelshuffle_inv(input,scale)
    unshuffle_F = F.pixel_unshuffle(input,scale)
    print(torch.equal(unshuffle_,unshuffle_F))
    print(torch.max(unshuffle_-unshuffle_F))


    shuffle_ = pixelshuffle(unshuffle_F,scale)
    shuffle_F = F.pixel_shuffle(unshuffle_F,scale)
    print(torch.equal(shuffle_,shuffle_F))
    print(torch.max(shuffle_-shuffle_F))

运行结果,与官方结果完全一致。
在这里插入图片描述


http://www.kler.cn/news/336720.html

相关文章:

  • guava里常用功能
  • 强化学习笔记之【Q-learning算法和DQN算法】
  • 区块链的编程语言有那些?
  • 基于STM32的智能垃圾桶控制系统设计
  • wordpress父分类和归档页调用子分类名称和链接
  • can 总线入门———can简介硬件电路
  • html+css+js实现Collapse 折叠面板
  • [运维]5.镜像本地存在但仍然从网络拉取的问题
  • Qt 6 相比 Qt 5 的主要提升与更新
  • Java基础-单例模式的实现
  • Android Codec2 CCodec(十六)C2AllocatorGralloc
  • 241006-Gradio中Chatbot通过CSS自适应调整高度
  • 黑名单与ip禁令是同一个东西吗
  • CSS中的class与id
  • VirtulBOX Ubuntu22安装dpdk23.11
  • 计算机网络——p2p
  • Prometheus监控MySQL主从数据库
  • simple c++ 无锁队列
  • Mybatis测试案例
  • 什么是 Angular 的 @HostBinding 注解