import torch import torch.nn as nn import torch.nn.functional as F import lib.utils.loss_utils as loss_utils from lib.config import cfg from collections import namedtuple def model_joint_fn_decorator(): ModelReturn = namedtuple("ModelReturn", ['loss', 'tb_dict', 'disp_dict']) MEAN_SIZE = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda() def model_fn(model, data): if cfg.RPN.ENABLED: pts_rect, pts_features, pts_input = data['pts_rect'], data['pts_features'], data['pts_input'] gt_boxes3d = data['gt_boxes3d'] if not cfg.RPN.FIXED: rpn_cls_label, rpn_reg_label = data['rpn_cls_label'], data['rpn_reg_label'] rpn_cls_label = torch.from_numpy(rpn_cls_label).cuda(non_blocking=True).long() rpn_reg_label = torch.from_numpy(rpn_reg_label).cuda(non_blocking=True).float() inputs = torch.from_numpy(pts_input).cuda(non_blocking=True).float() gt_boxes3d = torch.from_numpy(gt_boxes3d).cuda(non_blocking=True).float() input_data = {'pts_input': inputs, 'gt_boxes3d': gt_boxes3d} else: input_data = {} for key, val in data.items(): if key != 'sample_id': input_data[key] = torch.from_numpy(val).contiguous().cuda(non_blocking=True).float() if not cfg.RCNN.ROI_SAMPLE_JIT: pts_input = torch.cat((input_data['pts_input'], input_data['pts_features']), dim=-1) input_data['pts_input'] = pts_input ret_dict = model(input_data) tb_dict = {} disp_dict = {} loss = 0 if cfg.RPN.ENABLED and not cfg.RPN.FIXED: rpn_cls, rpn_reg = ret_dict['rpn_cls'], ret_dict['rpn_reg'] rpn_loss = get_rpn_loss(model, rpn_cls, rpn_reg, rpn_cls_label, rpn_reg_label, tb_dict) loss += rpn_loss disp_dict['rpn_loss'] = rpn_loss.item() if cfg.RCNN.ENABLED: rcnn_loss = get_rcnn_loss(model, ret_dict, tb_dict) disp_dict['reg_fg_sum'] = tb_dict['rcnn_reg_fg'] loss += rcnn_loss disp_dict['loss'] = loss.item() return ModelReturn(loss, tb_dict, disp_dict) def get_rpn_loss(model, rpn_cls, rpn_reg, rpn_cls_label, rpn_reg_label, tb_dict): if isinstance(model, nn.DataParallel): rpn_cls_loss_func = model.module.rpn.rpn_cls_loss_func else: rpn_cls_loss_func = model.rpn.rpn_cls_loss_func rpn_cls_label_flat = rpn_cls_label.view(-1) rpn_cls_flat = rpn_cls.view(-1) fg_mask = (rpn_cls_label_flat > 0) # RPN classification loss if cfg.RPN.LOSS_CLS == 'DiceLoss': rpn_loss_cls = rpn_cls_loss_func(rpn_cls, rpn_cls_label_flat) elif cfg.RPN.LOSS_CLS == 'SigmoidFocalLoss': rpn_cls_target = (rpn_cls_label_flat > 0).float() pos = (rpn_cls_label_flat > 0).float() neg = (rpn_cls_label_flat == 0).float() cls_weights = pos + neg pos_normalizer = pos.sum() cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0) rpn_loss_cls = rpn_cls_loss_func(rpn_cls_flat, rpn_cls_target, cls_weights) rpn_loss_cls_pos = (rpn_loss_cls * pos).sum() rpn_loss_cls_neg = (rpn_loss_cls * neg).sum() rpn_loss_cls = rpn_loss_cls.sum() tb_dict['rpn_loss_cls_pos'] = rpn_loss_cls_pos.item() tb_dict['rpn_loss_cls_neg'] = rpn_loss_cls_neg.item() elif cfg.RPN.LOSS_CLS == 'BinaryCrossEntropy': weight = rpn_cls_flat.new(rpn_cls_flat.shape[0]).fill_(1.0) weight[fg_mask] = cfg.RPN.FG_WEIGHT rpn_cls_label_target = (rpn_cls_label_flat > 0).float() batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rpn_cls_flat), rpn_cls_label_target, weight=weight, reduction='none') cls_valid_mask = (rpn_cls_label_flat >= 0).float() rpn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0) else: raise NotImplementedError # RPN regression loss point_num = rpn_reg.size(0) * rpn_reg.size(1) fg_sum = fg_mask.long().sum().item() if fg_sum != 0: loss_loc, loss_angle, loss_size, reg_loss_dict = \ loss_utils.get_reg_loss(rpn_reg.view(point_num, -1)[fg_mask], rpn_reg_label.view(point_num, 7)[fg_mask], loc_scope=cfg.RPN.LOC_SCOPE, loc_bin_size=cfg.RPN.LOC_BIN_SIZE, num_head_bin=cfg.RPN.NUM_HEAD_BIN, anchor_size=MEAN_SIZE, get_xz_fine=cfg.RPN.LOC_XZ_FINE, get_y_by_bin=False, get_ry_fine=False) loss_size = 3 * loss_size # consistent with old codes rpn_loss_reg = loss_loc + loss_angle + loss_size else: loss_loc = loss_angle = loss_size = rpn_loss_reg = rpn_loss_cls * 0 rpn_loss = rpn_loss_cls * cfg.RPN.LOSS_WEIGHT[0] + rpn_loss_reg * cfg.RPN.LOSS_WEIGHT[1] tb_dict.update({'rpn_loss_cls': rpn_loss_cls.item(), 'rpn_loss_reg': rpn_loss_reg.item(), 'rpn_loss': rpn_loss.item(), 'rpn_fg_sum': fg_sum, 'rpn_loss_loc': loss_loc.item(), 'rpn_loss_angle': loss_angle.item(), 'rpn_loss_size': loss_size.item()}) return rpn_loss def get_rcnn_loss(model, ret_dict, tb_dict): rcnn_cls, rcnn_reg = ret_dict['rcnn_cls'], ret_dict['rcnn_reg'] cls_label = ret_dict['cls_label'].float() reg_valid_mask = ret_dict['reg_valid_mask'] roi_boxes3d = ret_dict['roi_boxes3d'] roi_size = roi_boxes3d[:, 3:6] gt_boxes3d_ct = ret_dict['gt_of_rois'] pts_input = ret_dict['pts_input'] # rcnn classification loss if isinstance(model, nn.DataParallel): cls_loss_func = model.module.rcnn_net.cls_loss_func else: cls_loss_func = model.rcnn_net.cls_loss_func cls_label_flat = cls_label.view(-1) if cfg.RCNN.LOSS_CLS == 'SigmoidFocalLoss': rcnn_cls_flat = rcnn_cls.view(-1) cls_target = (cls_label_flat > 0).float() pos = (cls_label_flat > 0).float() neg = (cls_label_flat == 0).float() cls_weights = pos + neg pos_normalizer = pos.sum() cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0) rcnn_loss_cls = cls_loss_func(rcnn_cls_flat, cls_target, cls_weights) rcnn_loss_cls_pos = (rcnn_loss_cls * pos).sum() rcnn_loss_cls_neg = (rcnn_loss_cls * neg).sum() rcnn_loss_cls = rcnn_loss_cls.sum() tb_dict['rpn_loss_cls_pos'] = rcnn_loss_cls_pos.item() tb_dict['rpn_loss_cls_neg'] = rcnn_loss_cls_neg.item() elif cfg.RCNN.LOSS_CLS == 'BinaryCrossEntropy': rcnn_cls_flat = rcnn_cls.view(-1) batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none') cls_valid_mask = (cls_label_flat >= 0).float() rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0) elif cfg.TRAIN.LOSS_CLS == 'CrossEntropy': rcnn_cls_reshape = rcnn_cls.view(rcnn_cls.shape[0], -1) cls_target = cls_label_flat.long() cls_valid_mask = (cls_label_flat >= 0).float() batch_loss_cls = cls_loss_func(rcnn_cls_reshape, cls_target) normalizer = torch.clamp(cls_valid_mask.sum(), min=1.0) rcnn_loss_cls = (batch_loss_cls.mean(dim=1) * cls_valid_mask).sum() / normalizer else: raise NotImplementedError # rcnn regression loss batch_size = pts_input.shape[0] fg_mask = (reg_valid_mask > 0) fg_sum = fg_mask.long().sum().item() if fg_sum != 0: all_anchor_size = roi_size anchor_size = all_anchor_size[fg_mask] if cfg.RCNN.SIZE_RES_ON_ROI else MEAN_SIZE loss_loc, loss_angle, loss_size, reg_loss_dict = \ loss_utils.get_reg_loss(rcnn_reg.view(batch_size, -1)[fg_mask], gt_boxes3d_ct.view(batch_size, 7)[fg_mask], loc_scope=cfg.RCNN.LOC_SCOPE, loc_bin_size=cfg.RCNN.LOC_BIN_SIZE, num_head_bin=cfg.RCNN.NUM_HEAD_BIN, anchor_size=anchor_size, get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN, loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE, get_ry_fine=True) loss_size = 3 * loss_size # consistent with old codes rcnn_loss_reg = loss_loc + loss_angle + loss_size tb_dict.update(reg_loss_dict) else: loss_loc = loss_angle = loss_size = rcnn_loss_reg = rcnn_loss_cls * 0 rcnn_loss = rcnn_loss_cls + rcnn_loss_reg tb_dict['rcnn_loss_cls'] = rcnn_loss_cls.item() tb_dict['rcnn_loss_reg'] = rcnn_loss_reg.item() tb_dict['rcnn_loss'] = rcnn_loss.item() tb_dict['rcnn_loss_loc'] = loss_loc.item() tb_dict['rcnn_loss_angle'] = loss_angle.item() tb_dict['rcnn_loss_size'] = loss_size.item() tb_dict['rcnn_cls_fg'] = (cls_label > 0).sum().item() tb_dict['rcnn_cls_bg'] = (cls_label == 0).sum().item() tb_dict['rcnn_reg_fg'] = reg_valid_mask.sum().item() return rcnn_loss return model_fn