当前位置: 首页 > article >正文

4.grid_sample理解与使用

pytorch中的grid_sample

文章目录

  • pytorch中的grid_sample
    • grid_sample
    • `grid_sample`函数原型
    • 实例


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


grid_sample

直译为网格采样,给定一个mask patch,根据在目标图像上的坐标网格,将mask变换到目标图像上。

如上图,是将一个2x2mask根据坐标网格grid变换到6x6目标图像x0 y0 x1 y1 = 1,1,3,3的位置上,值得注意的是grid是经过运算得到的坐标网格,masktarget image对应位置的左上角处坐标应该为-1,-1,右下角处坐标应该为1,1,目标图像对应位置的像素值由mask通过插值得到。

知道了grid_sample的原理,再来看下torch中的函数。

grid_sample函数原型

torch.nn.functional.grid_sample(input,
                                grid, 
                                mode='bilinear',                
                                padding_mode='zeros', align_corners=None)
  • input输入image patch,支持4d5d输入。为4dshape N , C , H i n , W i n N,C,H_{in},W_{in} N,C,Hin,Win
  • grid坐标网格,当input4d时其shape N , H o u t , W o u t , 2 N,H_{out},W_{out},2 N,Hout,Wout,2,输出的shapeN,C,H_{out},W_{out},对于输出的位置output[n, :, h, w],‵grid[n, h, w]是二维向量,指定了其对应的input上的位置。output[n, :, h, w]根据‵grid[n, h, w]指定的对应input位置上的像素插值得到。grid指定了在input输入维度上标准化后的坐标大小,input左上角对应的应该是-1,-1,右下角对应的是1,1
  • mode插值方式,'bilinear' | 'nearest' | 'bicubic'
  • padding_mode,在(-1,1)外的输出图像上的像素值处理方式'zeros' | 'border' | 'reflection'
  • align_corners:是否对齐角

实例

以将一个100x100mask,网格采样到500x300的图像上(x,y,w,h)=(100, 100, 100, 200)为例,看一下grid_sample是如何使用的。

先计算grid,


import torch
import numpy as np
import cv2
import torch.nn.functional as F
import matplotlib.pyplot as plt

h, w = 300, 500
x0, y0, x1, y1 = torch.tensor([[100]]), torch.tensor([[100]]), torch.tensor([[200]]), torch.tensor([[300]])
N = 1
x0_int, y0_int = 0, 0
x1_int, y1_int = 500, 300
img_y = torch.arange(y0_int, y1_int, dtype=torch.float32) + 0.5
img_x = torch.arange(x0_int, x1_int, dtype=torch.float32) + 0.5
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
img_x = (img_x - x0) / (x1 - x0) * 2 - 1

gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
grid = torch.stack([gx, gy], dim=3)

这里使用的是mask在目标图像上的大小来对grid归一化的。

mask = np.zeros((100, 100), dtype=np.uint8)
ct = np.array([[50, 0],[99, 50], [50, 99], [0, 50]], dtype=np.int32)
mask = cv2.drawContours(mask, [ct], -1, 255,  cv2.FILLED)
plt.figure(1)
plt.imshow(mask)
mask = torch.from_numpy(mask)
masks = mask[None, None, :]

if not torch.jit.is_scripting():
    if not masks.dtype.is_floating_point:
        masks = masks.float()
        
img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)
plt.figure(2)
plt.imshow(img_masks.squeeze().numpy().astype(np.uint8))

根据gridmask映射到目标图像上的指定区域指定大小。


1.https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html


http://www.kler.cn/a/157514.html

相关文章:

  • Hutool工具包的常用工具类的使用介绍
  • 基于字节大模型的论文翻译(含免费源码)
  • 【Rust自学】4.4. 引用与借用
  • 数字时代的医疗挂号变革:SSM+Vue 系统设计与实现之道
  • 写SQL太麻烦?免费搭建 Text2SQL 应用,智能写 SQL | OceanBase AI 实践
  • ios 混合开发应用白屏问题
  • 【模电】基本共射放大电路的工作原理及波形分析
  • TCP/IP的体系结构
  • SCTransform normalization seurat
  • C++学习之路(十八)C++ 用Qt5实现一个工具箱(点击按钮以新窗口打开功能面板)- 示例代码拆分讲解
  • 深度学习模型部署与优化:关键考虑与实践策略
  • 重新定义页面滚动条
  • 常见的几种计算机编码格式
  • Oracle(2-9) Oracle Recovery Manager Overview and Configuration
  • IDEA构建springBoot新项目时JDK只有17和21,无法选择JDK8解决方案
  • 采用驱动IC和NMOS的防反电路设计
  • 【问题总结】Docker环境下,将Nacos版本2.0.4升级到2.2.3,操作留档 以及 踩坑记录
  • 【数据结构实验】排序(二)希尔排序算法的详细介绍与性能分析
  • pbootcms建站
  • 记录66666
  • oracle数据库 实例名是区分大小写的
  • nodejs+vue+微信小程序+python+PHP就业求职招聘信息平台的设计与实现-计算机毕业设计推荐
  • pyecharts可视化作图4:行业分布-条形图
  • 网络可信空间|探讨现有网络安全中可信空间建设问题,以及建设可信空间的关键要素
  • node运行报错:error:0308010C:digital envelope routines::unsupported
  • 07、pytest指定要运行哪些用例