您当前的位置: 首页 >  目标检测
  • 3浏览

    0关注

    417博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

detectron2(目标检测框架)无死角玩转-10:源码详解(6)-anchor的使用,loss计算

江南才尽,年少无知! 发布时间:2020-03-07 14:14:29 ,浏览量:3

以下链接是个人关于detectron2(目标检测框架),所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号− 海量资源。 \color{blue}{ 海量资源}。 海量资源。

detectron2(目标检测框架)无死角玩转-00:目录

前言

通过上一篇的博客,我们已经知道anchor是如何生成的了,这里再提一下,每个特征特的每个网格,都会对应生成多个anchor。下面我们看看这些anchor是如何使用的,或者说它有什么作用。

我们先进入detectron2\modeling\meta_arch\retinanet.py文件,找类class RetinaNet(nn.Module),然后找到def forward(self, batched_inputs):函数的如下内容:

        # 生成anchor
        anchors = self.anchor_generator(features)

        # 如果是训练,则结合 ground_truth 计算loss
        if self.training:
            gt_classes, gt_anchors_reg_deltas = self.get_ground_truth(anchors, gt_instances)
            return self.losses(gt_classes, gt_anchors_reg_deltas, box_cls, box_delta)

        # 如果是预测,则返回预测的结果
        else:
            results = self.inference(box_cls, box_delta, anchors, images.image_sizes)
            processed_results = []
            for results_per_image, input_per_image, image_size in zip(
                results, batched_inputs, images.image_sizes
            ):
                height = input_per_image.get("height", image_size[0])
                width = input_per_image.get("width", image_size[1])
                r = detector_postprocess(results_per_image, height, width)
                processed_results.append({"instances": r})
            return processed_results

可以看到无论是训练过程,还是预测过程,都使用到了生成的anchor。那么我们就先深入了解一下训练的anchor。

训练过程

从上面可以知道,训练过程只执行了如下代码:

        # 如果是训练,则结合 ground_truth 计算loss
        if self.training:
            gt_classes, gt_anchors_reg_deltas = self.get_ground_truth(anchors, gt_instances)
            return self.losses(gt_classes, gt_anchors_reg_deltas, box_cls, box_delta)

可以明显的知道,核心要点为self.get_ground_truth(anchors, gt_instances)与self.losses(gt_classes, gt_anchors_reg_deltas, box_cls, box_delta) 两个函数,先来看看self.get_ground_truth(),从函数的名字可以知道,该函数的主要作用是根据anchors,结合 gt_instances,去获得训练样本对应的ground truth。

在分析函数之前,我们先来回忆一下gt_instances是什么东西,本人做一份截图如下: 在这里插入图片描述 从图示,我们可以看出,每个gt_instance包含了一张图片的大小尺寸,以及目标物体的box,和每个物体的类别。要注意的是,这些图片的大小尺寸,以及box都已经进行了预处理。并非最开始原图的boxs。那么下面我们就分析def get_ground_truth(self, anchors, targets)函数吧。

    def get_ground_truth(self, anchors, targets):
        """
		为了代码简洁好看,本人没有粘贴英文注释了,有兴趣的朋友可以在源码中查看
        """
        gt_classes = []
        gt_anchors_deltas = []
        anchors = [Boxes.cat(anchors_i) for anchors_i in anchors]
        # list[Tensor(R, 4)], one for each image

        # 循环处理每张图片,获得每张图片对应的gt_classe以及gt_anchors_delta
        for anchors_per_image, targets_per_image in zip(anchors, targets):
            # 用每张图片的gt_boxes的,和所有anchor进行匹配,计算出对应的iou值。
            # 如targets_per_image.gt_boxes包含了N个box,anchors_per_image包含了M个anchor
            # 那么得到 match_quality_matrix 的形状为[N,M],存储为匹配的IOU值
            match_quality_matrix = pairwise_iou(targets_per_image.gt_boxes, anchors_per_image)

            # 通过阈值进行筛选,把低于阈值的去除掉,并且以下标所以和对应的anchor进行表示
            gt_matched_idxs, anchor_labels = self.matcher(match_quality_matrix)

            has_gt = len(targets_per_image) > 0

            # 如果 has_gt>0 说明存在ground truth
            if has_gt:
                # ground truth box regression
                # 通过索引获得anchor对应的匹配到的gt_boxes
                matched_gt_boxes = targets_per_image.gt_boxes[gt_matched_idxs]

                # 根据gt_boxes 以及 anchor 计算他们的偏移值,同时这个偏移值就是网络需要学习的东西
                gt_anchors_reg_deltas_i = self.box2box_transform.get_deltas(
                    anchors_per_image.tensor, matched_gt_boxes.tensor
                )

                # 对每个匹配到的anchor 与 gt_boxes 进行类别标记
                gt_classes_i = targets_per_image.gt_classes[gt_matched_idxs]

                # Anchors with label 0 are treated as background.
                gt_classes_i[anchor_labels == 0] = self.num_classes

                # Anchors with label -1 are ignored.
                gt_classes_i[anchor_labels == -1] = -1

            # 如果 has_gt            
关注
打赏
1592542134
查看更多评论
0.0385s