图解 RoIAlign 以及在 PyTorch 中的使用(含代码示例)

RoIAlign 的用处

RoIAlign 用于将任意尺寸感兴趣区域的特征图,都转换为具有固定尺寸 H×W 的小特征图。

与RoI pooling一样,其基本原理是将

h

×

w

h×w

h×w 的特征划分为

H

×

W

H×W

H×W 网格,每个格子是大小近似为

h

/

H

×

w

/

W

h/H×w/W

h/H×w/W 的子窗口 ,然后将每个子窗口中的值最大池化到相应的输出网格单元中。想复习RoI pooling概念的可以看这篇

RoIAlign 其实就是更精确版本的 RoIPooling,用双线性插值取代了RoIPooling中的直接取整的操作。

下面用一个具体图例看下 RoIAlign 计算原理。

RoIAlign 计算原理

输入一个feature map,对于每个不同尺寸的proposed region,需要转换成固定大小

H

×

W

H×W

H×W的 feature map,H和W是这一层的超参数。
在这里插入图片描述
黑色粗框部分是一个

7

×

5

7×5

7×5 大小的 proposed region,首先切分成

H

×

W

H×W

H×W 个sections(这里以2x2为例):
在这里插入图片描述
对每个section采样四个区域,用红色×表示其中心位置:
在这里插入图片描述
每个section中四个红色×的值,由双线性插值计算:
在这里插入图片描述
对每个 section 中四个值进行 max pooling,输出结果:
在这里插入图片描述
就是我们所需要的固定大小输出了。

这个固定大小输出可以通过全连接的层,用于边界框回归和分类,常用于检测和分割模型中。

双线性插值(Bilinear Interpolation)

借用下图从视觉上来理解双线性插值,黑点上的双线期插值是附近四个点的加权和,权值是四个点对应的颜色矩形在总面积中的占比。比如左上角黄点

(

x

1

,

y

2

)

(x_1,y_2)

(x1,y2) 对应的是右下较大的黄色矩阵面积。
在这里插入图片描述
在这里插入图片描述

pytorch中的实现

RoIAlign在pytorch中的实现是torchvision.ops.RoIAlign,torchvision.ops中实现的是计算机视觉中特定的operators。

class: torchvision.ops.RoIAlign(output_size, spatial_scale, sampling_ratio)

  • output_size (int or Tuple[int, int]) – 输出大小,用 (height, width) 表示。
  • spatial_scale (float) – 将输入坐标映射到框坐标的比例因子。默认值1.0。
  • sampling_ratio (int) – 插值网格中用于计算每个合并输出bin的输出值的采样点数目。如果> 0,则恰好使用sampling_ratio x sampling_ratio网格点。如果<= 0,则使用自适应数量的网格点(计算为cell (roi_width / pooled_w),同样计算高度)。默认值1。

torchvision.ops.roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1)

  • input (Tensor[N, C, H, W]) – 输入张量
  • boxes (Tensor[K, 5] or List[Tensor[L, 4]]) – 区域包围框以

    (

    x

    1

    ,

    y

    1

    ,

    x

    2

    ,

    y

    2

    )

    (x1, y1, x2, y2)

    (x1,y1,x2,y2) 形式表示。如果输入的是单个tensor,第一列表示batch index;如果输入是一个tensor List,每个tensor对应batch中的第

    i

    i

    i个元素的方框。

简单示例

import torch
import torchvision

# 创建RoIAlign层
pooler = torchvision.ops.RoIAlign(output_size=2,sampling_ratio=2,spatial_scale=5)

# 输入一个 8x8 的feature:
inputTensor = torch.rand(1,1,8,8)

inputTensor类似如下:
在这里插入图片描述

再创建一个box:

box =  torch.tensor([[0.0,0.375,0.875,0.625]]) 

output = pooler(inputTensor,[box])#shape:[1, 1, 2, 2]

输出结果:
在这里插入图片描述

在FasterRCNN中的使用示例

import torchvision
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

# load a pre-trained model for classification and return
# only the features
backbone = torchvision.models.mobilenet_v2(pretrained=True).features
# FasterRCNN needs to know the number of
# output channels in a backbone. For mobilenet_v2, it's 1280
# so we need to add it here
backbone.out_channels = 1280

# let's make the RPN generate 5 x 3 anchors per spatial
# location, with 5 different sizes and 3 different aspect
# ratios. We have a Tuple[Tuple[int]] because each feature
# map could potentially have different sizes and
# aspect ratios
anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
                                   aspect_ratios=((0.5, 1.0, 2.0),))

# let's define what are the feature maps that we will
# use to perform the region of interest cropping, as well as
# the size of the crop after rescaling.
# if your backbone returns a Tensor, featmap_names is expected to
# be [0]. More generally, the backbone should return an
# OrderedDict[Tensor], and in featmap_names you can choose which
# feature maps to use.
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
                                                output_size=7,
                                                sampling_ratio=2)

# put the pieces together inside a FasterRCNN model
model = FasterRCNN(backbone,
                   num_classes=2,
                   rpn_anchor_generator=anchor_generator,
                   box_roi_pool=roi_pooler)

参考链接

https://zhuanlan.zhihu.com/p/59692298
https://zhuanlan.zhihu.com/p/73138740
https://pytorch.org/docs/1.2.0/torchvision/ops.html
https://pytorch.org/docs/1.2.0/_modules/torchvision/ops/roi_align.html

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇

)">
下一篇>>