diff --git a/.github/workflows/premerge-cpu.yml b/.github/workflows/premerge-cpu.yml index d365d9bb..5d204dcb 100644 --- a/.github/workflows/premerge-cpu.yml +++ b/.github/workflows/premerge-cpu.yml @@ -16,29 +16,70 @@ jobs: premerge-cpu: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - name: Checkout code + uses: actions/checkout@v4 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: + python-version: 3.10.14 # This will be the default Python + + - name: Set up Miniconda + uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + # The python-version here is for Conda's base, if activated. + # We set auto-activate-base to false, so actions/setup-python's version remains default. python-version: 3.10.14 + auto-activate-base: false # Crucial: ensures the Python 3.10 from actions/setup-python is the default + + - name: Initialize Conda for shell integration + shell: bash + run: | + conda init bash + eval "$(conda shell.bash hook)" + - name: cache weekly timestamp - id: pip-cache + id: pip-cache-ts run: | - echo "::set-output name=datew::$(date '+%Y-%V')" - - name: cache for pip + echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT + + - name: Get pip cache directory + id: pip-cache-dir + # This runs using the default Python (3.10) + shell: bash + run: echo "dir=$(python -m pip cache dir)" >> $GITHUB_OUTPUT + + - name: Cache pip dependencies (for Python 3.10 global env) + uses: actions/cache@v4 + id: cache-pip + with: + path: ${{ steps.pip-cache-dir.outputs.dir }} + key: ${{ runner.os }}-pip-3.10-${{ steps.pip-cache-ts.outputs.datew }}-${{ hashFiles('**/requirements.txt', '**/setup.py') }} + restore-keys: | + ${{ runner.os }}-pip-3.10-${{ steps.pip-cache-ts.outputs.datew }}- + + # Cache Conda packages (for environments your bash script might create) + # setup-miniconda action sets MINICONDA_PATH_0 environment variable + - name: Cache Conda packages uses: actions/cache@v4 - id: cache + id: cache-conda with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - - name: Install dependencies + path: ${{ env.MINICONDA_PATH_0 }}/pkgs + key: ${{ runner.os }}-conda-pkgs-${{ steps.pip-cache-ts.outputs.datew }}-${{ hashFiles('**/environment*.yml', '**/requirements*.txt') }} # Adjust hashFiles as needed + restore-keys: | + ${{ runner.os }}-conda-pkgs-${{ steps.pip-cache-ts.outputs.datew }}- + + - name: Install global Python dependencies (using Python 3.10) + shell: bash run: | python -m pip install --upgrade pip wheel python -m pip install --upgrade setuptools - - name: check + + - name: Run check script + shell: bash -el {0} run: | # clean up temporary files $(pwd)/runtests.sh --clean df -h - bash $(pwd)/ci/run_premerge_cpu.sh changed - shell: bash + bash $(pwd)/ci/run_premerge_cpu.sh diff --git a/ci/bundle_custom_data.py b/ci/bundle_custom_data.py index 30aaa2b9..4d7d58f8 100644 --- a/ci/bundle_custom_data.py +++ b/ci/bundle_custom_data.py @@ -23,11 +23,16 @@ "maisi_ct_generative", "cxr_image_synthesis_latent_diffusion_model", "brain_image_synthesis_latent_diffusion_model", + "retinalOCT_RPD_segmentation", ] # This list is used for our CI tests to determine whether a bundle contains the preferred files. # If a bundle does not have any of the preferred files, please add the bundle name into the list. -exclude_verify_preferred_files_list = ["pediatric_abdominal_ct_segmentation", "maisi_ct_generative"] +exclude_verify_preferred_files_list = [ + "pediatric_abdominal_ct_segmentation", + "maisi_ct_generative", + "retinalOCT_RPD_segmentation", +] # This list is used for our CI tests to determine whether a bundle needs to be tested with # the `verify_export_torchscript` function in `verify_bundle.py`. @@ -47,6 +52,7 @@ "mednist_ddpm", "cxr_image_synthesis_latent_diffusion_model", "brain_image_synthesis_latent_diffusion_model", + "retinalOCT_RPD_segmentation", ] # This list is used for our CI tests to determine whether a bundle needs to be tested after downloading @@ -58,7 +64,9 @@ # This dict is used for our CI tests to install required dependencies that cannot be installed by `pip install` directly. # If a bundle has this kind of dependencies, please add the bundle name (key), and the path of the install script (value) # into the dict. -install_dependency_dict = {} +install_dependency_dict = { + "retinalOCT_RPD_segmentation": "ci/install_scripts/install_retinalOCT_RPD_segmentation_dependency.sh" +} # This list is used for our CI tests to determine whether a bundle supports TensorRT export. Related # test will be employed for bundles in the dict. diff --git a/ci/get_bundle_requirements.py b/ci/get_bundle_requirements.py index 9262b112..d70d3bf0 100644 --- a/ci/get_bundle_requirements.py +++ b/ci/get_bundle_requirements.py @@ -19,6 +19,8 @@ ALLOW_MONAI_RC = os.environ.get("ALLOW_MONAI_RC", "false").lower() in ("true", "1", "t", "y", "yes") +special_dependencies_list = ["detectron2"] + def increment_version(version): """ @@ -79,6 +81,8 @@ def get_requirements(bundle, models_path, requirements_file): if package_key in metadata.keys(): optional_dict = metadata[package_key] for name, version in optional_dict.items(): + if name in special_dependencies_list: + continue libs.append(f"{name}=={version}") if len(libs) > 0: diff --git a/ci/install_scripts/install_retinalOCT_RPD_segmentation_dependency.sh b/ci/install_scripts/install_retinalOCT_RPD_segmentation_dependency.sh new file mode 100644 index 00000000..74fae23c --- /dev/null +++ b/ci/install_scripts/install_retinalOCT_RPD_segmentation_dependency.sh @@ -0,0 +1,2 @@ +# ensure that using evironment with python==3.9 +python -m pip install 'git+https://github.com/facebookresearch/detectron2.git@65184fc057d4fab080a98564f6b60fae0b94edc4' diff --git a/ci/run_premerge_cpu.sh b/ci/run_premerge_cpu.sh index 000424c3..1e6ddad8 100755 --- a/ci/run_premerge_cpu.sh +++ b/ci/run_premerge_cpu.sh @@ -30,6 +30,12 @@ elif [[ $# -gt 1 ]]; then exit 1 fi +# Bunles that requires special python version +declare -A bundle_python_versions=( + ["retinalOCT_RPD_segmentation"]="3.9" +) +DEFAULT_PYTHON_VERSION_FOR_VENV="3.10" + # Usually, CPU test is required, but for some bundles that are too large to run in Github Actions, we can exclude them. exclude_test_list=("maisi_ct_generative") is_excluded() { @@ -41,22 +47,23 @@ is_excluded() { return 1 # Return false (1) if not excluded } +install_common_deps_in_activated_env() { + python -m pip install --upgrade pip wheel + python -m pip install --upgrade setuptools + python -m pip install jsonschema gdown pyyaml parameterized fire + export PYTHONPATH=$PWD +} + init_venv() { if [ ! -d "model_zoo_venv" ]; then # Check if the venv directory does not exist echo "initializing pip environment" python -m venv model_zoo_venv source model_zoo_venv/bin/activate - pip install --upgrade pip wheel - pip install --upgrade setuptools - pip install jsonschema gdown pyyaml parameterized fire - export PYTHONPATH=$PWD + install_common_deps_in_activated_env else echo "Virtual environment model_zoo_venv already exists. Activating..." source model_zoo_venv/bin/activate - pip install --upgrade pip wheel - pip install --upgrade setuptools - pip install jsonschema gdown pyyaml parameterized fire - export PYTHONPATH=$PWD + install_common_deps_in_activated_env fi } @@ -70,6 +77,42 @@ remove_venv() { fi } +init_conda_env() { + local python_version_to_create="$1" + local bundle_identifier="$2" + local conda_env_name="conda_env_${bundle_identifier}" + + # Always source conda.sh to ensure conda activate is available + if [ -n "$CONDA_EXE" ] && [ -f "$(dirname "$CONDA_EXE")/../etc/profile.d/conda.sh" ]; then + source "$(dirname "$CONDA_EXE")/../etc/profile.d/conda.sh" + elif [ -n "$MINICONDA_PATH_0" ] && [ -f "$MINICONDA_PATH_0/etc/profile.d/conda.sh" ]; then + source "$MINICONDA_PATH_0/etc/profile.d/conda.sh" + else + echo "Warning: Could not reliably source conda.sh for Conda activation." + fi + + if conda env list | grep -q "^${conda_env_name}[[:space:]]"; then + echo "Conda env '$conda_env_name' already exists. Removing for a clean start..." + conda env remove -n "$conda_env_name" -y + fi + + conda create -n "$conda_env_name" python="$python_version_to_create" -y + conda activate "$conda_env_name" + install_common_deps_in_activated_env + conda deactivate 2>/dev/null || true +} + +remove_conda_env() { + local conda_env_name_to_remove="$1" + if [ -z "$conda_env_name_to_remove" ]; then + echo "Warning: No Conda env name provided to remove_conda_env." + return + fi + echo "Deactivating and removing Conda environment: $conda_env_name_to_remove" + conda deactivate 2>/dev/null || true + conda env remove -n "$conda_env_name_to_remove" -y +} + verify_bundle() { for dir in /opt/hostedtoolcache/*; do if [[ $dir != "/opt/hostedtoolcache/Python" ]]; then @@ -106,7 +149,22 @@ verify_bundle() { else include_pre_release="" fi - init_venv + # determine if conda env should be used for the bundle + active_conda_env_for_bundle="" + required_python_version="${bundle_python_versions[$bundle]}" + use_conda_for_bundle=false + if [[ -n "$required_python_version" && "$required_python_version" != "$DEFAULT_PYTHON_VERSION_FOR_VENV" ]] + then + use_conda_for_bundle=true + fi + if $use_conda_for_bundle + then + init_conda_env "$required_python_version" "$bundle" + active_conda_env_for_bundle="conda_env_${bundle}" + conda activate "$active_conda_env_for_bundle" + else + init_venv + fi # Check if the requirements file exists and is not empty if [ -s "$requirements_file" ]; then echo "install required libraries for bundle: $bundle" @@ -114,7 +172,13 @@ verify_bundle() { fi # verify bundle python $(pwd)/ci/verify_bundle.py -b "$bundle" -m "min" # min tests on cpu - remove_venv + # cleanup + if $use_conda_for_bundle + then + remove_conda_env "$active_conda_env_for_bundle" + else + remove_venv + fi fi done else diff --git a/ci/run_premerge_gpu.sh b/ci/run_premerge_gpu.sh index 6aa745b6..5fc85f7c 100755 --- a/ci/run_premerge_gpu.sh +++ b/ci/run_premerge_gpu.sh @@ -29,22 +29,28 @@ if [[ $# -gt 1 ]]; then exit 1 fi +declare -A bundle_python_versions=( + ["retinalOCT_RPD_segmentation"]="3.9" +) +DEFAULT_PYTHON_VERSION_FOR_VENV="3.10" + +install_common_deps_in_activated_env() { + python -m pip install --upgrade pip wheel + python -m pip install --upgrade setuptools + python -m pip install jsonschema gdown pyyaml parameterized fire + export PYTHONPATH=$PWD +} + init_venv() { if [ ! -d "model_zoo_venv" ]; then # Check if the venv directory does not exist echo "initializing pip environment" python -m venv model_zoo_venv source model_zoo_venv/bin/activate - pip install --upgrade pip wheel - pip install --upgrade setuptools - pip install jsonschema gdown pyyaml parameterized fire - export PYTHONPATH=$PWD + install_common_deps_in_activated_env else echo "Virtual environment model_zoo_venv already exists. Activating..." source model_zoo_venv/bin/activate - pip install --upgrade pip wheel - pip install --upgrade setuptools - pip install jsonschema gdown pyyaml parameterized fire - export PYTHONPATH=$PWD + install_common_deps_in_activated_env fi } @@ -58,6 +64,42 @@ remove_venv() { fi } +init_conda_env() { + local python_version_to_create="$1" + local bundle_identifier="$2" + local conda_env_name="conda_env_${bundle_identifier}" + + # Always source conda.sh to ensure conda activate is available + if [ -n "$CONDA_EXE" ] && [ -f "$(dirname "$CONDA_EXE")/../etc/profile.d/conda.sh" ]; then + source "$(dirname "$CONDA_EXE")/../etc/profile.d/conda.sh" + elif [ -n "$MINICONDA_PATH_0" ] && [ -f "$MINICONDA_PATH_0/etc/profile.d/conda.sh" ]; then + source "$MINICONDA_PATH_0/etc/profile.d/conda.sh" + else + echo "Warning: Could not reliably source conda.sh for Conda activation." + fi + + if conda env list | grep -q "^${conda_env_name}[[:space:]]"; then + echo "Conda env '$conda_env_name' already exists. Removing for a clean start..." + conda env remove -n "$conda_env_name" -y + fi + + conda create -n "$conda_env_name" python="$python_version_to_create" -y + conda activate "$conda_env_name" + install_common_deps_in_activated_env + conda deactivate 2>/dev/null || true +} + +remove_conda_env() { + local conda_env_name_to_remove="$1" + if [ -z "$conda_env_name_to_remove" ]; then + echo "Warning: No Conda env name provided to remove_conda_env." + return + fi + echo "Deactivating and removing Conda environment: $conda_env_name_to_remove" + conda deactivate 2>/dev/null || true + conda env remove -n "$conda_env_name_to_remove" -y +} + verify_bundle() { echo 'Run verify bundle...' head_ref=$(git rev-parse HEAD) @@ -82,7 +124,22 @@ verify_bundle() { else include_pre_release="" fi - init_venv + # determine if conda env should be used for the bundle + active_conda_env_for_bundle="" + required_python_version="${bundle_python_versions[$bundle]}" + use_conda_for_bundle=false + if [[ -n "$required_python_version" && "$required_python_version" != "$DEFAULT_PYTHON_VERSION_FOR_VENV" ]] + then + use_conda_for_bundle=true + fi + if $use_conda_for_bundle + then + init_conda_env "$required_python_version" "$bundle" + active_conda_env_for_bundle="conda_env_${bundle}" + conda activate "$active_conda_env_for_bundle" + else + init_venv + fi # Check if the requirements file exists and is not empty if [ -s "$requirements_file" ]; then echo "install required libraries for bundle: $bundle" @@ -101,7 +158,13 @@ verify_bundle() { test_cmd="torchrun $(pwd)/ci/unit_tests/runner.py --b \"$bundle\" --dist True" fi eval $test_cmd - remove_venv + # cleanup + if $use_conda_for_bundle + then + remove_conda_env "$active_conda_env_for_bundle" + else + remove_venv + fi done else echo "this pull request does not change any bundles, skip verify." diff --git a/ci/run_regular_tests_cpu.sh b/ci/run_regular_tests_cpu.sh index 253cc86c..52936c0e 100755 --- a/ci/run_regular_tests_cpu.sh +++ b/ci/run_regular_tests_cpu.sh @@ -30,8 +30,7 @@ elif [[ $# -gt 1 ]]; then exit 1 fi -# Usually, CPU test is required, but for some bundles that are too large to run in Github Actions, we can exclude them. -exclude_test_list=("maisi_ct_generative") +exclude_test_list=("maisi_ct_generative" "retinalOCT_RPD_segmentation") is_excluded() { for item in "${exclude_test_list[@]}"; do # Use exclude_test_list here if [ "$1" == "$item" ]; then diff --git a/ci/unit_tests/test_retinalOCT_RPD_segmentation.py b/ci/unit_tests/test_retinalOCT_RPD_segmentation.py new file mode 100644 index 00000000..f245b0fe --- /dev/null +++ b/ci/unit_tests/test_retinalOCT_RPD_segmentation.py @@ -0,0 +1,118 @@ +import glob +import json +import os +import shutil +import subprocess +import sys +import tempfile +import unittest + +import pandas as pd +import yaml + + +class TestRPDInference(unittest.TestCase): + def setUp(self): + print(os.getcwd()) + # set the bundle root to the directory the test is being run from. + self.bundle_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "models", "retinalOCT_RPD_segmentation") + ) + # Change the current working directory to bundle_root + os.chdir(self.bundle_root) + + # Create a temporary directory for test data + self.test_data_dir = tempfile.mkdtemp() + self.extracted_dir = os.path.join(self.bundle_root, "sample_data") + + # create a dummy metadata.json file. + metadata_file = os.path.join(self.test_data_dir, "metadata.json") + metadata = { + "version": "0.0.1", + "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json", + } + with open(metadata_file, "w") as f: + json.dump(metadata, f) + + # create output directory. + self.output_dir = os.path.join(self.test_data_dir, "output") + os.makedirs(self.output_dir) + + def tearDown(self): + # Clean up the temporary directory + shutil.rmtree(self.test_data_dir) + + def test_inference_run(self): + # Override configuration parameters + override = { + "args": { + "extracted_dir": self.extracted_dir, + "output_dir": self.output_dir, + "run_extract": False, + "create_dataset": True, + "run_inference": True, + "binary_mask": True, + "binary_mask_overlay": True, + "instance_mask_overlay": True, + "dataset_name": "testDataset", + } + } + + # Load the original inference.yaml + inference_yaml_path = "configs/inference.yaml" + with open(inference_yaml_path, "r") as f: + inference_yaml = yaml.safe_load(f) + + # Modify inference.yaml with override parameters. + inference_yaml["args"].update(override["args"]) + + # Create a new inference.yaml in the test_data_dir + test_inference_yaml_path = os.path.join(self.test_data_dir, "inference.yaml") + with open(test_inference_yaml_path, "w") as f: + yaml.dump(inference_yaml, f) + + # Run the inference command using subprocess + cmd = [ + sys.executable, # Use the same Python interpreter + "-m", + "monai.bundle", + "run", + "inference", + "--bundle_root", + self.bundle_root, + "--config_file", + test_inference_yaml_path, # Use the new file + "--meta_file", + os.path.join(self.test_data_dir, "metadata.json"), + ] + + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError as e: + self.fail(f"Inference command failed: {e}") + + # Add assertions to check the output + # Check if output files were created + output_files = os.listdir(self.output_dir) + self.assertTrue(len(output_files) > 0) + + # Check for the COCO JSON file + coco_file_found = glob.glob(os.path.join(self.output_dir, "**", "coco_instances_results.json"), recursive=True) + print(coco_file_found) + self.assertTrue(len(coco_file_found) == 6) + + # Check for the TIFF files. + tiff_files_found = glob.glob(os.path.join(self.output_dir, "**", "*.tiff"), recursive=True) + self.assertTrue(len(tiff_files_found) == 6) + + # Check for the html files. + html_files_found = glob.glob(os.path.join(self.output_dir, "*.html")) + self.assertTrue(len(html_files_found) == 2) + + # At least 10 RPD present in sample data + dfvol = pd.read_html(os.path.join(self.output_dir, "dfvol_testDataset.html"))[0] + self.assertTrue(dfvol["dt_instances"].sum().iloc[0] > 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/models/retinalOCT_RPD_segmentation/LICENSE b/models/retinalOCT_RPD_segmentation/LICENSE new file mode 100644 index 00000000..dc32dc5a --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/LICENSE @@ -0,0 +1,25 @@ +BSD 2-Clause License + +Copyright (c) 2022, uw-biomedical-ml +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/models/retinalOCT_RPD_segmentation/configs/inference.yaml b/models/retinalOCT_RPD_segmentation/configs/inference.yaml new file mode 100644 index 00000000..fd148017 --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/configs/inference.yaml @@ -0,0 +1,23 @@ +imports: +- $import scripts +- $import scripts.inference + +args: + run_extract : False + input_dir : "/path/to/data" + extracted_dir : "/path/to/extracted/data" + input_format : "dicom" + create_dataset : True + dataset_name : "my_dataset_name" + + output_dir : "/path/to/model/output" + run_inference : True + create_tables : True + +# create visuals + binary_mask : False + binary_mask_overlay : True + instance_mask_overlay : False + +inference: +- $scripts.inference.main(@args) diff --git a/models/retinalOCT_RPD_segmentation/configs/metadata.json b/models/retinalOCT_RPD_segmentation/configs/metadata.json new file mode 100644 index 00000000..8b1c0f19 --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/configs/metadata.json @@ -0,0 +1,146 @@ +{ + "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json", + "version": "0.0.1", + "changelog": { + "0.0.1": "Initial version" + }, + "monai_version": "1.5.0", + "pytorch_version": "2.6.0", + "numpy_version": "1.26.4", + "optional_packages_version": {}, + "required_packages_version": { + "setuptools": "75.8.0", + "opencv-python-headless": "4.11.0.86", + "pandas": "2.3.0", + "seaborn": "0.13.2", + "scikit-learn": "1.6.1", + "progressbar": "2.5", + "pydicom": "3.0.1", + "fire": "0.7.0", + "torchvision": "0.21.0", + "detectron2": "0.6", + "lxml": "5.4.0", + "pillow": "11.2.1" + }, + "name": "retinalOCT_RPD_segmentation", + "task": "Reticular Pseudodrusen (RPD) instance segmentation.", + "description": "This network detects and segments Reticular Pseudodrusen (RPD) instances in Optical Coherence Tomography (OCT) B-scans which can be presented in a vol or dicom format.", + "authors": "Yelena Bagdasarova, Scott Song", + "copyright": "Copyright (c) 2022, uw-biomedical-ml", + "network_data_format": { + "inputs": { + "image": { + "type": "image", + "format": "magnitude", + "modality": "OCT", + "num_channels": 1, + "spatial_shape": [ + 496, + 1024 + ], + "dtype": "int16", + "value_range": [ + 0, + 256 + ], + "is_patch_data": false, + "channel_def": { + "0": "image" + } + } + }, + "preprocessed_data_sources": { + "vol_file": { + "type": "image", + "format": "magnitude", + "modality": "OCT", + "num_channels": 1, + "spatial_shape": [ + 496, + 1024, + "D" + ], + "dtype": "int16", + "value_range": [ + 0, + 256 + ], + "description": "The pixel array of each OCT slice is extracted with volreader and the png files saved to ////_oct_.png on disk, where is the slice number and a nested hierarchy of folders is created using the underscores in the original filename. " + }, + "dicom_series": { + "type": "image", + "format": "magnitude", + "modality": "OCT", + "SOP class UID": "1.2.840.10008.5.1.4.1.1.77.1.5.4", + "num_channels": 1, + "spatial_shape": [ + 496, + 1024, + "D" + ], + "dtype": "int16", + "value_range": [ + 0, + 256 + ], + "description": "The pixel array of each OCT slice is extracted with pydicom and the png files saved to //_oct_.png on disk, where is the slice number. " + } + }, + "outputs": { + "pred": { + "dtype": "dictionary", + "type": "dictionary", + "format": "COCO", + "modality": "n/a", + "value_range": [ + 0, + 1 + ], + "num_channels": 1, + "spatial_shape": [ + 496, + 1024 + ], + "channel_def": { + "0": "RPD" + }, + "description": "This output is a JSON file in COCO Instance Segmentation format, containing bounding boxes, segmentation masks, and output probabilities for detected instances." + } + }, + "post_processed_outputs": { + "binary segmentation": { + "type": "image", + "format": "TIFF", + "modality": "OCT", + "num_channels": 3, + "spatial_shape": [ + 496, + 1024 + ], + "description": "This output is a multi-page TIFF file. Each page of the TIFF image corresponds to a binary segmentation mask for a single OCT slice from the input volume. The segmentation masks are stacked in the same order as the original OCT slices." + }, + "binary segmentation overlay": { + "type": "image", + "format": "TIFF", + "modality": "OCT", + "num_channels": 3, + "spatial_shape": [ + 496, + 1024 + ], + "description": "This output is a multi-page TIFF file. Each page of the TIFF image corresponds to a single OCT slice from the input volume overlayed with the detected binary segmentation mask." + }, + "instance segmentation overlay": { + "type": "image", + "format": "TIFF", + "modality": "OCT", + "num_channels": 3, + "spatial_shape": [ + 496, + 1024 + ], + "description": "This output is a multi-page TIFF file. Each page of the TIFF image corresponds to a single OCT slice from the input volume overlayed with the detected binary segmentation mask." + } + } + } +} diff --git a/models/retinalOCT_RPD_segmentation/docs/Figure1.jpg b/models/retinalOCT_RPD_segmentation/docs/Figure1.jpg new file mode 100644 index 00000000..c3964f0e Binary files /dev/null and b/models/retinalOCT_RPD_segmentation/docs/Figure1.jpg differ diff --git a/models/retinalOCT_RPD_segmentation/docs/README.md b/models/retinalOCT_RPD_segmentation/docs/README.md new file mode 100644 index 00000000..39b78ce0 --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/docs/README.md @@ -0,0 +1,159 @@ + +# RPD OCT Segmentation +### **Authors** +Himeesh Kumar, Yelena Bagdasarova, Scott Song, Doron G. Hickey, Amy C. Cohn, Mali Okada, Robert P. Finger, Jan H. Terheyden, Ruth E. Hogg, Pierre-Henry Gabrielle, Louis Arnould, Maxime Jannaud, Xavier Hadoux, Peter van Wijngaarden, Carla J. Abbott, Lauren A.B. Hodgson, Roy Schwartz, Adnan Tufail, Emily Y. Chew, Cecilia S. Lee, Erica L. Fletcher, Melanie Bahlo, Brendan R.E. Ansell, Alice Pébay, Robyn H. Guymer, Aaron Y. Lee, Zhichao Wu + +### **Tags** +Reticular Pseudodrusen, AMD, OCT, Segmentation + +## **Model Description** +This model detects and segments Reticular Pseudodrusen (RPD) instances in Optical Coherence Tomography (OCT) B-scans. The instance segmentation model used a Mask-RCNN [1] head with the ResNeXt-101-32x8d-FPN [2] backbone (pretrained on ImageNet) implemented via the Detectron2 framework [3]. The model produces outputs that consist of bounding boxes and segmentation masks that delineate the coordinates and pixels of each instance detected, which are assigned a corresponding output probability. A tuneable probability threshold can then be applied to finalise the binary detection of an RPD instance. + +Five segmentation models using these RPD instance labels on the OCT B-scans were trained based on five-fold cross-validation which were used to form a final ensemble model using soft voting (see supplementary material of paper for more information on model training.) + +## **Data** +The model was trained using the prospectively-collected, baseline OCT scans (prior to any treatments) of individuals enrolled in the LEAD study [4] imaged using Heidelberg Spectralis HRA+OCT. OCT B-scans from 200 eyes from 100 individuals in the LEAD study were randomly selected to undergo manual annotations of RPD by a single grader (HK) at the pixel level, following training from two senior investigators (RHG and ZW). Only definite RPD lesions, defined as subretinal hyperreflective accumulations that altered the contour of, or broke through, the overlying photoreceptor ellipsoid zone on the OCT B-scans were annotated. + +The model was then internally tested in a different set of OCT scans from 125 eyes from 92 individuals from the LEAD study, and externally tested on five independent datasets: the MACUSTAR study [5], the Northern Ireland Cohort for Longitudinal Study of Ageing (NICOLA) study [6], the Montrachet study [7], AMD observational studies at the University of Bonn, Germany (UB), and a routine clinical care cohort seen at the University of Washington (UW). The presence of RPD was graded either as part of each study (MACUSTAR and UB datasets) or graded by one of the study investigators (HK; in the NICOLA, UW, and Montrachet datasets). All these studies defined RPD based on the presence of five or more definite lesions on more than one OCT B-scan that corresponded to hyporeflective lesions seen on near-infrared reflectance imaging. + +#### **Preprocessing** +Scans were kept at native resolution (1024 x 496 pixels). + +## **Performance** +In the external test datasets, the overall performance for detecting RPD in a volume scan was (AUC = 0·94; 95% CI = 0·92–0·97). In the internal test dataset, the Dice coefficient (DSC) between the model and manual annotations by retinal specialists for each B-scan was caculated and the average over the dataset is listed in the table below. Note that the DSC was assigned a value of 1·0 to all pairwise comparisons where no pixels on a B-scan were labelled as having RPD. + +![](Table2.gif) + +![](Figure1.jpg) + +For more details regarding evaluation results, please see Results section of paper. + + +## INSTALL +This bundle has been installed and tested using python 3.10. From the bundle directory, install the required packages using +``` +pip install -r ./docs/requirements.txt +``` +Install detectron2 using +``` +python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' +``` + +## USAGE +The expected image data is in PNG format at the scan level, VOL format at the volume level, or DICOM format at the volume level. To run inference, modify the parameters of the inference.yaml config file in the configs folder which looks like: + +``` +imports: +- $import scripts +- $import scripts.inference + +args: + run_extract : False + input_dir : "/path/to/data" + extracted_dir : "/path/to/extracted/data" + input_format : "dicom" + create_dataset : True + dataset_name : "my_dataset_name" + + output_dir : "/path/to/model/output" + run_inference : True + create_tables : True + +# create visuals + binary_mask : False + binary_mask_overlay : True + instance_mask_overlay : False + +inference: +- $scripts.inference.main(@args) +``` +Then in your bash shell run +``` +BUNDLE="/path/to/budle/RPDBundle" + +python -m monai.bundle run inference \ + --bundle_root "$BUNDLE" \ + --config_file "$BUNDLE/configs/inference.yaml" \ + --meta_file "$BUNDLE/configs/metadata.json" +``` +### VOL/DICOM EXTRACTION +If extracting DICOM or VOL files: +* set `run_extract` to `True` +* specify `input_dir`, the path to the directory that contains the VOL or DICOM files +* specify `extracted_dir`, the path to the directory where extracted images will be stored +* set `input_format` to "dicom" or "vol" + +The VOL or DICOM files can be in a nested hierarchy of folders, and all files in that directory with a VOL or DICOM extension will be extracted. + +For DICOM files, each OCT slice will be saved as a png file to `//_oct_.png` on disk, where `` is the slice number. + +For VOL files, each OCT slice will be saved as a png file to `////_oct_.png` on disk, where `` is the slice number and a nested hierarchy of folders is created using the underscores in the original filename. " + +### DATASET PACKAGING +Once you have the scans in PNG format, you can create a "dataset" in Detectron2 dictionary format for model consumption: +* specify `extracted_dir`, the path to the directory where the PNG files are stored +* set `create_dataset` to `True` +* set `dataset_name` to the chosen name of your dataset + +The summary tables and visual output is organized around OCT volumes, so please make sure that the basename of the PNG files looks like `_.` The dataset dictionary will be saved as pickle file in `////RPDBundle/datasets/.pk` + +### INFERENCE +To run inference on your dataset: +* set `dataset_name` to the name of your dataset which you create with the previous step and resides in `////RPDBundle/datasets/.pk` +* set `output_dir`, the path to the directory where model predictions and other data will be stored. +* set `run_inference` to `True` + +The final ensembled predictions will be saved in COCO Instance Segmentation format in `coco_instances_results.json` in the output directory. The output directory will also be populated with five folders with the preffix 'fold' which contain predictions from the individual models of the ensemble. + +### SUMMARY TABLES and VISUAL OUTPUT +Tables and images can be created from the predictions and written to the output directory. A confidence threshold of 0.5 is applied to the scored predictions by default. To change the threshold, set the `prob_thresh` value between 0.0 and 1.0. + +The tables can be created by setting `create_tables` to `True`: +* HTML table called `dfimg_.html` indexed by OCT-B scan with columns listing the detected number of RPD instances (dt_instances), pixels (dt_pixels), and horizontal pixels (dt_xpxs) in that B-scan. +* HTML table called `dfvol_.html` indexed by OCT volume with columns listing the detected number of RPD instances (dt_instances), pixels (dt_pixels), and horizontal pixels (dt_xpxs) in that volume. + +The predicted segmentations can be output as multi-page TIFFs, where each TIFF file corresponds to an input volume of the dataset, and each page to an OCT slice from the volume in original order. The output images can be binary masks, binary masks overlaying the original B-scan, and instance masks overlaying the original B-scan. Set the `binary_mask`, `binary_mask_overlay` and `instance_mask_overlay` flags in the yaml file to `True` accordingly. + +### SAMPLE DATA +As a reference, sample OCT-B scans are provided in PNG format under the sample_data directory. Set `extracted_dir` in `inference.yaml` to `sample_data` to run inference on these few set of images. + +## **System Configuration** +Inference on one Nvidia A100 gpu takes about 0.041 s/batch of 14 images, about 3G of gpu memory, and 6G of RAM. + +## **Limitations** +This model has not been tested for robustness of performance on OCTs imaged with other devices and with different scan parameters. + +## **Citation Info** + +``` +@article {Kumar2024.09.11.24312817, + author = {Kumar, Himeesh and Bagdasarova, Yelena and Song, Scott and Hickey, Doron G. and Cohn, Amy C. and Okada, Mali and Finger, Robert P. and Terheyden, Jan H. and Hogg, Ruth E. and Gabrielle, Pierre-Henry and Arnould, Louis and Jannaud, Maxime and Hadoux, Xavier and van Wijngaarden, Peter and Abbott, Carla J. and Hodgson, Lauren A.B. and Schwartz, Roy and Tufail, Adnan and Chew, Emily Y. and Lee, Cecilia S. and Fletcher, Erica L. and Bahlo, Melanie and Ansell, Brendan R.E. and P{\'e}bay, Alice and Guymer, Robyn H. and Lee, Aaron Y. and Wu, Zhichao}, + title = {Deep Learning-Based Detection of Reticular Pseudodrusen in Age-Related Macular Degeneration on Optical Coherence Tomography}, + elocation-id = {2024.09.11.24312817}, + year = {2024}, + doi = {10.1101/2024.09.11.24312817}, + publisher = {Cold Spring Harbor Laboratory Press}, + URL = {https://www.medrxiv.org/content/early/2024/09/12/2024.09.11.24312817}, + eprint = {https://www.medrxiv.org/content/early/2024/09/12/2024.09.11.24312817.full.pdf}, + journal = {medRxiv} +} +``` + +## **References** +[1]: He, Kaiming, Georgia Gkioxari, Piotr Dollár, and Ross Girshick. "Mask R-CNN." In Proceedings of the IEEE international conference on computer vision (ICCV), pp. 2961-2969. 2017. + +[2]: Xie, Saining, Ross Girshick, Piotr Dollár, Zhuowen Tu, and Kaiming He. "Aggregated residual transformations for deep neural networks." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1492-1500. 2017. + +[3]: Wu, Yuxin, Alexander Kirillov, Francisco Massa, Wan-Yen Lo, and Ross Girshick. "Detectron2." arXiv preprint arXiv:1902.09615 (2019). + +[4]: Liu X, Faes L, Kale AU, et al. A comparison of deep learning performance against health-care professionals in detecting diseases from medical imaging: a systematic review and meta-analysis. The Lancet Digital Health. 2019;1(6):e271–e97. + +[5]: Finger RP, Schmitz-Valckenberg S, Schmid M, et al. MACUSTAR: Development and Clinical Validation of Functional, Structural, and Patient-Reported Endpoints in Intermediate Age-Related Macular Degeneration. Ophthalmologica. 2019;241(2):61–72. + +[6]: Hogg RE, Wright DM, Quinn NB, et al. Prevalence and risk factors for age-related macular degeneration in a population-based cohort study of older adults in Northern Ireland using multimodal imaging: NICOLA Study. Br J Ophthalmol. 2022:bjophthalmol-2021-320469. + +[7]: Gabrielle P-H, Seydou A, Arnould L, et al. Subretinal Drusenoid Deposits in the Elderly in a Population-Based Study (the Montrachet Study). Invest Ophthalmol Vis Sci. 2019;60(14):4838–48. diff --git a/models/retinalOCT_RPD_segmentation/docs/Table2.gif b/models/retinalOCT_RPD_segmentation/docs/Table2.gif new file mode 100644 index 00000000..d69eead0 Binary files /dev/null and b/models/retinalOCT_RPD_segmentation/docs/Table2.gif differ diff --git a/models/retinalOCT_RPD_segmentation/docs/requirements.txt b/models/retinalOCT_RPD_segmentation/docs/requirements.txt new file mode 100644 index 00000000..1520ff54 --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/docs/requirements.txt @@ -0,0 +1,14 @@ +setuptools==75.8.0 +monai==1.5.0 +torch==2.6.0 +numpy==1.26.4 +opencv-python-headless==4.11.0.86 +pandas==2.3.0 +seaborn==0.13.2 +scikit-learn==1.6.1 +progressbar==2.5 +pydicom==3.0.1 +fire==0.7.0 +torchvision==0.21.0 +lxml==5.4.0 +pillow==11.2.1 \ No newline at end of file diff --git a/models/retinalOCT_RPD_segmentation/large_files.yaml b/models/retinalOCT_RPD_segmentation/large_files.yaml new file mode 100644 index 00000000..35eb4c99 --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/large_files.yaml @@ -0,0 +1,25 @@ +large_files: + - path: "models/fold1_model_final.pth" + url: "https://s3.us-west-2.amazonaws.com/comp.ophthalmology.uw.edu/fold1_model_final.pth" + hash_val: "" + hash_type: "" + + - path: "models/fold2_model_final.pth" + url: "https://s3.us-west-2.amazonaws.com/comp.ophthalmology.uw.edu/fold2_model_final.pth" + hash_val: "" + hash_type: "" + + - path: "models/fold3_model_final.pth" + url: "https://s3.us-west-2.amazonaws.com/comp.ophthalmology.uw.edu/fold3_model_final.pth" + hash_val: "" + hash_type: "" + + - path: "models/fold4_model_final.pth" + url: "https://s3.us-west-2.amazonaws.com/comp.ophthalmology.uw.edu/fold4_model_final.pth" + hash_val: "" + hash_type: "" + + - path: "models/fold5_model_final.pth" + url: "https://s3.us-west-2.amazonaws.com/comp.ophthalmology.uw.edu/fold5_model_final.pth" + hash_val: "" + hash_type: "" diff --git a/models/retinalOCT_RPD_segmentation/sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_026.png b/models/retinalOCT_RPD_segmentation/sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_026.png new file mode 100644 index 00000000..25bdb40f Binary files /dev/null and b/models/retinalOCT_RPD_segmentation/sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_026.png differ diff --git a/models/retinalOCT_RPD_segmentation/sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_051.png b/models/retinalOCT_RPD_segmentation/sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_051.png new file mode 100644 index 00000000..d649eaa4 Binary files /dev/null and b/models/retinalOCT_RPD_segmentation/sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_051.png differ diff --git a/models/retinalOCT_RPD_segmentation/sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_060.png b/models/retinalOCT_RPD_segmentation/sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_060.png new file mode 100644 index 00000000..dfca1fd4 Binary files /dev/null and b/models/retinalOCT_RPD_segmentation/sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_060.png differ diff --git a/models/retinalOCT_RPD_segmentation/sample_data/8c85a17e87eef485a975566dab6b54cafbabf1e4c558ab7b7637b88d962920af_oct_027.png b/models/retinalOCT_RPD_segmentation/sample_data/8c85a17e87eef485a975566dab6b54cafbabf1e4c558ab7b7637b88d962920af_oct_027.png new file mode 100644 index 00000000..e0de9a3e Binary files /dev/null and b/models/retinalOCT_RPD_segmentation/sample_data/8c85a17e87eef485a975566dab6b54cafbabf1e4c558ab7b7637b88d962920af_oct_027.png differ diff --git a/models/retinalOCT_RPD_segmentation/sample_data/8c85a17e87eef485a975566dab6b54cafbabf1e4c558ab7b7637b88d962920af_oct_033.png b/models/retinalOCT_RPD_segmentation/sample_data/8c85a17e87eef485a975566dab6b54cafbabf1e4c558ab7b7637b88d962920af_oct_033.png new file mode 100644 index 00000000..79cbddce Binary files /dev/null and b/models/retinalOCT_RPD_segmentation/sample_data/8c85a17e87eef485a975566dab6b54cafbabf1e4c558ab7b7637b88d962920af_oct_033.png differ diff --git a/models/retinalOCT_RPD_segmentation/scripts/Base-RCNN-FPN.yaml b/models/retinalOCT_RPD_segmentation/scripts/Base-RCNN-FPN.yaml new file mode 100644 index 00000000..7ee168fe --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/scripts/Base-RCNN-FPN.yaml @@ -0,0 +1,41 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + NAME: "build_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map + ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) + RPN: + IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] + PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level + PRE_NMS_TOPK_TEST: 1000 # Per FPN level + # Detectron1 uses 2000 proposals per-batch, + # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) + # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. + POST_NMS_TOPK_TRAIN: 1000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "StandardROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 +SOLVER: + IMS_PER_BATCH: 14 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TRAIN: (496,) +VERSION: 2 +DATALOADER: + FILTER_EMPTY_ANNOTATIONS: False diff --git a/models/retinalOCT_RPD_segmentation/scripts/Ensembler.py b/models/retinalOCT_RPD_segmentation/scripts/Ensembler.py new file mode 100644 index 00000000..1c225083 --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/scripts/Ensembler.py @@ -0,0 +1,198 @@ +import json +import os + +import numpy as np +import pandas as pd +import torch +from pycocotools.coco import COCO +from torchvision.ops.boxes import box_convert, box_iou +from tqdm import tqdm + + +class NpEncoder(json.JSONEncoder): + """Custom JSON encoder for NumPy data types. + + This encoder handles NumPy-specific types that are not serializable by + the default JSON library by converting them into standard Python types. + """ + + def default(self, obj): + """Converts NumPy objects to their native Python equivalents. + + Args: + obj (any): The object to encode. + + Returns: + any: The JSON-serializable representation of the object. + """ + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return super(NpEncoder, self).default(obj) + + +class Ensembler: + """A class to ensemble predictions from multiple object detection models. + + This class loads ground truth data and predictions from several models, + performs non-maximum suppression (NMS) to merge overlapping detections, + and saves the final ensembled results in COCO format. + """ + + def __init__( + self, output_dir, dataset_name, grplist, iou_thresh, coco_gt_path=None, coco_instances_results_fname=None + ): + """Initializes the Ensembler. + + Args: + output_dir (str): The base directory where model outputs and + ensembled results are stored. + dataset_name (str): The name of the dataset being evaluated. + grplist (list[str]): A list of subdirectory names, where each + subdirectory contains the prediction file from one model. + iou_thresh (float): The IoU threshold for considering two bounding + boxes as overlapping during NMS. + coco_gt_path (str, optional): The full path to the ground truth + COCO JSON file. If None, it's assumed to be in `output_dir`. + Defaults to None. + coco_instances_results_fname (str, optional): The filename for the + COCO prediction files within each model's subdirectory. + Defaults to "coco_instances_results.json". + """ + self.output_dir = output_dir + self.dataset_name = dataset_name + self.grplist = grplist + self.iou_thresh = iou_thresh + self.n_detectors = len(grplist) + + if coco_gt_path is None: + fname_gt = os.path.join(output_dir, dataset_name + "_coco_format.json") + else: + fname_gt = coco_gt_path + + if coco_instances_results_fname is None: + fname_dt = "coco_instances_results.json" + else: + fname_dt = coco_instances_results_fname + + # load in ground truth (form image lists) + coco_gt = COCO(fname_gt) + # populate detector truths + dtlist = [] + for grp in grplist: + fname = os.path.join(output_dir, grp, fname_dt) + dtlist.append(coco_gt.loadRes(fname)) + print("Successfully loaded {} into memory. {} instance detected.\n".format(fname, len(dtlist[-1].anns))) + + self.coco_gt = coco_gt + self.cats = [cat["id"] for cat in self.coco_gt.dataset["categories"]] + self.dtlist = dtlist + self.results = [] + + print( + "Working with {} models, {} categories, and {} images.".format( + self.n_detectors, len(self.cats), len(self.coco_gt.imgs.keys()) + ) + ) + + def mean_score_nms(self): + """Performs non-maximum suppression by merging overlapping boxes. + + This method iterates through all images and categories, merging sets of + overlapping bounding boxes from different detectors based on the IoU + threshold. For each merged set, it calculates a mean score and selects + the single box with the highest original score as the representative + detection for the ensembled output. + + Returns: + Ensembler: The instance itself, with the `self.results` attribute + populated with the ensembled predictions. + """ + + def nik_merge(lsts): + """Niklas B. https://github.com/rikpg/IntersectionMerge/blob/master/core.py""" + sets = [set(lst) for lst in lsts if lst] + merged = 1 + while merged: + merged = 0 + results = [] + while sets: + common, rest = sets[0], sets[1:] + sets = [] + for x in rest: + if x.isdisjoint(common): + sets.append(x) + else: + merged = 1 + common |= x + results.append(common) + sets = results + return sets + + winning_list = [] + print( + "Computing mean score non-max suppression ensembling for {} images.".format(len(self.coco_gt.imgs.keys())) + ) + for img in tqdm(self.coco_gt.imgs.keys()): + # print(img) + dflist = [] # a dataframe of detections + obj_set = set() # a set of objects (frozensets) + for i, coco_dt in enumerate(self.dtlist): # for each detector append predictions to df + dflist.append(pd.DataFrame(coco_dt.imgToAnns[img]).assign(det=i)) + df = pd.concat(dflist, ignore_index=True) + if not df.empty: + for cat in self.cats: # for each category + dfcat = df[df["category_id"] == cat] + ts = box_convert( + torch.tensor(dfcat["bbox"]), in_fmt="xywh", out_fmt="xyxy" + ) # list of tensor boxes for cateogory + iou_bool = np.array((box_iou(ts, ts) > self.iou_thresh)) # compute IoU matrix and threshold + for i in range(len(dfcat)): # for each detection in that category + fset = frozenset(dfcat.index[iou_bool[i]]) + obj_set.add(fset) # compute set of sets representing objects + # find overlapping sets + + # for fs in obj_set: #for existing sets + # if fs&fset: #check for + # fsnew = fs.union(fset) + # obj_set.remove(fs) + # obj_set.add(fsnew) + obj_set = nik_merge(obj_set) + for s in obj_set: # for each detected objects, find winning box and assign score as mean of scores + dfset = dfcat.loc[list(s)] + mean_score = dfset["score"].sum() / max( + self.n_detectors, len(s) + ) # allows for more detections than detectors + winning_box = dfset.iloc[dfset["score"].argmax()].to_dict() + winning_box["score"] = mean_score + winning_list.append(winning_box) + print("{} resulting instances from NMS".format(len(winning_list))) + self.results = winning_list + return self + + def save_coco_instances(self, fname="coco_instances_results.json"): + """Saves the ensembled prediction results to a JSON file. + + The output file follows the COCO instance format and can be used for + further evaluation. + + Args: + fname (str, optional): The filename for the output JSON file. + Defaults to "coco_instances_results.json". + """ + if self.results: + with open(os.path.join(self.output_dir, fname), "w") as f: + f.write(json.dumps(self.results, cls=NpEncoder)) + f.flush() + + +if __name__ == "__main__": + # Example usage: + # This assumes an 'output' directory with subdirectories 'fold1', 'fold2', etc., + # each containing a 'coco_instances_results.json' file. + ens = Ensembler("dev", ["fold1", "fold2", "fold3", "fold4", "fold5"], 0.2) + ens.mean_score_nms() diff --git a/models/retinalOCT_RPD_segmentation/scripts/__init__.py b/models/retinalOCT_RPD_segmentation/scripts/__init__.py new file mode 100644 index 00000000..79a1bc6b --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/scripts/__init__.py @@ -0,0 +1 @@ +from .inference import main # Import main from inference.py diff --git a/models/retinalOCT_RPD_segmentation/scripts/analysis_lib.py b/models/retinalOCT_RPD_segmentation/scripts/analysis_lib.py new file mode 100644 index 00000000..8e5e622c --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/scripts/analysis_lib.py @@ -0,0 +1,1243 @@ +""" +Utiltites for analyizing and visualizing model segmentations on dataset. +Yelena Bagdasarova, Scott Song +""" + +import json +import os +import pickle +import sys +import warnings + +import cv2 +import detectron2 +import detectron2.utils.comm as comm +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import torch +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.engine import DefaultPredictor +from detectron2.evaluation import COCOEvaluator +from detectron2.utils.visualizer import Visualizer +from matplotlib.backends.backend_pdf import PdfPages +from PIL import Image +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from pycocotools.mask import decode +from sklearn.metrics import average_precision_score, precision_recall_curve +from tqdm import tqdm + +# current_directory = os.getcwd() +# print(current_directory) +plt.style.use("./scripts/ybpres.mplstyle") + + +def grab_dataset(name): + """Creates a function to load a pickled dataset by name. + + This function returns another function that, when called, loads a dataset + from a pickle file located in the "datasets/" directory. + + Args: + name (str): The base name of the dataset file (without extension). + + Returns: + function: A zero-argument function that loads and returns the dataset. + """ + + def f(): + return pickle.load(open("datasets/" + name + ".pk", "rb")) + + return f + + +class OutputVis: + """A class to visualize model outputs and ground truth annotations.""" + + def __init__( + self, + dataset_name, + cfg=None, + prob_thresh=0.5, + pred_mode="model", + pred_file=None, + has_annotations=True, + draw_mode="default", + ): + """Initializes the OutputVis class. + + Args: + dataset_name (str): The name of the registered Detectron2 dataset. + cfg (CfgNode, optional): The Detectron2 configuration object. + Required if `pred_mode` is "model". Defaults to None. + prob_thresh (float, optional): The probability threshold to apply + to model predictions for visualization. Defaults to 0.5. + pred_mode (str, optional): The mode for getting predictions. Must be + either "model" (to use a live predictor) or "file" (to load + from a COCO results file). Defaults to "model". + pred_file (str, optional): The path to the COCO JSON results file. + Required if `pred_mode` is "file". Defaults to None. + has_annotations (bool, optional): Whether the dataset has ground + truth annotations to visualize. Defaults to True. + draw_mode (str, optional): The drawing style for visualizations. + Can be "default" (color) or "bw" (monochrome). Defaults to "default". + """ + self.dataset_name = dataset_name + self.cfg = cfg + self.prob_thresh = prob_thresh + self.data = DatasetCatalog.get(dataset_name) + if pred_mode == "model": + self.predictor = DefaultPredictor(cfg) + self._mode = "model" + elif pred_mode == "file": + with open(pred_file, "r") as f: + self.pred_instances = json.load(f) + self.instance_img_list = [p["image_id"] for p in self.pred_instances] + self._mode = "file" + else: + sys.exit('Invalid mode. Only "model" or "file" permitted.') + self.has_annotations = has_annotations + self.permitted_draw_modes = ["default", "bw"] + self.set_draw_mode(draw_mode) + self.font_size = 16 # 28 for ARVO + self.annotation_color = "r" + self.scale = 3.0 + + def set_draw_mode(self, draw_mode): + """Sets the drawing mode for visualizations. + + Args: + draw_mode (str): The drawing style. Must be one of the permitted + modes (e.g., "default", "bw"). + """ + if draw_mode not in self.permitted_draw_modes: + sys.exit("draw_mode must be one of the following: {}".format(self.permitted_draw_modes)) + self.draw_mode = draw_mode + + def get_ori_image(self, imgid): + """Retrieves the original image for a given image ID. + + The image is scaled up by a factor of 3 for better visualization. + + Args: + imgid (str): The 'image_id' from the dataset dictionary. + + Returns: + PIL.Image: The original image. + """ + dat = self.get_gt_image_data(imgid) # gt + im = cv2.imread(dat["file_name"]) # input to model + v_gt = Visualizer(im, MetadataCatalog.get(self.dataset_name), scale=self.scale) + result_image = v_gt.output.get_image() # get original image + img = Image.fromarray(result_image) + return img + + def get_gt_image_data(self, imgid): + """Returns the ground truth data dictionary for a given image ID. + + Args: + imgid (str): The 'image_id' from the dataset dictionary. + + Returns: + dict: The dataset dictionary for the specified image. + """ + gt_data = next(item for item in self.data if (item["image_id"] == imgid)) + return gt_data + + def produce_gt_image(self, dat, im): + """Creates an image with ground truth annotations overlaid. + + The visualization can be in color or monochrome depending on the draw mode. + + Args: + dat (dict): The dataset dictionary containing ground truth annotations. + im (np.ndarray): The input image in RGB format (H, W, C) as a NumPy array. + + Returns: + PIL.Image: The image with ground truth instances overlaid. + """ + v_gt = Visualizer(im, MetadataCatalog.get(self.dataset_name), scale=self.scale) + if self.has_annotations: # ground truth boxes and masks + segs = [ddict["segmentation"] for ddict in dat["annotations"]] + if self.draw_mode == "bw": + _bboxes = None + assigned_colors = [self.annotation_color] * len(segs) + else: # default behavior + bboxes = [ddict["bbox"] for ddict in dat["annotations"]] + _bboxes = detectron2.structures.Boxes(bboxes) + _bboxes = detectron2.structures.BoxMode.convert( + _bboxes.tensor, from_mode=1, to_mode=0 + ) # 0= XYXY, 1 = XYWH + assigned_colors = None + + result_image = v_gt.overlay_instances( + boxes=_bboxes, masks=segs, assigned_colors=assigned_colors, alpha=1.0 + ).get_image() + else: + result_image = v_gt.output.get_image() # get original image if no annotations + img = Image.fromarray(result_image) + return img + + def produce_model_image(self, imgid, dat, im): + """Creates an image with model-predicted instances overlaid. + + Predictions are either generated by the model or loaded from a file, + based on the configured `pred_mode`. + + Args: + imgid (str): The 'image_id' from the dataset dictionary. + dat (dict): The dataset dictionary for the image (used for height/width). + im (np.ndarray): The input image in RGB format (H, W, C) as a NumPy array. + + Returns: + PIL.Image: The image with model-predicted instances overlaid. + """ + v_dt = Visualizer(im, MetadataCatalog.get(self.dataset_name), scale=self.scale) + v_dt._default_font_size = self.font_size + + # get predictions from model or file + if self._mode == "model": + outputs = self.predictor(im)["instances"].to("cpu") + elif self._mode == "file": + outputs = self.get_outputs_from_file(imgid, (dat["height"], dat["width"])) + outputs = outputs[outputs.scores > self.prob_thresh] # apply probability threshold to instances + if self.draw_mode == "bw": + result_model = v_dt.overlay_instances( + masks=outputs.pred_masks, assigned_colors=[self.annotation_color] * len(outputs), alpha=1.0 + ).get_image() + else: # default behavior + result_model = v_dt.draw_instance_predictions(outputs).get_image() + img_model = Image.fromarray(result_model) + return img_model + + def get_image(self, imgid): + """Generates both ground truth and model prediction overlay images. + + Args: + imgid (str): The 'image_id' from the dataset dictionary. + + Returns: + tuple[PIL.Image, PIL.Image]: A tuple containing the ground truth + image and the model prediction image. + """ + dat = self.get_gt_image_data(imgid) # gt + im = cv2.imread(dat["file_name"]) # input to model + img = self.produce_gt_image(dat, im) + img_model = self.produce_model_image(imgid, dat, im) + return img, img_model + + def get_outputs_from_file(self, imgid, imgsize): + """Loads and formats model predictions from a COCO results file. + + Converts COCO-formatted instances into a Detectron2 `Instances` object + suitable for the visualizer. + + Args: + imgid (str): The 'image_id' of the desired image. + imgsize (tuple[int, int]): The (height, width) of the image. + + Returns: + detectron2.structures.Instances: An `Instances` object containing + the predictions. + """ + + pred_boxes = [] + scores = [] + pred_classes = [] + pred_masks = [] + for i, img in enumerate(self.instance_img_list): + if img == imgid: + pred_boxes.append(self.pred_instances[i]["bbox"]) + scores.append(self.pred_instances[i]["score"]) + pred_classes.append(int(self.pred_instances[i]["category_id"])) + # pred_masks_rle.append(self.pred_instances[i]['segmentation']) + pred_masks.append(decode(self.pred_instances[i]["segmentation"])) + _bboxes = detectron2.structures.Boxes(pred_boxes) + pred_boxes = detectron2.structures.BoxMode.convert(_bboxes.tensor, from_mode=1, to_mode=0) # 0= XYXY, 1 = XYWH + inst_dict = dict( + pred_boxes=pred_boxes, + scores=torch.tensor(np.array(scores)), + pred_classes=torch.tensor(np.array(pred_classes)), + pred_masks=torch.tensor(np.array(pred_masks)).to(torch.bool), + ) # pred_masks_rle=pred_masks_rle) + outputs = detectron2.structures.Instances(imgsize, **inst_dict) + return outputs + + @staticmethod + def height_crop_range(im, height_target=256): + """Calculates a vertical crop range centered on the brightest part of an image. + + Args: + im (np.ndarray): The input image as a NumPy array (H, W, C). + height_target (int, optional): The desired height of the crop. + Defaults to 256. + + Returns: + range: A range object representing the start and end pixel rows for the crop. + """ + yhist = im.sum(axis=1) # integrate over width of image + mu = np.average(np.arange(yhist.shape[0]), weights=yhist) + h1 = int(np.floor(mu - height_target / 2)) # inclusive + h2 = int(np.ceil(mu + height_target / 2)) # exclusive + if h1 < 0: + h1 = 0 + h2 = height_target + if h2 > yhist.shape[0]: + h2 = yhist.shape[0] + h1 = h2 - height_target + return range(h1, h2) + + def output_to_pdf(self, imgids, outname, dfimg=None): + """Exports visualizations of ground truth and model predictions to a PDF file. + + Each page of the PDF contains the ground truth and model prediction for one image. + + Args: + imgids (list[str]): A list of 'image_id' values to include in the PDF. + outname (str): The path and filename for the output PDF. + dfimg (pd.DataFrame, optional): A DataFrame with image statistics + to display on each page. Index should be `imgid`. Defaults to None. + """ + + gtstr = "" + dtstr = "" + + if dfimg is not None: + gtcols = dfimg.columns[["gt_" in col for col in dfimg.columns]] + dtcols = dfimg.columns[["dt_" in col for col in dfimg.columns]] + + with PdfPages(outname) as pdf: + for imgid in tqdm(imgids): + img, img_model = self.get_image(imgid) + # pdb.set_trace() + crop_range = self.height_crop_range(np.array(img.convert("L")), height_target=256 * self.scale) + img = np.array(img)[crop_range] + img_model = np.array(img_model)[crop_range] + + fig, ax = plt.subplots(2, 1, figsize=[22, 10], dpi=200) + ax[0].imshow(img) + ax[0].set_title(imgid + " Ground Truth") + ax[0].set_axis_off() + ax[1].imshow(img_model) + ax[1].set_title(imgid + " Model Prediction") + ax[1].set_axis_off() + if dfimg is not None: # annotate with provided stats + gtstr = ["{:s}={:.2f}".format(col, dfimg.loc[imgid, col]) for col in gtcols] + ax[0].text(0, 0.05 * (ax[0].get_ylim()[0]), gtstr, color="white", fontsize=14) + dtstr = ["{:s}={:.2f}".format(col, dfimg.loc[imgid, col]) for col in dtcols] + ax[1].text(0, 0.05 * (ax[1].get_ylim()[0]), dtstr, color="white", fontsize=14) + pdf.savefig(fig) + plt.close(fig) + + def save_imgarr_to_tiff(self, imgs, outname): + """Saves a list of PIL images to a multi-page TIFF file. + + Args: + imgs (list[PIL.Image]): A list of images to save. + outname (str): The path and filename for the output TIFF. + """ + if len(imgs) > 1: + imgs[0].save(outname, dpi=(400, 400), tags="", compression=None, save_all=True, append_images=imgs[1:]) + else: + imgs[0].save(outname) + + def output_ori_to_tiff(self, imgids, outname): + """Saves the original images for a list of IDs to a multi-page TIFF. + + Args: + imgids (list[str]): A list of 'image_id' values. + outname (str): The path and filename for the output TIFF. + """ + imgs = [] + for imgid in tqdm(imgids): + img_ori = self.get_ori_image(imgid) # PIL Image + imgs.append(img_ori) + self.save_imgarr_to_tiff(imgs, outname) + + def output_pred_to_tiff(self, imgids, outname, pred_only=False): + """Saves model prediction overlays for a list of IDs to a multi-page TIFF. + + Args: + imgids (list[str]): A list of 'image_id' values. + outname (str): The path and filename for the output TIFF. + pred_only (bool, optional): If True, overlays predictions on a + black background instead of the original image. Defaults to False. + """ + imgs = self.output_pred_to_list(imgids, pred_only) + self.save_imgarr_to_tiff(imgs, outname) + + def output_pred_to_list(self, imgids, pred_only=False): + """Generates a list of images with model predictions overlaid. + + Args: + imgids (list[str]): A list of 'image_id' values. + pred_only (bool, optional): If True, overlays predictions on a + black background. Defaults to False. + + Returns: + list[PIL.Image]: A list of the generated visualization images. + """ + imgs = [] + for imgid in tqdm(imgids): + dat = self.get_gt_image_data(imgid) # gt + if pred_only: + im = np.zeros((dat["height"], dat["width"], 3)) # blank image for overlay + assert ( + self._mode == "file" + ), 'pred_mode must be "file" when pred_only flage is set to True.' # fix this later + else: + im = cv2.imread(dat["file_name"]) # input to model + img_dt = self.produce_model_image(imgid, dat, im) + imgs.append(img_dt) + return imgs + + def output_all_to_tiff(self, imgids, outname): + """Saves a combined visualization (original, GT, prediction) to a TIFF. + + For each image ID, it creates a single composite image by concatenating + the original, ground truth overlay, and model prediction overlay, then + saves them to a multi-page TIFF. + + Args: + imgids (list[str]): A list of 'image_id' values. + outname (str): The path and filename for the output TIFF. + """ + imgs = [] + for imgid in tqdm(imgids): + img_gt, img_dt = self.get_image(imgid) + img_ori = self.get_ori_image(imgid) + hcrange = list(self.height_crop_range(np.array(img_ori.convert("L")), height_target=256 * self.scale)) + img_result = Image.fromarray( + np.concatenate( + ( + np.array(img_ori.convert("RGB"))[hcrange, :], + np.array(img_gt)[hcrange, :], + np.array(img_dt)[hcrange], + ) + ) + ) + imgs.append(img_result) + self.save_imgarr_to_tiff(imgs, outname) + + def get_enface_dt(self, grp, scan_height, scan_width, scan_spacing): + """Generates an en-face view of model predictions for a scan volume. + + Args: + grp (pd.DataFrame): DataFrame for a single scan volume, indexed by imgid. + scan_height (int): The height of a single scan image in pixels. + scan_width (int): The width of a single scan image in pixels. + scan_spacing (float): The spacing between scan centers in pixels. + + Returns: + np.ndarray: An en-face image of the model predictions. + """ + grp = grp.sort_index() + nscans = len(grp) + enface_height = int(np.ceil((nscans - 1) * scan_spacing)) + enface = np.zeros((enface_height, scan_width, 3), dtype=int) + for i, imgid in enumerate(grp.index): + pos = int(np.clip(np.floor(scan_spacing * i), 0, scan_width - 1)) # vertical enface position + + outputs = self.get_outputs_from_file(imgid, (scan_height, scan_width)) + outputs = outputs[outputs.scores > self.prob_thresh] + instances = outputs.pred_boxes[:, (0, 2)].round().clip(0, scan_width - 1).to(np.int) + + for inst in instances: + try: + enface[max(pos - 4, 0) : min(pos + 4, scan_width - 1), inst[0] : inst[1]] = np.array( + [255, 255, 255] + ) # random_color(rgb = True) + except IndexError: + print(pos, inst[0], inst[1]) + return enface + + def get_enface_gt(self, grp, scan_height, scan_width, scan_spacing): + """Generates an en-face view of ground truth annotations for a scan volume. + + Args: + grp (pd.DataFrame): DataFrame for a single scan volume, indexed by imgid. + scan_height (int): The height of a single scan image in pixels. + scan_width (int): The width of a single scan image in pixels. + scan_spacing (float): The spacing between scan centers in pixels. + + Returns: + np.ndarray: An en-face image of the ground truth annotations. + """ + grp = grp.sort_index() + nscans = len(grp) + enface_height = int(np.ceil((nscans - 1) * scan_spacing)) + enface = np.zeros((enface_height, scan_width, 3), dtype=int) + if not self.has_annotations: + enface[:, :] = np.array([100, 100, 100]) + + else: + # minx = scan_width + for i, imgid in enumerate(grp.index): + pos = int(np.clip(np.floor(scan_spacing * i), 0, scan_width - 1)) + instances = self.get_gt_image_data(imgid)["annotations"] + for inst in instances: + x1 = inst["bbox"][0] + # minx = min(minx,x1) + x2 = x1 + inst["bbox"][2] + try: + enface[max(pos - 4, 0) : min(pos + 4, scan_width - 1), x1:x2] = np.array( + [255, 255, 255] + ) # random_color(rgb = True) + except IndexError: + print(pos, x1, x2) + return enface + + def compare_enface(self, grp, name, scan_height, scan_width, scan_spacing): + """Creates a figure comparing the en-face views of predictions and ground truth. + + Args: + grp (pd.DataFrame): DataFrame for a single scan volume, indexed by imgid. + name (str): The name/ID of the scan volume for the plot title. + scan_height (int): The height of a single scan image in pixels. + scan_width (int): The width of a single scan image in pixels. + scan_spacing (float): The spacing between scan centers in pixels. + + Returns: + tuple[plt.Figure, np.ndarray]: A tuple containing the figure and axes objects. + """ + fig, ax = plt.subplots(1, 2, figsize=[18, 9], dpi=120) + + enface = self.get_enface_dt(grp, scan_height, scan_width, scan_spacing) + ax[0].imshow(enface) + ax[0].set_title(str(name) + " DT") + ax[0].set_aspect("equal") + + enface = self.get_enface_gt(grp, scan_height, scan_width, scan_spacing) + ax[1].imshow(enface) + ax[1].set_title(str(name) + " GT") + ax[1].set_aspect("equal") + return fig, ax + + +def wilson_ci(p, n, z): + """Calculates the Wilson score interval for a binomial proportion. + + Args: + p (float): The observed proportion of successes. + n (int): The total number of trials. + z (float): The z-score for the desired confidence level (e.g., 1.96 for 95%). + + Returns: + tuple[float, float]: A tuple containing the lower and upper bounds of the confidence interval. + """ + if p < 0 or p > 1 or n == 0: + if p < 0 or p > 1: + warnings.warn(f"The value of proportion {p} must be in the range [0,1]. Returning identity for CIs.") + else: + warnings.warn(f"The number of counts {n} must be above zero. Returning identity for CIs.") + return (p, p) + sym = z * (p * (1 - p) / n + z * z / 4 / n / n) ** 0.5 + asym = p + z * z / 2 / n + fact = 1 / (1 + z * z / n) + upper = fact * (asym + sym) + lower = fact * (asym - sym) + return (lower, upper) + + +class EvaluateClass(COCOEvaluator): + """A custom evaluation class extending COCOEvaluator for detailed analysis.""" + + def __init__(self, dataset_name, output_dir, prob_thresh=0.5, iou_thresh=0.1, evalsuper=True): + """Initializes the custom evaluator. + + Args: + dataset_name (str): The name of the registered Detectron2 dataset. + output_dir (str): Directory to store temporary evaluation files. + prob_thresh (float, optional): Probability threshold for calculating + precision, recall, and FPR. Defaults to 0.5. + iou_thresh (float, optional): IoU threshold for defining a true positive. + Defaults to 0.1. + evalsuper (bool, optional): If True, run the parent COCOEvaluator's + evaluate method to generate standard COCO metrics. Defaults to True. + """ + super().__init__(dataset_name, tasks={"bbox", "segm"}, output_dir=output_dir) + self.dataset_name = dataset_name + self.mycoco = None # pycocotools.cocoEval instance + self.cocoDt = None + self.cocoGt = None + self.evalsuper = evalsuper # if True, run COCOEvaluator.evaluate() when self.evaluate is run + self.prob_thresh = prob_thresh # instance probabilty threshold for scalars (precision,recall,fpr for scans) + self.iou_thresh = iou_thresh # iou threshold for defining precision,recall + self.pr = None + self.rc = None + self.fpr = None + + def reset(self): + """Resets the evaluator's state for a new evaluation run.""" + super().reset() + self.mycoco = None + + def process(self, inputs, outputs): + """Processes a batch of inputs and outputs from the model. + + This method is called by the evaluation loop for each batch. + + Args: + inputs (list[dict]): A list of dataset dictionaries. + outputs (list[dict]): A list of model output dictionaries. + """ + super().process(inputs, outputs) + + def evaluate(self): + """Runs the evaluation and calculates detailed performance metrics. + + This method orchestrates the COCO evaluation, calculates precision-recall + curves, and other custom metrics. + + Returns: + tuple[float, float]: The precision and recall at the specified + `prob_thresh` and `iou_thresh`. + """ + if self.evalsuper: + _ = super().evaluate() # this call populates coco_instances_results.json + comm.synchronize() + if not comm.is_main_process(): + return () + self.cocoGt = COCO( + os.path.join(self._output_dir, self.dataset_name + "_coco_format.json") + ) # produced when super is initialized + self.cocoDt = self.cocoGt.loadRes( + os.path.join(self._output_dir, "coco_instances_results.json") + ) # load detector results + self.mycoco = COCOeval(self.cocoGt, self.cocoDt, iouType="segm") + self.num_images = len(self.mycoco.params.imgIds) + print("Calculated metrics for {} images".format(self.num_images)) + self.mycoco.params.iouThrs = np.arange(0.10, 0.6, 0.1) + self.mycoco.params.maxDets = [100] + self.mycoco.params.areaRng = [[0, 10000000000.0]] + + self.mycoco.evaluate() + self.mycoco.accumulate() + + self.pr = self.mycoco.eval["precision"][ + :, :, 0, 0, 0 # iouthresh # recall level # catagory # area range + ] # max detections per image + self.rc = self.mycoco.params.recThrs + self.iou = self.mycoco.params.iouThrs + self.scores = self.mycoco.eval["scores"][:, :, 0, 0, 0] # unreliable if GT has no instances + p, r = self.get_precision_recall() + return p, r + + def plot_pr_curve(self, ax=None): + """Plots precision-recall curves for various IoU thresholds. + + Args: + ax (plt.Axes, optional): A matplotlib axes object to plot on. If None, + a new figure and axes are created. + """ + if ax is None: + fig, ax = plt.subplots(1, 1) + for i in range(len(self.iou)): + ax.plot(self.rc, self.pr[i], label="{:.2}".format(self.iou[i])) + ax.set_xlabel("Recall") + ax.set_ylabel("Precision") + ax.set_title("") + ax.legend(title="IoU") + + def plot_recall_vs_prob(self): + """Plots model score thresholds versus recall for various IoU thresholds.""" + plt.figure() + for i in range(len(self.iou)): + plt.plot(self.rc, self.scores[i], label="{:.2}".format(self.iou[i])) + plt.ylabel("Model probability") + plt.xlabel("Recall") + plt.legend(title="IoU") + + def get_precision_recall(self): + """Gets the precision and recall for the configured IoU and probability thresholds. + + Returns: + tuple[float, float]: The calculated precision and recall. + """ + iou_idx, rc_idx = self._find_iou_rc_inds() + precision = self.pr[iou_idx, rc_idx] + recall = self.rc[rc_idx] + return precision, recall + + def _calculate_fpr_matrix(self): + """(Private) Calculates the false positive rate matrix across all IoU and recall thresholds.""" + + # FP rate, 1 RPD in image = FP + if (self.scores.min() == -1) and (self.scores.max() == -1): + print( + "WARNING: Scores for all iou thresholds and all recall levels are not defined. " + "This can arise if ground truth annotations contain no instances. Leaving fpr matrix as None" + ) + self.fpr = None + return + + fpr = np.zeros((len(self.iou), len(self.rc))) + for i in range(len(self.iou)): + for j, s in enumerate(self.scores[i]): # j -> recall level, s -> corresponding score + ng = 0 # number of negative images + fp = 0 # number of false positives images + for el in self.mycoco.evalImgs: + if el is None: # no predictions, no gts + ng = ng + 1 + elif len(el["gtIds"]) == 0: # some predictions and no gts + ng = ng + 1 + if ( + np.array(el["dtScores"]) > s + ).sum() > 0: # if at least one score over threshold for recall level + fp = fp + 1 # count as FP + else: + continue + fpr[i, j] = fp / ng + self.fpr = fpr + + def _calculate_fpr(self): + """(Private) Calculates FPR for a single probability threshold. + + This is an alternate calculation used when the main FPR matrix cannot + be computed (e.g., no positive ground truth instances). + + Returns: + float: The calculated false positive rate. + """ + print("Using alternate calculation for fpr at instance score threshold of {}".format(self.prob_thresh)) + ng = 0 # number of negative images + fp = 0 # number of false positives images + for el in self.mycoco.evalImgs: + if el is None: # no predictions, no gts + ng = ng + 1 + elif len(el["gtIds"]) == 0: # some predictions and no gts + ng = ng + 1 + if ( + np.array(el["dtScores"]) > self.prob_thresh + ).sum() > 0: # if at least one score over threshold for recall level + fp = fp + 1 # count as FP + else: # gt has instances + continue + return fp / (ng + 1e-5) + + def _find_iou_rc_inds(self): + """(Private) Finds the indices corresponding to the configured IoU and probability thresholds. + + Returns: + tuple[int, int]: The index for the IoU threshold and the index for the recall level. + """ + try: + iou_idx = np.argwhere(self.iou == self.iou_thresh)[0][0] # first instance of + except IndexError: + print( + "iou threshold {} not found in mycoco.params.iouThrs {}".format( + self.iou_thresh, self.mycoco.params.iouThrs + ) + ) + exit(1) + # test above for out of bounds + inds = np.argwhere(self.scores[iou_idx] >= self.prob_thresh) + if len(inds) > 0: + rc_idx = inds[-1][0] # get recall index corresponding to prob_thresh + else: + rc_idx = 0 + return iou_idx, rc_idx + + def get_fpr(self): + """Gets the false positive rate for the configured thresholds. + + Returns: + float: The calculated false positive rate. Returns -1 if it cannot be computed. + """ + if self.fpr is None: + self._calculate_fpr_matrix() + + if self.fpr is not None: + iou_idx, rc_idx = self._find_iou_rc_inds() + fpr = self.fpr[iou_idx, rc_idx] + elif len(self.mycoco.cocoGt.anns) == 0: + fpr = self._calculate_fpr() + else: + fpr = -1 + return fpr + + def summarize_scalars(self): # for pretty printing + """Generates a dictionary summarizing key performance metrics with confidence intervals. + + Returns: + dict: A dictionary containing precision, recall, F1-score, FPR, + and their confidence intervals. + """ + p, r = self.get_precision_recall() + f1 = 2 * (p * r) / (p + r) + fpr = self.get_fpr() + + # Confidence intervals + z = 1.96 # 95% Gaussian + # instance count + inst_cnt = self.count_instances() + n_r = inst_cnt["gt_instances"] + n_p = inst_cnt["dt_instances"] + n_fpr = inst_cnt["gt_neg_scans"] + + def stat_ci(p, n, z): + return z * np.sqrt(p * (1 - p) / n) + + r_ci = wilson_ci(r, n_r, z) + p_ci = wilson_ci(p, n_p, z) + fpr_ci = wilson_ci(fpr, n_fpr, z) + + # propogate errors for f1 + int_r = stat_ci(r, n_r, z) + int_p = stat_ci(p, n_p, z) + int_f1 = (f1) * np.sqrt(int_r**2 * (1 / r - 1 / (p + r)) ** 2 + int_p**2 * (1 / p - 1 / (p + r)) ** 2) + f1_ci = (f1 - int_f1, f1 + int_f1) + + dd = dict( + dataset=self.dataset_name, + precision=float(p), + precision_ci=p_ci, + recall=float(r), + recall_ci=r_ci, + f1=float(f1), + f1_ci=f1_ci, + fpr=float(fpr), + fpr_ci=fpr_ci, + iou=self.iou_thresh, + probability=self.prob_thresh, + ) + return dd + + def count_instances(self): + """Counts ground truth and detected instances across the dataset. + + Returns: + dict: A dictionary with counts for 'gt_instances', 'dt_instances', + and 'gt_neg_scans' (images with no GT instances). + """ + gt_inst = 0 + dt_inst = 0 + gt_neg_scans = 0 + for _, val in self.cocoGt.imgs.items(): + imgid = val["id"] + # Gt instances + annids_gt = self.cocoGt.getAnnIds([imgid]) + anns_gt = self.cocoGt.loadAnns(annids_gt) + gt_inst += len(anns_gt) + if len(anns_gt) == 0: + gt_neg_scans += 1 + + # Dt instances + annids_dt = self.cocoDt.getAnnIds([imgid]) + anns_dt = self.cocoDt.loadAnns(annids_dt) + anns_dt = [ann for ann in anns_dt if ann["score"] > self.prob_thresh] + dt_inst += len(anns_dt) + + return dict(gt_instances=gt_inst, dt_instances=dt_inst, gt_neg_scans=gt_neg_scans) + + +class CreatePlotsRPD: + """A class to create various plots for analyzing RPD (Reticular Pseudodrusen) data.""" + + def __init__(self, dfimg): + """Initializes the plotting class with image-level data. + + Args: + dfimg (pd.DataFrame): A DataFrame where each row corresponds to an + image, containing counts for ground truth and detected instances + and pixels. Must include a 'volID' column. + """ + self.dfimg = dfimg + self.dfvol = self.dfimg.groupby(["volID"])[ + ["gt_instances", "gt_pxs", "gt_xpxs", "dt_instances", "dt_pxs", "dt_xpxs"] + ].sum() + + @classmethod + def initfromcoco(cls, mycoco, prob_thresh): + """Initializes the class from a COCOeval object. + + Args: + mycoco (COCOeval): An evaluated COCOeval object. + prob_thresh (float): The probability threshold to apply to detections. + + Returns: + CreatePlotsRPD: An instance of the class. + """ + df = pd.DataFrame( + index=mycoco.cocoGt.imgs.keys(), + columns=["gt_instances", "gt_pxs", "gt_xpxs", "dt_instances", "dt_pxs", "dt_xpxs"], + dtype=np.uint64, + ) + + for key, val in mycoco.cocoGt.imgs.items(): + imgid = val["id"] + # Gt instances + annids_gt = mycoco.cocoGt.getAnnIds([imgid]) + anns_gt = mycoco.cocoGt.loadAnns(annids_gt) + inst_gt = [mycoco.cocoGt.annToMask(ann).sum() for ann in anns_gt] + xproj_gt = [(mycoco.cocoGt.annToMask(ann).sum(axis=0) > 0).astype("uint8").sum() for ann in anns_gt] + # Dt instances + annids_dt = mycoco.cocoDt.getAnnIds([imgid]) + anns_dt = mycoco.cocoDt.loadAnns(annids_dt) + anns_dt = [ann for ann in anns_dt if ann["score"] > prob_thresh] + inst_dt = [mycoco.cocoDt.annToMask(ann).sum() for ann in anns_dt] + xproj_dt = [(mycoco.cocoDt.annToMask(ann).sum(axis=0) > 0).astype("uint8").sum() for ann in anns_dt] + + dat = [ + len(inst_gt), + np.array(inst_gt).sum(), + np.array(xproj_gt).sum(), + len(inst_dt), + np.array(inst_dt).sum(), + np.array(xproj_dt).sum(), + ] + df.loc[key] = dat + + newdf = pd.DataFrame( + [idx.rsplit(".", 1)[0].rsplit("_", 1) for idx in df.index], columns=["volID", "scan"], index=df.index + ) + df = df.merge(newdf, how="inner", left_index=True, right_index=True) + return cls(df) + + @classmethod + def initfromcsv(cls, fname): + """Initializes the class from a CSV file. + + Args: + fname (str): The path to the CSV file. + + Returns: + CreatePlotsRPD: An instance of the class. + """ + df = pd.read_csv(fname) + return cls(df) + + def get_max_limits(self, df): + """Calculates the maximum values for plotting limits. + + Args: + df (pd.DataFrame): The DataFrame to analyze. + + Returns: + tuple[int, int, int]: Max values for instances, x-pixels, and total pixels. + """ + max_inst = np.max([df.gt_instances.max(), df.dt_instances.max()]) + max_xpxs = np.max([df.gt_xpxs.max(), df.dt_xpxs.max()]) + max_pxs = np.max([df.gt_pxs.max(), df.dt_pxs.max()]) + # print('Max instances:',max_inst) + # print('Max xpxs:',max_xpxs) + # print('Max pxs:',max_pxs) + return max_inst, max_xpxs, max_pxs + + def vol_level_prc(self, df, gt_thresh=5, ax=None): + """Plots a volume-level precision-recall curve. + + Args: + df (pd.DataFrame): DataFrame with volume-level statistics. + gt_thresh (int, optional): The minimum number of ground truth + instances for a volume to be considered positive. Defaults to 5. + ax (plt.Axes, optional): Axes to plot on. Defaults to None. + + Returns: + tuple[float, tuple]: The average precision and the PR curve data. + """ + prc = precision_recall_curve(df.gt_instances >= gt_thresh, df.dt_instances) + if ax is None: + fig, ax = plt.subplots(1, 1) + ax.plot(prc[1], prc[0]) + ax.set_xlabel("RPD Volume Recall") + ax.set_ylabel("RPD Volume Precision") + + ap = average_precision_score(df.gt_instances >= gt_thresh, df.dt_instances) + return ap, prc + + def plot_img_level_instance_thresholding(self, df, inst): + """Plots P/R/FPR as a function of the instance count threshold. + + Args: + df (pd.DataFrame): DataFrame with image-level statistics. + inst (list[int]): A list of instance count thresholds to evaluate. + + Returns: + tuple[np.ndarray, np.ndarray, np.ndarray]: Arrays for precision, + recall, and FPR at each threshold. + """ + rc = np.zeros((len(inst),)) + pr = np.zeros((len(inst),)) + fpr = np.zeros((len(inst),)) + + fig, ax = plt.subplots(1, 3, figsize=[15, 5]) + for i, dt_thresh in enumerate(inst): + gt = df.gt_instances > dt_thresh + dt = df.dt_instances > dt_thresh + rc[i] = (gt & dt).sum() / gt.sum() + pr[i] = (gt & dt).sum() / dt.sum() + fpr[i] = ((~gt) & (dt)).sum() / ((~gt).sum()) + + ax[1].plot(inst, pr) + ax[1].set_ylim(0.45, 1.01) + ax[1].set_xlabel("instance threshold") + ax[1].set_ylabel("Precision") + + ax[0].plot(inst, rc) + ax[0].set_ylim(0.45, 1.01) + ax[0].set_ylabel("Recall") + ax[0].set_xlabel("instance threshold") + + ax[2].plot(inst, fpr) + ax[2].set_ylim(0, 0.80) + ax[2].set_xlabel("instance threshold") + ax[2].set_ylabel("FPR") + + plt.tight_layout() + return pr, rc, fpr + + def plot_img_level_instance_thresholding2(self, df, inst, gt_thresh, plot=True): + """Plots P/R/FPR vs. instance threshold with confidence intervals. + + Args: + df (pd.DataFrame): DataFrame with image-level statistics. + inst (list[int]): A list of instance count thresholds to evaluate. + gt_thresh (int): The ground truth instance threshold. + plot (bool, optional): Whether to generate a plot. Defaults to True. + + Returns: + dict: A dictionary containing arrays for P/R/FPR and their CIs. + """ + + rc = np.zeros((len(inst),)) + pr = np.zeros((len(inst),)) + fpr = np.zeros((len(inst),)) + rc_ci = np.zeros((len(inst), 2)) + pr_ci = np.zeros((len(inst), 2)) + fpr_ci = np.zeros((len(inst), 2)) + + for i, dt_thresh in enumerate(inst): + gt = df.gt_instances >= gt_thresh + dt = df.dt_instances >= dt_thresh + rc[i] = (gt & dt).sum() / gt.sum() + pr[i] = (gt & dt).sum() / dt.sum() + fpr[i] = ((~gt) & (dt)).sum() / ((~gt).sum()) + rc_ci[i, :] = wilson_ci(rc[i], gt.sum(), 1.96) + pr_ci[i, :] = wilson_ci(pr[i], dt.sum(), 1.96) + fpr_ci[i, :] = wilson_ci(fpr[i], ((~gt).sum()), 1.96) + + if plot: + fig, ax = plt.subplots(1, 3, figsize=[15, 5]) + # ax[0].plot(rc,pr) + # ax[0].set_xlabel('Recall') + # ax[0].set_ylabel('Precision') + + ax[1].plot(inst, pr) + ax[1].fill_between(inst, pr_ci[:, 0], pr_ci[:, 1], alpha=0.25) + # ax[1].set_ylim(0.45,1.01) + ax[1].set_xlabel("instance threshold") + ax[1].set_ylabel("Precision") + + ax[0].plot(inst, rc) + ax[0].fill_between(inst, rc_ci[:, 0], rc_ci[:, 1], alpha=0.25) + # ax[0].set_ylim(0.45,1.01) + ax[0].set_ylabel("Recall") + ax[0].set_xlabel("instance threshold") + + ax[2].plot(inst, fpr) + ax[2].fill_between(inst, fpr_ci[:, 0], fpr_ci[:, 1], alpha=0.25) + # ax[2].set_ylim(0,0.80) + ax[2].set_xlabel("instance threshold") + ax[2].set_ylabel("FPR") + + plt.tight_layout() + return dict(precision=pr, precision_ci=pr_ci, recall=rc, recall_ci=rc_ci, fpr=fpr, fpr_ci=fpr_ci) + + def gt_vs_dt_instances(self, ax=None): + """Plots mean detected instances vs. ground truth instances with error bars. + + Args: + ax (plt.Axes, optional): Axes to plot on. Defaults to None. + + Returns: + plt.Axes: The axes object with the plot. + """ + df = self.dfimg + max_inst, max_xpxs, max_pxs = self.get_max_limits(df) + idx = (df.gt_instances > 0) & (df.dt_instances > 0) + + if ax is None: + fig = plt.figure(dpi=100) + ax = fig.add_subplot(111) + + y = df[idx].groupby("gt_instances")["dt_instances"].mean() + yerr = df[idx].groupby("gt_instances")["dt_instances"].std() + ax.errorbar(y.index, y.values, yerr.values, fmt="*") + plt.plot([0, max_inst], [0, max_inst], alpha=0.5) + plt.xlim(0, max_inst + 1) + plt.ylim(0, max_inst + 1) + ax.set_aspect(1) + plt.xlabel("gt_instances") + plt.ylabel("dt_instances") + plt.tight_layout() + return ax + + def gt_vs_dt_instances_boxplot(self, ax=None): + """Creates a boxplot of detected instances for each ground truth instance count. + + Args: + ax (plt.Axes, optional): Axes to plot on. Defaults to None. + + Returns: + plt.Axes: The axes object with the plot. + """ + df = self.dfimg + max_inst, max_xpxs, max_pxs = self.get_max_limits(df) + max_inst = int(max_inst) + if ax is None: + fig = plt.figure(dpi=100) + ax = fig.add_subplot(111) + + ax.plot([0, max_inst + 1], [0, max_inst + 1], alpha=0.5) + x = df["gt_instances"].values.astype(int) + y = df["dt_instances"].values.astype(int) + sns.boxplot(x, y, ax=ax, width=0.5) + ax.set_xbound(0, max_inst + 1) + ax.set_ybound(0, max_inst + 1) + ax.set_aspect("equal") + + ax.set_title("") + ax.set_xlabel("gt_instances") + ax.set_ylabel("dt_instances") + + import matplotlib.ticker as pltticker + + loc = pltticker.MultipleLocator(base=2.0) + ax.xaxis.set_major_locator(loc) + ax.yaxis.set_major_locator(loc) + + return ax + + def gt_vs_dt_xpxs(self): + """Creates scatter plots comparing ground truth and detected x-pixels. + + Returns: + tuple[plt.Figure, plt.Figure, plt.Figure]: Figure handles for the three generated plots. + """ + df = self.dfimg + max_inst, max_xpxs, max_pxs = self.get_max_limits(df) + idx = (df.gt_instances > 0) & (df.dt_instances > 0) + dfsub = df[idx] + + fig1 = plt.figure(figsize=[10, 10], dpi=100) + ax = fig1.add_subplot(111) + sc = ax.scatter(dfsub["gt_xpxs"], dfsub["dt_xpxs"], c=dfsub["gt_instances"], cmap="viridis") + ax.set_aspect(1) + # ax = dfsub.plot(kind = 'scatter',x=,y=,c='gt_instances') + plt.plot([0, max_xpxs], [0, max_xpxs], alpha=0.5) + plt.xlim(0, max_xpxs) + plt.ylim(0, max_xpxs) + plt.xlabel("gt_xpxs") + plt.ylabel("dt_xpxs") + cbar = plt.colorbar(sc) + cbar.ax.set_ylabel("gt_instances") + plt.tight_layout() + + fig2 = plt.figure(figsize=[10, 10], dpi=100) + ax = fig2.add_subplot(111) + sc = ax.scatter(dfsub["gt_xpxs"], dfsub["gt_xpxs"] - dfsub["dt_xpxs"], c=dfsub["gt_instances"], cmap="viridis") + # ax = dfsub.plot(kind = 'scatter',x=,y=,c='gt_instances') + plt.plot([0, max_xpxs], [0, 0], alpha=0.5) + plt.xlabel("gt_xpxs") + plt.ylabel("gt_xpxs-dt_xpxs") + cbar = plt.colorbar(sc) + cbar.ax.set_ylabel("gt_instances") + plt.tight_layout() + + fig3 = plt.figure(dpi=100) + plt.hist(dfsub["gt_xpxs"] - dfsub["dt_xpxs"]) + plt.xlabel("gt_xpxs - dt_xpxs") + plt.ylabel("B-scans") + + return fig1, fig2, fig3 + + def gt_vs_dt_xpxs_mu(self): + """Plots binned means of detected vs. ground truth x-pixels. + + Returns: + plt.Figure: The figure handle for the plot. + """ + df = self.dfimg + max_inst, max_xpxs, max_pxs = self.get_max_limits(df) + idx = (df.gt_instances > 0) & (df.dt_instances > 0) + dfsub = df[idx] + + from scipy import stats + + mu_dt, bins, bnum = stats.binned_statistic(dfsub["gt_xpxs"], dfsub["dt_xpxs"], statistic="mean", bins=10) + std_dt, _, _ = stats.binned_statistic(dfsub["gt_xpxs"], dfsub["dt_xpxs"], statistic="std", bins=bins) + mu_gt, _, _ = stats.binned_statistic(dfsub["gt_xpxs"], dfsub["gt_xpxs"], statistic="mean", bins=bins) + std_gt, _, _ = stats.binned_statistic(dfsub["gt_xpxs"], dfsub["gt_xpxs"], statistic="std", bins=bins) + fig = plt.figure(dpi=100) + plt.errorbar(mu_gt, mu_dt, yerr=std_dt, xerr=std_gt, fmt="*") + plt.xlabel("gt_xpxs") + plt.ylabel("dt_xpxs") + plt.plot([0, max_xpxs], [0, max_xpxs], alpha=0.5) + plt.xlim(0, max_xpxs) + plt.ylim(0, max_xpxs) + plt.gca().set_aspect(1) + plt.tight_layout() + return fig + + def gt_dt_fp_fn_count(self): + """Plots histograms of false positive and false negative instance counts. + + Returns: + plt.Figure: The figure handle for the plot. + """ + df = self.dfimg + fig, ax = plt.subplots(1, 2, figsize=[10, 5]) + + idx = (df.gt_instances == 0) & (df.dt_instances > 0) + ax[0].hist(df[idx]["dt_instances"], bins=range(1, 10)) + ax[0].set_xlabel("dt instances") + ax[0].set_ylabel("B-scans") + ax[0].set_title("FP dt instance count per B-scan") + + idx = (df.gt_instances > 0) & (df.dt_instances == 0) + ax[1].hist(df[idx]["gt_instances"], bins=range(1, 10)) + ax[1].set_xlabel("gt instances") + ax[1].set_ylabel("B-scans") + ax[1].set_title("FN gt instance count per B-scan") + + plt.tight_layout() + return fig + + def avg_inst_size(self): + """Plots histograms of the average instance size in pixels. + + Compares the average size (in both total pixels and x-axis projection) + between ground truth and detected instances. + + Returns: + plt.Figure: The figure handle for the plot. + """ + df = self.dfimg + max_inst, max_xpxs, max_pxs = self.get_max_limits(df) + idx = (df.gt_instances > 0) & (df.dt_instances > 0) + dfsub = df[idx] + + fig = plt.figure(figsize=[10, 5]) + plt.subplot(121) + bins = np.arange(0, 120, 10) + ax = (dfsub.gt_xpxs / dfsub.gt_instances).hist(bins=bins, alpha=0.5, label="gt") + ax = (dfsub.dt_xpxs / dfsub.dt_instances).hist(bins=bins, alpha=0.5, label="dt") + ax.set_xlabel("xpxs") + ax.set_ylabel("B-scans") + ax.set_title("Average size of instance") + ax.legend() + + plt.subplot(122) + bins = np.arange(0, 600, 40) + ax = (dfsub.gt_pxs / dfsub.gt_instances).hist(bins=bins, alpha=0.5, label="gt") + ax = (dfsub.dt_pxs / dfsub.dt_instances).hist(bins=bins, alpha=0.5, label="dt") + ax.set_xlabel("pxs") + ax.set_ylabel("B-scans") + ax.set_title("Average size of instance") + ax.legend() + + plt.tight_layout() + return fig diff --git a/models/retinalOCT_RPD_segmentation/scripts/datasets/__init__.py b/models/retinalOCT_RPD_segmentation/scripts/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/retinalOCT_RPD_segmentation/scripts/datasets/data.py b/models/retinalOCT_RPD_segmentation/scripts/datasets/data.py new file mode 100644 index 00000000..77ff2552 --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/scripts/datasets/data.py @@ -0,0 +1,132 @@ +import distutils.util +import glob +import os +import shutil + +import cv2 +import pandas as pd +from PIL import Image +from pydicom import dcmread +from pydicom.fileset import FileSet +from tqdm import tqdm + +from .volReader import VolFile + +script_dir = os.path.dirname(__file__) + + +class Error(Exception): + """Base class for exceptions in this module.""" + + pass + + +def extract_files(dirtoextract, extracted_path, input_format): + """Extracts individual image frames from .vol or DICOM files. + + This function scans a directory for source files of a specified format + and extracts them into a structured output directory as PNG images. + It handles both .vol files and standard DICOM files. If the + output directory already contains files, it will prompt the user + before proceeding to overwrite them. + + Args: + dirtoextract (str): The root directory to search for source files. + extracted_path (str): The destination directory where the extracted + PNG images will be saved. + input_format (str): The format of the input files. Must be either + "vol" or "dicom". + """ + assert input_format in ["vol", "dicom"], 'Error: input_format must be "vol" or "dicom".' + proceed = True + if (os.path.isdir(extracted_path)) and (len(os.listdir(extracted_path)) != 0): + val = input( + f"{extracted_path} exists and is not empty. Files may be overwritten. Proceed with extraction? (Y/N)" + ) + proceed = bool(distutils.util.strtobool(val)) + if proceed: + print(f"Extracting files from {dirtoextract} into {extracted_path}...") + if input_format == "vol": + files_to_extract = glob.glob(os.path.join(dirtoextract, "**/*.vol"), recursive=True) + for _, line in enumerate(tqdm(files_to_extract)): + fpath = line.strip("\n") + vol = VolFile(fpath) + fpath = fpath.replace("\\", "/") + path, scan_str = fpath.strip(".vol").rsplit("/", 1) + extractpath = os.path.join(extracted_path, scan_str.replace("_", "/")) + os.makedirs(extractpath, exist_ok=True) + preffix = os.path.join(extractpath, scan_str + "_oct") + vol.render_oct_scans(preffix) + elif input_format == "dicom": + keywords = ["SOPInstanceUID", "PatientID", "ImageLaterality", "SeriesDate"] + list_of_dicts = [] + dirgen = glob.iglob(os.path.join(dirtoextract, "**/DICOMDIR"), recursive=True) + + for dsstr in dirgen: + fs = FileSet(dcmread(dsstr)) + fsgenopt = gen_opt_fs(fs) + for fi in tqdm(fsgenopt): + dd = dict() + # top level keywords + for key in keywords: + dd[key] = fi.get(key) + + volpath = os.path.join(extracted_path, f"{fi.SOPInstanceUID}") + shutil.rmtree(volpath, ignore_errors=True) + os.mkdir(volpath) + n = fi.NumberOfFrames + for i in range(n): + fname = os.path.join(volpath, f"{fi.SOPInstanceUID}_oct_{i:03d}.png") + Image.fromarray(fi.pixel_array[i]).save(fname) + list_of_dicts.append(dd.copy()) + dfoct = pd.DataFrame(list_of_dicts, columns=keywords) + dfoct.to_csv(os.path.join(extracted_path, "basic_meta.csv")) + else: + pass + + +def rpd_data(extracted_path): + """Generates a dataset list from a directory of extracted image files. + + Scans a directory recursively for PNG images and creates a list of + dictionaries, one for each image. This format is designed to be compatible + with Detectron2's `DatasetCatalog` and can be adapted to hold ground truth instances for evaluation. + + Args: + extracted_path (str): The root directory containing the extracted + .png image files to be included in the dataset. + + Returns: + list[dict]: A list where each dictionary represents an image and + contains its file path, dimensions, and a unique ID. + """ + dataset = [] + extracted_files = glob.glob(os.path.join(extracted_path, "**/*.[Pp][Nn][Gg]"), recursive=True) + print("Generating dataset of images...") + for fn in tqdm(extracted_files): + fn_adjusted = fn.replace("\\", "/") + imageid = fn_adjusted.split("/")[-1] + im = cv2.imread(fn) + dat = dict(file_name=fn_adjusted, height=im.shape[0], width=im.shape[1], image_id=imageid) + dataset.append(dat) + print(f"Found {len(dataset)} images") + return dataset + + +def gen_opt_fs(fs): + """A generator for finding and loading OPT modality DICOM datasets. + + This function filters a pydicom `FileSet` object for instances that have + the modality set to "OPT" (Ophthalmic Tomography) and yields each one + as a fully loaded pydicom dataset. + + Args: + fs (pydicom.fileset.FileSet): The pydicom FileSet to search through. + + Yields: + pydicom.dataset.FileDataset: A loaded DICOM dataset for each instance + with the "OPT" modality found in the FileSet. + """ + for instance in fs.find(Modality="OPT"): + ds = instance.load() + yield ds diff --git a/models/retinalOCT_RPD_segmentation/scripts/datasets/volReader.py b/models/retinalOCT_RPD_segmentation/scripts/datasets/volReader.py new file mode 100644 index 00000000..d8b8bc5d --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/scripts/datasets/volReader.py @@ -0,0 +1,397 @@ +# Aaron Y. Lee MD MSCI (University of Washington) Copyright 2019 +# +# Code ported from Markus Mayer's excellent work (https://www5.cs.fau.de/research/software/octseg/) +# +# Also thanks to who contributed to the original openVol.m in Markus's project +# Radim Kolar, Brno University, Czech Republic +# Kris Sheets, Retinal Cell Biology Lab, Neuroscience Center of Excellence, LSU Health Sciences Center, New Orleans + + +import array +import codecs +import datetime +import struct +from collections import OrderedDict + +import numpy as np + + +class VolFile: + def __init__(self, filename): + """ + Parses Heyex Spectralis *.vol files. + + Args: + filename (str): Path to vol file + + Returns: + volFile class + + """ + self.__parse_volfile(filename) + + @property + def oct(self): + """ + Retrieve OCT volume as a 3D numpy array. + + Returns: + 3D numpy array with OCT intensities as 'uint8' array + + """ + return self.wholefile["cScan"] + + @property + def irslo(self): + """ + Retrieve IR SLO image as 2D numpy array + + Returns: + 2D numpy array with IR reflectance SLO image as 'uint8' array. + + """ + return self.wholefile["sloImage"] + + @property + def grid(self): + """ + Retrieve the IR SLO pixel coordinates for the B scan OCT slices + + Returns: + 2D numpy array with the number of b scan images in the first dimension + and x_0, y_0, x_1, y_1 defining the line of the B scan on the pixel + coordinates of the IR SLO image. + + """ + wf = self.wholefile + grid = [] + for bi in range(len(wf["slice-headers"])): + bscan_head = wf["slice-headers"][bi] + x_0 = int(bscan_head["startX"] / wf["header"]["scaleXSlo"]) + x_1 = int(bscan_head["endX"] / wf["header"]["scaleXSlo"]) + y_0 = int(bscan_head["startY"] / wf["header"]["scaleYSlo"]) + y_1 = int(bscan_head["endY"] / wf["header"]["scaleYSlo"]) + grid.append([x_0, y_0, x_1, y_1]) + return grid + + def render_ir_slo(self, filename, render_grid=False): + """ + Renders IR SLO image as a PNG file and optionally overlays grid of B scans + + Args: + filename (str): filename to save IR SLO image + renderGrid (bool): True will render red lines for the location of the B scans. + + Returns: + None + + """ + from PIL import Image, ImageDraw + + wf = self.wholefile + a = np.copy(wf["sloImage"]) + if render_grid: + a = np.stack((a,) * 3, axis=-1) + a = Image.fromarray(a) + draw = ImageDraw.Draw(a) + grid = self.grid + for x_0, y_0, x_1, y_1 in grid: + draw.line((x_0, y_0, x_1, y_1), fill=(255, 0, 0), width=3) + a.save(filename) + else: + Image.fromarray(a).save(filename) + + def render_oct_scans(self, filepre="oct", render_seg=False): + """ + Renders OCT images a PNG file and optionally overlays segmentation lines + Also creates a CSV file of vol file features. + + Args: + filepre (str): filename prefix. OCT Images will be named as "_001.png" + renderSeg (bool): True will render colored lines for the segmentation of the RPE, ILM, and NFL on the B scans. + + Returns: + None + + """ + from PIL import Image + + wf = self.wholefile + for i in range(wf["cScan"].shape[0]): + a = np.copy(wf["cScan"][i]) + if render_seg: + a = np.stack((a,) * 3, axis=-1) + for li in range(wf["segmentations"].shape[0]): + for x in range(wf["segmentations"].shape[2]): + a[int(wf["segmentations"][li, i, x]), x, li] = 255 + + Image.fromarray(a).save("%s_%03d.png" % (filepre, i)) + + def __parse_volfile(self, fn, parse_seg=False): + print(fn) + wholefile = OrderedDict() + decode_hex = codecs.getdecoder("hex_codec") + with open(fn, "rb") as fin: + header = OrderedDict() + header["version"] = fin.read(12) + header["octSizeX"] = struct.unpack("I", fin.read(4))[0] # lateral resolution + header["numBscan"] = struct.unpack("I", fin.read(4))[0] + header["octSizeZ"] = struct.unpack("I", fin.read(4))[0] # OCT depth + header["scaleX"] = struct.unpack("d", fin.read(8))[0] + header["distance"] = struct.unpack("d", fin.read(8))[0] + header["scaleZ"] = struct.unpack("d", fin.read(8))[0] + header["sizeXSlo"] = struct.unpack("I", fin.read(4))[0] + header["sizeYSlo"] = struct.unpack("I", fin.read(4))[0] + header["scaleXSlo"] = struct.unpack("d", fin.read(8))[0] + header["scaleYSlo"] = struct.unpack("d", fin.read(8))[0] + header["fieldSizeSlo"] = struct.unpack("I", fin.read(4))[0] # FOV in degrees + header["scanFocus"] = struct.unpack("d", fin.read(8))[0] + header["scanPos"] = fin.read(4) + header["examTime"] = struct.unpack("=q", fin.read(8))[0] / 1e7 + header["examTime"] = datetime.datetime.utcfromtimestamp( + header["examTime"] - (369 * 365.25 + 4) * 24 * 60 * 60 + ) # needs to be checked + header["scanPattern"] = struct.unpack("I", fin.read(4))[0] + header["BscanHdrSize"] = struct.unpack("I", fin.read(4))[0] + header["ID"] = fin.read(16) + header["ReferenceID"] = fin.read(16) + header["PID"] = struct.unpack("I", fin.read(4))[0] + header["PatientID"] = fin.read(21) + header["unknown2"] = fin.read(3) + header["DOB"] = struct.unpack("d", fin.read(8))[0] - 25569 + header["DOB"] = datetime.datetime.utcfromtimestamp(0) + datetime.timedelta( + seconds=header["DOB"] * 24 * 60 * 60 + ) # needs to be checked + header["VID"] = struct.unpack("I", fin.read(4))[0] + header["VisitID"] = fin.read(24) + header["VisitDate"] = struct.unpack("d", fin.read(8))[0] - 25569 + header["VisitDate"] = datetime.datetime.utcfromtimestamp(0) + datetime.timedelta( + seconds=header["VisitDate"] * 24 * 60 * 60 + ) # needs to be checked + header["GridType"] = struct.unpack("I", fin.read(4))[0] + header["GridOffset"] = struct.unpack("I", fin.read(4))[0] + + wholefile["header"] = header + fin.seek(2048) + u = array.array("B") + u.frombytes(fin.read(header["sizeXSlo"] * header["sizeYSlo"])) + u = np.array(u).astype("uint8").reshape((header["sizeXSlo"], header["sizeYSlo"])) + wholefile["sloImage"] = u + + slo_offset = 2048 + header["sizeXSlo"] * header["sizeYSlo"] + oct_offset = header["BscanHdrSize"] + header["octSizeX"] * header["octSizeZ"] * 4 + bscans = [] + bscanheaders = [] + bscanqualities = [] + if parse_seg: + segmentations = None + for i in range(header["numBscan"]): + fin.seek(16 + slo_offset + i * oct_offset) + bscan_head = OrderedDict() + bscan_head["startX"] = struct.unpack("d", fin.read(8))[0] + bscan_head["startY"] = struct.unpack("d", fin.read(8))[0] + bscan_head["endX"] = struct.unpack("d", fin.read(8))[0] + bscan_head["endY"] = struct.unpack("d", fin.read(8))[0] + bscan_head["numSeg"] = struct.unpack("I", fin.read(4))[0] + bscan_head["offSeg"] = struct.unpack("I", fin.read(4))[0] + bscan_head["quality"] = struct.unpack("f", fin.read(4))[0] + bscan_head["shift"] = struct.unpack("I", fin.read(4))[0] + bscanheaders.append(bscan_head) + bscanqualities.append(bscan_head["quality"]) + + # extract OCT B scan data + fin.seek(header["BscanHdrSize"] + slo_offset + i * oct_offset) + u = array.array("f") + u.frombytes(fin.read(4 * header["octSizeX"] * header["octSizeZ"])) + u = np.array(u).reshape((header["octSizeZ"], header["octSizeX"])) + # remove out of boundary + v = struct.unpack("f", decode_hex("FFFF7F7F")[0]) + u[u == v] = 0 + # log normalize + u = np.log(10000 * u + 1) + u = (255.0 * (np.clip(u, 0, np.max(u)) / np.max(u))).astype("uint8") + bscans.append(u) + if parse_seg: + # extract OCT segmentations data + fin.seek(256 + slo_offset + i * oct_offset) + u = array.array("f") + u.frombytes(fin.read(4 * header["octSizeX"] * bscan_head["numSeg"])) + u = np.array(u) + print(u.shape) + u[u == v] = 0.0 + if segmentations is None: + segmentations = [] + for _ in range(bscan_head["numSeg"]): + segmentations.append([]) + + for j in range(bscan_head["numSeg"]): + segmentations[j].append(u[j * header["octSizeX"] : (j + 1) * header["octSizeX"]].tolist()) + wholefile["cScan"] = np.array(bscans) + if parse_seg: + wholefile["segmentations"] = np.array(segmentations) + wholefile["slice-headers"] = bscanheaders + wholefile["average-quality"] = np.mean(bscanqualities) + self.wholefile = wholefile + import csv + from pathlib import Path, PurePath + + vol_features = [ + PurePath(fn).name, + wholefile["header"]["version"].decode("utf-8").rstrip("\x00"), + wholefile["header"]["numBscan"], + wholefile["header"]["octSizeX"], + wholefile["header"]["octSizeZ"], + wholefile["header"]["distance"], + wholefile["header"]["scaleX"], + wholefile["header"]["scaleZ"], + wholefile["header"]["sizeXSlo"], + wholefile["header"]["sizeYSlo"], + wholefile["header"]["scaleXSlo"], + wholefile["header"]["scaleYSlo"], + wholefile["header"]["fieldSizeSlo"], + wholefile["header"]["scanFocus"], + wholefile["header"]["scanPos"].decode("utf-8").rstrip("\x00"), + wholefile["header"]["examTime"], + wholefile["header"]["scanPattern"], + wholefile["header"]["BscanHdrSize"], + wholefile["header"]["ID"].decode("utf-8").rstrip("\x00"), + wholefile["header"]["ReferenceID"].decode("utf-8").rstrip("\x00"), + wholefile["header"]["PID"], + wholefile["header"]["PatientID"].decode("utf-8").rstrip("\x00"), + wholefile["header"]["DOB"], + wholefile["header"]["VID"], + wholefile["header"]["VisitID"].decode("utf-8").rstrip("\x00"), + wholefile["header"]["VisitDate"], + wholefile["header"]["GridType"], + wholefile["header"]["GridOffset"], + wholefile["average-quality"], + ] + output_dir = PurePath(fn).parent + output_csv = output_dir.joinpath("vols.csv") + if not Path(output_csv).exists(): + print("Creating vols.csv as it does not exist.") + with open(output_csv, "w", newline="") as file: + writer = csv.writer(file) + writer.writerow( + [ + "filename", + "version", + "numBscan", + "octSizeX", + "octSizeZ", + "distance", + "scaleX", + "scaleZ", + "sizeXSlo", + "sizeYSlo", + "scaleXSlo", + "scaleYSlo", + "fieldSizeSlo", + "scanFocus", + "scanPos", + "examTime", + "scanPattern", + "BscanHdrSize", + "ID", + "ReferenceID", + "PID", + "PatientID", + "DOB", + "VID", + "VisitID", + "VisitDate", + "GridType", + "GridOffset", + "Average Quality", + ] + ) + with open(output_csv, "r", newline="") as file: + existing_vols = csv.reader(file) + for vol in existing_vols: + if vol[0] == PurePath(fn).name: + print("Skipping,", PurePath(fn).name, "already present in vols.csv.") + return + with open(output_csv, "a", newline="") as file: + print("Adding", PurePath(fn).name, "to vols.csv.") + writer = csv.writer(file) + writer.writerow(vol_features) + + @property + def file_header(self): + """ + Retrieve vol header fields + + Returns: + Dictionary with the following keys + - version: version number of vol file definition + - numBscan: number of B scan images in the volume + - octSizeX: number of pixels in the width of the OCT B scan + - octSizeZ: number of pixels in the height of the OCT B scan + - distance: unknown + - scaleX: resolution scaling factor of the width of the OCT B scan + - scaleZ: resolution scaling factor of the height of the OCT B scan + - sizeXSlo: number of pixels in the width of the IR SLO image + - sizeYSlo: number of pixels in the height of the IR SLO image + - scaleXSlo: resolution scaling factor of the width of the IR SLO image + - scaleYSlo: resolution scaling factor of the height of the IR SLO image + - fieldSizeSlo: field of view (FOV) of the retina in degrees + - scanFocus: unknown + - scanPos: Left or Right eye scanned + - examTime: Datetime of the scan (needs to be checked) + - scanPattern: unknown + - BscanHdrSize: size of B scan header in bytes + - ID: unknown + - ReferenceID + - PID: unknown + - PatientID: Patient ID string + - DOB: Date of birth + - VID: unknown + - VisitID: Visit ID string + - VisitDate: Datetime of visit (needs to be checked) + - GridType: unknown + - GridOffset: unknown + + """ + return self.wholefile["header"] + + def bscan_header(self, slicei): + """ + Retrieve the B Scan header information per slice. + + Args: + slicei (int): index of B scan + + Returns: + Dictionary with the following keys + - startX: x-coordinate for B scan on IR. (see getGrid) + - startY: y-coordinate for B scan on IR. (see getGrid) + - endX: x-coordinate for B scan on IR. (see getGrid) + - endY: y-coordinate for B scan on IR. (see getGrid) + - numSeg: 2 or 3 segmentation lines for the B scan + - quality: OCT signal quality + - shift: unknown + + """ + return self.wholefile["slice-headers"][slicei] + + def save_grid(self, outfn): + """ + Saves the grid coordinates mapping OCT Bscans to the IR SLO image to a text file. The text file + will be a tab-delimited file with 5 columns: The bscan number, x_0, y_0, x_1, y_1 in pixel space + scaled to the resolution of the IR SLO image. + + Args: + outfn (str): location of where to output the file + + Returns: + None + + """ + grid = self.grid + with open(outfn, "w") as fout: + fout.write("bscan\tx_0\ty_0\tx_1\ty_1\n") + ri = 0 + for r in grid: + r = [ri] + r + fout.write("%s\n" % "\t".join(map(str, r))) + ri += 1 diff --git a/models/retinalOCT_RPD_segmentation/scripts/inference.py b/models/retinalOCT_RPD_segmentation/scripts/inference.py new file mode 100644 index 00000000..8eaf9846 --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/scripts/inference.py @@ -0,0 +1,355 @@ +import json +import logging +import os +import pickle + +import pandas as pd +import progressbar +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader +from detectron2.evaluation import COCOEvaluator, inference_on_dataset +from detectron2.modeling import build_model + +from .analysis_lib import CreatePlotsRPD, EvaluateClass, OutputVis, grab_dataset +from .datasets import data +from .Ensembler import Ensembler +from .table_styles import styles + +# Change directory to the script's location to ensure relative paths work correctly. +os.chdir(os.path.dirname(os.path.abspath(__file__))) + + +logging.basicConfig(level=logging.INFO) + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +dpi = 120 + + +class MyProgressBar: + # https://stackoverflow.com/a/53643011/3826929 + # George C + def __init__(self): + self.pbar = None + + def __call__(self, block_num, block_size, total_size): + if not self.pbar: + self.pbar = progressbar.ProgressBar(maxval=total_size) + self.pbar.start() + + downloaded = block_num * block_size + if downloaded < total_size: + self.pbar.update(downloaded) + else: + self.pbar.finish() + + +def create_dataset(dataset_name, extracted_path): + """Creates a pickled dataset file from a directory of extracted images. + + This function scans the `extracted_path` for images, formats them into a + list of dictionaries compatible with Detectron2, and saves the list as a + pickle file. + + Args: + dataset_name (str): The name for the dataset, used for the output .pk file. + extracted_path (str): The directory containing the extracted image files. + """ + stored_data = data.rpd_data(extracted_path) + pickle.dump(stored_data, open(os.path.join(data.script_dir, f"{dataset_name}.pk"), "wb")) + + +def configure_model(): + """Loads and returns the model configuration from a YAML file. + + It reads a 'working.yaml' file located in the same directory as the script + to set up the Detectron2 configuration. + + Returns: + detectron2.config.CfgNode: The configuration object for the model. + """ + cfg = get_cfg() + moddir = os.path.dirname(os.path.realpath(__file__)) + name = "working.yaml" + cfg_path = os.path.join(moddir, name) + cfg.merge_from_file(cfg_path) + return cfg + + +def register_dataset(dataset_name): + """Registers a dataset with Detectron2's DatasetCatalog. + + This makes the dataset available to be loaded by Detectron2's data loaders. + It sets the class metadata to 'rpd'. + + Args: + dataset_name (str): The name under which to register the dataset. + """ + for name in [dataset_name]: + try: + DatasetCatalog.register(name, grab_dataset(name)) + except AssertionError as e: + print(f"Assertion failed: {e}. Already registered.") + MetadataCatalog.get(name).thing_classes = ["rpd"] + + +def run_prediction(cfg, dataset_name, output_path): + """Runs inference on a dataset using a cross-validation ensemble of models. + + It loads five different model weight files (fold1 to fold5), runs inference + for each model on the specified dataset, and saves the predictions in + separate subdirectories within `output_path`. + + Args: + cfg (CfgNode): The model configuration object. + dataset_name (str): The name of the registered dataset to run inference on. + output_path (str): The base directory to save prediction outputs. + """ + model = build_model(cfg) # returns a torch.nn.Module + myloader = build_detection_test_loader(cfg, dataset_name) + myeval = COCOEvaluator( + dataset_name, tasks={"bbox", "segm"}, output_dir=output_path + ) # produces _coco_format.json when initialized + for mdl in ("fold1", "fold2", "fold3", "fold4", "fold5"): + extract_directory = "../models" + file_name = mdl + "_model_final.pth" + model_weights_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), extract_directory, file_name) + print(model_weights_path) + DetectionCheckpointer(model).load(model_weights_path) # load a file, usually from cfg.MODEL.WEIGHTS + model.eval() # set model in evaluation mode + myeval.reset() + output_dir = os.path.join(output_path, mdl) + myeval._output_dir = output_dir + print("Running inference with model ", mdl) + _ = inference_on_dataset( + model, myloader, myeval + ) # produces coco_instance_results.json when myeval.evaluate is called + print("Done with predictions!") + + +def run_ensemble(dataset_name, output_path, iou_thresh=0.2): + """Ensembles predictions from multiple models using NMS. + + It initializes an `Ensembler`, runs the non-maximum suppression logic, and + saves the final combined predictions to a single COCO results file. + + Args: + dataset_name (str): The name of the dataset. + output_path (str): The base directory containing the individual model + prediction subdirectories. + iou_thresh (float, optional): The IoU threshold for ensembling. Defaults to 0.2. + + Returns: + Ensembler: The ensembler instance after running NMS. + """ + ens = Ensembler(output_path, dataset_name, ["fold1", "fold2", "fold3", "fold4", "fold5"], iou_thresh=iou_thresh) + ens.mean_score_nms() + ens.save_coco_instances() + return ens + + +def evaluate_dataset(dataset_name, output_path, iou_thresh=0.2, prob_thresh=0.5): + """Evaluates the final ensembled predictions against ground truth. + + It uses the custom `EvaluateClass` to calculate performance metrics and saves + a summary to a JSON file. + + Args: + dataset_name (str): The name of the dataset. + output_path (str): The directory containing the ensembled predictions file. + iou_thresh (float, optional): The IoU threshold for evaluation. Defaults to 0.2. + prob_thresh (float, optional): The probability threshold for evaluation. Defaults to 0.5. + + Returns: + EvaluateClass: The evaluation object containing detailed metrics. + """ + myeval = EvaluateClass(dataset_name, output_path, iou_thresh=iou_thresh, prob_thresh=prob_thresh, evalsuper=False) + myeval.evaluate() + with open(os.path.join(output_path, "scalar_dict.json"), "w") as outfile: + json.dump(obj=myeval.summarize_scalars(), fp=outfile) + return myeval + + +def create_table(myeval): + """Creates a DataFrame of per-image statistics from evaluation results. + + Args: + myeval (EvaluateClass): The evaluation object containing COCO results. + + Returns: + CreatePlotsRPD: An object containing DataFrames for image and volume stats. + """ + dataset_table = CreatePlotsRPD.initfromcoco(myeval.mycoco, myeval.prob_thresh) + dataset_table.dfimg.sort_index(inplace=True) + return dataset_table + # dataset_table.dfimg['scan'] = dataset_table.dfimg['scan'].astype('int') #depends on what we want scan field to be + + +def output_vol_predictions(dataset_table, vis, volid, output_path, output_mode="pred_overlay"): + """Generates and saves visualization TIFFs for a single scan volume. + + Args: + dataset_table (CreatePlotsRPD): Object containing the image/volume stats. + vis (OutputVis): The visualization object. + volid (str): The ID of the volume to visualize. + output_path (str): The directory to save the output TIFF file. + output_mode (str, optional): The type of visualization to create. + Options: "pred_overlay", "pred_only", "originals", "all". + Defaults to "pred_overlay". + """ + dfimg = dataset_table.dfimg + imgids = dfimg[dfimg["volID"] == volid].sort_index().index.values + outname = os.path.join(output_path, f"{volid}.tiff") + if output_mode == "pred_overlay": + vis.output_pred_to_tiff(imgids, outname, pred_only=False) + elif output_mode == "pred_only": + vis.output_pred_to_tiff(imgids, outname, pred_only=True) + elif output_mode == "originals": + vis.output_ori_to_tiff(imgids, outname) + elif output_mode == "all": + vis.output_all_to_tiff(imgids, outname) + else: + print(f"Invalid mode {output_mode} for function output_vol_predictions.") + + +def output_dataset_predictions(dataset_table, vis, output_path, output_mode="pred_overlay", draw_mode="default"): + """Generates and saves visualization TIFFs for all volumes in a dataset. + + Args: + dataset_table (CreatePlotsRPD): Object containing the image/volume stats. + vis (OutputVis): The visualization object. + output_path (str): The base directory to save the output TIFF files. + output_mode (str, optional): The type of visualization to create. + Defaults to "pred_overlay". + draw_mode (str, optional): The drawing style ("default" or "bw"). + Defaults to "default". + """ + vis.set_draw_mode(draw_mode) + os.makedirs(output_path, exist_ok=True) + for volid in dataset_table.dfvol.index: + output_vol_predictions(dataset_table, vis, volid, output_path, output_mode) + + +def create_dfvol(dataset_name, output_path, dataset_table): + """Creates and saves a styled HTML table of volume-level statistics. + + Args: + dataset_name (str): The name of the dataset. + output_path (str): The directory to save the HTML file. + dataset_table (CreatePlotsRPD): Object containing the volume DataFrame. + """ + dfvol = dataset_table.dfvol.sort_values(by=["dt_instances"], ascending=False) + with pd.option_context("styler.render.max_elements", int(dfvol.size) + 1): + html_str = dfvol.style.format("{:.0f}").set_table_styles(styles).to_html() + html_file = open(os.path.join(output_path, "dfvol_" + dataset_name + ".html"), "w") + html_file.write(html_str) + html_file.close() + + +def create_dfimg(dataset_name, output_path, dataset_table): + """Creates and saves a styled HTML table of image-level statistics. + + Args: + dataset_name (str): The name of the dataset. + output_path (str): The directory to save the HTML file. + dataset_table (CreatePlotsRPD): Object containing the image DataFrame. + """ + dfimg = dataset_table.dfimg.sort_index() + with pd.option_context("styler.render.max_elements", int(dfimg.size) + 1): + html_str = dfimg.style.set_table_styles(styles).to_html() + html_file = open(os.path.join(output_path, "dfimg_" + dataset_name + ".html"), "w") + html_file.write(html_str) + html_file.close() + + +def main(args): + """Main function to orchestrate the end-to-end analysis pipeline. + + This function controls the flow from data extraction to evaluation and + visualization based on the provided arguments. + + Args: + args (dict): A dictionary of command-line arguments and flags + controlling the pipeline execution. + """ + print(f"Received arguments: {args}") + + # Unpack arguments from the dictionary with default values + dataset_name = args.get("dataset_name") + input_dir = args.get("input_dir") + extracted_dir = args.get("extracted_dir") + input_format = args.get("input_format") + output_dir = args.get("output_dir") + run_extract = args.get("run_extract", True) + make_dataset = args.get("create_dataset", True) + run_inference = args.get("run_inference", True) + prob_thresh = args.get("prob_thresh", 0.5) + iou_thresh = args.get("iou_thresh", 0.2) + create_tables = args.get("create_tables", True) + + # Visualization flags + bm = args.get("binary_mask", False) + bmo = args.get("binary_mask_overlay", False) + imo = args.get("instance_mask_overlay", False) + make_visuals = bm or bmo or imo + + # --- Pipeline Steps --- + if run_extract: + os.makedirs(extracted_dir, exist_ok=True) + print("Starting file extraction...") + data.extract_files(input_dir, extracted_dir, input_format) + print("Image extraction complete!") + if make_dataset: + print("Creating dataset from extracted images...") + create_dataset(dataset_name, extracted_dir) + if run_inference: + print("Configuring model...") + cfg = configure_model() + print("Registering dataset...") + register_dataset(dataset_name) + os.makedirs(output_dir, exist_ok=True) + print("Running inference...") + run_prediction(cfg, dataset_name, output_dir) + print("Inference complete, running ensemble...") + run_ensemble(dataset_name, output_dir, iou_thresh) + print("Ensemble complete!") + if create_tables or make_visuals: + print("Registering dataset for evaluation...") + register_dataset(dataset_name) + print("Evaluating dataset...") + eval_obj = evaluate_dataset(dataset_name, output_dir, iou_thresh, prob_thresh) + print("Creating dataset table...") + table = create_table(eval_obj) + if create_tables: + create_dfvol(dataset_name, output_dir, table) + create_dfimg(dataset_name, output_dir, table) + print("Dataset HTML tables complete!") + if make_visuals: + print("Initializing visualizer...") + vis = OutputVis( + dataset_name, + prob_thresh=eval_obj.prob_thresh, + pred_mode="file", + pred_file=os.path.join(output_dir, "coco_instances_results.json"), + has_annotations=False, # Assuming we are visualizing on test data without GT + ) + vis.scale = 1.0 # Use original scale for output visuals + if bm: + print("Creating binary masks TIFF (no overlay)...") + vis.annotation_color = "w" + output_dataset_predictions( + table, vis, os.path.join(output_dir, "predicted_binary_masks"), "pred_only", "bw" + ) + if bmo: + print("Creating binary masks TIFF (with overlay)...") + output_dataset_predictions( + table, vis, os.path.join(output_dir, "predicted_binary_overlays"), "pred_overlay", "bw" + ) + if imo: + print("Creating instance masks TIFF (with overlay)...") + output_dataset_predictions( + table, vis, os.path.join(output_dir, "predicted_instance_overlays"), "pred_overlay", "default" + ) + print("Visualizations complete!") diff --git a/models/retinalOCT_RPD_segmentation/scripts/mask_rcnn_X_101_32x8d_FPN_1x.yaml b/models/retinalOCT_RPD_segmentation/scripts/mask_rcnn_X_101_32x8d_FPN_1x.yaml new file mode 100644 index 00000000..2664d88d --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/scripts/mask_rcnn_X_101_32x8d_FPN_1x.yaml @@ -0,0 +1,56 @@ +_BASE_: "Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl" + PIXEL_STD: [57.375, 57.120, 58.395] + MASK_ON: True + RESNETS: + STRIDE_IN_1X1: False # this is a C2 model + NUM_GROUPS: 32 + WIDTH_PER_GROUP: 8 + DEPTH: 101 + ROI_HEADS: + NUM_CLASSES: 1 + SCORE_THRESH_TEST: 0.001 + NMS_THRESH_TEST: .01 +INPUT: + MIN_SIZE_TRAIN: (496,) + MIN_SIZE_TEST: 496 +SOLVER: + BASE_LR: 0.02 + #GAMMA: 0.05 + #STEPS: (3000, 7000, 11000, 15000) + #MAX_ITER: 18000 + GAMMA: 0.1 + STEPS: (3000, 4500) + MAX_ITER: 6000 + CHECKPOINT_PERIOD: 300 + IMS_PER_BATCH: 14 +TEST: + DETECTIONS_PER_IMAGE: 30 # LVIS allows up to 300 + EVAL_PERIOD: 300 +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 + NUM_WORKERS: 4 +# DATASETS: +# TRAIN: ("fold1","fold2","fold3","fold4",) +# TEST: ("fold5",) +# OUTPUT_DIR: "./output_valid_fold5" +# DATASETS: +# TRAIN: ("fold2","fold3","fold4","fold5",) +# TEST: ("fold1",) +# OUTPUT_DIR: "./output_valid_fold1" +# DATASETS: +# TRAIN: ("fold3","fold4","fold5","fold1",) +# TEST: ("fold2",) +# OUTPUT_DIR: "./output_valid_fold2" +# DATASETS: +# TRAIN: ("fold4","fold5","fold1","fold2",) +# TEST: ("fold3",) +# OUTPUT_DIR: "./output_valid_fold3" +# DATASETS: +# TRAIN: ("fold5","fold1","fold2","fold3",) +# TEST: ("fold4",) +# OUTPUT_DIR: "./output_valid_fold4" + +#modifiying to commit again diff --git a/models/retinalOCT_RPD_segmentation/scripts/table_styles.py b/models/retinalOCT_RPD_segmentation/scripts/table_styles.py new file mode 100644 index 00000000..eb32e97c --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/scripts/table_styles.py @@ -0,0 +1,32 @@ +def hover(hover_color="#add8e6"): + return dict(selector="tbody tr:hover", props=[("background-color", "%s" % hover_color)]) + + +styles = [ + # table properties + dict( + selector=" ", + props=[ + ("margin", "0"), + ("font-family", '"Helvetica", "Arial", sans-serif'), + ("border-collapse", "collapse"), + ("border", "none"), + ("border", "2px solid #ccf"), + ], + ), + # #header color - optional + # dict(selector="thead", + # props=[("background-color","#cc8484") + # ]), + # background shading + dict(selector="tbody tr:nth-child(even)", props=[("background-color", "#fff")]), + dict(selector="tbody tr:nth-child(odd)", props=[("background-color", "#eee")]), + # cell spacing + dict(selector="td", props=[("padding", ".5em"), ("text-align", "center")]), + # header cell properties + dict(selector="th", props=[("font-size", "125%"), ("text-align", "center")]), + # caption placement + dict(selector="caption", props=[("caption-side", "bottom")]), + # render hover last to override background-color + hover(), +] diff --git a/models/retinalOCT_RPD_segmentation/scripts/working.yaml b/models/retinalOCT_RPD_segmentation/scripts/working.yaml new file mode 100644 index 00000000..2664d88d --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/scripts/working.yaml @@ -0,0 +1,56 @@ +_BASE_: "Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl" + PIXEL_STD: [57.375, 57.120, 58.395] + MASK_ON: True + RESNETS: + STRIDE_IN_1X1: False # this is a C2 model + NUM_GROUPS: 32 + WIDTH_PER_GROUP: 8 + DEPTH: 101 + ROI_HEADS: + NUM_CLASSES: 1 + SCORE_THRESH_TEST: 0.001 + NMS_THRESH_TEST: .01 +INPUT: + MIN_SIZE_TRAIN: (496,) + MIN_SIZE_TEST: 496 +SOLVER: + BASE_LR: 0.02 + #GAMMA: 0.05 + #STEPS: (3000, 7000, 11000, 15000) + #MAX_ITER: 18000 + GAMMA: 0.1 + STEPS: (3000, 4500) + MAX_ITER: 6000 + CHECKPOINT_PERIOD: 300 + IMS_PER_BATCH: 14 +TEST: + DETECTIONS_PER_IMAGE: 30 # LVIS allows up to 300 + EVAL_PERIOD: 300 +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 + NUM_WORKERS: 4 +# DATASETS: +# TRAIN: ("fold1","fold2","fold3","fold4",) +# TEST: ("fold5",) +# OUTPUT_DIR: "./output_valid_fold5" +# DATASETS: +# TRAIN: ("fold2","fold3","fold4","fold5",) +# TEST: ("fold1",) +# OUTPUT_DIR: "./output_valid_fold1" +# DATASETS: +# TRAIN: ("fold3","fold4","fold5","fold1",) +# TEST: ("fold2",) +# OUTPUT_DIR: "./output_valid_fold2" +# DATASETS: +# TRAIN: ("fold4","fold5","fold1","fold2",) +# TEST: ("fold3",) +# OUTPUT_DIR: "./output_valid_fold3" +# DATASETS: +# TRAIN: ("fold5","fold1","fold2","fold3",) +# TEST: ("fold4",) +# OUTPUT_DIR: "./output_valid_fold4" + +#modifiying to commit again diff --git a/models/retinalOCT_RPD_segmentation/scripts/ybpres.mplstyle b/models/retinalOCT_RPD_segmentation/scripts/ybpres.mplstyle new file mode 100644 index 00000000..a0c9a964 --- /dev/null +++ b/models/retinalOCT_RPD_segmentation/scripts/ybpres.mplstyle @@ -0,0 +1,6 @@ +axes.titlesize : 16 +axes.labelsize : 16 +lines.linewidth : 2 +lines.markersize : 6 +xtick.labelsize : 15 +ytick.labelsize : 15 diff --git a/setup.cfg b/setup.cfg index 53aed841..06fa1478 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,7 +20,11 @@ ignore = B028 B907 C419 -per_file_ignores = __init__.py: F401, __main__.py: F401 +# lowercase checks are not needed for the following files in the retinalOCT_RPD_segmentation model +# https://github.com/Project-MONAI/model-zoo/pull/748#issuecomment-2877638507 +per_file_ignores = + __init__.py: F401 + __main__.py: F401 exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py [isort]