目前在做实例分割相关的课题,这次介绍基于mmdetection的SOLO代码,主要是head部分的代码解读。
detectors/solo.py
这里里面定义了整体的结构,可见直接继承SingleStageInsDetector
类,从config文件可以看出,SOLO的backbone为resnet,neck为FPN,head为solo head。所以重点是在solo head。
@DETECTORS.register_module
class SOLO(SingleStageInsDetector):
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SOLO, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)
anchor_heads/solo_head.py
这里就是SOLO的重点介绍部分,包括head的定义,以及loss的计算。
def _init_layers(self):
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
self.ins_convs = nn.ModuleList()
self.cate_convs = nn.ModuleList()
# 在SOLO中self.stacked_convs=7,category和mas branch都是先经过7个conv block最后输出
for i in range(self.stacked_convs):
chn = self.in_channels + 2 if i == 0 else self.seg_feat_channels
self.ins_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
chn = self.in_channels if i == 0 else self.seg_feat_channels
self.cate_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
# 因为mask branch的输出通道数由grid决定,所有这里是个list,
# 即不同的level对应不同的输出conv block
self.solo_ins_list = nn.ModuleList()
for seg_num_grid in self.seg_num_grids:
self.solo_ins_list.append(
nn.Conv2d(
self.seg_feat_channels, seg_num_grid**2, 1))
# category branch的输出conv block的通道都是类别数-1
self.solo_cate = nn.Conv2d(
self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
定义完网络结构后,下面开始前向计算
def forward(self, feats, eval=False):
# self.split_feats将5个level的feature map尺寸重新处理一下
new_feats = self.split_feats(feats)
# 取出每个level的尺寸
featmap_sizes = [featmap.size()[-2:] for featmap in new_feats]
upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2)
# 这是前向计算的主体函数,self.forward_single每次处理一个level
ins_pred, cate_pred = multi_apply(self.forward_single, new_feats,
list(range(len(self.seg_num_grids))),
eval=eval, upsampled_size=upsampled_size)
return ins_pred, cate_pred
# 这里为什么将第一个和最后一个feature map的尺寸进行变换
def split_feats(self, feats):
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'),
feats[1],
feats[2],
feats[3],
F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))
下面的self.forward_single就是前向计算的主体部分,不复杂,注意的是在mask branch需要将x、y坐标信息拼接到对应level的feature map上。
def forward_single(self, x, idx, eval=False, upsampled_size=None):
'''
:param x: fpn每个level的feature map [N,C,H,W]
:param idx: [0,1,2,3,4]中的一个,用来指示当前的level级别
:param eval: False
:param upsampled_size: 最大feature map/C1 的h,w
:return:
'''
# 老规矩,因为这里有两个branch
ins_feat = x
cate_feat = x
# 这里先处理mask分支,1. 将x,y坐标拼接在feature map上,通道数+2,
# 这里是将坐标信息concat到feature map上
x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device) # w --> x
y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device) # h --> y
# 对x_range, y_range 进行扩充
y, x = torch.meshgrid(y_range, x_range)
# 将两个坐标扩成4维
y = y.expand([ins_feat.shape[0], 1, -1, -1])
x = x.expand([ins_feat.shape[0], 1, -1, -1])
# 将坐标cancat到feature map的通道上
coord_feat = torch.cat([x, y], 1)
ins_feat = torch.cat([ins_feat, coord_feat], 1)
# 将处理好的新的fearure map送进ins_convs
for i, ins_layer in enumerate(self.ins_convs):
ins_feat = ins_layer(ins_feat)
# 这里将feature map上采样到2H*2W,应该是为了提高mask分割的精度
ins_feat = F.interpolate(ins_feat, scale_factor=2, mode='bilinear')
# 这里获得了mask分支的结果
ins_pred = self.solo_ins_list[idx](ins_feat)
# 这里开始处理category分支
for i, cate_layer in enumerate(self.cate_convs):
# 如果是第一个conv,则需要进行采样,因为category分支的尺寸是h=w=S
if i == self.cate_down_pos:
seg_num_grid = self.seg_num_grids[idx]
cate_feat = F.interpolate(cate_feat, size=seg_num_grid, mode='bilinear')
cate_feat = cate_layer(cate_feat)
# 这里获得了category分支的结果
cate_pred = self.solo_cate(cate_feat)
# 如果使测试模式,
# 将mask分支的结果取sigmoid,并且上采样到原始C1的尺寸
# category分支的结果进行points_nms,这个待会再看
if eval:
ins_pred = F.interpolate(ins_pred.sigmoid(), size=upsampled_size, mode='bilinear')
cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
return ins_pred, cate_pred
此时模型的前向计算就结束了,框架很简洁,接下来的loss计算相对繁琐一些。
def loss(self,
ins_preds,
cate_preds,
gt_bbox_list,
gt_label_list,
gt_mask_list,
img_metas,
cfg,
gt_bboxes_ignore=None):
# featmap_sizes里面包含了五个level的mask branch输出尺寸
featmap_sizes = [featmap.size()[-2:] for featmap in
ins_preds]
# 接下来又遇到multi_apply函数,说明self.solo_target_single每次处理batch中的一张图片
# 其实从config文件也可以看出来,SOLO也用到了bbox信息
# 接下来我们直接跳到self.solo_target_single函数
ins_label_list, cate_label_list, ins_ind_label_list = multi_apply(
self.solo_target_single,
gt_bbox_list,
gt_label_list,
gt_mask_list,
featmap_sizes=featmap_sizes)
self.solo_target_single作用是为category和mask branch分支分配gt label
def solo_target_single(self,
gt_bboxes_raw,
gt_labels_raw,
gt_masks_raw,
featmap_sizes=None):
'''
:param gt_bboxes_raw: [objects,4]
:param gt_labels_raw: [objects,1]
:param gt_masks_raw: [objects,H,W]
:param featmap_sizes: []
:return:
'''
device = gt_labels_raw[0].device
# 这里获得单张图片上所有的object的area [object]
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
ins_label_list = []
cate_label_list = []
ins_ind_label_list = []
# 这里每次遍历一个level
for (lower_bound, upper_bound), stride, featmap_size, num_grid \
in zip(self.scale_ranges, self.strides, featmap_sizes, self.seg_num_grids):
# self.scale_ranges:((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048))
# 每个level的object的面积范围,这是缓解object重叠的一种方式,
# 将不同尺度的object分配到不同的level取预测
# self.strides :[8, 8, 16, 32, 32],对应与原图的采样比例 2**n
# 这里与之前的split_feat也对应上了,第一个下采样到1/2,最后一个上采样到2倍
# featmap_sizes : 每个level的尺寸
# self.seg_num_grids : [40, 36, 24, 16, 12]
# ins_label的尺寸与输出一致[S*S,H,W]
ins_label = torch.zeros([num_grid ** 2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device)
# cate_label的尺寸为[S,S]一个通道直接表示gt的label.
cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device)
# mask branch中S*S个通道的是否有匹配的gt
ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device)
# 这里返回的是area满足面积范围的索引,
# bbox是为了将不同大小的object分配到不同level
hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten()
# 如果没有满足的object,直接跳掉下一个level
if len(hit_indices) == 0:
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
continue
# 取出满足面积要求的object
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
# 这里为什么要.cpu().numpy()?
# 调试发现gt_masks使numpy,并不是tensor,为什么?
gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...]
# 为什么要乘self.sigma,不太懂. 0.1*h,0.1*w
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
# 因为maks分支输出的h,w是原level的2倍
output_stride = stride / 2
# 遍历每个object,先为每个grid分配label,根据矩形框和mask的中心
for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
#如果object的mask小于10个像素,则忽略
if seg_mask.sum() < 10:
continue
# mass center
# featmap_sizes[0][0] * 4是原图的尺寸(貌似可能不等于,因为图像尺寸不是统一的)
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
# 找出object的中心坐标
center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
# 将object的中心位置映射到grid上
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down
# 这里的四个值标定了一个矩形框,大小为0.1的原矩形框
# 我这样理解,如果一个object的面积稍大,则一个object对应的grid是多个
top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))
# 判断映射到grid里的中心是否在上述的矩形框里
# 当top==down==left==right时说明当前object只占S*S中的一个grid
top = max(top_box, coord_h-1)
down = min(down_box, coord_h+1)
left = max(coord_w-1, left_box)
right = min(right_box, coord_w+1)
# 将该位置的label记为gt_label
cate_label[top:(down+1), left:(right+1)] = gt_label
# 将seg_mask resize到当前level的mask branch的输出尺寸
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.Tensor(seg_mask)
# 为mask branch分配gt
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask
ins_ind_label[label] = True
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
return ins_label_list, cate_label_list, ins_ind_label_list
- ins_label_list: 保存了整个batch的mask branch的gt label
- cate_label_list: 整个batch的category branch的gt label
- ins_ind_label_list: 整个batch标识S*S通道存在object的通道
我们再返回loss函数接着看
# zip(*ins_label_list) [num_levels, N, S*S, 2H, 2W]
# 返回值每一组为整个batch的同一个level的mask branch gt label.
# zip(*ins_ind_label_list)
# 同上
# zip(ins_labels_level, ins_ind_labels_level)
# 每个image的同level
# 这里是将gt构造成pred的形式,便于后面的loss计算
ins_labels = [torch.cat([ins_labels_level_img[ins_ind_labels_level_img, ...]
for ins_labels_level_img, ins_ind_labels_level_img in
zip(ins_labels_level, ins_ind_labels_level)], 0)
for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list), zip(*ins_ind_label_list))]
ins_preds = [torch.cat([ins_preds_level_img[ins_ind_labels_level_img, ...]
for ins_preds_level_img, ins_ind_labels_level_img in
zip(ins_preds_level, ins_ind_labels_level)], 0)
for ins_preds_level, ins_ind_labels_level in zip(ins_preds, zip(*ins_ind_label_list))]
# ins_ind_labels [S*S*batch_size] eg. 40*40*2=3200
ins_ind_labels = [
torch.cat([ins_ind_labels_level_img.flatten()
for ins_ind_labels_level_img in ins_ind_labels_level])
for ins_ind_labels_level in zip(*ins_ind_label_list)
]
# flatten_ins_ind_labels [7744]
flatten_ins_ind_labels = torch.cat(ins_ind_labels)
num_ins = flatten_ins_ind_labels.int().sum()
最后就是实际的loss计算
# dice loss
# mask部分使用dict loss
loss_ins = []
for input, target in zip(ins_preds, ins_labels):
if input.size()[0] == 0:
continue
input = torch.sigmoid(input)
loss_ins.append(dice_loss(input, target))
loss_ins = torch.cat(loss_ins).mean()
loss_ins = loss_ins * self.ins_loss_weight
# cate
# category部分使用FocalLoss
cate_labels = [
torch.cat([cate_labels_level_img.flatten()
for cate_labels_level_img in cate_labels_level])
for cate_labels_level in zip(*cate_label_list)
]
flatten_cate_labels = torch.cat(cate_labels)
cate_preds = [
cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels)
for cate_pred in cate_preds
]
flatten_cate_preds = torch.cat(cate_preds)
loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)
return dict(
loss_ins=loss_ins,
loss_cate=loss_cate)