当前位置: 首页 > article >正文

F.interpolate函数

F.interpolate 是 PyTorch 中用于对张量(通常是图像数据)进行插值操作的函数,常用于调整张量的大小,例如改变图像的分辨率。它支持多种插值方法,包括最近邻插值、双线性插值和三次插值等。

语法

torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None)

参数

  1. input:

    • 输入的张量,形状通常为 (N, C, H, W)(N, C, D, H, W)(批次、通道数、高度、宽度 或深度、高度、宽度)。
  2. size:

    • 调整后张量的目标大小,可以是整数元组,例如 (height, width)
    • 优先级高于 scale_factor
  3. scale_factor:

    • 用于调整大小的比例因子,可以是浮点数或元组(对于高度和宽度分别指定比例)。
    • 如果指定了 size,此参数会被忽略。
  4. mode:

    • 指定插值方法,常用选项:
      • 'nearest':最近邻插值。
      • 'linear':线性插值(仅适用于 3D 输入)。
      • 'bilinear':双线性插值(常用于 2D 图像)。
      • 'bicubic':双三次插值(适用于 2D 图像)。
      • 'trilinear':三线性插值(适用于 3D 输入)。
      • 'area':区域插值,用于下采样。
  5. align_corners:

    • 仅在 mode'linear', 'bilinear', 'bicubic''trilinear' 时使用。
    • 如果为 True,则输入和输出的角像素对齐。

返回值

调整大小后的张量。


示例代码

1. 将图像从 640x640 调整为 832x832
import torch
import torch.nn.functional as F

# 创建一个随机图像张量,形状为 (batch_size=1, channels=3, height=640, width=640)
img = torch.randn(1, 3, 640, 640)

# 使用 F.interpolate 调整分辨率为 832x832
resized_img = F.interpolate(img, size=(832, 832), mode='bilinear', align_corners=False)

print("Original shape:", img.shape)
print("Resized shape:", resized_img.shape)
2. 使用比例调整图像大小
# 使用 scale_factor=1.3 对图像尺寸放大 1.3 倍
scaled_img = F.interpolate(img, scale_factor=1.3, mode='bilinear', align_corners=False)

print("Scaled shape:", scaled_img.shape)
3. 下采样为一半大小
# 使用 scale_factor=0.5 对图像尺寸缩小 50%
downsampled_img = F.interpolate(img, scale_factor=0.5, mode='area')

print("Downsampled shape:", downsampled_img.shape)

注意事项

  1. align_corners 的影响
    align_corners=True 时,插值会在输入和输出张量的角像素之间进行对齐;否则,计算比例时不对齐角像素。通常推荐 align_corners=False,避免形变或偏移。

  2. 选择插值方法

    • 双线性插值(bilinear)和双三次插值(bicubic)通常适用于图像重采样,生成更平滑的结果。
    • 最近邻插值(nearest)速度快,但结果不够平滑。
  3. 处理多通道输入
    F.interpolate 可直接处理多通道(如 RGB、IR 数据)的张量,不需要额外操作。


http://www.kler.cn/a/469136.html

相关文章:

  • VTK 鼠标+键盘重构
  • 【LC】2469. 温度转换
  • OpenKit 介绍
  • 如何利用PHP爬虫按关键字搜索淘宝商品
  • reactor中的并发
  • RabbitMq的Java项目实践
  • [Linux]redis5.0.x升级至7.x完整操作流程
  • 使用MySQL APT源在Linux上安装MySQL
  • spring mvc源码学习笔记之五
  • 【华为OD-E卷 - 九宫格按键输入 100分(python、java、c++、js、c)】
  • Linux系统常用命令详解
  • 怎么找回电脑所有连接过的WiFi密码
  • 【论文阅读笔记】LTX-Video: Realtime Video Latent Diffusion
  • 如何让编码更加高效专注?——程序员编程选明基 RD280U显示器
  • Django Swagger文档库drf-spectacular
  • 【Rust 知识点杂记】
  • 微信小程序提示 miniprogram-recycle-view 引入失败
  • leetcode hot 100 最长递增子序列
  • 智能体语言 Shire 1.2 发布:自定义多文件编辑、Sketch 视图、流式 diff、智能上下文感知...
  • AI生成PPT,效率与创意的双重升级
  • 【开源免费】基于SpringBoot+Vue.JS精品在线试题库系统(JAVA毕业设计)
  • 开发小技巧分享 01:JSON解析工具
  • 入手51单片机的学习路径
  • Linux中的tcpdump抓包命令详解:抓取TCP和UDP数据包并按小时输出文件
  • 【MyBatis-Plus 进阶功能】开发中常用场景剖析
  • C++之STL