Commit 3e0ff451 authored by szr712's avatar szr712

first commit

parents
*log*/
*.jpg
*.png
# compilation and distribution
__pycache__
_ext
*.pyc
*.so
build/
dist/
wheels/
# pytorch/python/numpy formats
*.pth
*.pkl
*.npy
# ipython/jupyter notebooks
*.ipynb
**/.ipynb_checkpoints/
# Editor temporaries
*.swn
*.swo
*.swp
*~
# pycharm editor settings
.idea
# vscode editor settings
.vscode
checkpoints
*tmp*
data
logs
weights
run.sh
# Traffic sign detection
## 介绍
本Repo包含了采用MegEngine实现的Faster-RCNN、FCOS、ATSS三个主流模型,并提供了在交通标志数据集(包括红灯,直行路标,向左转弯路标,禁止驶入,禁止车辆临时或长时停放5个类别)上的完整训练和测试代码
## 相关项目链接
- 本目录下代码基于最新版MegEngine,在开始运行本目录下的代码之前,请确保已经正确安装[MegEngine](https://github.com/MegEngine/MegEngine)
- [Models](https://github.com/MegEngine/Models/tree/master/official/vision/detection)
## 如何使用
script目录提供了(frcn__demo, fcos_demo, atss_demo).sh脚本,当准备工作完成之后(如数据、预训练模型等),可以一键跑通训练+测试+推理
- 克隆仓库:
`https://github.com/er-muyue/megengine-trafficsign.git`
- 安装依赖包(包含了megengine):
`pip3 install --user -r requirements.txt`
- 关于数据
- 本目录使用的是交通标志数据集,megstudio环境启动之后默认已经包含数据即,(放到当前目录的data文件夹下,待定)
- annotations 选用 `...traffic5/annotations_train_val_test`
```
/path/to/
|->traffic
| |images
| |annotations->|train.json
| | |val.json
| | |test.json
```
- 关于预训练参数
- 下载对应模型的预训练参数放到`/path/to/weights`
| 模型 | 初始化参数 |
| --- | --- |
| FRCN | https://data.megengine.org.cn/models/weights/faster_rcnn_res50_coco_3x_800size_40dot1_8682ff1a.pkl |
| FCOS | https://data.megengine.org.cn/models/weights/fcos_res50_coco_3x_800size_42dot2_b16f9c8b.pkl |
| ATSS | https://data.megengine.org.cn/models/weights/atss_res50_coco_3x_800size_42dot6_9a92ed8c.pkl |
- 训练模型
- `tools/train.py`的命令行选项如下:
- `-f`, 所需要训练的网络结构描述文件
- `-n`, 用于训练的devices(gpu)数量
- `-w`, 预训练的backbone网络权重
- `-b`,训练时采用的`batch size`, 默认2,表示每张卡训2张图
- `-d`, 数据集的上级目录,默认`/data/datasets`
- 默认情况下模型会存在 `logs/模型_gpus{}`目录下。
- 测试模型
- `tools/test.py`的命令行选项如下:
- `-f`, 所需要测试的网络结构描述文件
- `-n`, 用于测试的devices(gpu)数量
- `-w`, 需要测试的模型权重
- `-d`,数据集的上级目录,默认`/data/datasets`
- `-se`,连续测试的起始epoch数,默认为最后一个epoch,该参数的值必须大于等于0且小于模型的最大epoch数
- `-ee`,连续测试的结束epoch数,默认等于`-se`(即只测试1个epoch),该参数的值必须大于等于`-se`且小于模型的最大epoch数
- 图片推理
- `tools/inference.py`的命令行选项如下:
- `-f`, 测试的网络结构描述文件。
- `-w`, 需要测试的模型权重。
- `-i`, 需要测试的样例图片。
- 一键运行
- (frcn__demo, fcos_demo, atss_demo).sh提供了一键运行脚本,默认用户已经申请了两块GPU
- 评测结果(COCO Pretrained)—— train set 训练,val set测试,2卡
|Model|AP|AP50|AP75|APs|APm|APl|AR@1|AR@10|AR@100|ARs|ARm|ARl|注|
|--- |--- |--- |--- |--- |--- |--- |--- |--- |--- |--- |--- |--- |--- |
|FRCN |44.5 |69.4 |49.6 |30.3 |52.4 |67.9 |42.3 |56.3 |56.6 |41.6 |63.3 |76.9 |1X |
|FRCN |48.0 |71.4 |55.3 |32.7 |58.0 |74.2 |44.7 |58.6 |58.7 |42.3 |67.8 |81.5 |2X |
|FCOS |38.2 |60.4 |41.2 |18.8 |48.2 |68.3 |37.5 |51.0 |52.3 |31.7 |63.0 |80.3 |1X |
|FCOS |46.6 |66.9 |51.7 |26.3 |57.5 |75.0 |45.0 |60.0 |60.9 |40.8 |71.9 |84.7 |2X |
|ATSS |38.4 |59.6 |42.2 |20.4 |48.4 |65.7 |37.6 |51.8 |52.8 |33.3 |63.0 |77.3 |1X |
|ATSS |46.8 |67.5 |52.6 |25.7 |58.8 |75.1 |44.3 |60.5 |61.2 |40.3 |73.2 |85.7 |2X |
- 参考链接
- 暂无
#!/usr/bin/python3
# -*- coding:utf-8 -*-
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import models
class CustomerConfig(models.ATSSConfig):
def __init__(self):
super().__init__()
# ------------------------ dataset cfg ---------------------- #
self.train_dataset = dict(
name="traffic5",
root="images",
ann_file="annotations/train.json",
remove_images_without_annotations=True,
)
self.test_dataset = dict(
name="traffic5",
root="images",
ann_file="annotations/val.json",
test_final_ann_file="annotations/test.json",
remove_images_without_annotations=False,
)
self.num_classes = 5
# ------------------------ training cfg ---------------------- #
self.basic_lr = 0.02 / 16
self.max_epoch = 24
self.lr_decay_stages = [16, 21]
self.nr_images_epoch = 2226
self.warm_iters = 100
self.log_interval = 10
Net = models.ATSS
Cfg = CustomerConfig
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import models
class CustomerConfig(models.FasterRCNNConfig):
def __init__(self):
super().__init__()
# ------------------------ dataset cfg ---------------------- #
self.train_dataset = dict(
name="traffic5",
root="images",
ann_file="annotations/train.json",
remove_images_without_annotations=True,
)
self.test_dataset = dict(
name="traffic5",
root="images",
ann_file="annotations/val.json",
test_final_ann_file="annotations/test.json",
remove_images_without_annotations=False,
)
self.num_classes = 5
# ------------------------ training cfg ---------------------- #
self.basic_lr = 0.02 / 16
self.max_epoch = 24
self.lr_decay_stages = [16, 21]
self.nr_images_epoch = 2226
self.warm_iters = 100
self.log_interval = 10
Net = models.FasterRCNN
Cfg = CustomerConfig
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import models
class CustomerConfig(models.FasterRCNNConfig):
def __init__(self):
super().__init__()
# ------------------------ dataset cfg ---------------------- #
self.train_dataset = dict(
name="traffic5",
root="images",
ann_file="annotations/train.json",
remove_images_without_annotations=True,
)
self.test_dataset = dict(
name="traffic5",
root="images",
ann_file="annotations/val.json",
test_final_ann_file="annotations/test.json",
remove_images_without_annotations=False,
)
self.num_classes = 5
# ------------------------ training cfg ---------------------- #
# self.basic_lr = 0.02 / 16
self.basic_lr = 0.002 / 16
self.max_epoch = 35
self.lr_decay_stages = [16, 21]
self.nr_images_epoch = 2226
self.warm_iters = 100
self.log_interval = 10
Net = models.FasterRCNN
Cfg = CustomerConfig
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import models
class CustomerConfig(models.FCOSConfig):
def __init__(self):
super().__init__()
# ------------------------ dataset cfg ---------------------- #
self.train_dataset = dict(
name="traffic5",
root="images",
ann_file="annotations/train.json",
remove_images_without_annotations=True,
)
self.test_dataset = dict(
name="traffic5",
root="images",
ann_file="annotations/val.json",
test_final_ann_file="annotations/test.json",
remove_images_without_annotations=False,
)
self.num_classes = 5
# ------------------------ training cfg ---------------------- #
self.basic_lr = 0.02 / 16
self.max_epoch = 24
self.lr_decay_stages = [16, 21]
self.nr_images_epoch = 2226
self.warm_iters = 100
self.log_interval = 10
Net = models.FCOS
Cfg = CustomerConfig
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .basic import *
from .det import *
_EXCLUDE = {}
__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .functional import *
from .nn import *
from .norm import *
_EXCLUDE = {}
__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional
import numpy as np
import megengine.distributed as dist
import megengine.functional as F
from megengine import Tensor
def get_padded_tensor(
array: Tensor, multiple_number: int = 32, pad_value: float = 0
) -> Tensor:
""" pad the nd-array to multiple stride of th e
Args:
array (Tensor):
the tensor with the shape of [batch, channel, height, width]
multiple_number (int):
make the height and width can be divided by multiple_number
pad_value (int): the value to be padded
Returns:
padded_array (Tensor)
"""
batch, chl, t_height, t_width = array.shape
padded_height = (
(t_height + multiple_number - 1) // multiple_number * multiple_number
)
padded_width = (t_width + multiple_number - 1) // multiple_number * multiple_number
padded_array = F.full(
(batch, chl, padded_height, padded_width), pad_value, dtype=array.dtype
)
ndim = array.ndim
if ndim == 4:
padded_array[:, :, :t_height, :t_width] = array
elif ndim == 3:
padded_array[:, :t_height, :t_width] = array
else:
raise Exception("Not supported tensor dim: %d" % ndim)
return padded_array
def safelog(x, eps=None):
if eps is None:
eps = np.finfo(x.dtype).eps
return F.log(F.maximum(x, eps))
def batched_nms(
boxes: Tensor, scores: Tensor, idxs: Tensor, iou_thresh: float, max_output: Optional[int] = None
) -> Tensor:
r"""
Performs non-maximum suppression (NMS) on the boxes according to
their intersection-over-union (IoU).
:param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on;
each box is expected to be in `(x1, y1, x2, y2)` format.
:param iou_thresh: ``IoU`` threshold for overlapping.
:param idxs: tensor of shape `(N,)`, the class indexs of boxes in the batch.
:param scores: tensor of shape `(N,)`, the score of boxes.
:return: indices of the elements that have been kept by NMS.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
x = np.zeros((100,4))
np.random.seed(42)
x[:,:2] = np.random.rand(100,2) * 20
x[:,2:] = np.random.rand(100,2) * 20 + 100
scores = tensor(np.random.rand(100))
idxs = tensor(np.random.randint(0, 10, 100))
inp = tensor(x)
result = batched_nms(inp, scores, idxs, iou_thresh=0.6)
print(result.numpy())
Outputs:
.. testoutput::
[75 41 99 98 69 64 11 27 35 18]
"""
assert (
boxes.ndim == 2 and boxes.shape[1] == 4
), "the expected shape of boxes is (N, 4)"
assert scores.ndim == 1, "the expected shape of scores is (N,)"
assert idxs.ndim == 1, "the expected shape of idxs is (N,)"
assert (
boxes.shape[0] == scores.shape[0] == idxs.shape[0]
), "number of boxes, scores and idxs are not matched"
idxs = idxs.detach()
max_coordinate = boxes.max()
offsets = idxs.astype("float32") * (max_coordinate + 1)
boxes = boxes + offsets.reshape(-1, 1)
return F.nn.nms(boxes, scores, iou_thresh, max_output)
def all_reduce_mean(array: Tensor) -> Tensor:
if dist.get_world_size() > 1:
array = dist.functional.all_reduce_sum(array) / dist.get_world_size()
return array
# -*- coding: utf-8 -*-
# Copyright 2019 - present, Facebook, Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
# ---------------------------------------------------------------------
from collections import namedtuple
import megengine.module as M
class Conv2d(M.Conv2d):
"""
A wrapper around :class:`megengine.module.Conv2d`.
"""
def __init__(self, *args, **kwargs):
"""
Extra keyword arguments supported in addition to
`megengine.module.Conv2d`.
Args:
norm (M.Module, optional): a normalization layer
activation (callable(Tensor) -> Tensor): a callable activation
function
"""
norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None)
super().__init__(*args, **kwargs)
self.norm = norm
self.activation = activation
def forward(self, x):
x = super().forward(x)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
class ShapeSpec(namedtuple("_ShapeSpec", ["channels", "height", "width", "stride"])):
"""
A simple structure that contains basic shape specification about a tensor.
Useful for getting the modules output channels when building the graph.
"""
def __new__(cls, channels=None, height=None, width=None, stride=None):
return super().__new__(cls, channels, height, width, stride)
# -*- coding: utf-8 -*-
# Copyright 2019 - present, Facebook, Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
# ---------------------------------------------------------------------
from functools import partial
import megengine.module as M
from megengine.module.normalization import GroupNorm, InstanceNorm, LayerNorm
def get_norm(norm):
"""
Args:
norm (str): currently support "BN", "SyncBN", "FrozenBN", "GN", "LN" and "IN"
Returns:
M.Module or None: the normalization layer
"""
if norm is None:
return None
norm = {
"BN": M.BatchNorm2d,
"SyncBN": M.SyncBatchNorm,
"FrozenBN": partial(M.BatchNorm2d, freeze=True),
"GN": GroupNorm,
"LN": LayerNorm,
"IN": InstanceNorm,
}[norm]
return norm
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .anchor import *
from .box_head import *
from .box_utils import *
from .fpn import *
from .loss import *
from .matcher import *
from .point_head import *
from .pooler import *
from .rcnn import *
from .rpn import *
from .sampling import *
_EXCLUDE = {}
__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from abc import ABCMeta, abstractmethod
from typing import List
import numpy as np
import megengine.functional as F
from megengine import Tensor, tensor
def meshgrid(x, y):
assert len(x.shape) == 1
assert len(y.shape) == 1
mesh_shape = (y.shape[0], x.shape[0])
mesh_x = F.broadcast_to(x, mesh_shape)
mesh_y = F.broadcast_to(y.reshape(-1, 1), mesh_shape)
return mesh_x, mesh_y
def create_anchor_grid(featmap_size, offsets, stride, device):
step_x, step_y = featmap_size
shift = offsets * stride
grid_x = F.arange(shift, step_x * stride + shift, step=stride, device=device)
grid_y = F.arange(shift, step_y * stride + shift, step=stride, device=device)
grids_x, grids_y = meshgrid(grid_y, grid_x)
return grids_x.reshape(-1), grids_y.reshape(-1)
class BaseAnchorGenerator(metaclass=ABCMeta):
"""base class for anchor generator.
"""
def __init__(self):
pass
@property
@abstractmethod
def anchor_dim(self):
pass
@abstractmethod
def generate_anchors_by_features(self, sizes, device) -> List[Tensor]:
pass
def __call__(self, featmaps):
feat_sizes = [fmap.shape[-2:] for fmap in featmaps]
return self.generate_anchors_by_features(feat_sizes, featmaps[0].device)
class AnchorBoxGenerator(BaseAnchorGenerator):
"""default anchor box generator, usually used in anchor-based methods.
This class generate anchors by feature map in level.
Args:
anchor_scales (list): anchor scales based on stride.
The practical anchor scale is anchor_scale * stride
anchor_ratios (list): anchor aspect ratios.
strides (list): strides of inputs.
offset (float): center point offset. default is 0.5.
"""
# pylint: disable=dangerous-default-value
def __init__(
self,
anchor_scales: list = [[32], [64], [128], [256], [512]],
anchor_ratios: list = [[0.5, 1, 2]],
strides: list = [4, 8, 16, 32, 64],
offset: float = 0.5,
):
super().__init__()
self.anchor_scales = np.array(anchor_scales, dtype=np.float32)
self.anchor_ratios = np.array(anchor_ratios, dtype=np.float32)
self.strides = strides
self.offset = offset
self.num_features = len(strides)
self.base_anchors = self._different_level_anchors(anchor_scales, anchor_ratios)
@property
def anchor_dim(self):
return 4
def _different_level_anchors(self, scales, ratios):
if len(scales) == 1:
scales *= self.num_features
assert len(scales) == self.num_features
if len(ratios) == 1:
ratios *= self.num_features
assert len(ratios) == self.num_features
return [
tensor(self.generate_base_anchors(scale, ratio))
for scale, ratio in zip(scales, ratios)
]
def generate_base_anchors(self, scales, ratios):
base_anchors = []
areas = [s ** 2.0 for s in scales]
for area in areas:
for ratio in ratios:
w = math.sqrt(area / ratio)
h = ratio * w
# center-based anchor
x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0
base_anchors.append([x0, y0, x1, y1])
return base_anchors
def generate_anchors_by_features(self, sizes, device):
all_anchors = []
assert len(sizes) == self.num_features, (
"input features expected {}, got {}".format(self.num_features, len(sizes))
)
for size, stride, base_anchor in zip(sizes, self.strides, self.base_anchors):
grid_x, grid_y = create_anchor_grid(size, self.offset, stride, device)
grids = F.stack([grid_x, grid_y, grid_x, grid_y], axis=1)
all_anchors.append(
(F.expand_dims(grids, axis=1) + F.expand_dims(base_anchor, axis=0)).reshape(-1, 4)
)
return all_anchors
class AnchorPointGenerator(BaseAnchorGenerator):
"""default anchor point generator, usually used in anchor-free methods.
This class generate anchors by feature map in level.
Args:
num_anchors (int): number of anchors per location
strides (list): strides of inputs.
offset (float): center point offset. default is 0.5.
"""
# pylint: disable=dangerous-default-value
def __init__(
self,
num_anchors: int = 1,
strides: list = [4, 8, 16, 32, 64],
offset: float = 0.5,
):
super().__init__()
self.num_anchors = num_anchors
self.strides = strides
self.offset = offset
self.num_features = len(strides)
@property
def anchor_dim(self):
return 2
def generate_anchors_by_features(self, sizes, device):
all_anchors = []
assert len(sizes) == self.num_features, (
"input features expected {}, got {}".format(self.num_features, len(sizes))
)
for size, stride in zip(sizes, self.strides):
grid_x, grid_y = create_anchor_grid(size, self.offset, stride, device)
grids = F.stack([grid_x, grid_y], axis=1)
all_anchors.append(
F.broadcast_to(
F.expand_dims(grids, axis=1), (grids.shape[0], self.num_anchors, 2)
).reshape(-1, 2)
) # FIXME: need F.repeat
return all_anchors
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from typing import List
import megengine.module as M
from megengine import Tensor
import layers
class BoxHead(M.Module):
"""
The head used when anchor boxes are adopted for object classification and box regression.
"""
def __init__(self, cfg, input_shape: List[layers.ShapeSpec]):
super().__init__()
in_channels = input_shape[0].channels
num_classes = cfg.num_classes
num_convs = 4
prior_prob = cfg.cls_prior_prob
num_anchors = [
len(cfg.anchor_scales[i]) * len(cfg.anchor_ratios[i])
for i in range(len(input_shape))
]
assert (
len(set(num_anchors)) == 1
), "not support different number of anchors between levels"
num_anchors = num_anchors[0]
cls_subnet = []
bbox_subnet = []
for _ in range(num_convs):
cls_subnet.append(
M.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
)
cls_subnet.append(M.ReLU())
bbox_subnet.append(
M.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
)
bbox_subnet.append(M.ReLU())
self.cls_subnet = M.Sequential(*cls_subnet)
self.bbox_subnet = M.Sequential(*bbox_subnet)
self.cls_score = M.Conv2d(
in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1
)
self.bbox_pred = M.Conv2d(
in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1
)
# Initialization
for modules in [
self.cls_subnet, self.bbox_subnet, self.cls_score, self.bbox_pred
]:
for layer in modules.modules():
if isinstance(layer, M.Conv2d):
M.init.normal_(layer.weight, mean=0, std=0.01)
M.init.fill_(layer.bias, 0)
# Use prior in model initialization to improve stability
bias_value = -math.log((1 - prior_prob) / prior_prob)
M.init.fill_(self.cls_score.bias, bias_value)
def forward(self, features: List[Tensor]):
logits, offsets = [], []
for feature in features:
logits.append(self.cls_score(self.cls_subnet(feature)))
offsets.append(self.bbox_pred(self.bbox_subnet(feature)))
return logits, offsets
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABCMeta, abstractmethod
import numpy as np
import megengine.functional as F
from megengine import Tensor
class BoxCoderBase(metaclass=ABCMeta):
"""Boxcoder class.
"""
def __init__(self):
pass
@abstractmethod
def encode(self) -> Tensor:
pass
@abstractmethod
def decode(self) -> Tensor:
pass
class BoxCoder(BoxCoderBase, metaclass=ABCMeta):
# pylint: disable=dangerous-default-value
def __init__(
self,
reg_mean=[0.0, 0.0, 0.0, 0.0],
reg_std=[1.0, 1.0, 1.0, 1.0],
):
"""
Args:
reg_mean(np.ndarray): [x0_mean, x1_mean, y0_mean, y1_mean] or None
reg_std(np.ndarray): [x0_std, x1_std, y0_std, y1_std] or None
"""
self.reg_mean = np.array(reg_mean, dtype=np.float32)[None, :]
self.reg_std = np.array(reg_std, dtype=np.float32)[None, :]
super().__init__()
@staticmethod
def _box_ltrb_to_cs_opr(bbox, addaxis=None):
""" transform the left-top right-bottom encoding bounding boxes
to center and size encodings"""
bbox_width = bbox[:, 2] - bbox[:, 0]
bbox_height = bbox[:, 3] - bbox[:, 1]
bbox_ctr_x = bbox[:, 0] + 0.5 * bbox_width
bbox_ctr_y = bbox[:, 1] + 0.5 * bbox_height
if addaxis is None:
return bbox_width, bbox_height, bbox_ctr_x, bbox_ctr_y
else:
return (
F.expand_dims(bbox_width, addaxis),
F.expand_dims(bbox_height, addaxis),
F.expand_dims(bbox_ctr_x, addaxis),
F.expand_dims(bbox_ctr_y, addaxis),
)
def encode(self, bbox: Tensor, gt: Tensor) -> Tensor:
bbox_width, bbox_height, bbox_ctr_x, bbox_ctr_y = self._box_ltrb_to_cs_opr(bbox)
gt_width, gt_height, gt_ctr_x, gt_ctr_y = self._box_ltrb_to_cs_opr(gt)
target_dx = (gt_ctr_x - bbox_ctr_x) / bbox_width
target_dy = (gt_ctr_y - bbox_ctr_y) / bbox_height
target_dw = F.log(gt_width / bbox_width)
target_dh = F.log(gt_height / bbox_height)
target = F.stack([target_dx, target_dy, target_dw, target_dh], axis=1)
target -= self.reg_mean
target /= self.reg_std
return target
def decode(self, anchors: Tensor, deltas: Tensor) -> Tensor:
deltas *= self.reg_std
deltas += self.reg_mean
(
anchor_width,
anchor_height,
anchor_ctr_x,
anchor_ctr_y,
) = self._box_ltrb_to_cs_opr(anchors, 1)
pred_ctr_x = anchor_ctr_x + deltas[:, 0::4] * anchor_width
pred_ctr_y = anchor_ctr_y + deltas[:, 1::4] * anchor_height
pred_width = anchor_width * F.exp(deltas[:, 2::4])
pred_height = anchor_height * F.exp(deltas[:, 3::4])
pred_x1 = pred_ctr_x - 0.5 * pred_width
pred_y1 = pred_ctr_y - 0.5 * pred_height
pred_x2 = pred_ctr_x + 0.5 * pred_width
pred_y2 = pred_ctr_y + 0.5 * pred_height
pred_box = F.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=2)
pred_box = pred_box.reshape(pred_box.shape[0], -1)
return pred_box
class PointCoder(BoxCoderBase, metaclass=ABCMeta):
def encode(self, point: Tensor, gt: Tensor) -> Tensor:
return F.concat([point - gt[..., :2], gt[..., 2:] - point], axis=-1)
def decode(self, anchors: Tensor, deltas: Tensor) -> Tensor:
return F.stack([
F.expand_dims(anchors[:, 0], axis=1) - deltas[:, 0::4],
F.expand_dims(anchors[:, 1], axis=1) - deltas[:, 1::4],
F.expand_dims(anchors[:, 0], axis=1) + deltas[:, 2::4],
F.expand_dims(anchors[:, 1], axis=1) + deltas[:, 3::4],
], axis=2).reshape(deltas.shape)
def get_iou(boxes1: Tensor, boxes2: Tensor, return_ioa=False) -> Tensor:
"""
Given two lists of boxes of size N and M,
compute the IoU (intersection over union)
between __all__ N x M pairs of boxes.
The box order must be (xmin, ymin, xmax, ymax).
Args:
boxes1 (Tensor): boxes tensor with shape (N, 4)
boxes2 (Tensor): boxes tensor with shape (M, 4)
return_ioa (Bool): wheather return Intersection over Boxes1 or not, default: False
Returns:
iou (Tensor): IoU matrix, shape (N,M).
"""
b_box1 = F.expand_dims(boxes1, axis=1)
b_box2 = F.expand_dims(boxes2, axis=0)
iw = F.minimum(b_box1[:, :, 2], b_box2[:, :, 2]) - F.maximum(
b_box1[:, :, 0], b_box2[:, :, 0]
)
ih = F.minimum(b_box1[:, :, 3], b_box2[:, :, 3]) - F.maximum(
b_box1[:, :, 1], b_box2[:, :, 1]
)
inter = F.maximum(iw, 0) * F.maximum(ih, 0)
area_box1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
area_box2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
union = F.expand_dims(area_box1, axis=1) + F.expand_dims(area_box2, axis=0) - inter
overlaps = F.maximum(inter / union, 0)
if return_ioa:
ioa = F.maximum(inter / area_box1, 0)
return overlaps, ioa
return overlaps
def get_clipped_boxes(boxes, hw):
""" Clip the boxes into the image region."""
# x1 >=0
box_x1 = F.clip(boxes[:, 0::4], lower=0, upper=hw[1])
# y1 >=0
box_y1 = F.clip(boxes[:, 1::4], lower=0, upper=hw[0])
# x2 < im_info[1]
box_x2 = F.clip(boxes[:, 2::4], lower=0, upper=hw[1])
# y2 < im_info[0]
box_y2 = F.clip(boxes[:, 3::4], lower=0, upper=hw[0])
clip_box = F.concat([box_x1, box_y1, box_x2, box_y2], axis=1)
return clip_box
def filter_boxes(boxes, size=0):
width = boxes[:, 2] - boxes[:, 0]
height = boxes[:, 3] - boxes[:, 1]
keep = (width > size) & (height > size)
return keep
# -*- coding: utf-8 -*-
# Copyright 2019 - present, Facebook, Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
# ---------------------------------------------------------------------
import math
from typing import List
import megengine.functional as F
import megengine.module as M
import layers
class FPN(M.Module):
"""
This module implements Feature Pyramid Network.
It creates pyramid features built on top of some input feature maps which
are produced by the backbone networks like ResNet.
"""
# pylint: disable=dangerous-default-value
def __init__(
self,
bottom_up: M.Module,
in_features: List[str],
out_channels: int = 256,
norm: str = None,
top_block: M.Module = None,
strides: List[int] = [8, 16, 32],
channels: List[int] = [512, 1024, 2048],
):
"""
Args:
bottom_up (M.Module): module representing the bottom up sub-network.
it generates multi-scale feature maps which formatted as a
dict like {'res3': res3_feature, 'res4': res4_feature}
in_features (list[str]): list of input feature maps keys coming
from the `bottom_up` which will be used in FPN.
e.g. ['res3', 'res4', 'res5']
out_channels (int): number of channels used in the output
feature maps.
norm (str): the normalization type.
top_block (nn.Module or None): the module build upon FPN layers.
"""
super(FPN, self).__init__()
in_strides = strides
in_channels = channels
norm = layers.get_norm(norm)
use_bias = norm is None
self.lateral_convs = list()
self.output_convs = list()
for idx, in_channels in enumerate(in_channels):
lateral_norm = None if norm is None else norm(out_channels)
output_norm = None if norm is None else norm(out_channels)
lateral_conv = layers.Conv2d(
in_channels,
out_channels,
kernel_size=1,
bias=use_bias,
norm=lateral_norm,
)
output_conv = layers.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=use_bias,
norm=output_norm,
)
M.init.msra_normal_(lateral_conv.weight, mode="fan_in")
M.init.msra_normal_(output_conv.weight, mode="fan_in")
if use_bias:
M.init.fill_(lateral_conv.bias, 0)
M.init.fill_(output_conv.bias, 0)
stage = int(math.log2(in_strides[idx]))
setattr(self, "fpn_lateral{}".format(stage), lateral_conv)
setattr(self, "fpn_output{}".format(stage), output_conv)
self.lateral_convs.insert(0, lateral_conv)
self.output_convs.insert(0, output_conv)
self.top_block = top_block
self.in_features = in_features
self.bottom_up = bottom_up
# follow the common practices, FPN features are named to "p<stage>",
# like ["p2", "p3", ..., "p6"]
self._out_feature_strides = {
"p{}".format(int(math.log2(s))): s for s in in_strides
}
# top block output feature maps.
if self.top_block is not None:
for s in range(stage, stage + self.top_block.num_levels):
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
self._out_features = list(sorted(self._out_feature_strides.keys()))
self._out_feature_channels = {k: out_channels for k in self._out_features}
def forward(self, x):
bottom_up_features = self.bottom_up.extract_features(x)
x = [bottom_up_features[f] for f in self.in_features[::-1]]
results = []
prev_features = self.lateral_convs[0](x[0])
results.append(self.output_convs[0](prev_features))
for features, lateral_conv, output_conv in zip(
x[1:], self.lateral_convs[1:], self.output_convs[1:]
):
top_down_features = F.nn.interpolate(
prev_features, features.shape[2:], mode="BILINEAR"
)
lateral_features = lateral_conv(features)
prev_features = lateral_features + top_down_features
results.insert(0, output_conv(prev_features))
if self.top_block is not None:
top_block_in_feature = bottom_up_features.get(
self.top_block.in_feature, None
)
if top_block_in_feature is None:
top_block_in_feature = results[
self._out_features.index(self.top_block.in_feature)
]
results.extend(self.top_block(top_block_in_feature))
return dict(zip(self._out_features, results))
def output_shape(self):
return {
name: layers.ShapeSpec(
channels=self._out_feature_channels[name],
stride=self._out_feature_strides[name],
)
for name in self._out_features
}
class FPNP6(M.Module):
"""
used in FPN, generate a downsampled P6 feature from P5.
"""
def __init__(self, in_feature="p5"):
super().__init__()
self.num_levels = 1
self.in_feature = in_feature
self.pool = M.MaxPool2d(kernel_size=1, stride=2, padding=0)
def forward(self, x):
return [self.pool(x)]
class LastLevelP6P7(M.Module):
"""
This module is used in RetinaNet to generate extra layers, P6 and P7 from
C5 feature.
"""
def __init__(self, in_channels: int, out_channels: int, in_feature="res5"):
super().__init__()
self.num_levels = 2
if in_feature == "p5":
assert in_channels == out_channels
self.in_feature = in_feature
self.p6 = M.Conv2d(in_channels, out_channels, 3, 2, 1)
self.p7 = M.Conv2d(out_channels, out_channels, 3, 2, 1)
def forward(self, x):
p6 = self.p6(x)
p7 = self.p7(F.relu(p6))
return [p6, p7]
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import megengine.functional as F
from megengine import Tensor
def binary_cross_entropy(logits: Tensor, targets: Tensor) -> Tensor:
r"""Binary Cross Entropy
Args:
logits (Tensor):
the predicted logits
targets (Tensor):
the assigned targets with the same shape as logits
Returns:
the calculated binary cross entropy.
"""
return -(targets * F.logsigmoid(logits) + (1 - targets) * F.logsigmoid(-logits))
def sigmoid_focal_loss(
logits: Tensor, targets: Tensor, alpha: float = -1, gamma: float = 0,
) -> Tensor:
r"""Focal Loss for Dense Object Detection:
<https://arxiv.org/pdf/1708.02002.pdf>
.. math::
FL(p_t) = -\alpha_t(1-p_t)^\gamma \log(p_t)
Args:
logits (Tensor):
the predicted logits
targets (Tensor):
the assigned targets with the same shape as logits
alpha (float):
parameter to mitigate class imbalance. Default: -1
gamma (float):
parameter to mitigate easy/hard loss imbalance. Default: 0
Returns:
the calculated focal loss.
"""
scores = F.sigmoid(logits)
loss = binary_cross_entropy(logits, targets)
if gamma != 0:
loss *= (targets * (1 - scores) + (1 - targets) * scores) ** gamma
if alpha >= 0:
loss *= targets * alpha + (1 - targets) * (1 - alpha)
return loss
def smooth_l1_loss(pred: Tensor, target: Tensor, beta: float = 1.0) -> Tensor:
r"""Smooth L1 Loss
Args:
pred (Tensor):
the predictions
target (Tensor):
the assigned targets with the same shape as pred
beta (int):
the parameter of smooth l1 loss.
Returns:
the calculated smooth l1 loss.
"""
x = pred - target
abs_x = F.abs(x)
if beta < 1e-5:
loss = abs_x
else:
in_loss = 0.5 * x ** 2 / beta
out_loss = abs_x - 0.5 * beta
loss = F.where(abs_x < beta, in_loss, out_loss)
return loss
def iou_loss(
pred: Tensor, target: Tensor, box_mode: str = "xyxy", loss_type: str = "iou", eps: float = 1e-8,
) -> Tensor:
if box_mode == "ltrb":
pred = F.concat([-pred[..., :2], pred[..., 2:]], axis=-1)
target = F.concat([-target[..., :2], target[..., 2:]], axis=-1)
elif box_mode != "xyxy":
raise NotImplementedError
pred_area = F.maximum(pred[..., 2] - pred[..., 0], 0) * F.maximum(
pred[..., 3] - pred[..., 1], 0
)
target_area = F.maximum(target[..., 2] - target[..., 0], 0) * F.maximum(
target[..., 3] - target[..., 1], 0
)
w_intersect = F.maximum(
F.minimum(pred[..., 2], target[..., 2]) - F.maximum(pred[..., 0], target[..., 0]), 0
)
h_intersect = F.maximum(
F.minimum(pred[..., 3], target[..., 3]) - F.maximum(pred[..., 1], target[..., 1]), 0
)
area_intersect = w_intersect * h_intersect
area_union = pred_area + target_area - area_intersect
ious = area_intersect / F.maximum(area_union, eps)
if loss_type == "iou":
loss = -F.log(F.maximum(ious, eps))
elif loss_type == "linear_iou":
loss = 1 - ious
elif loss_type == "giou":
g_w_intersect = F.maximum(pred[..., 2], target[..., 2]) - F.minimum(
pred[..., 0], target[..., 0]
)
g_h_intersect = F.maximum(pred[..., 3], target[..., 3]) - F.minimum(
pred[..., 1], target[..., 1]
)
ac_union = g_w_intersect * g_h_intersect
gious = ious - (ac_union - area_union) / F.maximum(ac_union, eps)
loss = 1 - gious
return loss
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import megengine.functional as F
class Matcher:
def __init__(self, thresholds, labels, allow_low_quality_matches=False):
assert len(thresholds) + 1 == len(labels), "thresholds and labels are not matched"
assert all(low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:]))
thresholds.append(float("inf"))
thresholds.insert(0, -float("inf"))
self.thresholds = thresholds
self.labels = labels
self.allow_low_quality_matches = allow_low_quality_matches
def __call__(self, matrix):
"""
matrix(tensor): A two dim tensor with shape of (N, M). N is number of GT-boxes,
while M is the number of anchors in detection.
"""
assert len(matrix.shape) == 2
max_scores = matrix.max(axis=0)
match_indices = F.argmax(matrix, axis=0)
# default ignore label: -1
labels = F.full_like(match_indices, -1)
for label, low, high in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
mask = (max_scores >= low) & (max_scores < high)
labels[mask] = label
if self.allow_low_quality_matches:
mask = (matrix == F.max(matrix, axis=1, keepdims=True)).sum(axis=0) > 0
labels[mask] = 1
return match_indices, labels
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from typing import List
import numpy as np
import megengine as mge
import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.module.normalization import GroupNorm
import layers
class PointHead(M.Module):
"""
The head used when anchor points are adopted for object classification and box regression.
"""
def __init__(self, cfg, input_shape: List[layers.ShapeSpec]):
super().__init__()
self.stride_list = cfg.stride
in_channels = input_shape[0].channels
num_classes = cfg.num_classes
num_convs = 4
prior_prob = cfg.cls_prior_prob
num_anchors = [cfg.num_anchors] * len(input_shape)
assert (
len(set(num_anchors)) == 1
), "not support different number of anchors between levels"
num_anchors = num_anchors[0]
cls_subnet = []
bbox_subnet = []
for _ in range(num_convs):
cls_subnet.append(
M.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
)
cls_subnet.append(GroupNorm(32, in_channels))
cls_subnet.append(M.ReLU())
bbox_subnet.append(
M.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
)
bbox_subnet.append(GroupNorm(32, in_channels))
bbox_subnet.append(M.ReLU())
self.cls_subnet = M.Sequential(*cls_subnet)
self.bbox_subnet = M.Sequential(*bbox_subnet)
self.cls_score = M.Conv2d(
in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1
)
self.bbox_pred = M.Conv2d(
in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1
)
self.ctrness = M.Conv2d(
in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1
)
# Initialization
for modules in [
self.cls_subnet, self.bbox_subnet, self.cls_score, self.bbox_pred,
self.ctrness
]:
for layer in modules.modules():
if isinstance(layer, M.Conv2d):
M.init.normal_(layer.weight, mean=0, std=0.01)
M.init.fill_(layer.bias, 0)
# Use prior in model initialization to improve stability
bias_value = -math.log((1 - prior_prob) / prior_prob)
M.init.fill_(self.cls_score.bias, bias_value)
self.scale_list = mge.Parameter(np.ones(len(self.stride_list), dtype=np.float32))
def forward(self, features: List[Tensor]):
logits, offsets, ctrness = [], [], []
for feature, scale, stride in zip(features, self.scale_list, self.stride_list):
logits.append(self.cls_score(self.cls_subnet(feature)))
bbox_subnet = self.bbox_subnet(feature)
offsets.append(F.relu(self.bbox_pred(bbox_subnet) * scale) * stride)
ctrness.append(self.ctrness(bbox_subnet))
return logits, offsets, ctrness
# -*- coding:utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
import numpy as np
import megengine.functional as F
def roi_pool(
rpn_fms, rois, stride, pool_shape, pooler_type="roi_align",
):
rois = rois.detach()
assert len(stride) == len(rpn_fms)
canonical_level = 4
canonical_box_size = 224
min_level = int(math.log2(stride[0]))
max_level = int(math.log2(stride[-1]))
num_fms = len(rpn_fms)
box_area = (rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2])
assigned_level = F.floor(
canonical_level + F.log(F.sqrt(box_area) / canonical_box_size) / np.log(2).astype("float32")
).astype("int32")
assigned_level = F.minimum(assigned_level, max_level)
assigned_level = F.maximum(assigned_level, min_level)
assigned_level = assigned_level - min_level
# avoid empty assignment
assigned_level = F.concat(
[assigned_level, F.arange(num_fms, dtype="int32", device=assigned_level.device)],
)
rois = F.concat([rois, F.zeros((num_fms, rois.shape[-1]))])
pool_list, inds_list = [], []
for i in range(num_fms):
_, inds = F.cond_take(assigned_level == i, assigned_level)
level_rois = rois[inds]
if pooler_type == "roi_pool":
pool_fm = F.nn.roi_pooling(
rpn_fms[i], level_rois, pool_shape, mode="max", scale=1.0 / stride[i]
)
elif pooler_type == "roi_align":
pool_fm = F.nn.roi_align(
rpn_fms[i],
level_rois,
pool_shape,
mode="average",
spatial_scale=1.0 / stride[i],
sample_points=2,
aligned=True,
)
pool_list.append(pool_fm)
inds_list.append(inds)
fm_order = F.argsort(F.concat(inds_list, axis=0))
pool_feature = F.concat(pool_list, axis=0)
pool_feature = pool_feature[fm_order][:-num_fms]
return pool_feature
# -*- coding:utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import megengine.functional as F
import megengine.module as M
import layers
class RCNN(M.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.box_coder = layers.BoxCoder(cfg.rcnn_reg_mean, cfg.rcnn_reg_std)
# roi head
self.in_features = cfg.rcnn_in_features
self.stride = cfg.rcnn_stride
self.pooling_method = cfg.pooling_method
self.pooling_size = cfg.pooling_size
self.fc1 = M.Linear(256 * self.pooling_size[0] * self.pooling_size[1], 1024)
self.fc2 = M.Linear(1024, 1024)
for l in [self.fc1, self.fc2]:
M.init.normal_(l.weight, std=0.01)
M.init.fill_(l.bias, 0)
# box predictor
self.pred_cls = M.Linear(1024, cfg.num_classes + 1)
self.pred_delta = M.Linear(1024, cfg.num_classes * 4)
M.init.normal_(self.pred_cls.weight, std=0.01)
M.init.normal_(self.pred_delta.weight, std=0.001)
for l in [self.pred_cls, self.pred_delta]:
M.init.fill_(l.bias, 0)
def forward(self, fpn_fms, rcnn_rois, im_info=None, gt_boxes=None):
rcnn_rois, labels, bbox_targets = self.get_ground_truth(
rcnn_rois, im_info, gt_boxes
)
fpn_fms = [fpn_fms[x] for x in self.in_features]
pool_features = layers.roi_pool(
fpn_fms, rcnn_rois, self.stride, self.pooling_size, self.pooling_method,
)
flatten_feature = F.flatten(pool_features, start_axis=1)
roi_feature = F.relu(self.fc1(flatten_feature))
roi_feature = F.relu(self.fc2(roi_feature))
pred_logits = self.pred_cls(roi_feature)
pred_offsets = self.pred_delta(roi_feature)
if self.training:
# loss for rcnn classification
loss_rcnn_cls = F.loss.cross_entropy(pred_logits, labels, axis=1)
# loss for rcnn regression
pred_offsets = pred_offsets.reshape(-1, self.cfg.num_classes, 4)
num_samples = labels.shape[0]
fg_mask = labels > 0
loss_rcnn_bbox = layers.smooth_l1_loss(
pred_offsets[fg_mask, labels[fg_mask] - 1],
bbox_targets[fg_mask],
self.cfg.rcnn_smooth_l1_beta,
).sum() / F.maximum(num_samples, 1)
loss_dict = {
"loss_rcnn_cls": loss_rcnn_cls,
"loss_rcnn_bbox": loss_rcnn_bbox,
}
return loss_dict
else:
# slice 1 for removing background
pred_scores = F.softmax(pred_logits, axis=1)[:, 1:]
pred_offsets = pred_offsets.reshape(-1, 4)
target_shape = (rcnn_rois.shape[0], self.cfg.num_classes, 4)
# rois (N, 4) -> (N, 1, 4) -> (N, 80, 4) -> (N * 80, 4)
base_rois = F.broadcast_to(
F.expand_dims(rcnn_rois[:, 1:5], axis=1), target_shape).reshape(-1, 4)
pred_bbox = self.box_coder.decode(base_rois, pred_offsets)
return pred_bbox, pred_scores
def get_ground_truth(self, rpn_rois, im_info, gt_boxes):
if not self.training:
return rpn_rois, None, None
return_rois = []
return_labels = []
return_bbox_targets = []
# get per image proposals and gt_boxes
for bid in range(gt_boxes.shape[0]):
num_valid_boxes = im_info[bid, 4].astype("int32")
gt_boxes_per_img = gt_boxes[bid, :num_valid_boxes, :]
batch_inds = F.full((gt_boxes_per_img.shape[0], 1), bid)
gt_rois = F.concat([batch_inds, gt_boxes_per_img[:, :4]], axis=1)
batch_roi_mask = rpn_rois[:, 0] == bid
# all_rois : [batch_id, x1, y1, x2, y2]
all_rois = F.concat([rpn_rois[batch_roi_mask], gt_rois])
overlaps = layers.get_iou(all_rois[:, 1:], gt_boxes_per_img)
max_overlaps = overlaps.max(axis=1)
gt_assignment = F.argmax(overlaps, axis=1).astype("int32")
labels = gt_boxes_per_img[gt_assignment, 4]
# ---------------- get the fg/bg labels for each roi ---------------#
fg_mask = (max_overlaps >= self.cfg.fg_threshold) & (labels >= 0)
bg_mask = (
(max_overlaps >= self.cfg.bg_threshold_low)
& (max_overlaps < self.cfg.bg_threshold_high)
)
num_fg_rois = int(self.cfg.num_rois * self.cfg.fg_ratio)
fg_inds_mask = layers.sample_labels(fg_mask, num_fg_rois, True, False)
num_bg_rois = int(self.cfg.num_rois - fg_inds_mask.sum())
bg_inds_mask = layers.sample_labels(bg_mask, num_bg_rois, True, False)
labels[bg_inds_mask] = 0
keep_mask = fg_inds_mask | bg_inds_mask
labels = labels[keep_mask].astype("int32")
rois = all_rois[keep_mask]
target_boxes = gt_boxes_per_img[gt_assignment[keep_mask], :4]
bbox_targets = self.box_coder.encode(rois[:, 1:], target_boxes)
bbox_targets = bbox_targets.reshape(-1, 4)
return_rois.append(rois)
return_labels.append(labels)
return_bbox_targets.append(bbox_targets)
return (
F.concat(return_rois, axis=0).detach(),
F.concat(return_labels, axis=0).detach(),
F.concat(return_bbox_targets, axis=0).detach(),
)
This diff is collapsed.
# -*- coding:utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import megengine.functional as F
import megengine.module as M
import layers
class RPN(M.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.box_coder = layers.BoxCoder(cfg.rpn_reg_mean, cfg.rpn_reg_std)
# check anchor settings
assert len(set(len(x) for x in cfg.anchor_scales)) == 1
assert len(set(len(x) for x in cfg.anchor_ratios)) == 1
self.num_cell_anchors = len(cfg.anchor_scales[0]) * len(cfg.anchor_ratios[0])
self.stride_list = np.array(cfg.rpn_stride).astype(np.float32)
rpn_channel = cfg.rpn_channel
self.in_features = cfg.rpn_in_features
self.anchor_generator = layers.AnchorBoxGenerator(
anchor_scales=cfg.anchor_scales,
anchor_ratios=cfg.anchor_ratios,
strides=cfg.rpn_stride,
offset=self.cfg.anchor_offset,
)
self.matcher = layers.Matcher(
cfg.match_thresholds, cfg.match_labels, cfg.match_allow_low_quality
)
self.rpn_conv = M.Conv2d(256, rpn_channel, kernel_size=3, stride=1, padding=1)
self.rpn_cls_score = M.Conv2d(
rpn_channel, self.num_cell_anchors, kernel_size=1, stride=1
)
self.rpn_bbox_offsets = M.Conv2d(
rpn_channel, self.num_cell_anchors * 4, kernel_size=1, stride=1
)
for l in [self.rpn_conv, self.rpn_cls_score, self.rpn_bbox_offsets]:
M.init.normal_(l.weight, std=0.01)
M.init.fill_(l.bias, 0)
def forward(self, features, im_info, boxes=None):
# prediction
features = [features[x] for x in self.in_features]
# get anchors
anchors_list = self.anchor_generator(features)
pred_cls_logit_list = []
pred_bbox_offset_list = []
for x in features:
t = F.relu(self.rpn_conv(x))
scores = self.rpn_cls_score(t)
pred_cls_logit_list.append(
scores.reshape(
scores.shape[0],
self.num_cell_anchors,
scores.shape[2],
scores.shape[3],
)
)
bbox_offsets = self.rpn_bbox_offsets(t)
pred_bbox_offset_list.append(
bbox_offsets.reshape(
bbox_offsets.shape[0],
self.num_cell_anchors,
4,
bbox_offsets.shape[2],
bbox_offsets.shape[3],
)
)
# get rois from the predictions
rpn_rois = self.find_top_rpn_proposals(
pred_cls_logit_list, pred_bbox_offset_list, anchors_list, im_info
)
if self.training:
rpn_labels, rpn_offsets = self.get_ground_truth(
anchors_list, boxes, im_info[:, 4].astype(np.int32)
)
pred_cls_logits, pred_bbox_offsets = self.merge_rpn_score_box(
pred_cls_logit_list, pred_bbox_offset_list
)
fg_mask = rpn_labels > 0
valid_mask = rpn_labels >= 0
num_valid = valid_mask.sum()
# rpn classification loss
loss_rpn_cls = F.loss.binary_cross_entropy(
pred_cls_logits[valid_mask], rpn_labels[valid_mask]
)
# rpn regression loss
loss_rpn_bbox = layers.smooth_l1_loss(
pred_bbox_offsets[fg_mask],
rpn_offsets[fg_mask],
self.cfg.rpn_smooth_l1_beta,
).sum() / F.maximum(num_valid, 1)
loss_dict = {"loss_rpn_cls": loss_rpn_cls, "loss_rpn_bbox": loss_rpn_bbox}
return rpn_rois, loss_dict
else:
return rpn_rois
def find_top_rpn_proposals(
self, rpn_cls_score_list, rpn_bbox_offset_list, anchors_list, im_info
):
prev_nms_top_n = (
self.cfg.train_prev_nms_top_n
if self.training
else self.cfg.test_prev_nms_top_n
)
post_nms_top_n = (
self.cfg.train_post_nms_top_n
if self.training
else self.cfg.test_post_nms_top_n
)
return_rois = []
for bid in range(im_info.shape[0]):
batch_proposal_list = []
batch_score_list = []
batch_level_list = []
for l, (rpn_cls_score, rpn_bbox_offset, anchors) in enumerate(
zip(rpn_cls_score_list, rpn_bbox_offset_list, anchors_list)
):
# get proposals and scores
offsets = rpn_bbox_offset[bid].transpose(2, 3, 0, 1).reshape(-1, 4)
proposals = self.box_coder.decode(anchors, offsets)
scores = rpn_cls_score[bid].transpose(1, 2, 0).flatten()
scores.detach()
# prev nms top n
scores, order = F.topk(scores, descending=True, k=prev_nms_top_n)
proposals = proposals[order]
batch_proposal_list.append(proposals)
batch_score_list.append(scores)
batch_level_list.append(F.full_like(scores, l))
# gather proposals, scores, level
proposals = F.concat(batch_proposal_list, axis=0)
scores = F.concat(batch_score_list, axis=0)
levels = F.concat(batch_level_list, axis=0)
proposals = layers.get_clipped_boxes(proposals, im_info[bid])
# filter invalid proposals and apply total level nms
keep_mask = layers.filter_boxes(proposals)
proposals = proposals[keep_mask]
scores = scores[keep_mask]
levels = levels[keep_mask]
nms_keep_inds = layers.batched_nms(
proposals, scores, levels, self.cfg.rpn_nms_threshold, post_nms_top_n
)
# generate rois to rcnn head, rois shape (N, 5), info [batch_id, x1, y1, x2, y2]
rois = F.concat([proposals, scores.reshape(-1, 1)], axis=1)
rois = rois[nms_keep_inds]
batch_inds = F.full((rois.shape[0], 1), bid)
batch_rois = F.concat([batch_inds, rois[:, :4]], axis=1)
return_rois.append(batch_rois)
return_rois = F.concat(return_rois, axis=0)
return return_rois.detach()
def merge_rpn_score_box(self, rpn_cls_score_list, rpn_bbox_offset_list):
final_rpn_cls_score_list = []
final_rpn_bbox_offset_list = []
for bid in range(rpn_cls_score_list[0].shape[0]):
batch_rpn_cls_score_list = []
batch_rpn_bbox_offset_list = []
for i in range(len(self.in_features)):
rpn_cls_scores = rpn_cls_score_list[i][bid].transpose(1, 2, 0).flatten()
rpn_bbox_offsets = (
rpn_bbox_offset_list[i][bid].transpose(2, 3, 0, 1).reshape(-1, 4)
)
batch_rpn_cls_score_list.append(rpn_cls_scores)
batch_rpn_bbox_offset_list.append(rpn_bbox_offsets)
batch_rpn_cls_scores = F.concat(batch_rpn_cls_score_list, axis=0)
batch_rpn_bbox_offsets = F.concat(batch_rpn_bbox_offset_list, axis=0)
final_rpn_cls_score_list.append(batch_rpn_cls_scores)
final_rpn_bbox_offset_list.append(batch_rpn_bbox_offsets)
final_rpn_cls_scores = F.concat(final_rpn_cls_score_list, axis=0)
final_rpn_bbox_offsets = F.concat(final_rpn_bbox_offset_list, axis=0)
return final_rpn_cls_scores, final_rpn_bbox_offsets
def get_ground_truth(self, anchors_list, batched_gt_boxes, batched_num_gts):
anchors = F.concat(anchors_list, axis=0)
labels_list = []
offsets_list = []
for bid in range(batched_gt_boxes.shape[0]):
gt_boxes = batched_gt_boxes[bid, :batched_num_gts[bid]]
overlaps = layers.get_iou(gt_boxes[:, :4], anchors)
matched_indices, labels = self.matcher(overlaps)
offsets = self.box_coder.encode(anchors, gt_boxes[matched_indices, :4])
# sample positive labels
num_positive = int(self.cfg.num_sample_anchors * self.cfg.positive_anchor_ratio)
labels = layers.sample_labels(labels, num_positive, 1, -1)
# sample negative labels
num_positive = (labels == 1).sum().astype(np.int32)
num_negative = self.cfg.num_sample_anchors - num_positive
labels = layers.sample_labels(labels, num_negative, 0, -1)
labels_list.append(labels)
offsets_list.append(offsets)
return (
F.concat(labels_list, axis=0).detach(),
F.concat(offsets_list, axis=0).detach(),
)
# -*- coding:utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import megengine.functional as F
from megengine.random import uniform
def sample_labels(labels, num_samples, label_value, ignore_label=-1):
"""sample N labels with label value = sample_labels
Args:
labels(Tensor): shape of label is (N,)
num_samples(int):
label_value(int):
Returns:
label(Tensor): label after sampling
"""
assert labels.ndim == 1, "Only tensor of dim 1 is supported."
mask = (labels == label_value)
num_valid = mask.sum()
if num_valid <= num_samples:
return labels
random_tensor = F.zeros_like(labels).astype("float32")
random_tensor[mask] = uniform(size=num_valid)
_, invalid_inds = F.topk(random_tensor, k=num_samples - num_valid)
labels[invalid_inds] = ignore_label
return labels
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .faster_rcnn import *
from .fcos import *
from .atss import *
_EXCLUDE = {}
__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
This diff is collapsed.
# -*- coding:utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import megengine.functional as F
import megengine.module as M
import layers
from layers.det import resnet
class FasterRCNN(M.Module):
"""
Implement Faster R-CNN (https://arxiv.org/abs/1506.01497).
"""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# ----------------------- build backbone ------------------------ #
bottom_up = getattr(resnet, cfg.backbone)(
norm=layers.get_norm(cfg.backbone_norm), pretrained=cfg.backbone_pretrained
)
del bottom_up.fc
# ----------------------- build FPN ----------------------------- #
self.backbone = layers.FPN(
bottom_up=bottom_up,
in_features=cfg.fpn_in_features,
out_channels=cfg.fpn_out_channels,
norm=cfg.fpn_norm,
top_block=layers.FPNP6(),
strides=cfg.fpn_in_strides,
channels=cfg.fpn_in_channels,
)
# ----------------------- build RPN ----------------------------- #
self.rpn = layers.RPN(cfg)
# ----------------------- build RCNN head ----------------------- #
self.rcnn = layers.RCNN(cfg)
def preprocess_image(self, image):
padded_image = layers.get_padded_tensor(image, 32, 0.0)
normed_image = (
padded_image
- np.array(self.cfg.img_mean, dtype=np.float32)[None, :, None, None]
) / np.array(self.cfg.img_std, dtype=np.float32)[None, :, None, None]
return normed_image
def forward(self, image, im_info, gt_boxes=None):
image = self.preprocess_image(image)
features = self.backbone(image)
if self.training:
return self._forward_train(features, im_info, gt_boxes)
else:
return self.inference(features, im_info)
def _forward_train(self, features, im_info, gt_boxes):
rpn_rois, rpn_losses = self.rpn(features, im_info, gt_boxes)
rcnn_losses = self.rcnn(features, rpn_rois, im_info, gt_boxes)
loss_rpn_cls = rpn_losses["loss_rpn_cls"]
loss_rpn_bbox = rpn_losses["loss_rpn_bbox"]
loss_rcnn_cls = rcnn_losses["loss_rcnn_cls"]
loss_rcnn_bbox = rcnn_losses["loss_rcnn_bbox"]
total_loss = loss_rpn_cls + loss_rpn_bbox + loss_rcnn_cls + loss_rcnn_bbox
loss_dict = {
"total_loss": total_loss,
"rpn_cls": loss_rpn_cls,
"rpn_bbox": loss_rpn_bbox,
"rcnn_cls": loss_rcnn_cls,
"rcnn_bbox": loss_rcnn_bbox,
}
self.cfg.losses_keys = list(loss_dict.keys())
return loss_dict
def inference(self, features, im_info):
rpn_rois = self.rpn(features, im_info)
pred_boxes, pred_score = self.rcnn(features, rpn_rois)
pred_boxes = pred_boxes.reshape(-1, 4)
scale_w = im_info[0, 1] / im_info[0, 3]
scale_h = im_info[0, 0] / im_info[0, 2]
pred_boxes = pred_boxes / F.concat([scale_w, scale_h, scale_w, scale_h], axis=0)
clipped_boxes = layers.get_clipped_boxes(
pred_boxes, im_info[0, 2:4]
).reshape(-1, self.cfg.num_classes, 4)
return pred_score, clipped_boxes
class FasterRCNNConfig:
# pylint: disable=too-many-statements
def __init__(self):
self.backbone = "resnext101_32x8d"
self.backbone_pretrained = True
self.backbone_norm = "FrozenBN"
self.backbone_freeze_at = 2
self.fpn_norm = None
self.fpn_in_features = ["res2", "res3", "res4", "res5"]
self.fpn_in_strides = [4, 8, 16, 32]
self.fpn_in_channels = [256, 512, 1024, 2048]
self.fpn_out_channels = 256
# ------------------------ data cfg -------------------------- #
self.train_dataset = dict(
name="coco",
root="train2017",
ann_file="annotations/instances_train2017.json",
remove_images_without_annotations=True,
)
self.test_dataset = dict(
name="coco",
root="val2017",
ann_file="annotations/instances_val2017.json",
remove_images_without_annotations=False,
)
self.num_classes = 80
self.img_mean = [103.530, 116.280, 123.675] # BGR
self.img_std = [57.375, 57.120, 58.395]
# ----------------------- rpn cfg ------------------------- #
self.rpn_stride = [4, 8, 16, 32, 64]
self.rpn_in_features = ["p2", "p3", "p4", "p5", "p6"]
self.rpn_channel = 256
self.rpn_reg_mean = [0.0, 0.0, 0.0, 0.0]
self.rpn_reg_std = [1.0, 1.0, 1.0, 1.0]
self.anchor_scales = [[x] for x in [32, 64, 128, 256, 512]]
self.anchor_ratios = [[0.5, 1, 2]]
self.anchor_offset = 0.5
self.match_thresholds = [0.3, 0.7]
self.match_labels = [0, -1, 1]
self.match_allow_low_quality = True
self.rpn_nms_threshold = 0.7
self.num_sample_anchors = 256
self.positive_anchor_ratio = 0.5
# ----------------------- rcnn cfg ------------------------- #
self.rcnn_stride = [4, 8, 16, 32]
self.rcnn_in_features = ["p2", "p3", "p4", "p5"]
self.rcnn_reg_mean = [0.0, 0.0, 0.0, 0.0]
self.rcnn_reg_std = [0.1, 0.1, 0.2, 0.2]
self.pooling_method = "roi_align"
self.pooling_size = (7, 7)
self.num_rois = 512
self.fg_ratio = 0.5
self.fg_threshold = 0.5
self.bg_threshold_high = 0.5
self.bg_threshold_low = 0.0
self.class_aware_box = True
# ------------------------ loss cfg -------------------------- #
self.rpn_smooth_l1_beta = 0 # use L1 loss
self.rcnn_smooth_l1_beta = 0 # use L1 loss
self.num_losses = 5
# ------------------------ training cfg ---------------------- #
self.train_image_short_size = (640, 672, 704, 736, 768, 800)
self.train_image_max_size = 1333
self.train_prev_nms_top_n = 2000
self.train_post_nms_top_n = 1000
self.basic_lr = 0.02 / 16 # The basic learning rate for single-image
self.momentum = 0.9
self.weight_decay = 1e-4
self.log_interval = 20
self.nr_images_epoch = 80000
self.max_epoch = 54
self.warm_iters = 500
self.lr_decay_rate = 0.1
self.lr_decay_stages = [42, 50]
# ------------------------ testing cfg ----------------------- #
self.test_image_short_size = 800
self.test_image_max_size = 1333
self.test_prev_nms_top_n = 1000
self.test_post_nms_top_n = 1000
self.test_max_boxes_per_image = 100
self.test_vis_threshold = 0.3
self.test_cls_threshold = 0.05
self.test_nms = 0.5
This diff is collapsed.
#!/usr/bin/env bash
gpu=2
WORK_DIR=$(cd "$(dirname "$0")"; pwd)
export PYTHONPATH=${WORK_DIR}/
# train
python3 tools/train.py -n ${gpu} -b 2 \
-f configs/atss_res50_800size_trafficdet_demo.py -d . \
-w weights/atss_res50_coco_3x_800size_42dot6_9a92ed8c.pkl
# test
# 1X
python3 tools/test.py -n ${gpu} -se 11 \
-f configs/atss_res50_800size_trafficdet_demo.py -d .
# 2X
python3 tools/test.py -n ${gpu} -se 23 \
-f configs/atss_res50_800size_trafficdet_demo.py -d .
\ No newline at end of file
#!/usr/bin/env bash
gpu=2
WORK_DIR=$(cd "$(dirname "$0")"; pwd)
export PYTHONPATH=${WORK_DIR}/
# train
python3 tools/train.py -n ${gpu} -b 2 \
-f configs/fcos_res50_800size_trafficdet_demo.py -d . \
-w weights/fcos_res50_coco_3x_800size_42dot2_b16f9c8b.pkl
# test
# 1X
python3 tools/test.py -n ${gpu} -se 11 \
-f configs/fcos_res50_800size_trafficdet_demo.py -d .
# 2X
python3 tools/test.py -n ${gpu} -se 23 \
-f configs/fcos_res50_800size_trafficdet_demo.py -d .
\ No newline at end of file
#!/usr/bin/env bash
gpu=2
WORK_DIR=$(cd "$(dirname "$0")"; pwd)
export PYTHONPATH=${WORK_DIR}/
# train
python3 tools/train.py -n ${gpu} -b 2 \
-f configs/faster_rcnn_res50_800size_trafficdet_demo.py -d . \
-w weights/faster_rcnn_res50_coco_3x_800size_40dot1_8682ff1a.pkl
# test
# 1X
python3 tools/test.py -n ${gpu} -se 11 \
-f configs/faster_rcnn_res50_800size_trafficdet_demo.py -d .
# 2X
python3 tools/test.py -n ${gpu} -se 23 \
-f configs/faster_rcnn_res50_800size_trafficdet_demo.py -d .
\ No newline at end of file
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine.data.dataset import COCO, Objects365, PascalVOC
from tools.dataset import Traffic5
data_mapper = dict(
coco=COCO,
objects365=Objects365,
voc=PascalVOC,
traffic5=Traffic5,
)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import os
import json
from collections import defaultdict
import cv2
import numpy as np
from megengine.data.dataset.vision.meta_vision import VisionDataset
def has_valid_annotation(anno, order):
# if it"s empty, there is no annotation
if len(anno) == 0:
return False
if "boxes" in order or "boxes_category" in order:
if "bbox" not in anno[0]:
return False
return True
class Traffic5(VisionDataset):
r"""
Traffic Detection Challenge Dataset.
"""
supported_order = (
"image",
"boxes",
"boxes_category",
"info",
)
def __init__(
self, root, ann_file, remove_images_without_annotations=False, *, order=None
):
super().__init__(root, order=order, supported_order=self.supported_order)
with open(ann_file, "r") as f:
dataset = json.load(f)
self.imgs = dict()
for img in dataset["images"]:
self.imgs[img["id"]] = img
self.img_to_anns = defaultdict(list)
for ann in dataset["annotations"]:
# for saving memory
if (
"boxes" not in self.order
and "boxes_category" not in self.order
and "bbox" in ann
):
del ann["bbox"]
if "polygons" not in self.order and "segmentation" in ann:
del ann["segmentation"]
self.img_to_anns[ann["image_id"]].append(ann)
self.cats = dict()
for cat in dataset["categories"]:
self.cats[cat["id"]] = cat
self.ids = list(sorted(self.imgs.keys()))
# filter images without detection annotations
if remove_images_without_annotations:
ids = []
for img_id in self.ids:
anno = self.img_to_anns[img_id]
# filter crowd annotations
anno = [obj for obj in anno if obj["iscrowd"] == 0]
anno = [
obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0
]
if has_valid_annotation(anno, order):
ids.append(img_id)
self.img_to_anns[img_id] = anno
else:
del self.imgs[img_id]
del self.img_to_anns[img_id]
self.ids = ids
self.json_category_id_to_contiguous_id = {
v: i + 1 for i, v in enumerate(sorted(self.cats.keys()))
}
self.contiguous_category_id_to_json_id = {
v: k for k, v in self.json_category_id_to_contiguous_id.items()
}
def __getitem__(self, index):
img_id = self.ids[index]
anno = self.img_to_anns[img_id]
target = []
for k in self.order:
if k == "image":
file_name = self.imgs[img_id]["file_name"]
path = os.path.join(self.root, file_name)
# print(path)
image = cv2.imread(path, cv2.IMREAD_COLOR)
target.append(image)
elif k == "boxes":
boxes = [obj["bbox"] for obj in anno]
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
# transfer boxes from xywh to xyxy
boxes[:, 2:] += boxes[:, :2]
target.append(boxes)
elif k == "boxes_category":
boxes_category = [obj["category_id"] for obj in anno]
boxes_category = [
self.json_category_id_to_contiguous_id[c] for c in boxes_category
]
boxes_category = np.array(boxes_category, dtype=np.int32)
target.append(boxes_category)
elif k == "info":
info = self.imgs[img_id]
info = [info["height"], info["width"], info["file_name"], img_id]
target.append(info)
else:
raise NotImplementedError
return tuple(target)
def __len__(self):
return len(self.ids)
def get_img_info(self, index):
img_id = self.ids[index]
img_info = self.imgs[img_id]
return img_info
class_names = (
"red_tl",
"arr_s",
"arr_l",
"no_driving_mark_allsort",
"no_parking_mark",
)
classes_originID = {
"red_tl": 0,
"arr_s": 1,
"arr_l": 2,
"no_driving_mark_allsort": 3,
"no_parking_mark": 4,
}
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import cv2
import megengine as mge
from tools.data_mapper import data_mapper
from tools.utils import DetEvaluator, import_from_file
logger = mge.get_logger(__name__)
logger.setLevel("INFO")
def make_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-f", "--file", default="net.py", type=str, help="net description file"
)
parser.add_argument(
"-w", "--weight_file", default=None, type=str, help="weights file",
)
parser.add_argument("-i", "--image", type=str)
return parser
def main():
parser = make_parser()
args = parser.parse_args()
current_network = import_from_file(args.file)
cfg = current_network.Cfg()
cfg.backbone_pretrained = False
model = current_network.Net(cfg)
model.eval()
state_dict = mge.load(args.weight_file)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict)
evaluator = DetEvaluator(model)
ori_img = cv2.imread(args.image)
image, im_info = DetEvaluator.process_inputs(
ori_img.copy(), model.cfg.test_image_short_size, model.cfg.test_image_max_size,
)
pred_res = evaluator.predict(
image=mge.tensor(image),
im_info=mge.tensor(im_info)
)
res_img = DetEvaluator.vis_det(
ori_img,
pred_res,
is_show_label=True,
classes=data_mapper[cfg.test_dataset["name"]].class_names,
)
cv2.imwrite("results.jpg", res_img)
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
def py_cpu_nms(dets, thresh):
x1 = np.ascontiguousarray(dets[:, 0])
y1 = np.ascontiguousarray(dets[:, 1])
x2 = np.ascontiguousarray(dets[:, 2])
y2 = np.ascontiguousarray(dets[:, 3])
areas = (x2 - x1) * (y2 - y1)
order = dets[:, 4].argsort()[::-1]
keep = list()
while order.size > 0:
pick_idx = order[0]
keep.append(pick_idx)
order = order[1:]
xx1 = np.maximum(x1[pick_idx], x1[order])
yy1 = np.maximum(y1[pick_idx], y1[order])
xx2 = np.minimum(x2[pick_idx], x2[order])
yy2 = np.minimum(y2[pick_idx], y2[order])
inter = np.maximum(xx2 - xx1, 0) * np.maximum(yy2 - yy1, 0)
iou = inter / np.maximum(areas[pick_idx] + areas[order] - inter, 1e-5)
order = order[iou <= thresh]
return keep
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import json
import os
from multiprocessing import Process, Queue
from tqdm import tqdm
import megengine as mge
import megengine.distributed as dist
from megengine.data import DataLoader
from tools.data_mapper import data_mapper
from tools.utils import DetEvaluator, InferenceSampler, import_from_file
logger = mge.get_logger(__name__)
logger.setLevel("INFO")
def make_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-f", "--file", default="net.py", type=str, help="net description file"
)
parser.add_argument(
"-w", "--weight_file", default=None, type=str, help="weights file",
)
parser.add_argument(
"-n", "--devices", default=1, type=int, help="total number of gpus for testing",
)
parser.add_argument(
"-d", "--dataset_dir", default="/data/datasets", type=str,
)
parser.add_argument("-se", "--start_epoch", default=-1, type=int)
parser.add_argument("-ee", "--end_epoch", default=-1, type=int)
parser.add_argument(
"-mn","--model_name", default="", type=str,
)
return parser
def main():
# pylint: disable=import-outside-toplevel,too-many-branches,too-many-statements
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
parser = make_parser()
args = parser.parse_args()
current_network = import_from_file(args.file)
cfg = current_network.Cfg()
if args.weight_file:
args.start_epoch = args.end_epoch = -1
else:
if args.start_epoch == -1:
args.start_epoch = cfg.max_epoch - 1
if args.end_epoch == -1:
args.end_epoch = args.start_epoch
assert 0 <= args.start_epoch <= args.end_epoch < cfg.max_epoch
for epoch_num in range(args.start_epoch, args.end_epoch + 1):
if args.weight_file:
weight_file = args.weight_file
else:
weight_file = "logs/{}/epoch_{}.pkl".format(
os.path.basename(args.file).split(".")[0] + f'_gpus{args.devices}', epoch_num
)
result_list = []
if args.devices > 1:
result_queue = Queue(2000)
master_ip = "localhost"
server = dist.Server()
port = server.py_server_port
procs = []
for i in range(args.devices):
proc = Process(
target=worker,
args=(
current_network,
weight_file,
args.dataset_dir,
result_queue,
master_ip,
port,
args.devices,
i,
),
)
proc.start()
procs.append(proc)
# num_imgs = dict(coco=5000, objects365=30000, traffic5=584) # test set
num_imgs = dict(coco=5000, objects365=30000, traffic5=299) # val set
for _ in tqdm(range(num_imgs[cfg.test_dataset["name"]])):
result_list.append(result_queue.get())
for p in procs:
p.join()
else:
worker(current_network, weight_file, args.dataset_dir, result_list)
all_results = DetEvaluator.format(result_list, cfg)
json_path = "logs/{}/epoch_{}.json".format(
os.path.basename(args.file).split(".")[0] + f'_gpus{args.devices}', epoch_num
)
all_results = json.dumps(all_results)
with open(json_path, "w") as fo:
fo.write(all_results)
logger.info("Save to %s finished, start evaluation!", json_path)
eval_gt = COCO(
os.path.join(
args.dataset_dir, cfg.test_dataset["name"], cfg.test_dataset["ann_file"]
)
)
eval_dt = eval_gt.loadRes(json_path)
cocoEval = COCOeval(eval_gt, eval_dt, iouType="bbox")
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
metrics = [
"AP",
"AP@0.5",
"AP@0.75",
"APs",
"APm",
"APl",
"AR@1",
"AR@10",
"AR@100",
"ARs",
"ARm",
"ARl",
]
logger.info("mmAP".center(32, "-"))
for i, m in enumerate(metrics):
logger.info("|\t%s\t|\t%.03f\t|", m, cocoEval.stats[i])
logger.info("%.04f", (cocoEval.stats[0]*0.5+cocoEval.stats[2]*0.3+cocoEval.stats[3]*0.2) )
logger.info("-" * 32)
def worker(
current_network, weight_file, dataset_dir, result_list,
master_ip=None, port=None, world_size=1, rank=0
):
if world_size > 1:
dist.init_process_group(
master_ip=master_ip,
port=port,
world_size=world_size,
rank=rank,
device=rank,
)
cfg = current_network.Cfg()
cfg.backbone_pretrained = False
model = current_network.Net(cfg)
model.eval()
state_dict = mge.load(weight_file)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict)
evaluator = DetEvaluator(model)
test_loader = build_dataloader(dataset_dir, model.cfg)
if dist.get_world_size() == 1:
test_loader = tqdm(test_loader)
for data in test_loader:
image, im_info = DetEvaluator.process_inputs(
data[0][0],
model.cfg.test_image_short_size,
model.cfg.test_image_max_size,
)
pred_res = evaluator.predict(
image=mge.tensor(image),
im_info=mge.tensor(im_info)
)
result = {
"det_res": pred_res,
"image_id": int(data[1][3][0]),
}
if dist.get_world_size() > 1:
result_list.put_nowait(result)
else:
result_list.append(result)
def build_dataloader(dataset_dir, cfg):
val_dataset = data_mapper[cfg.test_dataset["name"]](
os.path.join(dataset_dir, cfg.test_dataset["name"], cfg.test_dataset["root"]),
os.path.join(dataset_dir, cfg.test_dataset["name"], cfg.test_dataset["ann_file"]),
order=["image", "info"],
)
val_sampler = InferenceSampler(val_dataset, 1)
val_dataloader = DataLoader(val_dataset, sampler=val_sampler, num_workers=2)
return val_dataloader
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import json
import os
from multiprocessing import Process, Queue
from tqdm import tqdm
import megengine as mge
import megengine.distributed as dist
from megengine.data import DataLoader
from tools.data_mapper import data_mapper
from tools.utils import DetEvaluator, InferenceSampler, import_from_file
logger = mge.get_logger(__name__)
logger.setLevel("INFO")
def make_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-f", "--file", default="net.py", type=str, help="net description file"
)
parser.add_argument(
"-w", "--weight_file", default=None, type=str, help="weights file",
)
parser.add_argument(
"-n", "--devices", default=1, type=int, help="total number of gpus for testing",
)
parser.add_argument(
"-d", "--dataset_dir", default="/data/datasets", type=str,
)
parser.add_argument("-se", "--start_epoch", default=-1, type=int)
parser.add_argument("-ee", "--end_epoch", default=-1, type=int)
return parser
def main():
# pylint: disable=import-outside-toplevel,too-many-branches,too-many-statements
parser = make_parser()
args = parser.parse_args()
current_network = import_from_file(args.file)
cfg = current_network.Cfg()
if args.weight_file:
args.start_epoch = args.end_epoch = -1
else:
if args.start_epoch == -1:
args.start_epoch = cfg.max_epoch - 1
if args.end_epoch == -1:
args.end_epoch = args.start_epoch
assert 0 <= args.start_epoch <= args.end_epoch < cfg.max_epoch
for epoch_num in range(args.start_epoch, args.end_epoch + 1):
if args.weight_file:
weight_file = args.weight_file
else:
weight_file = "logs/{}/epoch_{}.pkl".format(
os.path.basename(args.file).split(".")[0] + f'_gpus{args.devices}', epoch_num
)
result_list = []
if args.devices > 1:
result_queue = Queue(2000)
master_ip = "localhost"
server = dist.Server()
port = server.py_server_port
procs = []
for i in range(args.devices):
proc = Process(
target=worker,
args=(
current_network,
weight_file,
args.dataset_dir,
result_queue,
master_ip,
port,
args.devices,
i,
),
)
proc.start()
procs.append(proc)
num_imgs = dict(coco=5000, objects365=30000, traffic5=584) # test set
for _ in tqdm(range(num_imgs[cfg.test_dataset["name"]])):
result_list.append(result_queue.get())
for p in procs:
p.join()
else:
worker(current_network, weight_file, args.dataset_dir, result_list)
all_results = DetEvaluator.format(result_list, cfg)
json_path = "logs/{}/test_final_epoch_{}.json".format(
os.path.basename(args.file).split(".")[0] + f'_gpus{args.devices}', epoch_num
)
all_results = json.dumps(all_results)
with open(json_path, "w") as fo:
fo.write(all_results)
logger.info("Save to %s finished, start evaluation!", json_path)
def worker(
current_network, weight_file, dataset_dir, result_list,
master_ip=None, port=None, world_size=1, rank=0
):
if world_size > 1:
dist.init_process_group(
master_ip=master_ip,
port=port,
world_size=world_size,
rank=rank,
device=rank,
)
cfg = current_network.Cfg()
cfg.backbone_pretrained = False
model = current_network.Net(cfg)
model.eval()
state_dict = mge.load(weight_file)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict)
evaluator = DetEvaluator(model)
test_loader = build_dataloader(dataset_dir, model.cfg)
if dist.get_world_size() == 1:
test_loader = tqdm(test_loader)
for data in test_loader:
image, im_info = DetEvaluator.process_inputs(
data[0][0],
model.cfg.test_image_short_size,
model.cfg.test_image_max_size,
)
pred_res = evaluator.predict(
image=mge.tensor(image),
im_info=mge.tensor(im_info)
)
result = {
"det_res": pred_res,
"image_id": int(data[1][3][0]),
}
if dist.get_world_size() > 1:
result_list.put_nowait(result)
else:
result_list.append(result)
def build_dataloader(dataset_dir, cfg):
val_dataset = data_mapper[cfg.test_dataset["name"]](
os.path.join(dataset_dir, cfg.test_dataset["name"], cfg.test_dataset["root"]),
os.path.join(dataset_dir, cfg.test_dataset["name"], cfg.test_dataset["test_final_ann_file"]),
order=["image", "info"],
)
val_sampler = InferenceSampler(val_dataset, 1)
val_dataloader = DataLoader(val_dataset, sampler=val_sampler, num_workers=2)
return val_dataloader
if __name__ == "__main__":
main()
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment