Skip to content

Helios object detection accuracy debugging (incl DETR example configs) #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions data/helios/satlas_marine_infra/baseline_satlaspretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
model:
class_path: rslearn.train.lightning_module.RslearnLightningModule
init_args:
model:
class_path: rslearn.models.multitask.MultiTaskModel
init_args:
encoder:
- class_path: rslearn.models.swin.Swin
init_args:
pretrained: true
input_channels: 9
output_layers: [1, 3, 5, 7]
- class_path: rslearn.models.fpn.Fpn
init_args:
in_channels: [128, 256, 512, 1024]
out_channels: 128
decoders:
detect:
- class_path: rslearn.models.faster_rcnn.FasterRCNN
init_args:
downsample_factors: [4, 8, 16, 32]
num_channels: 128
num_classes: 3
anchor_sizes: [[32], [64], [128], [256]]
lr: 0.0001
plateau: true
plateau_factor: 0.2
plateau_patience: 2
plateau_min_lr: 0
plateau_cooldown: 10
restore_config:
restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth
remap_prefixes:
- ["backbone.backbone.backbone.", "encoder.0.model."]
data:
class_path: rslearn.train.data_module.RslearnDataModule
init_args:
path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/
inputs:
image:
data_type: "raster"
layers: ["sentinel2"]
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
passthrough: true
dtype: FLOAT32
mask:
data_type: "raster"
layers: ["mask"]
bands: ["mask"]
passthrough: true
dtype: FLOAT32
is_target: true
targets:
data_type: "vector"
layers: ["label"]
is_target: true
task:
class_path: rslearn.train.tasks.multi_task.MultiTask
init_args:
tasks:
detect:
class_path: rslp.satlas.train.MarineInfraTask
init_args:
property_name: "category"
classes: ["unknown", "platform", "turbine"]
box_size: 15
remap_values: [[0, 0.25], [0, 255]]
image_bands: [2, 1, 0]
exclude_by_center: true
enable_map_metric: true
enable_f1_metric: true
f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]]
skip_unknown_categories: true
f1_metric_kwargs:
cmp_mode: "distance"
cmp_threshold: 15
flatten_classes: true
input_mapping:
detect:
targets: "targets"
batch_size: 4
num_workers: 32
default_config:
transforms:
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 3000
valid_range: [0, 1]
bands: [0, 1, 2]
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 8160
valid_range: [0, 1]
bands: [3, 4, 5, 6, 7, 8]
- class_path: rslp.transforms.mask.Mask
train_config:
patch_size: 256
tags:
split: train
nonempty: "yes"
val_config:
patch_size: 256
tags:
split: val
nonempty: "yes"
test_config:
patch_size: 256
tags:
split: val
nonempty: "yes"
trainer:
max_epochs: 500
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: "epoch"
- class_path: rslearn.train.prediction_writer.RslearnWriter
init_args:
path: placeholder
output_layer: output
selector: ["detect"]
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
save_top_k: 1
save_last: true
monitor: val_detect/mAP
mode: max
rslp_project: helios_finetuning
rslp_experiment: 20250404_marine_satlaspretrain_swinb_00
133 changes: 133 additions & 0 deletions data/helios/satlas_marine_infra/baseline_satlaspretrain_128.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
model:
class_path: rslearn.train.lightning_module.RslearnLightningModule
init_args:
model:
class_path: rslearn.models.multitask.MultiTaskModel
init_args:
encoder:
- class_path: rslearn.models.swin.Swin
init_args:
pretrained: true
input_channels: 9
output_layers: [1, 3, 5, 7]
- class_path: rslearn.models.fpn.Fpn
init_args:
in_channels: [128, 256, 512, 1024]
out_channels: 128
decoders:
detect:
- class_path: rslearn.models.faster_rcnn.FasterRCNN
init_args:
downsample_factors: [4, 8, 16, 32]
num_channels: 128
num_classes: 3
anchor_sizes: [[32], [64], [128], [256]]
lr: 0.0001
plateau: true
plateau_factor: 0.2
plateau_patience: 2
plateau_min_lr: 0
plateau_cooldown: 10
restore_config:
restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth
remap_prefixes:
- ["backbone.backbone.backbone.", "encoder.0.model."]
data:
class_path: rslearn.train.data_module.RslearnDataModule
init_args:
path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/
inputs:
image:
data_type: "raster"
layers: ["sentinel2"]
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
passthrough: true
dtype: FLOAT32
mask:
data_type: "raster"
layers: ["mask"]
bands: ["mask"]
passthrough: true
dtype: FLOAT32
is_target: true
targets:
data_type: "vector"
layers: ["label"]
is_target: true
task:
class_path: rslearn.train.tasks.multi_task.MultiTask
init_args:
tasks:
detect:
class_path: rslp.satlas.train.MarineInfraTask
init_args:
property_name: "category"
classes: ["unknown", "platform", "turbine"]
box_size: 15
remap_values: [[0, 0.25], [0, 255]]
image_bands: [2, 1, 0]
exclude_by_center: true
enable_map_metric: true
enable_f1_metric: true
f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]]
skip_unknown_categories: true
f1_metric_kwargs:
cmp_mode: "distance"
cmp_threshold: 15
flatten_classes: true
input_mapping:
detect:
targets: "targets"
batch_size: 4
num_workers: 32
default_config:
transforms:
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 3000
valid_range: [0, 1]
bands: [0, 1, 2]
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 8160
valid_range: [0, 1]
bands: [3, 4, 5, 6, 7, 8]
- class_path: rslp.transforms.mask.Mask
train_config:
patch_size: 128
tags:
split: train
nonempty: "yes"
val_config:
patch_size: 128
load_all_patches: true
tags:
split: val
nonempty: "yes"
test_config:
patch_size: 128
load_all_patches: true
tags:
split: val
nonempty: "yes"
trainer:
max_epochs: 500
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: "epoch"
- class_path: rslearn.train.prediction_writer.RslearnWriter
init_args:
path: placeholder
output_layer: output
selector: ["detect"]
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
save_top_k: 1
save_last: true
monitor: val_detect/mAP
mode: max
rslp_project: helios_finetuning
rslp_experiment: 20250514_marine_satlaspretrain_swinb_128_00
Loading
Loading