PyTorch实现非极大值抑制(NMS)

拓朴露台
• 阅读 1074

NMS即non maximum suppression即非极大抑制,顾名思义就是抑制不是极大值的元素,搜索局部的极大值。在最近几年常见的物体检测算法(包括rcnn、sppnet、fast-rcnn、faster-rcnn等)中,最终都会从一张图片中找出很多个可能是物体的矩形框,然后为每个矩形框为做类别分类概率。本文来通过Pytorch实现NMS算法。

如果你在做计算机视觉(特别是目标检测),你肯定会听说过非极大值抑制(nms)。网上有很多不错的文章给出了适当的概述。简而言之,非最大抑制使用一些启发式方法减少了输出边界框的数量,例如交叉除以并集(iou)。

在PyTorch的文档中说:NMS 迭代地删除与另一个(得分较高)框的 IoU 大于 iou_threshold 的得分较低的框。

为了研究其如何工作,让我们加载一个图像并创建边界框

 from PIL import Image
 import torch
 import matplotlib.pyplot as plt
 import numpy as np
 
 # credit https://i0.wp.com/craffic.co.in/wp-content/uploads/2021/02/ai-remastered-rick-astley-never-gonna-give-you-up.jpg?w=1600&ssl=1
 img = Image.open("./samples/never-gonna-give-you-up.webp")
 img

我们手动创建 两个框,一个人脸,一个话筒

 original_bboxes = torch.tensor([
     # head
     [ 565, 73, 862, 373],
     # mic
     [807, 309, 865, 434]
 ]).float()
 
 w, h = img.size
 # we need them in range [0, 1]
 original_bboxes[...,0] /= h
 original_bboxes[...,1] /= w
 original_bboxes[...,2] /= h
 original_bboxes[...,3] /= w

这些bboxes 都是在[0,1]范围内的,虽然这不是必需的,但当有多个类时,这是非常有用的(我们稍后将看到为什么)。

 from torchvision.utils import draw_bounding_boxes
 from torchvision.transforms.functional import to_tensor
 from typing import List
 
 def plot_bboxes(img : Image.Image, bboxes: torch.Tensor, *args, **kwargs) -> plt.Figure:
     w, h = img.size
     # from [0, 1] to image size
     bboxes = bboxes.clone()
     bboxes[...,0] *= h
     bboxes[...,1] *= w
     bboxes[...,2] *= h
     bboxes[...,3] *= w
     fig = plt.figure()
     img_with_bboxes = draw_bounding_boxes((to_tensor(img) * 255).to(torch.uint8), bboxes, *args, **kwargs, width=4)
     return plt.imshow(img_with_bboxes.permute(1,2,0).numpy())
 
 plot_bboxes(img, original_bboxes, labels=["head", "mic"])

PyTorch实现非极大值抑制(NMS)

为了说明,我们添加一些重叠的框

 max_bboxes = 3
 scaling = torch.tensor([1, .96, .97, 1.02])
 shifting = torch.tensor([0, 0.001, 0.002, -0.002])
 
 # broadcasting magic (2, 1, 4) * (1, 3, 1)
 bboxes = (original_bboxes[:,None,:] * scaling[..., None] + shifting[..., None]).view(-1, 4)
 
 plot_bboxes(img, bboxes, colors=[*["yellow"] * 4, *["blue"] * 4], labels=[*["head"] * 4, *["mic"] * 4])

PyTorch实现非极大值抑制(NMS)

现在可以看到,有6个bboxes ,这里我们还需要定义一个分数,这通常由模型输出。

 scores = torch.tensor([
     0.98, 0.85, 0.5, 0.2, # for head
     1, 0.92, 0.3, 0.1 # for mic
 ])

我们标签的分类,0代表人脸,1代表麦克风

 labels = torch.tensor([0,0,0,0,1,1,1,1])

最后,让我们排列一下这些数据

 perm = torch.randperm(scores.shape[0])
 bboxes = bboxes[perm]
 scores = scores[perm]
 labels = labels[perm]

让我们看看结果

 plot_bboxes(img, bboxes, 
             colors=["yellow" if el.item() == 0 else "blue" for el in labels], 
             labels=["head" if el.item()  == 0 else "mic" for el in labels]
            )

PyTorch实现非极大值抑制(NMS)

好了,这样我们模拟了模型的输出了,下面进入正题。

NMS是通过迭代删除低分数重叠的边界框来工作的。步骤如下。

