mirror of
https://github.com/zhigang1992/PointRCNN.git
synced 2026-01-12 22:49:40 +08:00
326 lines
14 KiB
Python
326 lines
14 KiB
Python
import _init_path
|
|
import os
|
|
import numpy as np
|
|
import pickle
|
|
import torch
|
|
|
|
import lib.utils.kitti_utils as kitti_utils
|
|
import lib.utils.roipool3d.roipool3d_utils as roipool3d_utils
|
|
import lib.utils.iou3d.iou3d_utils as iou3d_utils
|
|
from lib.datasets.kitti_dataset import KittiDataset
|
|
import argparse
|
|
|
|
np.random.seed(1024)
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--mode', type=str, default='generator')
|
|
parser.add_argument('--class_name', type=str, default='Car')
|
|
parser.add_argument('--save_dir', type=str, default='./../data/KITTI/aug_scene/training')
|
|
parser.add_argument('--split', type=str, default='train')
|
|
parser.add_argument('--gt_database_dir', type=str, default='gt_database/train_gt_database_3level_Car.pkl')
|
|
parser.add_argument('--include_similar', action='store_true', default=False)
|
|
parser.add_argument('--aug_times', type=int, default=4)
|
|
args = parser.parse_args()
|
|
|
|
PC_REDUCE_BY_RANGE = True
|
|
if args.class_name == 'Car':
|
|
PC_AREA_SCOPE = np.array([[-40, 40], [-1, 3], [0, 70.4]]) # x, y, z scope in rect camera coords
|
|
else:
|
|
PC_AREA_SCOPE = np.array([[-30, 30], [-1, 3], [0, 50]])
|
|
|
|
|
|
def log_print(info, fp=None):
|
|
print(info)
|
|
if fp is not None:
|
|
print(info, file=fp)
|
|
|
|
|
|
def save_kitti_format(calib, bbox3d, obj_list, img_shape, save_fp):
|
|
corners3d = kitti_utils.boxes3d_to_corners3d(bbox3d)
|
|
img_boxes, _ = calib.corners3d_to_img_boxes(corners3d)
|
|
|
|
img_boxes[:, 0] = np.clip(img_boxes[:, 0], 0, img_shape[1] - 1)
|
|
img_boxes[:, 1] = np.clip(img_boxes[:, 1], 0, img_shape[0] - 1)
|
|
img_boxes[:, 2] = np.clip(img_boxes[:, 2], 0, img_shape[1] - 1)
|
|
img_boxes[:, 3] = np.clip(img_boxes[:, 3], 0, img_shape[0] - 1)
|
|
|
|
# Discard boxes that are larger than 80% of the image width OR height
|
|
img_boxes_w = img_boxes[:, 2] - img_boxes[:, 0]
|
|
img_boxes_h = img_boxes[:, 3] - img_boxes[:, 1]
|
|
box_valid_mask = np.logical_and(img_boxes_w < img_shape[1] * 0.8, img_boxes_h < img_shape[0] * 0.8)
|
|
|
|
for k in range(bbox3d.shape[0]):
|
|
if box_valid_mask[k] == 0:
|
|
continue
|
|
x, z, ry = bbox3d[k, 0], bbox3d[k, 2], bbox3d[k, 6]
|
|
beta = np.arctan2(z, x)
|
|
alpha = -np.sign(beta) * np.pi / 2 + beta + ry
|
|
|
|
print('%s %.2f %d %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f' %
|
|
(args.class_name, obj_list[k].trucation, int(obj_list[k].occlusion), alpha, img_boxes[k, 0], img_boxes[k, 1],
|
|
img_boxes[k, 2], img_boxes[k, 3],
|
|
bbox3d[k, 3], bbox3d[k, 4], bbox3d[k, 5], bbox3d[k, 0], bbox3d[k, 1], bbox3d[k, 2],
|
|
bbox3d[k, 6]), file=save_fp)
|
|
|
|
|
|
class AugSceneGenerator(KittiDataset):
|
|
def __init__(self, root_dir, gt_database=None, split='train', classes=args.class_name):
|
|
super().__init__(root_dir, split=split)
|
|
self.gt_database = None
|
|
if classes == 'Car':
|
|
self.classes = ('Background', 'Car')
|
|
elif classes == 'People':
|
|
self.classes = ('Background', 'Pedestrian', 'Cyclist')
|
|
elif classes == 'Pedestrian':
|
|
self.classes = ('Background', 'Pedestrian')
|
|
elif classes == 'Cyclist':
|
|
self.classes = ('Background', 'Cyclist')
|
|
else:
|
|
assert False, "Invalid classes: %s" % classes
|
|
|
|
self.gt_database = gt_database
|
|
|
|
def __len__(self):
|
|
raise NotImplementedError
|
|
|
|
def __getitem__(self, item):
|
|
raise NotImplementedError
|
|
|
|
def filtrate_dc_objects(self, obj_list):
|
|
valid_obj_list = []
|
|
for obj in obj_list:
|
|
if obj.cls_type in ['DontCare']:
|
|
continue
|
|
valid_obj_list.append(obj)
|
|
|
|
return valid_obj_list
|
|
|
|
def filtrate_objects(self, obj_list):
|
|
valid_obj_list = []
|
|
type_whitelist = self.classes
|
|
if args.include_similar:
|
|
type_whitelist = list(self.classes)
|
|
if 'Car' in self.classes:
|
|
type_whitelist.append('Van')
|
|
if 'Pedestrian' in self.classes or 'Cyclist' in self.classes:
|
|
type_whitelist.append('Person_sitting')
|
|
|
|
for obj in obj_list:
|
|
if obj.cls_type in type_whitelist:
|
|
valid_obj_list.append(obj)
|
|
return valid_obj_list
|
|
|
|
@staticmethod
|
|
def get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape):
|
|
"""
|
|
Valid point should be in the image (and in the PC_AREA_SCOPE)
|
|
:param pts_rect:
|
|
:param pts_img:
|
|
:param pts_rect_depth:
|
|
:param img_shape:
|
|
:return:
|
|
"""
|
|
val_flag_1 = np.logical_and(pts_img[:, 0] >= 0, pts_img[:, 0] < img_shape[1])
|
|
val_flag_2 = np.logical_and(pts_img[:, 1] >= 0, pts_img[:, 1] < img_shape[0])
|
|
val_flag_merge = np.logical_and(val_flag_1, val_flag_2)
|
|
pts_valid_flag = np.logical_and(val_flag_merge, pts_rect_depth >= 0)
|
|
|
|
if PC_REDUCE_BY_RANGE:
|
|
x_range, y_range, z_range = PC_AREA_SCOPE
|
|
pts_x, pts_y, pts_z = pts_rect[:, 0], pts_rect[:, 1], pts_rect[:, 2]
|
|
range_flag = (pts_x >= x_range[0]) & (pts_x <= x_range[1]) \
|
|
& (pts_y >= y_range[0]) & (pts_y <= y_range[1]) \
|
|
& (pts_z >= z_range[0]) & (pts_z <= z_range[1])
|
|
pts_valid_flag = pts_valid_flag & range_flag
|
|
return pts_valid_flag
|
|
|
|
@staticmethod
|
|
def check_pc_range(xyz):
|
|
"""
|
|
:param xyz: [x, y, z]
|
|
:return:
|
|
"""
|
|
x_range, y_range, z_range = PC_AREA_SCOPE
|
|
if (x_range[0] <= xyz[0] <= x_range[1]) and (y_range[0] <= xyz[1] <= y_range[1]) and \
|
|
(z_range[0] <= xyz[2] <= z_range[1]):
|
|
return True
|
|
return False
|
|
|
|
def aug_one_scene(self, sample_id, pts_rect, pts_intensity, all_gt_boxes3d):
|
|
"""
|
|
:param pts_rect: (N, 3)
|
|
:param gt_boxes3d: (M1, 7)
|
|
:param all_gt_boxex3d: (M2, 7)
|
|
:return:
|
|
"""
|
|
assert self.gt_database is not None
|
|
extra_gt_num = np.random.randint(10, 15)
|
|
try_times = 50
|
|
cnt = 0
|
|
cur_gt_boxes3d = all_gt_boxes3d.copy()
|
|
cur_gt_boxes3d[:, 4] += 0.5
|
|
cur_gt_boxes3d[:, 5] += 0.5 # enlarge new added box to avoid too nearby boxes
|
|
|
|
extra_gt_obj_list = []
|
|
extra_gt_boxes3d_list = []
|
|
new_pts_list, new_pts_intensity_list = [], []
|
|
src_pts_flag = np.ones(pts_rect.shape[0], dtype=np.int32)
|
|
|
|
road_plane = self.get_road_plane(sample_id)
|
|
a, b, c, d = road_plane
|
|
|
|
while try_times > 0:
|
|
try_times -= 1
|
|
|
|
rand_idx = np.random.randint(0, self.gt_database.__len__() - 1)
|
|
|
|
new_gt_dict = self.gt_database[rand_idx]
|
|
new_gt_box3d = new_gt_dict['gt_box3d'].copy()
|
|
new_gt_points = new_gt_dict['points'].copy()
|
|
new_gt_intensity = new_gt_dict['intensity'].copy()
|
|
new_gt_obj = new_gt_dict['obj']
|
|
center = new_gt_box3d[0:3]
|
|
if PC_REDUCE_BY_RANGE and (self.check_pc_range(center) is False):
|
|
continue
|
|
if cnt > extra_gt_num:
|
|
break
|
|
if new_gt_points.__len__() < 5: # too few points
|
|
continue
|
|
|
|
# put it on the road plane
|
|
cur_height = (-d - a * center[0] - c * center[2]) / b
|
|
move_height = new_gt_box3d[1] - cur_height
|
|
new_gt_box3d[1] -= move_height
|
|
new_gt_points[:, 1] -= move_height
|
|
|
|
cnt += 1
|
|
|
|
iou3d = iou3d_utils.boxes_iou3d_gpu(torch.from_numpy(new_gt_box3d.reshape(1, 7)).cuda(),
|
|
torch.from_numpy(cur_gt_boxes3d).cuda()).cpu().numpy()
|
|
|
|
valid_flag = iou3d.max() < 1e-8
|
|
if not valid_flag:
|
|
continue
|
|
|
|
enlarged_box3d = new_gt_box3d.copy()
|
|
enlarged_box3d[3] += 2 # remove the points above and below the object
|
|
boxes_pts_mask_list = roipool3d_utils.pts_in_boxes3d_cpu(torch.from_numpy(pts_rect),
|
|
torch.from_numpy(enlarged_box3d.reshape(1, 7)))
|
|
pt_mask_flag = (boxes_pts_mask_list[0].numpy() == 1)
|
|
src_pts_flag[pt_mask_flag] = 0 # remove the original points which are inside the new box
|
|
|
|
new_pts_list.append(new_gt_points)
|
|
new_pts_intensity_list.append(new_gt_intensity)
|
|
enlarged_box3d = new_gt_box3d.copy()
|
|
enlarged_box3d[4] += 0.5
|
|
enlarged_box3d[5] += 0.5 # enlarge new added box to avoid too nearby boxes
|
|
cur_gt_boxes3d = np.concatenate((cur_gt_boxes3d, enlarged_box3d.reshape(1, 7)), axis=0)
|
|
extra_gt_boxes3d_list.append(new_gt_box3d.reshape(1, 7))
|
|
extra_gt_obj_list.append(new_gt_obj)
|
|
|
|
if new_pts_list.__len__() == 0:
|
|
return False, pts_rect, pts_intensity, None, None
|
|
|
|
extra_gt_boxes3d = np.concatenate(extra_gt_boxes3d_list, axis=0)
|
|
# remove original points and add new points
|
|
pts_rect = pts_rect[src_pts_flag == 1]
|
|
pts_intensity = pts_intensity[src_pts_flag == 1]
|
|
new_pts_rect = np.concatenate(new_pts_list, axis=0)
|
|
new_pts_intensity = np.concatenate(new_pts_intensity_list, axis=0)
|
|
pts_rect = np.concatenate((pts_rect, new_pts_rect), axis=0)
|
|
pts_intensity = np.concatenate((pts_intensity, new_pts_intensity), axis=0)
|
|
|
|
return True, pts_rect, pts_intensity, extra_gt_boxes3d, extra_gt_obj_list
|
|
|
|
def aug_one_epoch_scene(self, base_id, data_save_dir, label_save_dir, split_list, log_fp=None):
|
|
for idx, sample_id in enumerate(self.image_idx_list):
|
|
sample_id = int(sample_id)
|
|
print('process gt sample (%s, id=%06d)' % (args.split, sample_id))
|
|
|
|
pts_lidar = self.get_lidar(sample_id)
|
|
calib = self.get_calib(sample_id)
|
|
pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3])
|
|
pts_img, pts_rect_depth = calib.rect_to_img(pts_rect)
|
|
img_shape = self.get_image_shape(sample_id)
|
|
|
|
pts_valid_flag = self.get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape)
|
|
pts_rect = pts_rect[pts_valid_flag][:, 0:3]
|
|
pts_intensity = pts_lidar[pts_valid_flag][:, 3]
|
|
|
|
# all labels for checking overlapping
|
|
all_obj_list = self.filtrate_dc_objects(self.get_label(sample_id))
|
|
all_gt_boxes3d = np.zeros((all_obj_list.__len__(), 7), dtype=np.float32)
|
|
for k, obj in enumerate(all_obj_list):
|
|
all_gt_boxes3d[k, 0:3], all_gt_boxes3d[k, 3], all_gt_boxes3d[k, 4], all_gt_boxes3d[k, 5], \
|
|
all_gt_boxes3d[k, 6] = obj.pos, obj.h, obj.w, obj.l, obj.ry
|
|
|
|
# gt_boxes3d of current label
|
|
obj_list = self.filtrate_objects(self.get_label(sample_id))
|
|
if args.class_name != 'Car' and obj_list.__len__() == 0:
|
|
continue
|
|
|
|
# augment one scene
|
|
aug_flag, pts_rect, pts_intensity, extra_gt_boxes3d, extra_gt_obj_list = \
|
|
self.aug_one_scene(sample_id, pts_rect, pts_intensity, all_gt_boxes3d)
|
|
|
|
# save augment result to file
|
|
pts_info = np.concatenate((pts_rect, pts_intensity.reshape(-1, 1)), axis=1)
|
|
bin_file = os.path.join(data_save_dir, '%06d.bin' % (base_id + sample_id))
|
|
pts_info.astype(np.float32).tofile(bin_file)
|
|
|
|
# save filtered original gt_boxes3d
|
|
label_save_file = os.path.join(label_save_dir, '%06d.txt' % (base_id + sample_id))
|
|
with open(label_save_file, 'w') as f:
|
|
for obj in obj_list:
|
|
print(obj.to_kitti_format(), file=f)
|
|
|
|
if aug_flag:
|
|
# augment successfully
|
|
save_kitti_format(calib, extra_gt_boxes3d, extra_gt_obj_list, img_shape=img_shape, save_fp=f)
|
|
else:
|
|
extra_gt_boxes3d = np.zeros((0, 7), dtype=np.float32)
|
|
log_print('Save to file (new_obj: %s): %s' % (extra_gt_boxes3d.__len__(), label_save_file), fp=log_fp)
|
|
split_list.append('%06d' % (base_id + sample_id))
|
|
|
|
def generate_aug_scene(self, aug_times, log_fp=None):
|
|
data_save_dir = os.path.join(args.save_dir, 'rectified_data')
|
|
label_save_dir = os.path.join(args.save_dir, 'aug_label')
|
|
os.makedirs(data_save_dir, exist_ok=True)
|
|
os.makedirs(label_save_dir, exist_ok=True)
|
|
|
|
split_file = os.path.join(args.save_dir, '%s_aug.txt' % args.split)
|
|
split_list = self.image_idx_list.copy()
|
|
for epoch in range(aug_times):
|
|
base_id = (epoch + 1) * 10000
|
|
self.aug_one_epoch_scene(base_id, data_save_dir, label_save_dir, split_list, log_fp=log_fp)
|
|
|
|
with open(split_file, 'w') as f:
|
|
for idx, sample_id in enumerate(split_list):
|
|
print(sample_id, file=f, end='')
|
|
if idx != len(split_list) - 1:
|
|
print('', file=f)
|
|
log_print('Save split file to %s' % split_file, fp=log_fp)
|
|
target_dir = '../data/KITTI/ImageSets/'
|
|
os.system('cp %s %s' % (split_file, target_dir))
|
|
log_print('Copy split file from %s to %s' % (split_file, target_dir), fp=log_fp)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
os.makedirs(args.save_dir, exist_ok=True)
|
|
info_file = os.path.join(args.save_dir, 'log_info.txt')
|
|
|
|
if args.mode == 'generator':
|
|
log_fp = open(info_file, 'w')
|
|
|
|
gt_database = pickle.load(open(args.gt_database_dir, 'rb'))
|
|
log_print('Loading gt_database(%d) from %s' % (gt_database.__len__(), args.gt_database_dir), fp=log_fp)
|
|
|
|
dataset = AugSceneGenerator(root_dir='../data', gt_database=gt_database, split=args.split)
|
|
dataset.generate_aug_scene(aug_times=args.aug_times, log_fp=log_fp)
|
|
|
|
log_fp.close()
|
|
|
|
else:
|
|
pass
|
|
|