mirror of
https://github.com/zhigang1992/PointRCNN.git
synced 2026-01-12 22:49:40 +08:00
216 lines
9.8 KiB
Python
216 lines
9.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import lib.utils.loss_utils as loss_utils
|
|
from lib.config import cfg
|
|
from collections import namedtuple
|
|
|
|
|
|
def model_joint_fn_decorator():
|
|
ModelReturn = namedtuple("ModelReturn", ['loss', 'tb_dict', 'disp_dict'])
|
|
MEAN_SIZE = torch.from_numpy(cfg.CLS_MEAN_SIZE[0]).cuda()
|
|
|
|
def model_fn(model, data):
|
|
if cfg.RPN.ENABLED:
|
|
pts_rect, pts_features, pts_input = data['pts_rect'], data['pts_features'], data['pts_input']
|
|
gt_boxes3d = data['gt_boxes3d']
|
|
|
|
if not cfg.RPN.FIXED:
|
|
rpn_cls_label, rpn_reg_label = data['rpn_cls_label'], data['rpn_reg_label']
|
|
rpn_cls_label = torch.from_numpy(rpn_cls_label).cuda(non_blocking=True).long()
|
|
rpn_reg_label = torch.from_numpy(rpn_reg_label).cuda(non_blocking=True).float()
|
|
|
|
inputs = torch.from_numpy(pts_input).cuda(non_blocking=True).float()
|
|
gt_boxes3d = torch.from_numpy(gt_boxes3d).cuda(non_blocking=True).float()
|
|
input_data = {'pts_input': inputs, 'gt_boxes3d': gt_boxes3d}
|
|
else:
|
|
input_data = {}
|
|
for key, val in data.items():
|
|
if key != 'sample_id':
|
|
input_data[key] = torch.from_numpy(val).contiguous().cuda(non_blocking=True).float()
|
|
if not cfg.RCNN.ROI_SAMPLE_JIT:
|
|
pts_input = torch.cat((input_data['pts_input'], input_data['pts_features']), dim=-1)
|
|
input_data['pts_input'] = pts_input
|
|
|
|
ret_dict = model(input_data)
|
|
|
|
tb_dict = {}
|
|
disp_dict = {}
|
|
loss = 0
|
|
if cfg.RPN.ENABLED and not cfg.RPN.FIXED:
|
|
rpn_cls, rpn_reg = ret_dict['rpn_cls'], ret_dict['rpn_reg']
|
|
rpn_loss = get_rpn_loss(model, rpn_cls, rpn_reg, rpn_cls_label, rpn_reg_label, tb_dict)
|
|
loss += rpn_loss
|
|
disp_dict['rpn_loss'] = rpn_loss.item()
|
|
|
|
if cfg.RCNN.ENABLED:
|
|
rcnn_loss = get_rcnn_loss(model, ret_dict, tb_dict)
|
|
disp_dict['reg_fg_sum'] = tb_dict['rcnn_reg_fg']
|
|
loss += rcnn_loss
|
|
|
|
disp_dict['loss'] = loss.item()
|
|
|
|
return ModelReturn(loss, tb_dict, disp_dict)
|
|
|
|
def get_rpn_loss(model, rpn_cls, rpn_reg, rpn_cls_label, rpn_reg_label, tb_dict):
|
|
if isinstance(model, nn.DataParallel):
|
|
rpn_cls_loss_func = model.module.rpn.rpn_cls_loss_func
|
|
else:
|
|
rpn_cls_loss_func = model.rpn.rpn_cls_loss_func
|
|
|
|
rpn_cls_label_flat = rpn_cls_label.view(-1)
|
|
rpn_cls_flat = rpn_cls.view(-1)
|
|
fg_mask = (rpn_cls_label_flat > 0)
|
|
|
|
# RPN classification loss
|
|
if cfg.RPN.LOSS_CLS == 'DiceLoss':
|
|
rpn_loss_cls = rpn_cls_loss_func(rpn_cls, rpn_cls_label_flat)
|
|
|
|
elif cfg.RPN.LOSS_CLS == 'SigmoidFocalLoss':
|
|
rpn_cls_target = (rpn_cls_label_flat > 0).float()
|
|
pos = (rpn_cls_label_flat > 0).float()
|
|
neg = (rpn_cls_label_flat == 0).float()
|
|
cls_weights = pos + neg
|
|
pos_normalizer = pos.sum()
|
|
cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0)
|
|
rpn_loss_cls = rpn_cls_loss_func(rpn_cls_flat, rpn_cls_target, cls_weights)
|
|
rpn_loss_cls_pos = (rpn_loss_cls * pos).sum()
|
|
rpn_loss_cls_neg = (rpn_loss_cls * neg).sum()
|
|
rpn_loss_cls = rpn_loss_cls.sum()
|
|
tb_dict['rpn_loss_cls_pos'] = rpn_loss_cls_pos.item()
|
|
tb_dict['rpn_loss_cls_neg'] = rpn_loss_cls_neg.item()
|
|
|
|
elif cfg.RPN.LOSS_CLS == 'BinaryCrossEntropy':
|
|
weight = rpn_cls_flat.new(rpn_cls_flat.shape[0]).fill_(1.0)
|
|
weight[fg_mask] = cfg.RPN.FG_WEIGHT
|
|
rpn_cls_label_target = (rpn_cls_label_flat > 0).float()
|
|
batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rpn_cls_flat), rpn_cls_label_target,
|
|
weight=weight, reduction='none')
|
|
cls_valid_mask = (rpn_cls_label_flat >= 0).float()
|
|
rpn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
# RPN regression loss
|
|
point_num = rpn_reg.size(0) * rpn_reg.size(1)
|
|
fg_sum = fg_mask.long().sum().item()
|
|
if fg_sum != 0:
|
|
loss_loc, loss_angle, loss_size, reg_loss_dict = \
|
|
loss_utils.get_reg_loss(rpn_reg.view(point_num, -1)[fg_mask],
|
|
rpn_reg_label.view(point_num, 7)[fg_mask],
|
|
loc_scope=cfg.RPN.LOC_SCOPE,
|
|
loc_bin_size=cfg.RPN.LOC_BIN_SIZE,
|
|
num_head_bin=cfg.RPN.NUM_HEAD_BIN,
|
|
anchor_size=MEAN_SIZE,
|
|
get_xz_fine=cfg.RPN.LOC_XZ_FINE,
|
|
get_y_by_bin=False,
|
|
get_ry_fine=False)
|
|
|
|
loss_size = 3 * loss_size # consistent with old codes
|
|
rpn_loss_reg = loss_loc + loss_angle + loss_size
|
|
else:
|
|
loss_loc = loss_angle = loss_size = rpn_loss_reg = rpn_loss_cls * 0
|
|
|
|
rpn_loss = rpn_loss_cls * cfg.RPN.LOSS_WEIGHT[0] + rpn_loss_reg * cfg.RPN.LOSS_WEIGHT[1]
|
|
|
|
tb_dict.update({'rpn_loss_cls': rpn_loss_cls.item(), 'rpn_loss_reg': rpn_loss_reg.item(),
|
|
'rpn_loss': rpn_loss.item(), 'rpn_fg_sum': fg_sum, 'rpn_loss_loc': loss_loc.item(),
|
|
'rpn_loss_angle': loss_angle.item(), 'rpn_loss_size': loss_size.item()})
|
|
|
|
return rpn_loss
|
|
|
|
def get_rcnn_loss(model, ret_dict, tb_dict):
|
|
rcnn_cls, rcnn_reg = ret_dict['rcnn_cls'], ret_dict['rcnn_reg']
|
|
|
|
cls_label = ret_dict['cls_label'].float()
|
|
reg_valid_mask = ret_dict['reg_valid_mask']
|
|
roi_boxes3d = ret_dict['roi_boxes3d']
|
|
roi_size = roi_boxes3d[:, 3:6]
|
|
gt_boxes3d_ct = ret_dict['gt_of_rois']
|
|
pts_input = ret_dict['pts_input']
|
|
|
|
# rcnn classification loss
|
|
if isinstance(model, nn.DataParallel):
|
|
cls_loss_func = model.module.rcnn_net.cls_loss_func
|
|
else:
|
|
cls_loss_func = model.rcnn_net.cls_loss_func
|
|
|
|
cls_label_flat = cls_label.view(-1)
|
|
|
|
if cfg.RCNN.LOSS_CLS == 'SigmoidFocalLoss':
|
|
rcnn_cls_flat = rcnn_cls.view(-1)
|
|
|
|
cls_target = (cls_label_flat > 0).float()
|
|
pos = (cls_label_flat > 0).float()
|
|
neg = (cls_label_flat == 0).float()
|
|
cls_weights = pos + neg
|
|
pos_normalizer = pos.sum()
|
|
cls_weights = cls_weights / torch.clamp(pos_normalizer, min=1.0)
|
|
|
|
rcnn_loss_cls = cls_loss_func(rcnn_cls_flat, cls_target, cls_weights)
|
|
rcnn_loss_cls_pos = (rcnn_loss_cls * pos).sum()
|
|
rcnn_loss_cls_neg = (rcnn_loss_cls * neg).sum()
|
|
rcnn_loss_cls = rcnn_loss_cls.sum()
|
|
tb_dict['rpn_loss_cls_pos'] = rcnn_loss_cls_pos.item()
|
|
tb_dict['rpn_loss_cls_neg'] = rcnn_loss_cls_neg.item()
|
|
|
|
elif cfg.RCNN.LOSS_CLS == 'BinaryCrossEntropy':
|
|
rcnn_cls_flat = rcnn_cls.view(-1)
|
|
batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rcnn_cls_flat), cls_label, reduction='none')
|
|
cls_valid_mask = (cls_label_flat >= 0).float()
|
|
rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)
|
|
|
|
elif cfg.TRAIN.LOSS_CLS == 'CrossEntropy':
|
|
rcnn_cls_reshape = rcnn_cls.view(rcnn_cls.shape[0], -1)
|
|
cls_target = cls_label_flat.long()
|
|
cls_valid_mask = (cls_label_flat >= 0).float()
|
|
|
|
batch_loss_cls = cls_loss_func(rcnn_cls_reshape, cls_target)
|
|
normalizer = torch.clamp(cls_valid_mask.sum(), min=1.0)
|
|
rcnn_loss_cls = (batch_loss_cls.mean(dim=1) * cls_valid_mask).sum() / normalizer
|
|
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
# rcnn regression loss
|
|
batch_size = pts_input.shape[0]
|
|
fg_mask = (reg_valid_mask > 0)
|
|
fg_sum = fg_mask.long().sum().item()
|
|
if fg_sum != 0:
|
|
all_anchor_size = roi_size
|
|
anchor_size = all_anchor_size[fg_mask] if cfg.RCNN.SIZE_RES_ON_ROI else MEAN_SIZE
|
|
|
|
loss_loc, loss_angle, loss_size, reg_loss_dict = \
|
|
loss_utils.get_reg_loss(rcnn_reg.view(batch_size, -1)[fg_mask],
|
|
gt_boxes3d_ct.view(batch_size, 7)[fg_mask],
|
|
loc_scope=cfg.RCNN.LOC_SCOPE,
|
|
loc_bin_size=cfg.RCNN.LOC_BIN_SIZE,
|
|
num_head_bin=cfg.RCNN.NUM_HEAD_BIN,
|
|
anchor_size=anchor_size,
|
|
get_xz_fine=True, get_y_by_bin=cfg.RCNN.LOC_Y_BY_BIN,
|
|
loc_y_scope=cfg.RCNN.LOC_Y_SCOPE, loc_y_bin_size=cfg.RCNN.LOC_Y_BIN_SIZE,
|
|
get_ry_fine=True)
|
|
|
|
loss_size = 3 * loss_size # consistent with old codes
|
|
rcnn_loss_reg = loss_loc + loss_angle + loss_size
|
|
tb_dict.update(reg_loss_dict)
|
|
else:
|
|
loss_loc = loss_angle = loss_size = rcnn_loss_reg = rcnn_loss_cls * 0
|
|
|
|
rcnn_loss = rcnn_loss_cls + rcnn_loss_reg
|
|
tb_dict['rcnn_loss_cls'] = rcnn_loss_cls.item()
|
|
tb_dict['rcnn_loss_reg'] = rcnn_loss_reg.item()
|
|
tb_dict['rcnn_loss'] = rcnn_loss.item()
|
|
|
|
tb_dict['rcnn_loss_loc'] = loss_loc.item()
|
|
tb_dict['rcnn_loss_angle'] = loss_angle.item()
|
|
tb_dict['rcnn_loss_size'] = loss_size.item()
|
|
|
|
tb_dict['rcnn_cls_fg'] = (cls_label > 0).sum().item()
|
|
tb_dict['rcnn_cls_bg'] = (cls_label == 0).sum().item()
|
|
tb_dict['rcnn_reg_fg'] = reg_valid_mask.sum().item()
|
|
|
|
return rcnn_loss
|
|
|
|
return model_fn
|