对于anchor-based模型来说,anchor无疑是模型设计中至关重要的一环。之前有了解过2d网络中anchor的生成逻辑,这次借助训练3D模型的机会来了解一下OpenPCDet中3d网络模型中anchor的生成逻辑。这里主要是通过解读models/dense_heads/target_assigner/anchor_generator.py中generate_anchors函数来了解这一过程。为方便理解代码中各种变量的具体含义,我准备了一个小型的配置文件。
训练集中总共包含'Truck'和'Car'两个类别,这两个类别的anchor分别分配和两种尺寸以及两个旋转角度。代码中涉及到的具体变量的具体数值以訪配置文件为参考进行计算。
1 def generate_anchors(self, grid_sizes): 2 assert len(grid_sizes) == self.num_of_anchor_sets 3 all_anchors = [] 4 num_anchors_per_location = [] 5 for grid_size, anchor_size, anchor_rotation, anchor_height, align_center in zip( 6 grid_sizes, self.anchor_sizes, self.anchor_rotations, self.anchor_heights, self.align_center): 7 8 num_anchors_per_location.append(len(anchor_rotation) * len(anchor_size) * len(anchor_height)) 9 if align_center: 10 x_stride = (self.anchor_range[3] - self.anchor_range[0]) / grid_size[0] 11 y_stride = (self.anchor_range[4] - self.anchor_range[1]) / grid_size[1] 12 x_offset, y_offset = x_stride / 2, y_stride / 2 13 else: 14 x_stride = (self.anchor_range[3] - self.anchor_range[0]) / (grid_size[0] - 1) 15 y_stride = (self.anchor_range[4] - self.anchor_range[1]) / (grid_size[1] - 1) 16 x_offset, y_offset = 0, 0 17 18 x_shifts = torch.arange( 19 self.anchor_range[0] + x_offset, self.anchor_range[3] + 1e-5, step=x_stride, dtype=torch.float32, 20 ).cuda() 21 y_shifts = torch.arange( 22 self.anchor_range[1] + y_offset, self.anchor_range[4] + 1e-5, step=y_stride, dtype=torch.float32, 23 ).cuda() 24 z_shifts = x_shifts.new_tensor(anchor_height) 25 26 num_anchor_size, num_anchor_rotation = anchor_size.__len__(), anchor_rotation.__len__() 27 anchor_rotation = x_shifts.new_tensor(anchor_rotation) 28 anchor_size = x_shifts.new_tensor(anchor_size) 29 x_shifts, y_shifts, z_shifts = torch.meshgrid([ 30 x_shifts, y_shifts, z_shifts 31 ]) # [x_grid, y_grid, z_grid] 32 anchors = torch.stack((x_shifts, y_shifts, z_shifts), dim=-1) 33 anchors = anchors[:, :, :, None, :].repeat(1, 1, 1, anchor_size.shape[0], 1) #x,y,z 34 35 anchor_size = anchor_size.view(1, 1, 1, -1, 3).repeat([*anchors.shape[0:3], 1, 1]) 36 anchors = torch.cat((anchors, anchor_size), dim=-1) #x,y,z,l,w,h 37 38 anchors = anchors[:, :, :, :, None, :].repeat(1, 1, 1, 1, num_anchor_rotation, 1) 39 anchor_rotation = anchor_rotation.view(1, 1, 1, 1, -1, 1).repeat([*anchors.shape[0:3], num_anchor_size, 1, 1]) 40 anchors = torch.cat((anchors, anchor_rotation), dim=-1) #x,y,z,l,w,h,rz 41 42 anchors = anchors.permute(2, 1, 0, 3, 4, 5).contiguous() 43 anchors[..., 2] += anchors[..., 5] / 2 # shift to box centers 44 all_anchors.append(anchors) 45 return all_anchors, num_anchors_per_locationfor循环的意思就是按照配置文件中ANCHOR_GENERATOR_CONFIG规定的一个配置项一个配置项的来生成anchor。在正式进入主体for循环之前,先来瞅瞅各变量的内容。
(Pdb) grid_sizes #bev视角下的输入网格大小,这个是根据你配置文件中的PINT_CLOUD_RANGE和VOXEL_SIZE计算出来的
[array([216, 248]), array([216, 248])]
(Pdb) self.anchor_sizes
[[[14.2606669, 8.87147485, 8.42032782]], [[5.37165288, 2.36589975, 2.51961153]]]
(Pdb) self.anchor_rotations
[[0, 1.57], [0, 1.57]]
(Pdb) self.anchor_heights #Anchor底部高度,注意是底部高度
[[5.51], [1.58]]
(Pdb) self.align_center
[False, False]
行8:计算出当前配置项下,一个网格单元的anchor数量,例如: 2*2*1;
行9-行16:计算出网格单元在x,y方向上的步长以及偏移量,例如:
(Pdb) x_stride
0.46511627906976744
(Pdb) y_stride
0.4048582995951417
(Pdb) x_offset
0
(Pdb) y_offset
0
行18-行24:计算出所有网格在x,y,z方向上的偏移量,这个正好是根据上一步计算出来的x_stride,y_stride以及x_offset,y_offset计算得到,例如:
(Pdb) x_shifts.shape
torch.Size([216])
(Pdb) y_shifts.shape
torch.Size([248])
(Pdb) z_shifts.shape
torch.Size([1])
行29-行32:进一步,通过torch.meshgrid构造出偏移矩阵,例如:
(Pdb) x_shifts.shape
torch.Size([216, 248, 1])
(Pdb) y_shifts.shape
torch.Size([216, 248, 1])
(Pdb) z_shifts.shape
torch.Size([216, 248, 1])
行32:
(Pdb) anchors.shape
torch.Size([216, 248, 1, 3])
行33:
(Pdb) anchors.shape
torch.Size([216, 248, 1, 1, 3])
行35:
(Pdb) anchor_size.shape
torch.Size([216, 248, 1, 1, 3])
行36:
(Pdb) anchors.shape
torch.Size([216, 248, 1, 1, 6]) #6: x,y,z,l,w,h
行38-行40:
(Pdb) anchors.shape
torch.Size([216, 248, 1, 1, 2, 7]) #7: x,y,z,l,w,h,rotation
行42:
(Pdb) anchors.shape
torch.Size([1, 248, 216, 1, 2, 7])
行43:偏移至中心高度
(Pdb) anchors.shape
torch.Size([1, 248, 216, 1, 2, 7])
至此,计算出来每一项配置项中所有网格点的anchor的实际位置,尺寸,旋转角度。