Skip to content

experiment: our model #3109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
172 changes: 172 additions & 0 deletions configs/mvxnet/SqueezeFPN_FireRPFNet_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# MVXNet with custom backbones(SqueezeFPN + FireRPFNet): Efficient Multi-Modal 3D Object Detection

## Abstract

This project focuses on developing a computationally efficient and lightweight 3D object detection model for autonomous vehicles. By experimenting with and modifying the backbone architecture of existing fusion-based models, this project aims to reduce computational demand while maintaining or improving detection accuracy. The key research involved exploring the effectiveness of lightweight architectures like SqueezeNet, EfficientNet, and MobileNet for images, whereas VoxelNext and custom models, such as Residual Pilar Feature Network (RPFNet) and Fire Residual Pilar Feature Network (FireRPFNet) for LiDAR data processing. RPFNet and Fire RPFNet are our proposed solutions, which are inspired by fire module from SqueezeNet and Convolution Block Attention Module (CBAM), to processor lidar point cloud. The potential impact of this project extends to safer and more efficient autonomous driving, contributing to the advancement of intelligent transportation systems.

## Overview

Our model introduces a novel, computationally efficient approach to 3D object detection through several key innovations:

### Lightweight Image Processing
- **SqueezeFPN**: A highly efficient feature pyramid network adapted from SqueezeNet's architecture
- **Optimized Feature Extraction**: Carefully balanced network depth and width for optimal performance
- **Memory-Efficient Design**: Reduced parameter count while maintaining high feature quality

### Advanced LiDAR Processing
- **FireRPFNet**: Our custom-designed backbone network that combines:
- Fire modules from SqueezeNet for efficient feature extraction
- Convolution Block Attention Module (CBAM) for enhanced feature focus
- Residual connections for improved gradient flow
- **Dynamic Voxel Feature Encoder**: Adaptive point cloud feature learning that adjusts to input density

### Efficient Multi-Modal Fusion
- Early fusion strategy for optimal feature integration
- Balanced computational resource allocation between modalities
- Adaptive feature aggregation for robust object detection

## Architecture

<div align=center>
The model consists of three main components:

1. **Image Branch**:
- SqueezeFPN backbone
- Efficient feature pyramid with squeeze and expand layers
- Multi-scale feature maps [512, 512, 512, 512]

2. **LiDAR Branch**:
- Dynamic voxel feature encoder
- FireRPFNet backbone with CBAM attention
- Efficient feature processing with Fire modules

3. **Multi-modal Fusion**:
- Early fusion of image and point cloud features
- Point-wise feature fusion
- Adaptive feature aggregation
</div>

## Results on KITTI Dataset

### AP@40 IoU Results (Car)

| Metric | Easy | Moderate | Hard |
|:------:|:----:|:--------:|:----:|
| 3D Detection | 97.37 | 91.93 | 89.53 |
| Bird's Eye View | 97.48 | 92.29 | 89.92 |
| 2D Detection | 95.39 | 88.64 | 84.56 |
| AOS | 94.92 | 87.22 | 82.64 |

### Comparison with State-of-the-Art

#### 3D Detection AP@40 (%)

