mirror of
https://github.com/zhigang1992/PointRCNN.git
synced 2026-06-10 07:10:15 +08:00
113 lines
4.3 KiB
Python
113 lines
4.3 KiB
Python
import torch
|
|
import roipool3d_cuda
|
|
import numpy as np
|
|
import lib.utils.kitti_utils as kitti_utils
|
|
|
|
|
|
def roipool3d_gpu(pts, pts_feature, boxes3d, pool_extra_width, sampled_pt_num=512):
|
|
"""
|
|
:param pts: (B, N, 3)
|
|
:param pts_feature: (B, N, C)
|
|
:param boxes3d: (B, M, 7)
|
|
:param pool_extra_width: float
|
|
:param sampled_pt_num: int
|
|
:return:
|
|
pooled_features: (B, M, 512, 3 + C)
|
|
pooled_empty_flag: (B, M)
|
|
"""
|
|
batch_size, boxes_num, feature_len = pts.shape[0], boxes3d.shape[1], pts_feature.shape[2]
|
|
pooled_boxes3d = kitti_utils.enlarge_box3d(boxes3d.view(-1, 7), pool_extra_width).view(batch_size, -1, 7)
|
|
|
|
pooled_features = torch.cuda.FloatTensor(torch.Size((batch_size, boxes_num,
|
|
sampled_pt_num, 3 + feature_len))).zero_()
|
|
pooled_empty_flag = torch.cuda.IntTensor(torch.Size((batch_size, boxes_num))).zero_()
|
|
|
|
roipool3d_cuda.forward(pts.contiguous(), pooled_boxes3d.contiguous(),
|
|
pts_feature.contiguous(), pooled_features, pooled_empty_flag)
|
|
|
|
return pooled_features, pooled_empty_flag
|
|
|
|
|
|
def pts_in_boxes3d_cpu(pts, boxes3d):
|
|
"""
|
|
:param pts: (N, 3) in rect-camera coords
|
|
:param boxes3d: (M, 7)
|
|
:return: boxes_pts_mask_list: (M), list with [(N), (N), ..]
|
|
"""
|
|
if not pts.is_cuda:
|
|
pts = pts.float().contiguous()
|
|
boxes3d = boxes3d.float().contiguous()
|
|
pts_flag = torch.LongTensor(torch.Size((boxes3d.size(0), pts.size(0)))) # (M, N)
|
|
roipool3d_cuda.pts_in_boxes3d_cpu(pts_flag, pts, boxes3d)
|
|
|
|
boxes_pts_mask_list = []
|
|
for k in range(0, boxes3d.shape[0]):
|
|
cur_mask = pts_flag[k] > 0
|
|
boxes_pts_mask_list.append(cur_mask)
|
|
return boxes_pts_mask_list
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def roipool_pc_cpu(pts, pts_feature, boxes3d, sampled_pt_num):
|
|
"""
|
|
:param pts: (N, 3)
|
|
:param pts_feature: (N, C)
|
|
:param boxes3d: (M, 7)
|
|
:param sampled_pt_num: int
|
|
:return:
|
|
"""
|
|
pts = pts.cpu().float().contiguous()
|
|
pts_feature = pts_feature.cpu().float().contiguous()
|
|
boxes3d = boxes3d.cpu().float().contiguous()
|
|
assert pts.shape[0] == pts_feature.shape[0] and pts.shape[1] == 3, '%s %s' % (pts.shape, pts_feature.shape)
|
|
assert pts.is_cuda is False
|
|
pooled_pts = torch.FloatTensor(torch.Size((boxes3d.shape[0], sampled_pt_num, 3))).zero_()
|
|
pooled_features = torch.FloatTensor(torch.Size((boxes3d.shape[0], sampled_pt_num, pts_feature.shape[1]))).zero_()
|
|
pooled_empty_flag = torch.LongTensor(boxes3d.shape[0]).zero_()
|
|
roipool3d_cuda.roipool3d_cpu(pts, boxes3d, pts_feature, pooled_pts, pooled_features, pooled_empty_flag)
|
|
return pooled_pts, pooled_features, pooled_empty_flag
|
|
|
|
|
|
def roipool3d_cpu(boxes3d, pts, pts_feature, pts_extra_input, pool_extra_width, sampled_pt_num=512,
|
|
canonical_transform=True):
|
|
"""
|
|
:param boxes3d: (N, 7)
|
|
:param pts: (N, 3)
|
|
:param pts_feature: (N, C)
|
|
:param pts_extra_input: (N, C2)
|
|
:param pool_extra_width: constant
|
|
:param sampled_pt_num: constant
|
|
:return:
|
|
"""
|
|
pooled_boxes3d = kitti_utils.enlarge_box3d(boxes3d, pool_extra_width)
|
|
|
|
pts_feature_all = np.concatenate((pts_extra_input, pts_feature), axis=1)
|
|
|
|
# Note: if pooled_empty_flag[i] > 0, the pooled_pts[i], pooled_features[i] will be zero
|
|
pooled_pts, pooled_features, pooled_empty_flag = \
|
|
roipool_pc_cpu(torch.from_numpy(pts), torch.from_numpy(pts_feature_all),
|
|
torch.from_numpy(pooled_boxes3d), sampled_pt_num)
|
|
|
|
extra_input_len = pts_extra_input.shape[1]
|
|
sampled_pts_input = torch.cat((pooled_pts, pooled_features[:, :, 0:extra_input_len]), dim=2).numpy()
|
|
sampled_pts_feature = pooled_features[:, :, extra_input_len:].numpy()
|
|
|
|
if canonical_transform:
|
|
# Translate to the roi coordinates
|
|
roi_ry = boxes3d[:, 6] % (2 * np.pi) # 0~2pi
|
|
roi_center = boxes3d[:, 0:3]
|
|
|
|
# shift to center
|
|
sampled_pts_input[:, :, 0:3] = sampled_pts_input[:, :, 0:3] - roi_center[:, np.newaxis, :]
|
|
for k in range(sampled_pts_input.shape[0]):
|
|
sampled_pts_input[k] = kitti_utils.rotate_pc_along_y(sampled_pts_input[k], roi_ry[k])
|
|
|
|
return sampled_pts_input, sampled_pts_feature
|
|
|
|
return sampled_pts_input, sampled_pts_feature, pooled_empty_flag.numpy()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pass
|