bboxes are sorted by score in decreasing order
init a vector keep with ones
for i in len(bboxes):
    # was suppressed
    if keep[i] == 0:
        continue
    # compare with all the others
    for j in len(bbox):
        if keep[j]:
            if (iou(bboxes[i], bboxes[j]) > iou_threshold):
                keep[j] = 0

return keep

我们的Pytorch实现,采用三个参数(这实际上是从pytorch的文档中复制和粘贴的):

  • box (Tensor[N, 4])) – 用于执行 NMS 的框。它们应该是 (x1, y1, x2, y2) 格式,0 <= x1 < x2 和 0 <= y1 < y2。
  • score (Tensor[N]) – 每个box 的得分
  • iou_threshold (float) – 丢弃所有 IoU > iou_threshold 的框
  • 返回值是非抑制边界框的索引
from torchvision.ops.boxes import box_iou

def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
    order = torch.argsort(-scores)
    indices = torch.arange(bboxes.shape[0])
    keep = torch.ones_like(indices, dtype=torch.bool)
    for i in indices:
        if keep[i]:
            bbox = bboxes[order[i]]
            iou = box_iou(bbox[None,...],(bboxes[order[i + 1:]]) * keep[i + 1:][...,None])
            overlapped = torch.nonzero(iou > iou_threshold)
            keep[overlapped + i + 1] = 0
    return order[keep]

让我们详细说明下这个参数:

order = scores.argsort()

根据分数得到排序的指标

indices = torch.arange(bboxes.shape[0])

创建用于迭代bboxes的索引 indices

keep = torch.ones_like(indices, dtype=torch.bool)

keep是用于判断一个bbox是否应该保留的向量,如果Keep [i] == 1,则bboxes[order[i]]不被抑制

for i in indices:
    ...

for循环遍历所有的box,如果当前box未被抑制,则keep[i] = 1

bbox = bboxes[order[i]]]

来通过已排序的位置获取bbox

iou = box_iou(bbox[None,...], (bboxes[order[i + 1:]]) * keep[i + 1:][...,None])

计算当前bbox和所有其他候选bbox之间的iou。这将把所有抑制框设置为零(因为keep将等于0)

(bboxes ...)[order[i + 1:]]

在排序的顺序中与后面所有的框进行比较,因为需要跳过当前的框,所以这里是i+ 1,

overlapped = torch.nonzero(iou > iou_threshold)
keep[overlapped + i + 1] = 0

计算和选择iou大于iou_threshold的索引。

我们之前对bboxes进行了切片,(bboxes…)[i + 1:]),所以我们需要添加这些索引的偏移量,这就是后面+ i + 1的原因。

最后返回order[keep],这样映射回原始的box索引(未排序),这样一个简单的函数就执行完成了。

让我们看看结果

nms_indices = nms(bboxes, scores, .45)
plot_bboxes(img, 
            bboxes[nms_indices],
            colors=["yellow" if el.item() == 0 else "blue" for el in labels[nms_indices]], 
            labels=["head" if el.item()  == 0 else "mic" for el in labels[nms_indices]]
           )

PyTorch实现非极大值抑制(NMS)

因为有多个类,所以需要让nms在同一个类中计算iou。还记得上面我们提到的在[0,1]之间吗?可以给它们添加标签,把不同类的框区分开。

nms_indices = nms(bboxes + labels[..., None], scores, .45)
plot_bboxes(img, 
            bboxes[nms_indices],
            colors=["yellow" if el.item() == 0 else "blue" for el in labels[nms_indices]], 
            labels=["head" if el.item()  == 0 else "mic" for el in labels[nms_indices]]
           )

PyTorch实现非极大值抑制(NMS)

如果我们将阈值更改为0.1,就得到了下图

PyTorch实现非极大值抑制(NMS)

让我们对比下pytorch官方的实现:

from torchvision.ops.boxes import nms as torch_nms
nms_indices = torch_nms(bboxes + labels[..., None], scores, .45)
plot_bboxes(img, 
            bboxes[nms_indices],
            colors=["yellow" if el.item() == 0 else "blue" for el in labels[nms_indices]], 
            labels=["head" if el.item()  == 0 else "mic" for el in labels[nms_indices]]
           )

PyTorch实现非极大值抑制(NMS)

结果是一样的。然我们看看时间:

