diff --git a/tutorials/generative/2d_ldm/2d_finetuning_stable_diffusion.ipynb b/tutorials/generative/2d_ldm/2d_finetuning_stable_diffusion.ipynb new file mode 100644 index 00000000..c2bea969 --- /dev/null +++ b/tutorials/generative/2d_ldm/2d_finetuning_stable_diffusion.ipynb @@ -0,0 +1,360 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "734c91f1", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) MONAI Consortium\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "bb70a99f", + "metadata": {}, + "source": [ + "# Finetuning Stable Diffusion to Generate 2D Medical Images\n", + "\n", + "In this tutorial, we will convert the Stable Diffusion weights to be loaded using MONAI Generative Model classes. Next, we will use a similar approach presented in [1,2] and finetune (and train from scratch) the second stage of the latent diffusion model.\n", + "\n", + "[1] - Chambon et al. \"RoentGen: Vision-Language Foundation Model for Chest X-ray Generation.\" https://arxiv.org/abs/2211.12737\n", + "\n", + "[2] - Chambon et al. \"Adapting Pretrained Vision-Language Foundational Models to Medical Imaging Domains.\" https://arxiv.org/abs/2210.04133" + ] + }, + { + "cell_type": "markdown", + "id": "b97c43d9", + "metadata": {}, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57bd0843", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "import time\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import sys\n", + "from monai import transforms\n", + "from monai.apps import DecathlonDataset\n", + "from monai.config import print_config\n", + "from monai.data import DataLoader\n", + "from monai.utils import set_determinism\n", + "from torch.cuda.amp import GradScaler, autocast\n", + "from tqdm import tqdm\n", + "\n", + "from generative.networks.nets.diffusion_model_unet import DiffusionModelUNet\n", + "\n", + "print_config()" + ] + }, + { + "cell_type": "markdown", + "id": "df46172d", + "metadata": {}, + "source": [ + "### Setup data directory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a58dcafa", + "metadata": {}, + "outputs": [], + "source": [ + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", + "root_dir = tempfile.mkdtemp() if directory is None else directory" + ] + }, + { + "cell_type": "markdown", + "id": "880d3b1a", + "metadata": {}, + "source": [ + "### Set deterministic training for reproducibility" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01459e7e", + "metadata": {}, + "outputs": [], + "source": [ + "set_determinism(42)" + ] + }, + { + "cell_type": "markdown", + "id": "22aa0d4b", + "metadata": {}, + "source": [ + "## Setup BRATS Dataset - Transforms for extracting 2D slices from 3D volumes\n", + "\n", + "We now download the BraTS dataset and extract the 2D slices from the 3D volumes. The `slice_label` is used to indicate whether the slice contains an anomaly or not.\n", + "\n", + "Here we use transforms to augment the training dataset, as usual:\n", + "\n", + "1. `LoadImaged` loads the brain images from files.\n", + "2. `EnsureChannelFirstd` ensures the original data to construct \"channel first\" shape.\n", + "3. The first `Lambdad` transform chooses the first channel of the image, which is the T1-weighted image.\n", + "4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm to match the original paper.\n", + "5. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images.\n", + "6. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image.\n", + "6. The last `Lambdad` transform obtains `slice_label` by summing up the label to have a single scalar value (healthy `=1` or not `=2` )." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ffea194", + "metadata": {}, + "outputs": [], + "source": [ + "channel = 0 # 0 = Flair\n", + "assert channel in [0, 1, 2, 3], \"Choose a valid channel\"\n", + "\n", + "train_transforms = transforms.Compose(\n", + " [\n", + " transforms.LoadImaged(keys=[\"image\", \"label\"]),\n", + " transforms.EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n", + " transforms.Lambdad(keys=[\"image\"], func=lambda x: x[channel, :, :, :]),\n", + " transforms.AddChanneld(keys=[\"image\"]),\n", + " transforms.EnsureTyped(keys=[\"image\", \"label\"]),\n", + " transforms.Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", + " transforms.Spacingd(keys=[\"image\", \"label\"], pixdim=(3.0, 3.0, 2.0), mode=(\"bilinear\", \"nearest\")),\n", + " transforms.CenterSpatialCropd(keys=[\"image\", \"label\"], roi_size=(64, 64, 44)),\n", + " transforms.ScaleIntensityRangePercentilesd(keys=\"image\", lower=0, upper=99.5, b_min=0, b_max=1),\n", + " transforms.RandSpatialCropd(keys=[\"image\", \"label\"], roi_size=(64, 64, 1), random_size=False),\n", + " transforms.Lambdad(keys=[\"image\", \"label\"], func=lambda x: x.squeeze(-1)),\n", + " transforms.CopyItemsd(keys=[\"label\"], times=1, names=[\"slice_label\"]),\n", + " transforms.Lambdad(keys=[\"slice_label\"], func=lambda x: 2.0 if x.sum() > 0 else 1.0),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "94518df7", + "metadata": {}, + "source": [ + "### Load Training and Validation Datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce3e3517", + "metadata": {}, + "outputs": [], + "source": [ + "train_ds = DecathlonDataset(\n", + " root_dir=root_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " section=\"training\",\n", + " cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n", + " num_workers=4,\n", + " download=False, # Set download to True if the dataset hasnt been downloaded yet\n", + " seed=0,\n", + " transform=train_transforms,\n", + ")\n", + "print(f\"Length of training data: {len(train_ds)}\")\n", + "print(f'Train image shape {train_ds[0][\"image\"].shape}')\n", + "\n", + "val_ds = DecathlonDataset(\n", + " root_dir=root_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " section=\"validation\",\n", + " cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n", + " num_workers=4,\n", + " download=False, # Set download to True if the dataset hasnt been downloaded yet\n", + " seed=0,\n", + " transform=train_transforms,\n", + ")\n", + "print(f\"Length of training data: {len(val_ds)}\")\n", + "print(f'Validation Image shape {val_ds[0][\"image\"].shape}')" + ] + }, + { + "cell_type": "markdown", + "id": "18394239", + "metadata": {}, + "source": [ + "## Converting Stable Diffusion weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b77aca02", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "8830c874", + "metadata": {}, + "source": [ + "## Finetuning Diffusion Model\n", + "\n", + "At this step, we instantiate the MONAI components to create a DDIM, the UNET with conditioning, the noise scheduler, and the inferer used for training and sampling. We are using\n", + "the deterministic DDIM scheduler containing 1000 timesteps, and a 2D UNET with attention mechanisms.\n", + "\n", + "The `attention` mechanism is essential for ensuring good conditioning and images manipulation here.\n", + "\n", + "An `embedding layer`, which is also optimised during training, is used in the original work because it was empirically shown to improve conditioning compared to a single scalar information." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a52880db", + "metadata": {}, + "outputs": [], + "source": [ + "condition_dropout = 0.15\n", + "n_iterations = 2e4\n", + "batch_size = 64\n", + "val_interval = 100\n", + "iter_loss_list = []\n", + "val_iter_loss_list = []\n", + "iterations = []\n", + "iteration = 0\n", + "iter_loss = 0\n", + "\n", + "train_loader = DataLoader(\n", + " train_ds, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True, persistent_workers=True\n", + ")\n", + "val_loader = DataLoader(\n", + " val_ds, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=True, persistent_workers=True\n", + ")\n", + "\n", + "scaler = GradScaler()\n", + "total_start = time.time()\n", + "\n", + "while iteration < n_iterations:\n", + " for batch in train_loader:\n", + " iteration += 1\n", + " model.train()\n", + " images, classes = batch[\"image\"].to(device), batch[\"slice_label\"].to(device)\n", + " # 15% of the time, class conditioning dropout\n", + " classes = classes * (torch.rand_like(classes) > condition_dropout)\n", + " # cross attention expects shape [batch size, sequence length, channels]\n", + " class_embedding = embed(classes.long().to(device)).unsqueeze(1)\n", + " optimizer.zero_grad(set_to_none=True)\n", + " # pick a random time step t\n", + " timesteps = torch.randint(0, 1000, (len(images),)).to(device)\n", + "\n", + " with autocast(enabled=True):\n", + " # Generate random noise\n", + " noise = torch.randn_like(images).to(device)\n", + " # Get model prediction\n", + " noise_pred = inferer(\n", + " inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps, condition=class_embedding\n", + " )\n", + " loss = F.mse_loss(noise_pred.float(), noise.float())\n", + "\n", + " scaler.scale(loss).backward()\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + " iter_loss += loss.item()\n", + " sys.stdout.write(f\"Iteration {iteration}/{n_iterations} - train Loss {loss.item():.4f}\" + \"\\r\")\n", + " sys.stdout.flush()\n", + "\n", + " if (iteration) % val_interval == 0:\n", + " model.eval()\n", + " val_iter_loss = 0\n", + " for val_step, val_batch in enumerate(val_loader):\n", + " images, classes = val_batch[\"image\"].to(device), val_batch[\"slice_label\"].to(device)\n", + " # cross attention expects shape [batch size, sequence length, channels]\n", + " class_embedding = embed(classes.long().to(device)).unsqueeze(1)\n", + " timesteps = torch.randint(0, 1000, (len(images),)).to(device)\n", + " with torch.no_grad():\n", + " with autocast(enabled=True):\n", + " noise = torch.randn_like(images).to(device)\n", + " noise_pred = inferer(\n", + " inputs=images,\n", + " diffusion_model=model,\n", + " noise=noise,\n", + " timesteps=timesteps,\n", + " condition=class_embedding,\n", + " )\n", + " val_loss = F.mse_loss(noise_pred.float(), noise.float())\n", + " val_iter_loss += val_loss.item()\n", + " iter_loss_list.append(iter_loss / val_interval)\n", + " val_iter_loss_list.append(val_iter_loss / (val_step + 1))\n", + " iterations.append(iteration)\n", + " iter_loss = 0\n", + " print(\n", + " f\"Train Loss {loss.item():.4f}, Interval Loss {iter_loss_list[-1]:.4f}, Interval Loss Val {val_iter_loss_list[-1]:.4f}\"\n", + " )\n", + "\n", + "\n", + "total_time = time.time() - total_start\n", + "\n", + "print(f\"train diffusion completed, total time: {total_time}.\")\n", + "\n", + "plt.style.use(\"seaborn-bright\")\n", + "plt.title(\"Learning Curves Diffusion Model\", fontsize=20)\n", + "plt.plot(iterations, iter_loss_list, color=\"C0\", linewidth=2.0, label=\"Train\")\n", + "plt.plot(\n", + " iterations, val_iter_loss_list, color=\"C1\", linewidth=2.0, label=\"Validation\"\n", + ") # np.linspace(1, n_iterations, len(val_iter_loss_list))\n", + "plt.yticks(fontsize=12), plt.xticks(fontsize=12)\n", + "plt.xlabel(\"Iterations\", fontsize=16), plt.ylabel(\"Loss\", fontsize=16)\n", + "plt.legend(prop={\"size\": 14})\n", + "plt.show()" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "formats": "auto:light,ipynb", + "notebook_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/generative/2d_ldm/2d_finetuning_stable_diffusion.py b/tutorials/generative/2d_ldm/2d_finetuning_stable_diffusion.py new file mode 100644 index 00000000..28c6442d --- /dev/null +++ b/tutorials/generative/2d_ldm/2d_finetuning_stable_diffusion.py @@ -0,0 +1,217 @@ +# + +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# - + +# # Finetuning Stable Diffusion to Generate 2D Medical Images +# +# In this tutorial, we will convert the Stable Diffusion weights to be loaded using MONAI Generative Model classes. Next, we will use a similar approach presented in [1,2] and finetune (and train from scratch) the second stage of the latent diffusion model. +# +# [1] - Chambon et al. "RoentGen: Vision-Language Foundation Model for Chest X-ray Generation." https://arxiv.org/abs/2211.12737 +# +# [2] - Chambon et al. "Adapting Pretrained Vision-Language Foundational Models to Medical Imaging Domains." https://arxiv.org/abs/2210.04133 + +# ## Setup imports + +# + +import tempfile +import time +import os +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F +import sys +from monai import transforms +from monai.apps import DecathlonDataset +from monai.config import print_config +from monai.data import DataLoader +from monai.utils import set_determinism +from torch.cuda.amp import GradScaler, autocast + +print_config() +# - + +# ### Setup data directory + +directory = os.environ.get("MONAI_DATA_DIRECTORY") +root_dir = tempfile.mkdtemp() if directory is None else directory +root_dir = '/tmp/tmpic4meymr' +# ### Set deterministic training for reproducibility + +set_determinism(42) + +# ## Setup BRATS Dataset - Transforms for extracting 2D slices from 3D volumes +# +# We now download the BraTS dataset and extract the 2D slices from the 3D volumes. The `slice_label` is used to indicate whether the slice contains an anomaly or not. +# +# Here we use transforms to augment the training dataset, as usual: +# +# 1. `LoadImaged` loads the brain images from files. +# 2. `EnsureChannelFirstd` ensures the original data to construct "channel first" shape. +# 3. The first `Lambdad` transform chooses the first channel of the image, which is the T1-weighted image. +# 4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm to match the original paper. +# 5. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images. +# 6. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image. +# 6. The last `Lambdad` transform obtains `slice_label` by summing up the label to have a single scalar value (healthy `=1` or not `=2` ). + +# + +channel = 1 # 1 = T1-weighted +train_transforms = transforms.Compose( + [ + transforms.LoadImaged(keys=["image"]), + transforms.EnsureChannelFirstd(keys=["image"]), + transforms.Lambdad(keys=["image"], func=lambda x: x[channel, :, :, :]), + transforms.AddChanneld(keys=["image"]), + transforms.EnsureTyped(keys=["image"]), + transforms.Orientationd(keys=["image"], axcodes="RAS"), + transforms.ScaleIntensityRangePercentilesd(keys="image", lower=0, upper=99.5, b_min=0, b_max=1), + transforms.RandSpatialCropd(keys=["image"], roi_size=(240, 240, 1), random_size=False), + transforms.Lambdad(keys=["image"], func=lambda x: x.squeeze(-1)), + ] +) +# - + +# ### Load Training and Validation Datasets + +# + +train_ds = DecathlonDataset( + root_dir=root_dir, + task="Task01_BrainTumour", + section="training", + cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise + num_workers=4, + download=True, # Set download to True if the dataset hasnt been downloaded yet + seed=0, + transform=train_transforms, +) +print(f"Length of training data: {len(train_ds)}") +print(f'Train image shape {train_ds[0]["image"].shape}') + +val_ds = DecathlonDataset( + root_dir=root_dir, + task="Task01_BrainTumour", + section="validation", + cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise + num_workers=4, + download=False, # Set download to True if the dataset hasnt been downloaded yet + seed=0, + transform=train_transforms, +) +print(f"Length of training data: {len(val_ds)}") +print(f'Validation Image shape {val_ds[0]["image"].shape}') +# - + +# ## Converting Stable Diffusion weights + + + +# ## Finetuning Diffusion Model +# +# At this step, we instantiate the MONAI components to create a DDIM, the UNET with conditioning, the noise scheduler, and the inferer used for training and sampling. We are using +# the deterministic DDIM scheduler containing 1000 timesteps, and a 2D UNET with attention mechanisms. +# +# The `attention` mechanism is essential for ensuring good conditioning and images manipulation here. +# +# An `embedding layer`, which is also optimised during training, is used in the original work because it was empirically shown to improve conditioning compared to a single scalar information. + +# + +condition_dropout = 0.15 +n_iterations = 2e4 +batch_size = 64 +val_interval = 100 +iter_loss_list = [] +val_iter_loss_list = [] +iterations = [] +iteration = 0 +iter_loss = 0 + +train_loader = DataLoader( + train_ds, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True, persistent_workers=True +) +val_loader = DataLoader( + val_ds, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=True, persistent_workers=True +) + +scaler = GradScaler() +total_start = time.time() + +while iteration < n_iterations: + for batch in train_loader: + iteration += 1 + model.train() + images, classes = batch["image"].to(device), batch["slice_label"].to(device) + # 15% of the time, class conditioning dropout + classes = classes * (torch.rand_like(classes) > condition_dropout) + # cross attention expects shape [batch size, sequence length, channels] + class_embedding = embed(classes.long().to(device)).unsqueeze(1) + optimizer.zero_grad(set_to_none=True) + # pick a random time step t + timesteps = torch.randint(0, 1000, (len(images),)).to(device) + + with autocast(enabled=True): + # Generate random noise + noise = torch.randn_like(images).to(device) + # Get model prediction + noise_pred = inferer( + inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps, condition=class_embedding + ) + loss = F.mse_loss(noise_pred.float(), noise.float()) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + iter_loss += loss.item() + sys.stdout.write(f"Iteration {iteration}/{n_iterations} - train Loss {loss.item():.4f}" + "\r") + sys.stdout.flush() + + if (iteration) % val_interval == 0: + model.eval() + val_iter_loss = 0 + for val_step, val_batch in enumerate(val_loader): + images, classes = val_batch["image"].to(device), val_batch["slice_label"].to(device) + # cross attention expects shape [batch size, sequence length, channels] + class_embedding = embed(classes.long().to(device)).unsqueeze(1) + timesteps = torch.randint(0, 1000, (len(images),)).to(device) + with torch.no_grad(): + with autocast(enabled=True): + noise = torch.randn_like(images).to(device) + noise_pred = inferer( + inputs=images, + diffusion_model=model, + noise=noise, + timesteps=timesteps, + condition=class_embedding, + ) + val_loss = F.mse_loss(noise_pred.float(), noise.float()) + val_iter_loss += val_loss.item() + iter_loss_list.append(iter_loss / val_interval) + val_iter_loss_list.append(val_iter_loss / (val_step + 1)) + iterations.append(iteration) + iter_loss = 0 + print( + f"Train Loss {loss.item():.4f}, Interval Loss {iter_loss_list[-1]:.4f}, Interval Loss Val {val_iter_loss_list[-1]:.4f}" + ) + + +total_time = time.time() - total_start + +print(f"train diffusion completed, total time: {total_time}.") + +plt.style.use("seaborn-bright") +plt.title("Learning Curves Diffusion Model", fontsize=20) +plt.plot(iterations, iter_loss_list, color="C0", linewidth=2.0, label="Train") +plt.plot( + iterations, val_iter_loss_list, color="C1", linewidth=2.0, label="Validation" +) # np.linspace(1, n_iterations, len(val_iter_loss_list)) +plt.yticks(fontsize=12), plt.xticks(fontsize=12) +plt.xlabel("Iterations", fontsize=16), plt.ylabel("Loss", fontsize=16) +plt.legend(prop={"size": 14}) +plt.show()