| Model | Easy | Moderate | Hard |
|:-----:|:----:|:--------:|:----:|
| **Ours (SqueezeFPN+FireRPFNet)** | **97.37** | **91.93** | **89.53** |
| [TRTConv](https://www.cvlibs.net/datasets/kitti/eval_object_detail.php?&result=30bdb9fd69e93886221650a590744f76bbb3d773) | 91.90 | 85.04 | 80.38 |
| [GLENet-VR](https://paperswithcode.com/paper/glenet-boosting-3d-object-detectors-with/review/?hl=60125) | 91.67 | 83.23 | 78.43 |
| [SE-SSD](https://paperswithcode.com/paper/cia-ssd-confident-iou-aware-single-stage/review/?hl=34051) | 91.49 | 82.54 | 77.15 |

#### Bird's Eye View AP@40 (%)

| Model | Easy | Moderate | Hard |
|:-----:|:----:|:--------:|:----:|
| **Ours (SqueezeFPN+FireRPFNet)** | **97.48** | **92.29** | **89.92** | [Pre-trained Model (25 epochs)](https://drive.google.com/file/d/19h9m7tb1bX-W1ZViocx4J8knanz6PuLz/view?usp=drive_link) |
| [SE-SSD](https://paperswithcode.com/paper/cia-ssd-confident-iou-aware-single-stage/review/?hl=34051) | 96.59 | 92.28 | 89.72 |
| [TRTConv](https://www.cvlibs.net/datasets/kitti/eval_object_detail.php?&result=30bdb9fd69e93886221650a590744f76bbb3d773) | 95.55 | 92.04 | 87.23 |
| [PV-RCNN](https://paperswithcode.com/paper/pv-rcnn-point-voxel-feature-set-abstraction/review/?hl=13979) | 94.98 | 90.65 | 86.14 |

## Key Features

1. **Efficient Architecture**
- Lightweight SqueezeFPN for image feature extraction
- Memory-efficient FireRPFNet for point cloud processing
- Dynamic voxel encoding for adaptive feature learning

2. **State-of-the-Art Performance**
- Achieves top performance in both 3D detection and Bird's Eye View
- Significant improvements over existing methods
- Robust performance across different difficulty levels

3. **Multi-modal Fusion**
- Early fusion strategy for better feature integration
- Effective combination of image and LiDAR information
- Enhanced feature representation for accurate detection

## Training and Testing

### Training
```shell
# Single GPU training
python tools/train.py configs/mvxnet/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py

# Multi-GPU training
TORCH_DISTRIBUTED_DEBUG=INFO ./tools/dist_train.sh configs/mvxnet/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py 8
```

### Testing
```shell
# Single GPU testing
python tools/test.py configs/mvxnet/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py ${CHECKPOINT_FILE}

# Multi-GPU testing
./tools/dist_test.sh configs/mvxnet/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py ${CHECKPOINT_FILE} 8
```

### inference command example
update sample details before running command
```
python demo/multi_modality_demo.py \
data/kitti/testing/velodyne/000068.bin \
data/kitti/testing/image_2/000068.png \
data/kitti/kitti_infos_test.pkl \
work_dirs/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py \
work_dirs/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class/epoch_25.pth
```

### Pre-trained Model

We provide a pre-trained model for 25 epochs on the KITTI dataset:
- [Download Pre-trained Model (25 epochs)](https://drive.google.com/file/d/19h9m7tb1bX-W1ZViocx4J8knanz6PuLz/view?usp=drive_link)



#### Usage Instructions
```shell
# Directly test the pre-trained model
python tools/test.py configs/mvxnet/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py /path/to/downloaded/checkpoint.pth

# Resume training from the pre-trained model
python tools/train.py configs/mvxnet/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py --resume /path/to/downloaded/checkpoint.pth
```

## Model Configuration

The model configuration can be found in:
```
configs/mvxnet/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py
```

Key configuration details:
- Backbone: SqueezeFPN + FireRPFNet
- Training schedule: 25 epochs with cosine learning rate
- Input: Multi-modal (LiDAR point cloud + Camera image)
- Classes: Pedestrian, Cyclist , Car
- Voxel size: [0.05, 0.05, 0.1]
- Point cloud range: [0, -40, -3, 70.4, 40, 1]

## Additional Resources

### Repository and Model Files
- [Complete Project Resources](https://drive.google.com/drive/folders/161GQlhavUursme3voNuQaF1YGtYi9E6a)
- Pre-trained models
- Logs
- Additional configuration files
- Supplementary materials

### Quick Access
- [25-Epoch Pre-trained Model](https://drive.google.com/file/d/19h9m7tb1bX-W1ZViocx4J8knanz6PuLz/view?usp=drive_link)
- Configuration File: `configs/mvxnet/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py`

**Note**: All resources are subject to the project's licensing terms. Please review and comply with usage guidelines.
220 changes: 220 additions & 0 deletions configs/mvxnet/mvxnet_sqeezefpn_fire_rpfnet_kitti-3d-3class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# MVX-Net (SqueezeFPN camera branch + FireRPFNet lidar)
# for KITTI 3-class.

_base_ = ['../_base_/schedules/cosine.py', '../_base_/default_runtime.py']

# -----------------------------------------------------------------------------
# Geometry
# -----------------------------------------------------------------------------
voxel_size = [0.05, 0.05, 0.1]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]

# -----------------------------------------------------------------------------
# Model
# -----------------------------------------------------------------------------
model = dict(
type='DynamicMVXFasterRCNN',
# --------------------------------------------------
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_type='dynamic',
voxel_layer=dict(
max_num_points=-1,
point_cloud_range=point_cloud_range,
voxel_size=voxel_size,
max_voxels=(-1, -1)),
mean=[102.9801, 115.9465, 122.7717],
std=[1.0, 1.0, 1.0],
bgr_to_rgb=False,
pad_size_divisor=32),

# ----------------------- image branch -----------------------
img_backbone=dict(
type='SQUEEZE',
in_channels=3,
out_channels=[64, 128, 256, 512],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
conv_cfg=dict(type='Conv2d', bias=False)),
img_neck=dict(
type='SQUEEZEFPN',
in_channels=[64, 128, 256, 512],
out_channels=[512, 512, 512, 512],
norm_cfg=dict(type='BN', requires_grad=False)),

# ----------------------- LiDAR voxel encoder ----------------
pts_voxel_encoder=dict(
type='DynamicVFE',
in_channels=4,
feat_channels=[64, 64],
with_distance=False,
voxel_size=voxel_size,
with_cluster_center=True,
with_voxel_center=True,
point_cloud_range=point_cloud_range,
fusion_layer=dict(
type='PointFusion',
img_channels=512,
pts_channels=64,
mid_channels=128,
out_channels=128,
img_levels=[0, 1, 2, 3],
align_corners=False,
activate_out=True,
fuse_out=False)),

# ----------------------- Sparse middle encoder --------------
pts_middle_encoder=dict(
type='SparseEncoder',
in_channels=128,
sparse_shape=[41, 1600, 1408],
order=('conv', 'norm', 'act')),

# ----------------------- FireRPFNet backbone -------------
pts_backbone=dict(
type='FireRPFNet',
in_channels=256, # output of SparseEncoder
layer_channels=[128, 256, 256, 256],
with_cbam=True),

pts_neck=None, # RPFNet is already deep enough

# ----------------------- Anchor head ------------------------
pts_bbox_head=dict(
type='Anchor3DHead',
num_classes=3,
in_channels=256,
feat_channels=256,
use_direction_classifier=True,
anchor_generator=dict(
type='Anchor3DRangeGenerator',
ranges=[
[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -1.78, 70.4, 40.0, -1.78],
],
sizes=[[0.8, 0.6, 1.73], [1.76, 0.6, 1.73], [3.9, 1.6, 1.56]],
rotations=[0, 1.57],
reshape_out=False),
assigner_per_size=True,
diff_rad_by_sin=True,
assign_per_class=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict(
type='mmdet.FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='mmdet.SmoothL1Loss', beta=1.0 / 9.0,
loss_weight=2.0),
loss_dir=dict(type='mmdet.CrossEntropyLoss', use_sigmoid=False,
loss_weight=0.2)),

# ----------------------- Train / Test cfg -------------------
train_cfg=dict(
pts=dict(
assigner=[
dict(type='Max3DIoUAssigner', iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.35, neg_iou_thr=0.2, min_pos_iou=0.2,
ignore_iof_thr=-1),
dict(type='Max3DIoUAssigner', iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.35, neg_iou_thr=0.2, min_pos_iou=0.2,
ignore_iof_thr=-1),
dict(type='Max3DIoUAssigner', iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6, neg_iou_thr=0.45, min_pos_iou=0.45,
ignore_iof_thr=-1),
],
allowed_border=0, pos_weight=-1, debug=False)),

test_cfg=dict(
pts=dict(use_rotate_nms=True, nms_across_levels=False, nms_thr=0.01,
score_thr=0.1, min_bbox_size=0, nms_pre=100, max_num=50))
)

# -----------------------------------------------------------------------------
# Dataset & pipelines
# -----------------------------------------------------------------------------

# dataset settings
dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
class_names = ['Pedestrian', 'Cyclist', 'Car']
metainfo = dict(classes=class_names)
input_modality = dict(use_lidar=True, use_camera=True)
backend_args = None

train_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4,
backend_args=backend_args),
dict(type='LoadImageFromFile', backend_args=backend_args),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True,
with_bbox=True, with_label=True),
dict(type='RandomResize', scale=[(320, 96), (1280, 384)], keep_ratio=True),
dict(type='GlobalRotScaleTrans', rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05], translation_std=[0.2, 0.2, 0.2]),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
dict(type='Pack3DDetInputs', keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d',
'gt_bboxes', 'gt_labels'])
]

test_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4,
backend_args=backend_args),
dict(type='LoadImageFromFile', backend_args=backend_args),
dict(type='MultiScaleFlipAug3D', img_scale=(1280, 384), pts_scale_ratio=1,
flip=False,
transforms=[
dict(type='Resize', scale=0, keep_ratio=True),
dict(type='GlobalRotScaleTrans', rot_range=[0, 0], scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
]),
dict(type='Pack3DDetInputs', keys=['points', 'img'])
]

modality = dict(use_lidar=True, use_camera=True)

train_dataloader = dict(
batch_size=2, num_workers=2, sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(type='RepeatDataset', times=2, dataset=dict(
type=dataset_type, data_root=data_root, modality=modality,
ann_file='kitti_infos_train.pkl',
data_prefix=dict(pts='training/velodyne_reduced', img='training/image_2'),
pipeline=train_pipeline, filter_empty_gt=False, metainfo=metainfo,
box_type_3d='LiDAR', backend_args=backend_args)))

val_dataloader = dict(
batch_size=1, num_workers=1, sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(type=dataset_type, data_root=data_root, modality=modality,
ann_file='kitti_infos_val.pkl',
data_prefix=dict(pts='training/velodyne_reduced', img='training/image_2'),
pipeline=test_pipeline, metainfo=metainfo, test_mode=True,
box_type_3d='LiDAR', backend_args=backend_args))

test_dataloader = val_dataloader

# -----------------------------------------------------------------------------
# Optimizer / Schedulers / Runtime
# -----------------------------------------------------------------------------
optim_wrapper = dict(optimizer=dict(lr=0.001, weight_decay=0.01),
clip_grad=dict(max_norm=35, norm_type=2))

val_evaluator = dict(type='KittiMetric', ann_file='data/kitti/kitti_infos_val.pkl')


# optim_wrapper = dict(
# optimizer=dict(weight_decay=0.01),
# clip_grad=dict(max_norm=35, norm_type=2),
# )
# val_evaluator = dict(
# type='KittiMetric', ann_file='data/kitti/kitti_infos_val.pkl')

test_evaluator = val_evaluator

vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=20, val_interval=5)
Loading