%%timeit
nms(bboxes + labels[..., None], scores, .45)
#534 µs ± 22.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%%timeit
torch_nms(bboxes + labels[..., None], scores, .45)
#54.4 µs ± 3.29 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

我们的实现慢了大约10倍,哈,这个结果很正常,因为我们我们没有使用自定义的cpp内核!但是这并不代表我们的实现没有用,因为手写代码我们完全了解了NMS的工作原理,这是本文的真正意义,总之在这篇文章中我们看到了如何在PyTorch中实现非最大抑制,这对你了解目标检测的相关知识是非常有帮助的。

https://avoid.overfit.cn/post/1ffeb08f8ea4494cb992b0ad05db174b

作者:Francesco Zuppichini

点赞
收藏
评论区
推荐文章
blmius blmius
4年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
Wesley13 Wesley13
3年前
MySQL部分从库上面因为大量的临时表tmp_table造成慢查询
背景描述Time:20190124T00:08:14.70572408:00User@Host:@Id:Schema:sentrymetaLast_errno:0Killed:0Query_time:0.315758Lock_
美凌格栋栋酱 美凌格栋栋酱
7个月前
Oracle 分组与拼接字符串同时使用
SELECTT.,ROWNUMIDFROM(SELECTT.EMPLID,T.NAME,T.BU,T.REALDEPART,T.FORMATDATE,SUM(T.S0)S0,MAX(UPDATETIME)CREATETIME,LISTAGG(TOCHAR(
皕杰报表之UUID
​在我们用皕杰报表工具设计填报报表时,如何在新增行里自动增加id呢?能新增整数排序id吗?目前可以在新增行里自动增加id,但只能用uuid函数增加UUID编码,不能新增整数排序id。uuid函数说明:获取一个UUID,可以在填报表中用来创建数据ID语法:uuid()或uuid(sep)参数说明:sep布尔值,生成的uuid中是否包含分隔符'',缺省为
Jacquelyn38 Jacquelyn38
4年前
2020年前端实用代码段,为你的工作保驾护航
有空的时候,自己总结了几个代码段,在开发中也经常使用,谢谢。1、使用解构获取json数据let jsonData  id: 1,status: "OK",data: 'a', 'b';let  id, status, data: number   jsonData;console.log(id, status, number )
Wesley13 Wesley13
3年前
FLV文件格式
1.        FLV文件对齐方式FLV文件以大端对齐方式存放多字节整型。如存放数字无符号16位的数字300(0x012C),那么在FLV文件中存放的顺序是:|0x01|0x2C|。如果是无符号32位数字300(0x0000012C),那么在FLV文件中的存放顺序是:|0x00|0x00|0x00|0x01|0x2C。2.  
Wesley13 Wesley13
3年前
mysql设置时区
mysql设置时区mysql\_query("SETtime\_zone'8:00'")ordie('时区设置失败,请联系管理员!');中国在东8区所以加8方法二:selectcount(user\_id)asdevice,CONVERT\_TZ(FROM\_UNIXTIME(reg\_time),'08:00','0
Wesley13 Wesley13
3年前
PHP创建多级树型结构
<!lang:php<?php$areaarray(array('id'1,'pid'0,'name''中国'),array('id'5,'pid'0,'name''美国'),array('id'2,'pid'1,'name''吉林'),array('id'4,'pid'2,'n
Wesley13 Wesley13
3年前
00:Java简单了解
浅谈Java之概述Java是SUN(StanfordUniversityNetwork),斯坦福大学网络公司)1995年推出的一门高级编程语言。Java是一种面向Internet的编程语言。随着Java技术在web方面的不断成熟,已经成为Web应用程序的首选开发语言。Java是简单易学,完全面向对象,安全可靠,与平台无关的编程语言。
Stella981 Stella981
3年前
Django中Admin中的一些参数配置
设置在列表中显示的字段,id为django模型默认的主键list_display('id','name','sex','profession','email','qq','phone','status','create_time')设置在列表可编辑字段list_editable
Python进阶者 Python进阶者
1年前
Excel中这日期老是出来00:00:00,怎么用Pandas把这个去除
大家好,我是皮皮。一、前言前几天在Python白银交流群【上海新年人】问了一个Pandas数据筛选的问题。问题如下:这日期老是出来00:00:00,怎么把这个去除。二、实现过程后来【论草莓如何成为冻干莓】给了一个思路和代码如下:pd.toexcel之前把这