mirror of
https://github.com/zhigang1992/PointRCNN.git
synced 2026-01-13 17:22:54 +08:00
71 lines
2.7 KiB
Python
71 lines
2.7 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from lib.net.rpn import RPN
|
|
from lib.net.rcnn_net import RCNNNet
|
|
from lib.config import cfg
|
|
|
|
|
|
class PointRCNN(nn.Module):
|
|
def __init__(self, num_classes, use_xyz=True, mode='TRAIN'):
|
|
super().__init__()
|
|
|
|
assert cfg.RPN.ENABLED or cfg.RCNN.ENABLED
|
|
|
|
if cfg.RPN.ENABLED:
|
|
self.rpn = RPN(use_xyz=use_xyz, mode=mode)
|
|
|
|
if cfg.RCNN.ENABLED:
|
|
rcnn_input_channels = 128 # channels of rpn features
|
|
if cfg.RCNN.BACKBONE == 'pointnet':
|
|
self.rcnn_net = RCNNNet(num_classes=num_classes, input_channels=rcnn_input_channels, use_xyz=use_xyz)
|
|
elif cfg.RCNN.BACKBONE == 'pointsift':
|
|
pass
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def forward(self, input_data):
|
|
if cfg.RPN.ENABLED:
|
|
output = {}
|
|
# rpn inference
|
|
with torch.set_grad_enabled((not cfg.RPN.FIXED) and self.training):
|
|
if cfg.RPN.FIXED:
|
|
self.rpn.eval()
|
|
rpn_output = self.rpn(input_data)
|
|
output.update(rpn_output)
|
|
|
|
# rcnn inference
|
|
if cfg.RCNN.ENABLED:
|
|
with torch.no_grad():
|
|
rpn_cls, rpn_reg = rpn_output['rpn_cls'], rpn_output['rpn_reg']
|
|
backbone_xyz, backbone_features = rpn_output['backbone_xyz'], rpn_output['backbone_features']
|
|
|
|
rpn_scores_raw = rpn_cls[:, :, 0]
|
|
rpn_scores_norm = torch.sigmoid(rpn_scores_raw)
|
|
seg_mask = (rpn_scores_norm > cfg.RPN.SCORE_THRESH).float()
|
|
pts_depth = torch.norm(backbone_xyz, p=2, dim=2)
|
|
|
|
# proposal layer
|
|
rois, roi_scores_raw = self.rpn.proposal_layer(rpn_scores_raw, rpn_reg, backbone_xyz) # (B, M, 7)
|
|
|
|
output['rois'] = rois
|
|
output['roi_scores_raw'] = roi_scores_raw
|
|
output['seg_result'] = seg_mask
|
|
|
|
rcnn_input_info = {'rpn_xyz': backbone_xyz,
|
|
'rpn_features': backbone_features.permute((0, 2, 1)),
|
|
'seg_mask': seg_mask,
|
|
'roi_boxes3d': rois,
|
|
'pts_depth': pts_depth}
|
|
if self.training:
|
|
rcnn_input_info['gt_boxes3d'] = input_data['gt_boxes3d']
|
|
|
|
rcnn_output = self.rcnn_net(rcnn_input_info)
|
|
output.update(rcnn_output)
|
|
|
|
elif cfg.RCNN.ENABLED:
|
|
output = self.rcnn_net(input_data)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return output
|