mirror of
https://github.com/zhigang1992/PointRCNN.git
synced 2026-01-12 22:49:40 +08:00
71 lines
2.4 KiB
Python
71 lines
2.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from pointnet2_lib.pointnet2.pointnet2_modules import PointnetFPModule, PointnetSAModuleMSG
|
|
from lib.config import cfg
|
|
|
|
|
|
def get_model(input_channels=6, use_xyz=True):
|
|
return Pointnet2MSG(input_channels=input_channels, use_xyz=use_xyz)
|
|
|
|
|
|
class Pointnet2MSG(nn.Module):
|
|
def __init__(self, input_channels=6, use_xyz=True):
|
|
super().__init__()
|
|
|
|
self.SA_modules = nn.ModuleList()
|
|
channel_in = input_channels
|
|
|
|
skip_channel_list = [input_channels]
|
|
for k in range(cfg.RPN.SA_CONFIG.NPOINTS.__len__()):
|
|
mlps = cfg.RPN.SA_CONFIG.MLPS[k].copy()
|
|
channel_out = 0
|
|
for idx in range(mlps.__len__()):
|
|
mlps[idx] = [channel_in] + mlps[idx]
|
|
channel_out += mlps[idx][-1]
|
|
|
|
self.SA_modules.append(
|
|
PointnetSAModuleMSG(
|
|
npoint=cfg.RPN.SA_CONFIG.NPOINTS[k],
|
|
radii=cfg.RPN.SA_CONFIG.RADIUS[k],
|
|
nsamples=cfg.RPN.SA_CONFIG.NSAMPLE[k],
|
|
mlps=mlps,
|
|
use_xyz=use_xyz,
|
|
bn=cfg.RPN.USE_BN
|
|
)
|
|
)
|
|
skip_channel_list.append(channel_out)
|
|
channel_in = channel_out
|
|
|
|
self.FP_modules = nn.ModuleList()
|
|
|
|
for k in range(cfg.RPN.FP_MLPS.__len__()):
|
|
pre_channel = cfg.RPN.FP_MLPS[k + 1][-1] if k + 1 < len(cfg.RPN.FP_MLPS) else channel_out
|
|
self.FP_modules.append(
|
|
PointnetFPModule(mlp=[pre_channel + skip_channel_list[k]] + cfg.RPN.FP_MLPS[k])
|
|
)
|
|
|
|
def _break_up_pc(self, pc):
|
|
xyz = pc[..., 0:3].contiguous()
|
|
features = (
|
|
pc[..., 3:].transpose(1, 2).contiguous()
|
|
if pc.size(-1) > 3 else None
|
|
)
|
|
|
|
return xyz, features
|
|
|
|
def forward(self, pointcloud: torch.cuda.FloatTensor):
|
|
xyz, features = self._break_up_pc(pointcloud)
|
|
|
|
l_xyz, l_features = [xyz], [features]
|
|
for i in range(len(self.SA_modules)):
|
|
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
|
|
l_xyz.append(li_xyz)
|
|
l_features.append(li_features)
|
|
|
|
for i in range(-1, -(len(self.FP_modules) + 1), -1):
|
|
l_features[i - 1] = self.FP_modules[i](
|
|
l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
|
|
)
|
|
|
|
return l_xyz[0], l_features[0]
|