diff --git a/guides/img/object_detection_retinanet/object_detection_retinanet_14_2.png b/guides/img/object_detection_retinanet/object_detection_retinanet_14_2.png
new file mode 100644
index 0000000000..a1dc0b1517
Binary files /dev/null and b/guides/img/object_detection_retinanet/object_detection_retinanet_14_2.png differ
diff --git a/guides/img/object_detection_retinanet/object_detection_retinanet_14_3.png b/guides/img/object_detection_retinanet/object_detection_retinanet_14_3.png
new file mode 100644
index 0000000000..a1dc0b1517
Binary files /dev/null and b/guides/img/object_detection_retinanet/object_detection_retinanet_14_3.png differ
diff --git a/guides/img/object_detection_retinanet/object_detection_retinanet_23_0.png b/guides/img/object_detection_retinanet/object_detection_retinanet_23_0.png
new file mode 100644
index 0000000000..7b07840a97
Binary files /dev/null and b/guides/img/object_detection_retinanet/object_detection_retinanet_23_0.png differ
diff --git a/guides/img/object_detection_retinanet/object_detection_retinanet_35_0.png b/guides/img/object_detection_retinanet/object_detection_retinanet_35_0.png
new file mode 100644
index 0000000000..0bd518fc73
Binary files /dev/null and b/guides/img/object_detection_retinanet/object_detection_retinanet_35_0.png differ
diff --git a/guides/img/object_detection_retinanet/object_detection_retinanet_41_27687.png b/guides/img/object_detection_retinanet/object_detection_retinanet_41_27687.png
new file mode 100644
index 0000000000..bbbbba87d4
Binary files /dev/null and b/guides/img/object_detection_retinanet/object_detection_retinanet_41_27687.png differ
diff --git a/guides/img/object_detection_retinanet/object_detection_retinanet_41_27688.png b/guides/img/object_detection_retinanet/object_detection_retinanet_41_27688.png
new file mode 100644
index 0000000000..6e76955f15
Binary files /dev/null and b/guides/img/object_detection_retinanet/object_detection_retinanet_41_27688.png differ
diff --git a/guides/img/object_detection_retinanet/retinanet_architecture.png b/guides/img/object_detection_retinanet/retinanet_architecture.png
new file mode 100644
index 0000000000..95ef949713
Binary files /dev/null and b/guides/img/object_detection_retinanet/retinanet_architecture.png differ
diff --git a/guides/ipynb/keras_hub/object_detection_retinanet.ipynb b/guides/ipynb/keras_hub/object_detection_retinanet.ipynb
new file mode 100644
index 0000000000..0d80dc05d9
--- /dev/null
+++ b/guides/ipynb/keras_hub/object_detection_retinanet.ipynb
@@ -0,0 +1,1149 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# Object Detection with KerasHub\n",
+ "\n",
+ "**Authors:** [Siva Sravana Kumar Neeli](https://github.com/sineeli), [Sachin Prasad](https://github.com/sachinprasadhs)
\n",
+ "**Date created:** 2025/04/28
\n",
+ "**Last modified:** 2025/04/28
\n",
+ "**Description:** RetinaNet Object Detection: Training, Fine-tuning, and Inference."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "## Introduction\n",
+ "\n",
+ "Object detection is a crucial computer vision task that goes beyond simple image\n",
+ "classification. It requires models to not only identify the types of objects\n",
+ "present in an image but also pinpoint their locations using bounding boxes. This\n",
+ "dual requirement of classification and localization makes object detection a\n",
+ "more complex and powerful tool.\n",
+ "Object detection models are broadly classified into two categories: \"two-stage\"\n",
+ "and \"single-stage\" detectors. Two-stage detectors often achieve higher accuracy\n",
+ "by first proposing regions of interest and then classifying them. However, this\n",
+ "approach can be computationally expensive. Single-stage detectors, on the other\n",
+ "hand, aim for speed by directly predicting object classes and bounding boxes in\n",
+ "a single pass.\n",
+ "\n",
+ "In this tutorial, we'll be diving into `RetinaNet`, a powerful object detection\n",
+ "model known for its speed and precision. `RetinaNet` is a single-stage detector,\n",
+ "a design choice that allows it to be remarkably efficient. Its impressive\n",
+ "performance stems from two key architectural innovations:\n",
+ "1. **Feature Pyramid Network (FPN):** FPN equips `RetinaNet` with the ability to\n",
+ "seamlessly detect objects of all scales, from distant, tiny instances to large,\n",
+ "prominent ones.\n",
+ "2. **Focal Loss:** This ingenious loss function tackles the common challenge of\n",
+ "imbalanced data by focusing the model's learning on the most crucial and\n",
+ "challenging object examples, leading to enhanced accuracy without compromising\n",
+ "speed.\n",
+ "\n",
+ "\n",
+ "\n",
+ "### References\n",
+ "\n",
+ "- [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)\n",
+ "- [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Setup and Imports\n",
+ "\n",
+ "Let's install the dependencies and import the necessary modules.\n",
+ "\n",
+ "To run this tutorial, you will need to install the following packages:\n",
+ "\n",
+ "* `keras-hub`\n",
+ "* `keras`\n",
+ "* `opencv-python`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q --upgrade keras-hub\n",
+ "!pip install -q --upgrade keras\n",
+ "!pip install -q opencv-python"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"jax\" # or \"tensorflow\" or \"torch\"\n",
+ "import keras\n",
+ "import keras_hub\n",
+ "import tensorflow as tf"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Helper functions\n",
+ "\n",
+ "We download the Pascal VOC 2012 and 2007 datasets using these helper functions,\n",
+ "prepare them for the object detection task, and split them into training and\n",
+ "validation datasets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "# @title Helper functions\n",
+ "import logging\n",
+ "import multiprocessing\n",
+ "from builtins import open\n",
+ "import os.path\n",
+ "import xml\n",
+ "\n",
+ "import tensorflow_datasets as tfds\n",
+ "\n",
+ "VOC_2007_URL = (\n",
+ " \"http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar\"\n",
+ ")\n",
+ "VOC_2012_URL = (\n",
+ " \"http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar\"\n",
+ ")\n",
+ "VOC_2007_test_URL = (\n",
+ " \"http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar\"\n",
+ ")\n",
+ "\n",
+ "# Note that this list doesn't contain the background class. In the\n",
+ "# classification use case, the label is 0 based (aeroplane -> 0), whereas in\n",
+ "# segmentation use case, the 0 is reserved for background, so aeroplane maps to\n",
+ "# 1.\n",
+ "CLASSES = [\n",
+ " \"aeroplane\",\n",
+ " \"bicycle\",\n",
+ " \"bird\",\n",
+ " \"boat\",\n",
+ " \"bottle\",\n",
+ " \"bus\",\n",
+ " \"car\",\n",
+ " \"cat\",\n",
+ " \"chair\",\n",
+ " \"cow\",\n",
+ " \"diningtable\",\n",
+ " \"dog\",\n",
+ " \"horse\",\n",
+ " \"motorbike\",\n",
+ " \"person\",\n",
+ " \"pottedplant\",\n",
+ " \"sheep\",\n",
+ " \"sofa\",\n",
+ " \"train\",\n",
+ " \"tvmonitor\",\n",
+ "]\n",
+ "COCO_90_CLASS_MAPPING = {\n",
+ " 1: \"person\",\n",
+ " 2: \"bicycle\",\n",
+ " 3: \"car\",\n",
+ " 4: \"motorcycle\",\n",
+ " 5: \"airplane\",\n",
+ " 6: \"bus\",\n",
+ " 7: \"train\",\n",
+ " 8: \"truck\",\n",
+ " 9: \"boat\",\n",
+ " 10: \"traffic light\",\n",
+ " 11: \"fire hydrant\",\n",
+ " 13: \"stop sign\",\n",
+ " 14: \"parking meter\",\n",
+ " 15: \"bench\",\n",
+ " 16: \"bird\",\n",
+ " 17: \"cat\",\n",
+ " 18: \"dog\",\n",
+ " 19: \"horse\",\n",
+ " 20: \"sheep\",\n",
+ " 21: \"cow\",\n",
+ " 22: \"elephant\",\n",
+ " 23: \"bear\",\n",
+ " 24: \"zebra\",\n",
+ " 25: \"giraffe\",\n",
+ " 27: \"backpack\",\n",
+ " 28: \"umbrella\",\n",
+ " 31: \"handbag\",\n",
+ " 32: \"tie\",\n",
+ " 33: \"suitcase\",\n",
+ " 34: \"frisbee\",\n",
+ " 35: \"skis\",\n",
+ " 36: \"snowboard\",\n",
+ " 37: \"sports ball\",\n",
+ " 38: \"kite\",\n",
+ " 39: \"baseball bat\",\n",
+ " 40: \"baseball glove\",\n",
+ " 41: \"skateboard\",\n",
+ " 42: \"surfboard\",\n",
+ " 43: \"tennis racket\",\n",
+ " 44: \"bottle\",\n",
+ " 46: \"wine glass\",\n",
+ " 47: \"cup\",\n",
+ " 48: \"fork\",\n",
+ " 49: \"knife\",\n",
+ " 50: \"spoon\",\n",
+ " 51: \"bowl\",\n",
+ " 52: \"banana\",\n",
+ " 53: \"apple\",\n",
+ " 54: \"sandwich\",\n",
+ " 55: \"orange\",\n",
+ " 56: \"broccoli\",\n",
+ " 57: \"carrot\",\n",
+ " 58: \"hot dog\",\n",
+ " 59: \"pizza\",\n",
+ " 60: \"donut\",\n",
+ " 61: \"cake\",\n",
+ " 62: \"chair\",\n",
+ " 63: \"couch\",\n",
+ " 64: \"potted plant\",\n",
+ " 65: \"bed\",\n",
+ " 67: \"dining table\",\n",
+ " 70: \"toilet\",\n",
+ " 72: \"tv\",\n",
+ " 73: \"laptop\",\n",
+ " 74: \"mouse\",\n",
+ " 75: \"remote\",\n",
+ " 76: \"keyboard\",\n",
+ " 77: \"cell phone\",\n",
+ " 78: \"microwave\",\n",
+ " 79: \"oven\",\n",
+ " 80: \"toaster\",\n",
+ " 81: \"sink\",\n",
+ " 82: \"refrigerator\",\n",
+ " 84: \"book\",\n",
+ " 85: \"clock\",\n",
+ " 86: \"vase\",\n",
+ " 87: \"scissors\",\n",
+ " 88: \"teddy bear\",\n",
+ " 89: \"hair drier\",\n",
+ " 90: \"toothbrush\",\n",
+ "}\n",
+ "# This is used to map between string class to index.\n",
+ "CLASS_TO_INDEX = {name: index for index, name in enumerate(CLASSES)}\n",
+ "INDEX_TO_CLASS = {index: name for index, name in enumerate(CLASSES)}\n",
+ "\n",
+ "\n",
+ "def get_image_ids(data_dir, split):\n",
+ " \"\"\"To get image ids from the \"train\", \"eval\" or \"trainval\" files of VOC data.\"\"\"\n",
+ " data_file_mapping = {\n",
+ " \"train\": \"train.txt\",\n",
+ " \"eval\": \"val.txt\",\n",
+ " \"trainval\": \"trainval.txt\",\n",
+ " \"test\": \"test.txt\",\n",
+ " }\n",
+ " with open(\n",
+ " os.path.join(data_dir, \"ImageSets\", \"Main\", data_file_mapping[split]),\n",
+ " \"r\",\n",
+ " ) as f:\n",
+ " image_ids = f.read().splitlines()\n",
+ " logging.info(f\"Received {len(image_ids)} images for {split} dataset.\")\n",
+ " return image_ids\n",
+ "\n",
+ "\n",
+ "def load_images(example):\n",
+ " \"\"\"Loads VOC images for segmentation task from the provided paths\"\"\"\n",
+ " image_file_path = example.pop(\"image/file_path\")\n",
+ " image = tf.io.read_file(image_file_path)\n",
+ " image = tf.image.decode_jpeg(image)\n",
+ "\n",
+ " example.update(\n",
+ " {\n",
+ " \"image\": image,\n",
+ " }\n",
+ " )\n",
+ " return example\n",
+ "\n",
+ "\n",
+ "def parse_annotation_data(annotation_file_path):\n",
+ " \"\"\"Parse the annotation XML file for the image.\n",
+ "\n",
+ " The annotation contains the metadata, as well as the object bounding box\n",
+ " information.\n",
+ "\n",
+ " \"\"\"\n",
+ " with open(annotation_file_path, \"r\") as f:\n",
+ " root = xml.etree.ElementTree.parse(f).getroot()\n",
+ "\n",
+ " size = root.find(\"size\")\n",
+ " width = int(size.find(\"width\").text)\n",
+ " height = int(size.find(\"height\").text)\n",
+ " filename = root.find(\"filename\").text\n",
+ "\n",
+ " objects = []\n",
+ " for obj in root.findall(\"object\"):\n",
+ " # Get object's label name.\n",
+ " label = CLASS_TO_INDEX[obj.find(\"name\").text.lower()]\n",
+ " bndbox = obj.find(\"bndbox\")\n",
+ " xmax = int(float(bndbox.find(\"xmax\").text))\n",
+ " xmin = int(float(bndbox.find(\"xmin\").text))\n",
+ " ymax = int(float(bndbox.find(\"ymax\").text))\n",
+ " ymin = int(float(bndbox.find(\"ymin\").text))\n",
+ " objects.append(\n",
+ " {\n",
+ " \"label\": label,\n",
+ " \"bbox\": [ymin, xmin, ymax, xmax],\n",
+ " }\n",
+ " )\n",
+ "\n",
+ " return {\n",
+ " \"image/filename\": filename,\n",
+ " \"width\": width,\n",
+ " \"height\": height,\n",
+ " \"objects\": objects,\n",
+ " }\n",
+ "\n",
+ "\n",
+ "def parse_single_image(annotation_file_path):\n",
+ " \"\"\"Creates metadata of VOC images and path.\"\"\"\n",
+ " data_dir, annotation_file_name = os.path.split(annotation_file_path)\n",
+ " data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))\n",
+ " image_annotations = parse_annotation_data(annotation_file_path)\n",
+ "\n",
+ " result = {\n",
+ " \"image/file_path\": os.path.join(\n",
+ " data_dir, \"JPEGImages\", image_annotations[\"image/filename\"]\n",
+ " )\n",
+ " }\n",
+ " result.update(image_annotations)\n",
+ " # Labels field should be same as the 'object.label'\n",
+ " labels = list(set([o[\"label\"] for o in result[\"objects\"]]))\n",
+ " result[\"labels\"] = sorted(labels)\n",
+ " return result\n",
+ "\n",
+ "\n",
+ "def build_metadata(data_dir, image_ids):\n",
+ " \"\"\"Transpose the metadata which convert from list of dict to dict of list.\"\"\"\n",
+ " # Parallel process all the images.\n",
+ " image_file_paths = [\n",
+ " os.path.join(data_dir, \"JPEGImages\", i + \".jpg\") for i in image_ids\n",
+ " ]\n",
+ " annotation_file_paths = tf.io.gfile.glob(\n",
+ " os.path.join(data_dir, \"Annotations\", \"*.xml\")\n",
+ " )\n",
+ " pool_size = 10 if len(image_ids) > 10 else len(annotation_file_paths)\n",
+ " with multiprocessing.Pool(pool_size) as p:\n",
+ " metadata = p.map(parse_single_image, annotation_file_paths)\n",
+ "\n",
+ " keys = [\n",
+ " \"image/filename\",\n",
+ " \"image/file_path\",\n",
+ " \"labels\",\n",
+ " \"width\",\n",
+ " \"height\",\n",
+ " ]\n",
+ " result = {}\n",
+ " for key in keys:\n",
+ " values = [value[key] for value in metadata]\n",
+ " result[key] = values\n",
+ "\n",
+ " # The ragged objects need some special handling\n",
+ " for key in [\"label\", \"bbox\"]:\n",
+ " values = []\n",
+ " objects = [value[\"objects\"] for value in metadata]\n",
+ " for object in objects:\n",
+ " values.append([o[key] for o in object])\n",
+ " result[\"objects/\" + key] = values\n",
+ " return result\n",
+ "\n",
+ "\n",
+ "def build_dataset_from_metadata(metadata):\n",
+ " \"\"\"Builds TensorFlow dataset from the image metadata of VOC dataset.\"\"\"\n",
+ " # The objects need some manual conversion to ragged tensor.\n",
+ " metadata[\"labels\"] = tf.ragged.constant(metadata[\"labels\"])\n",
+ " metadata[\"objects/label\"] = tf.ragged.constant(metadata[\"objects/label\"])\n",
+ " metadata[\"objects/bbox\"] = tf.ragged.constant(\n",
+ " metadata[\"objects/bbox\"], ragged_rank=1\n",
+ " )\n",
+ "\n",
+ " dataset = tf.data.Dataset.from_tensor_slices(metadata)\n",
+ " dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE)\n",
+ " return dataset\n",
+ "\n",
+ "\n",
+ "def load_voc(\n",
+ " year=\"2007\",\n",
+ " split=\"trainval\",\n",
+ " data_dir=\"./\",\n",
+ " voc_url=VOC_2007_URL,\n",
+ "):\n",
+ " extracted_dir = os.path.join(\"VOCdevkit\", f\"VOC{year}\")\n",
+ " get_data = keras.utils.get_file(\n",
+ " fname=os.path.basename(voc_url),\n",
+ " origin=voc_url,\n",
+ " cache_dir=data_dir,\n",
+ " extract=True,\n",
+ " )\n",
+ " data_dir = os.path.join(get_data, extracted_dir)\n",
+ " image_ids = get_image_ids(data_dir, split)\n",
+ " metadata = build_metadata(data_dir, image_ids)\n",
+ " dataset = build_dataset_from_metadata(metadata)\n",
+ "\n",
+ " return dataset\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Load the dataset\n",
+ "\n",
+ "Let's load the training data. Here, we load both the VOC 2007 and 2012 datasets\n",
+ "and split them into training and validation sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "train_ds_2007 = load_voc(\n",
+ " year=\"2007\",\n",
+ " split=\"trainval\",\n",
+ " data_dir=\"./\",\n",
+ " voc_url=VOC_2007_URL,\n",
+ ")\n",
+ "train_ds_2012 = load_voc(\n",
+ " year=\"2012\",\n",
+ " split=\"trainval\",\n",
+ " data_dir=\"./\",\n",
+ " voc_url=VOC_2012_URL,\n",
+ ")\n",
+ "eval_ds = load_voc(\n",
+ " year=\"2007\",\n",
+ " split=\"test\",\n",
+ " data_dir=\"./\",\n",
+ " voc_url=VOC_2007_test_URL,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Inference using a pre-trained object detector\n",
+ "\n",
+ "Let's begin with the simplest `KerasHub` API: a pre-trained object detector. In\n",
+ "this example, we will construct an object detector that was pre-trained on the\n",
+ "`COCO` dataset. We'll use this model to detect objects in a sample image.\n",
+ "\n",
+ "The highest-level module in KerasHub is a `task`. A `task` is a `keras.Model`\n",
+ "consisting of a (generally pre-trained) backbone model and task-specific layers.\n",
+ "Here's an example using `keras_hub.models.ImageObjectDetector` with the\n",
+ "`RetinaNet` model architecture and `ResNet50` as the backbone.\n",
+ "\n",
+ "`ResNet` is a great starting model when constructing an image classification\n",
+ "pipeline. This architecture manages to achieve high accuracy while using a\n",
+ "relatively small number of parameters. If a ResNet isn't powerful enough for the\n",
+ "task you are hoping to solve, be sure to check out KerasHub's other available\n",
+ "backbones here https://keras.io/keras_hub/presets/"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "object_detector = keras_hub.models.ImageObjectDetector.from_preset(\n",
+ " \"retinanet_resnet50_fpn_coco\"\n",
+ ")\n",
+ "object_detector.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Preprocessing Layers\n",
+ "\n",
+ "Let's define the below preprocessing layers:\n",
+ "\n",
+ "- Resizing Layer: Resizes the image and maintains the aspect ratio by applying\n",
+ "padding when `pad_to_aspect_ratio=True`. Also, sets the default bounding box\n",
+ "format for representing the data.\n",
+ "- Max Bounding Box Layer: Limits the maximum number of bounding boxes per image."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "image_size = (800, 800)\n",
+ "batch_size = 4\n",
+ "bbox_format = \"yxyx\"\n",
+ "epochs = 5\n",
+ "\n",
+ "resizing = keras.layers.Resizing(\n",
+ " height=image_size[0],\n",
+ " width=image_size[1],\n",
+ " interpolation=\"bilinear\",\n",
+ " pad_to_aspect_ratio=True,\n",
+ " bounding_box_format=bbox_format,\n",
+ ")\n",
+ "\n",
+ "max_box_layer = keras.layers.MaxNumBoundingBoxes(\n",
+ " max_number=100, bounding_box_format=bbox_format\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Predict and Visualize\n",
+ "\n",
+ "Next, let's obtain predictions from our object detector by loading the image and\n",
+ "visualizing them. We'll apply the preprocessing pipeline defined in the\n",
+ "preprocessing layers step."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "filepath = keras.utils.get_file(\n",
+ " origin=\"http://farm4.staticflickr.com/3755/10245052896_958cbf4766_z.jpg\"\n",
+ ")\n",
+ "image = keras.utils.load_img(filepath)\n",
+ "image = keras.ops.cast(image, \"float32\")\n",
+ "image = keras.ops.expand_dims(image, axis=0)\n",
+ "\n",
+ "predictions = object_detector.predict(image, batch_size=1)\n",
+ "\n",
+ "keras.visualization.plot_bounding_box_gallery(\n",
+ " resizing(image), # resize image as per prediction preprocessing pipeline\n",
+ " bounding_box_format=bbox_format,\n",
+ " y_pred=predictions,\n",
+ " scale=4,\n",
+ " class_mapping=COCO_90_CLASS_MAPPING,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Fine tuning a pretrained object detector\n",
+ "\n",
+ "In this guide, we'll assemble a full training pipeline for a KerasHub `RetinaNet`\n",
+ "object detection model. This includes data loading, augmentation, training, and\n",
+ "inference using Pascal VOC 2007 & 2012 dataset!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## TFDS Preprocessing\n",
+ "\n",
+ "This preprocessing step prepares the TFDS dataset for object detection. It\n",
+ "includes:\n",
+ "- Merging the Pascal VOC 2007 and 2012 datasets.\n",
+ "- Resizing all images to a resolution of 800x800 pixels.\n",
+ "- Limiting the number of bounding boxes per image to a maximum of 100.\n",
+ "- Finally, the resulting dataset is batched into sets of 4 images and bounding\n",
+ "box annotations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def decode_custom_tfds(record):\n",
+ " \"\"\"Decodes a custom TFDS record into a dictionary.\n",
+ "\n",
+ " Args:\n",
+ " record: A dictionary representing a single TFDS record.\n",
+ "\n",
+ " Returns:\n",
+ " A dictionary with \"images\" and \"bounding_boxes\".\n",
+ " \"\"\"\n",
+ " image = record[\"image\"]\n",
+ " boxes = record[\"objects/bbox\"]\n",
+ " labels = record[\"objects/label\"]\n",
+ "\n",
+ " bounding_boxes = {\"boxes\": boxes, \"labels\": labels}\n",
+ "\n",
+ " return {\"images\": image, \"bounding_boxes\": bounding_boxes}\n",
+ "\n",
+ "\n",
+ "def convert_to_tuple(record):\n",
+ " \"\"\"Converts a decoded TFDS record to a tuple for keras-hub.\n",
+ "\n",
+ " Args:\n",
+ " record: A dictionary returned by `decode_custom_tfds` or `decode_tfds`.\n",
+ "\n",
+ " Returns:\n",
+ " A tuple (image, bounding_boxes).\n",
+ " \"\"\"\n",
+ " return record[\"images\"], {\n",
+ " \"boxes\": record[\"bounding_boxes\"][\"boxes\"],\n",
+ " \"labels\": record[\"bounding_boxes\"][\"labels\"],\n",
+ " }\n",
+ "\n",
+ "\n",
+ "def decode_tfds(record):\n",
+ " \"\"\"Decodes a standard TFDS object detection record.\n",
+ "\n",
+ " Args:\n",
+ " record: A dictionary representing a single TFDS record.\n",
+ "\n",
+ " Returns:\n",
+ " A dictionary with \"images\" and \"bounding_boxes\".\n",
+ " \"\"\"\n",
+ " image = record[\"image\"]\n",
+ " image_shape = tf.shape(image)\n",
+ " height, width = image_shape[0], image_shape[1]\n",
+ " boxes = keras.utils.bounding_boxes.convert_format(\n",
+ " record[\"objects\"][\"bbox\"],\n",
+ " source=\"rel_yxyx\",\n",
+ " target=bbox_format,\n",
+ " height=height,\n",
+ " width=width,\n",
+ " )\n",
+ " labels = record[\"objects\"][\"label\"]\n",
+ "\n",
+ " bounding_boxes = {\"boxes\": boxes, \"labels\": labels}\n",
+ "\n",
+ " return {\"images\": image, \"bounding_boxes\": bounding_boxes}\n",
+ "\n",
+ "\n",
+ "def preprocess_tfds(ds):\n",
+ " \"\"\"Preprocesses a TFDS dataset for object detection.\n",
+ "\n",
+ " Args:\n",
+ " ds: The TFDS dataset.\n",
+ " resizing: A resizing function.\n",
+ " max_box_layer: A max box processing function.\n",
+ " batch_size: The batch size.\n",
+ "\n",
+ " Returns:\n",
+ " A preprocessed TFDS dataset.\n",
+ " \"\"\"\n",
+ " ds = ds.map(resizing, num_parallel_calls=tf.data.AUTOTUNE)\n",
+ " ds = ds.map(max_box_layer, num_parallel_calls=tf.data.AUTOTUNE)\n",
+ " ds = ds.batch(batch_size, drop_remainder=True)\n",
+ " return ds\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Now concatenate both 2007 and 2012 VOC data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "train_ds = train_ds_2007.concatenate(train_ds_2012)\n",
+ "train_ds = train_ds.map(decode_custom_tfds, num_parallel_calls=tf.data.AUTOTUNE)\n",
+ "train_ds = preprocess_tfds(train_ds)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Load the eval data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "eval_ds = eval_ds.map(decode_custom_tfds, num_parallel_calls=tf.data.AUTOTUNE)\n",
+ "eval_ds = preprocess_tfds(eval_ds)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Let's visualize batch of training data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "record = next(iter(train_ds.shuffle(100).take(1)))\n",
+ "keras.visualization.plot_bounding_box_gallery(\n",
+ " record[\"images\"],\n",
+ " bounding_box_format=bbox_format,\n",
+ " y_true=record[\"bounding_boxes\"],\n",
+ " scale=3,\n",
+ " rows=2,\n",
+ " cols=2,\n",
+ " class_mapping=INDEX_TO_CLASS,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Decoded TFDS record to a tuple for keras-hub"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "train_ds = train_ds.map(convert_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)\n",
+ "train_ds = train_ds.prefetch(tf.data.AUTOTUNE)\n",
+ "\n",
+ "eval_ds = eval_ds.map(convert_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)\n",
+ "eval_ds = eval_ds.prefetch(tf.data.AUTOTUNE)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Configure RetinaNet Model\n",
+ "\n",
+ "Configure the model with `backbone`, `num_classes` and `preprocessor`.\n",
+ "Use callbacks for recording logs and saving checkpoints."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def get_callbacks(experiment_path):\n",
+ " \"\"\"Creates a list of callbacks for model training.\n",
+ "\n",
+ " Args:\n",
+ " experiment_path (str): Path to the experiment directory.\n",
+ "\n",
+ " Returns:\n",
+ " List of keras callback instances.\n",
+ " \"\"\"\n",
+ " tb_logs_path = os.path.join(experiment_path, \"logs\")\n",
+ " ckpt_path = os.path.join(experiment_path, \"weights\")\n",
+ " return [\n",
+ " keras.callbacks.BackupAndRestore(ckpt_path, delete_checkpoint=False),\n",
+ " keras.callbacks.TensorBoard(\n",
+ " tb_logs_path,\n",
+ " update_freq=1,\n",
+ " ),\n",
+ " keras.callbacks.ModelCheckpoint(\n",
+ " ckpt_path + \"/{epoch:04d}-{val_loss:.2f}.weights.h5\",\n",
+ " save_best_only=True,\n",
+ " save_weights_only=True,\n",
+ " verbose=1,\n",
+ " ),\n",
+ " ]\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Load backbone weights and preprocessor config\n",
+ "\n",
+ "Let's use the \"retinanet_resnet50_fpn_coco\" pretrained weights as the backbone\n",
+ "model, applying its predefined configuration from the preprocessor of the\n",
+ "\"retinanet_resnet50_fpn_coco\" preset.\n",
+ "Define a RetinaNet object detector model with the backbone and preprocessor\n",
+ "specified above, and set `num_classes` to 20 to represent the object categories\n",
+ "from Pascal VOC.\n",
+ "Finally, compile the model using Mean Absolute Error (MAE) as the box loss."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "backbone = keras_hub.models.Backbone.from_preset(\"retinanet_resnet50_fpn_coco\")\n",
+ "\n",
+ "preprocessor = keras_hub.models.RetinaNetObjectDetectorPreprocessor.from_preset(\n",
+ " \"retinanet_resnet50_fpn_coco\"\n",
+ ")\n",
+ "model = keras_hub.models.RetinaNetObjectDetector(\n",
+ " backbone=backbone, num_classes=len(CLASSES), preprocessor=preprocessor\n",
+ ")\n",
+ "model.compile(box_loss=keras.losses.MeanAbsoluteError(reduction=\"sum\"))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Train the model\n",
+ "\n",
+ "Now that the object detector model is compiled, let's train it using the\n",
+ "training and validation data we created earlier.\n",
+ "For demonstration purposes, we have used a small number of epochs. You can\n",
+ "increase the number of epochs to achieve better results.\n",
+ "\n",
+ "**Note:** The model is trained on an L4 GPU. Training for 5 epochs on a T4 GPU\n",
+ "takes approximately 7 hours."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "model.fit(\n",
+ " train_ds,\n",
+ " epochs=epochs,\n",
+ " validation_data=eval_ds,\n",
+ " callbacks=get_callbacks(\"fine_tuning\"),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Prediction on evaluation data\n",
+ "\n",
+ "Let's predict the model using our evaluation dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "images, y_true = next(iter(eval_ds.shuffle(50).take(1)))\n",
+ "y_pred = model.predict(images)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Plot the predictions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "keras.visualization.plot_bounding_box_gallery(\n",
+ " images,\n",
+ " bounding_box_format=bbox_format,\n",
+ " y_true=y_true,\n",
+ " y_pred=y_pred,\n",
+ " scale=3,\n",
+ " rows=2,\n",
+ " cols=2,\n",
+ " class_mapping=INDEX_TO_CLASS,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Custom training object detector\n",
+ "\n",
+ "Additionally, you can customize the object detector by modifying the image\n",
+ "converter, selecting a different image encoder, etc.\n",
+ "\n",
+ "### Image Converter\n",
+ "\n",
+ "The `RetinaNetImageConverter` class prepares images for use with the `RetinaNet`\n",
+ "object detection model. Here's what it does:\n",
+ "\n",
+ "- Scaling and Offsetting\n",
+ "- ImageNet Normalization\n",
+ "- Resizing"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "image_converter = keras_hub.layers.RetinaNetImageConverter(scale=1 / 255)\n",
+ "\n",
+ "preprocessor = keras_hub.models.RetinaNetObjectDetectorPreprocessor(\n",
+ " image_converter=image_converter\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Image Encoder and RetinaNet Backbone\n",
+ "\n",
+ "The image encoder, while typically initialized with pre-trained weights\n",
+ "(e.g., from ImageNet), can also be instantiated without them. This results in\n",
+ "the image encoder (and, consequently, the entire object detection network built\n",
+ "upon it) having randomly initialized weights.\n",
+ "\n",
+ "Here we load pre-trained ResNet50 model.\n",
+ "This will serve as the base for extracting image features.\n",
+ "\n",
+ "And then Build the RetinaNet Feature Pyramid Network (FPN) on top of the ResNet50\n",
+ "backbone. The FPN creates multi-scale feature maps for better object detection\n",
+ "at different sizes.\n",
+ "\n",
+ "**Note:**\n",
+ "`use_p5`: If True, the output of the last backbone layer (typically `P5` in an\n",
+ "`FPN`) is used as input to create higher-level feature maps (e.g., `P6`, `P7`)\n",
+ "through additional convolutional layers. If `False`, the original `P5` feature\n",
+ "map from the backbone is directly used as input for creating the coarser levels,\n",
+ "bypassing any further processing of `P5` within the feature pyramid. Defaults to\n",
+ "`False`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "image_encoder = keras_hub.models.Backbone.from_preset(\"resnet_50_imagenet\")\n",
+ "\n",
+ "backbone = keras_hub.models.RetinaNetBackbone(\n",
+ " image_encoder=image_encoder, min_level=3, max_level=5, use_p5=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Train and visualize RetinaNet model\n",
+ "\n",
+ "**Note:** Training the model (for demonstration purposes only 5 epochs). In a\n",
+ "real scenario, you would train for many more epochs (often hundreds) to achieve\n",
+ "good results."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "model = keras_hub.models.RetinaNetObjectDetector(\n",
+ " backbone=backbone,\n",
+ " num_classes=len(CLASSES),\n",
+ " preprocessor=preprocessor,\n",
+ " use_prediction_head_norm=True,\n",
+ ")\n",
+ "model.compile(\n",
+ " optimizer=keras.optimizers.Adam(learning_rate=0.001),\n",
+ " box_loss=keras.losses.MeanAbsoluteError(reduction=\"sum\"),\n",
+ ")\n",
+ "\n",
+ "model.fit(\n",
+ " train_ds,\n",
+ " epochs=epochs,\n",
+ " validation_data=eval_ds,\n",
+ " callbacks=get_callbacks(\"custom_training\"),\n",
+ ")\n",
+ "\n",
+ "images, y_true = next(iter(eval_ds.shuffle(50).take(1)))\n",
+ "y_pred = model.predict(images)\n",
+ "\n",
+ "keras.visualization.plot_bounding_box_gallery(\n",
+ " images,\n",
+ " bounding_box_format=bbox_format,\n",
+ " y_true=y_true,\n",
+ " y_pred=y_pred,\n",
+ " scale=3,\n",
+ " rows=2,\n",
+ " cols=2,\n",
+ " class_mapping=INDEX_TO_CLASS,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Conclusion\n",
+ "\n",
+ "In this tutorial, you learned how to custom train and fine-tune the RetinaNet\n",
+ "object detector.\n",
+ "\n",
+ "You can experiment with different existing backbones trained on ImageNet as the\n",
+ "image encoder, or you can fine-tune your own backbone.\n",
+ "\n",
+ "This configuration is equivalent to training the model from scratch, as opposed\n",
+ "to fine-tuning a pre-trained model.\n",
+ "\n",
+ "Training from scratch generally requires significantly more data and\n",
+ "computational resources to achieve performance comparable to fine-tuning.\n",
+ "\n",
+ "To achieve better results when fine-tuning the model, you can increase the\n",
+ "number of epochs and experiment with different hyperparameter values.\n",
+ "In addition to the training data used here, you can also use other object\n",
+ "detection datasets, but keep in mind that custom training these requires\n",
+ "high GPU memory."
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "object_detection_retinanet",
+ "private_outputs": false,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/guides/keras_hub/object_detection_retinanet.py b/guides/keras_hub/object_detection_retinanet.py
new file mode 100644
index 0000000000..3c659f34b2
--- /dev/null
+++ b/guides/keras_hub/object_detection_retinanet.py
@@ -0,0 +1,812 @@
+"""
+Title: Object Detection with KerasHub
+Authors: [Siva Sravana Kumar Neeli](https://github.com/sineeli), [Sachin Prasad](https://github.com/sachinprasadhs)
+Date created: 2025/04/28
+Last modified: 2025/04/28
+Description: RetinaNet Object Detection: Training, Fine-tuning, and Inference.
+Accelerator: GPU
+"""
+
+"""
+
+
+## Introduction
+
+Object detection is a crucial computer vision task that goes beyond simple image
+classification. It requires models to not only identify the types of objects
+present in an image but also pinpoint their locations using bounding boxes. This
+dual requirement of classification and localization makes object detection a
+more complex and powerful tool.
+Object detection models are broadly classified into two categories: "two-stage"
+and "single-stage" detectors. Two-stage detectors often achieve higher accuracy
+by first proposing regions of interest and then classifying them. However, this
+approach can be computationally expensive. Single-stage detectors, on the other
+hand, aim for speed by directly predicting object classes and bounding boxes in
+a single pass.
+
+In this tutorial, we'll be diving into `RetinaNet`, a powerful object detection
+model known for its speed and precision. `RetinaNet` is a single-stage detector,
+a design choice that allows it to be remarkably efficient. Its impressive
+performance stems from two key architectural innovations:
+1. **Feature Pyramid Network (FPN):** FPN equips `RetinaNet` with the ability to
+seamlessly detect objects of all scales, from distant, tiny instances to large,
+prominent ones.
+2. **Focal Loss:** This ingenious loss function tackles the common challenge of
+imbalanced data by focusing the model's learning on the most crucial and
+challenging object examples, leading to enhanced accuracy without compromising
+speed.
+
+
+
+### References
+
+- [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)
+- [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144)
+"""
+
+"""
+## Setup and Imports
+
+Let's install the dependencies and import the necessary modules.
+
+To run this tutorial, you will need to install the following packages:
+
+* `keras-hub`
+* `keras`
+* `opencv-python`
+"""
+
+"""shell
+pip install -q --upgrade keras-hub
+pip install -q --upgrade keras
+pip install -q opencv-python
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"
+import keras
+import keras_hub
+import tensorflow as tf
+
+"""
+### Helper functions
+
+We download the Pascal VOC 2012 and 2007 datasets using these helper functions,
+prepare them for the object detection task, and split them into training and
+validation datasets.
+"""
+# @title Helper functions
+import logging
+import multiprocessing
+from builtins import open
+import os.path
+import xml
+
+import tensorflow_datasets as tfds
+
+VOC_2007_URL = (
+ "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar"
+)
+VOC_2012_URL = (
+ "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
+)
+VOC_2007_test_URL = (
+ "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar"
+)
+
+# Note that this list doesn't contain the background class. In the
+# classification use case, the label is 0 based (aeroplane -> 0), whereas in
+# segmentation use case, the 0 is reserved for background, so aeroplane maps to
+# 1.
+CLASSES = [
+ "aeroplane",
+ "bicycle",
+ "bird",
+ "boat",
+ "bottle",
+ "bus",
+ "car",
+ "cat",
+ "chair",
+ "cow",
+ "diningtable",
+ "dog",
+ "horse",
+ "motorbike",
+ "person",
+ "pottedplant",
+ "sheep",
+ "sofa",
+ "train",
+ "tvmonitor",
+]
+COCO_90_CLASS_MAPPING = {
+ 1: "person",
+ 2: "bicycle",
+ 3: "car",
+ 4: "motorcycle",
+ 5: "airplane",
+ 6: "bus",
+ 7: "train",
+ 8: "truck",
+ 9: "boat",
+ 10: "traffic light",
+ 11: "fire hydrant",
+ 13: "stop sign",
+ 14: "parking meter",
+ 15: "bench",
+ 16: "bird",
+ 17: "cat",
+ 18: "dog",
+ 19: "horse",
+ 20: "sheep",
+ 21: "cow",
+ 22: "elephant",
+ 23: "bear",
+ 24: "zebra",
+ 25: "giraffe",
+ 27: "backpack",
+ 28: "umbrella",
+ 31: "handbag",
+ 32: "tie",
+ 33: "suitcase",
+ 34: "frisbee",
+ 35: "skis",
+ 36: "snowboard",
+ 37: "sports ball",
+ 38: "kite",
+ 39: "baseball bat",
+ 40: "baseball glove",
+ 41: "skateboard",
+ 42: "surfboard",
+ 43: "tennis racket",
+ 44: "bottle",
+ 46: "wine glass",
+ 47: "cup",
+ 48: "fork",
+ 49: "knife",
+ 50: "spoon",
+ 51: "bowl",
+ 52: "banana",
+ 53: "apple",
+ 54: "sandwich",
+ 55: "orange",
+ 56: "broccoli",
+ 57: "carrot",
+ 58: "hot dog",
+ 59: "pizza",
+ 60: "donut",
+ 61: "cake",
+ 62: "chair",
+ 63: "couch",
+ 64: "potted plant",
+ 65: "bed",
+ 67: "dining table",
+ 70: "toilet",
+ 72: "tv",
+ 73: "laptop",
+ 74: "mouse",
+ 75: "remote",
+ 76: "keyboard",
+ 77: "cell phone",
+ 78: "microwave",
+ 79: "oven",
+ 80: "toaster",
+ 81: "sink",
+ 82: "refrigerator",
+ 84: "book",
+ 85: "clock",
+ 86: "vase",
+ 87: "scissors",
+ 88: "teddy bear",
+ 89: "hair drier",
+ 90: "toothbrush",
+}
+# This is used to map between string class to index.
+CLASS_TO_INDEX = {name: index for index, name in enumerate(CLASSES)}
+INDEX_TO_CLASS = {index: name for index, name in enumerate(CLASSES)}
+
+
+def get_image_ids(data_dir, split):
+ """To get image ids from the "train", "eval" or "trainval" files of VOC data."""
+ data_file_mapping = {
+ "train": "train.txt",
+ "eval": "val.txt",
+ "trainval": "trainval.txt",
+ "test": "test.txt",
+ }
+ with open(
+ os.path.join(data_dir, "ImageSets", "Main", data_file_mapping[split]),
+ "r",
+ ) as f:
+ image_ids = f.read().splitlines()
+ logging.info(f"Received {len(image_ids)} images for {split} dataset.")
+ return image_ids
+
+
+def load_images(example):
+ """Loads VOC images for segmentation task from the provided paths"""
+ image_file_path = example.pop("image/file_path")
+ image = tf.io.read_file(image_file_path)
+ image = tf.image.decode_jpeg(image)
+
+ example.update(
+ {
+ "image": image,
+ }
+ )
+ return example
+
+
+def parse_annotation_data(annotation_file_path):
+ """Parse the annotation XML file for the image.
+
+ The annotation contains the metadata, as well as the object bounding box
+ information.
+
+ """
+ with open(annotation_file_path, "r") as f:
+ root = xml.etree.ElementTree.parse(f).getroot()
+
+ size = root.find("size")
+ width = int(size.find("width").text)
+ height = int(size.find("height").text)
+ filename = root.find("filename").text
+
+ objects = []
+ for obj in root.findall("object"):
+ # Get object's label name.
+ label = CLASS_TO_INDEX[obj.find("name").text.lower()]
+ bndbox = obj.find("bndbox")
+ xmax = int(float(bndbox.find("xmax").text))
+ xmin = int(float(bndbox.find("xmin").text))
+ ymax = int(float(bndbox.find("ymax").text))
+ ymin = int(float(bndbox.find("ymin").text))
+ objects.append(
+ {
+ "label": label,
+ "bbox": [ymin, xmin, ymax, xmax],
+ }
+ )
+
+ return {
+ "image/filename": filename,
+ "width": width,
+ "height": height,
+ "objects": objects,
+ }
+
+
+def parse_single_image(annotation_file_path):
+ """Creates metadata of VOC images and path."""
+ data_dir, annotation_file_name = os.path.split(annotation_file_path)
+ data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))
+ image_annotations = parse_annotation_data(annotation_file_path)
+
+ result = {
+ "image/file_path": os.path.join(
+ data_dir, "JPEGImages", image_annotations["image/filename"]
+ )
+ }
+ result.update(image_annotations)
+ # Labels field should be same as the 'object.label'
+ labels = list(set([o["label"] for o in result["objects"]]))
+ result["labels"] = sorted(labels)
+ return result
+
+
+def build_metadata(data_dir, image_ids):
+ """Transpose the metadata which convert from list of dict to dict of list."""
+ # Parallel process all the images.
+ image_file_paths = [
+ os.path.join(data_dir, "JPEGImages", i + ".jpg") for i in image_ids
+ ]
+ annotation_file_paths = tf.io.gfile.glob(
+ os.path.join(data_dir, "Annotations", "*.xml")
+ )
+ pool_size = 10 if len(image_ids) > 10 else len(annotation_file_paths)
+ with multiprocessing.Pool(pool_size) as p:
+ metadata = p.map(parse_single_image, annotation_file_paths)
+
+ keys = [
+ "image/filename",
+ "image/file_path",
+ "labels",
+ "width",
+ "height",
+ ]
+ result = {}
+ for key in keys:
+ values = [value[key] for value in metadata]
+ result[key] = values
+
+ # The ragged objects need some special handling
+ for key in ["label", "bbox"]:
+ values = []
+ objects = [value["objects"] for value in metadata]
+ for object in objects:
+ values.append([o[key] for o in object])
+ result["objects/" + key] = values
+ return result
+
+
+def build_dataset_from_metadata(metadata):
+ """Builds TensorFlow dataset from the image metadata of VOC dataset."""
+ # The objects need some manual conversion to ragged tensor.
+ metadata["labels"] = tf.ragged.constant(metadata["labels"])
+ metadata["objects/label"] = tf.ragged.constant(metadata["objects/label"])
+ metadata["objects/bbox"] = tf.ragged.constant(
+ metadata["objects/bbox"], ragged_rank=1
+ )
+
+ dataset = tf.data.Dataset.from_tensor_slices(metadata)
+ dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE)
+ return dataset
+
+
+def load_voc(
+ year="2007",
+ split="trainval",
+ data_dir="./",
+ voc_url=VOC_2007_URL,
+):
+ extracted_dir = os.path.join("VOCdevkit", f"VOC{year}")
+ get_data = keras.utils.get_file(
+ fname=os.path.basename(voc_url),
+ origin=voc_url,
+ cache_dir=data_dir,
+ extract=True,
+ )
+ data_dir = os.path.join(get_data, extracted_dir)
+ image_ids = get_image_ids(data_dir, split)
+ metadata = build_metadata(data_dir, image_ids)
+ dataset = build_dataset_from_metadata(metadata)
+
+ return dataset
+
+
+"""
+## Load the dataset
+
+Let's load the training data. Here, we load both the VOC 2007 and 2012 datasets
+and split them into training and validation sets.
+"""
+train_ds_2007 = load_voc(
+ year="2007",
+ split="trainval",
+ data_dir="./",
+ voc_url=VOC_2007_URL,
+)
+train_ds_2012 = load_voc(
+ year="2012",
+ split="trainval",
+ data_dir="./",
+ voc_url=VOC_2012_URL,
+)
+eval_ds = load_voc(
+ year="2007",
+ split="test",
+ data_dir="./",
+ voc_url=VOC_2007_test_URL,
+)
+
+"""
+## Inference using a pre-trained object detector
+
+Let's begin with the simplest `KerasHub` API: a pre-trained object detector. In
+this example, we will construct an object detector that was pre-trained on the
+`COCO` dataset. We'll use this model to detect objects in a sample image.
+
+The highest-level module in KerasHub is a `task`. A `task` is a `keras.Model`
+consisting of a (generally pre-trained) backbone model and task-specific layers.
+Here's an example using `keras_hub.models.ImageObjectDetector` with the
+`RetinaNet` model architecture and `ResNet50` as the backbone.
+
+`ResNet` is a great starting model when constructing an image classification
+pipeline. This architecture manages to achieve high accuracy while using a
+relatively small number of parameters. If a ResNet isn't powerful enough for the
+task you are hoping to solve, be sure to check out KerasHub's other available
+backbones here https://keras.io/keras_hub/presets/
+"""
+
+object_detector = keras_hub.models.ImageObjectDetector.from_preset(
+ "retinanet_resnet50_fpn_coco"
+)
+object_detector.summary()
+
+"""
+## Preprocessing Layers
+
+Let's define the below preprocessing layers:
+
+- Resizing Layer: Resizes the image and maintains the aspect ratio by applying
+padding when `pad_to_aspect_ratio=True`. Also, sets the default bounding box
+format for representing the data.
+- Max Bounding Box Layer: Limits the maximum number of bounding boxes per image.
+"""
+image_size = (800, 800)
+batch_size = 4
+bbox_format = "yxyx"
+epochs = 5
+
+resizing = keras.layers.Resizing(
+ height=image_size[0],
+ width=image_size[1],
+ interpolation="bilinear",
+ pad_to_aspect_ratio=True,
+ bounding_box_format=bbox_format,
+)
+
+max_box_layer = keras.layers.MaxNumBoundingBoxes(
+ max_number=100, bounding_box_format=bbox_format
+)
+
+"""
+### Predict and Visualize
+
+Next, let's obtain predictions from our object detector by loading the image and
+visualizing them. We'll apply the preprocessing pipeline defined in the
+preprocessing layers step.
+"""
+
+filepath = keras.utils.get_file(
+ origin="http://farm4.staticflickr.com/3755/10245052896_958cbf4766_z.jpg"
+)
+image = keras.utils.load_img(filepath)
+image = keras.ops.cast(image, "float32")
+image = keras.ops.expand_dims(image, axis=0)
+
+predictions = object_detector.predict(image, batch_size=1)
+
+keras.visualization.plot_bounding_box_gallery(
+ resizing(image), # resize image as per prediction preprocessing pipeline
+ bounding_box_format=bbox_format,
+ y_pred=predictions,
+ scale=4,
+ class_mapping=COCO_90_CLASS_MAPPING,
+)
+
+"""
+## Fine tuning a pretrained object detector
+
+In this guide, we'll assemble a full training pipeline for a KerasHub `RetinaNet`
+object detection model. This includes data loading, augmentation, training, and
+inference using Pascal VOC 2007 & 2012 dataset!
+"""
+
+"""
+## TFDS Preprocessing
+
+This preprocessing step prepares the TFDS dataset for object detection. It
+includes:
+- Merging the Pascal VOC 2007 and 2012 datasets.
+- Resizing all images to a resolution of 800x800 pixels.
+- Limiting the number of bounding boxes per image to a maximum of 100.
+- Finally, the resulting dataset is batched into sets of 4 images and bounding
+box annotations.
+"""
+
+
+def decode_custom_tfds(record):
+ """Decodes a custom TFDS record into a dictionary.
+
+ Args:
+ record: A dictionary representing a single TFDS record.
+
+ Returns:
+ A dictionary with "images" and "bounding_boxes".
+ """
+ image = record["image"]
+ boxes = record["objects/bbox"]
+ labels = record["objects/label"]
+
+ bounding_boxes = {"boxes": boxes, "labels": labels}
+
+ return {"images": image, "bounding_boxes": bounding_boxes}
+
+
+def convert_to_tuple(record):
+ """Converts a decoded TFDS record to a tuple for keras-hub.
+
+ Args:
+ record: A dictionary returned by `decode_custom_tfds` or `decode_tfds`.
+
+ Returns:
+ A tuple (image, bounding_boxes).
+ """
+ return record["images"], {
+ "boxes": record["bounding_boxes"]["boxes"],
+ "labels": record["bounding_boxes"]["labels"],
+ }
+
+
+def decode_tfds(record):
+ """Decodes a standard TFDS object detection record.
+
+ Args:
+ record: A dictionary representing a single TFDS record.
+
+ Returns:
+ A dictionary with "images" and "bounding_boxes".
+ """
+ image = record["image"]
+ image_shape = tf.shape(image)
+ height, width = image_shape[0], image_shape[1]
+ boxes = keras.utils.bounding_boxes.convert_format(
+ record["objects"]["bbox"],
+ source="rel_yxyx",
+ target=bbox_format,
+ height=height,
+ width=width,
+ )
+ labels = record["objects"]["label"]
+
+ bounding_boxes = {"boxes": boxes, "labels": labels}
+
+ return {"images": image, "bounding_boxes": bounding_boxes}
+
+
+def preprocess_tfds(ds):
+ """Preprocesses a TFDS dataset for object detection.
+
+ Args:
+ ds: The TFDS dataset.
+ resizing: A resizing function.
+ max_box_layer: A max box processing function.
+ batch_size: The batch size.
+
+ Returns:
+ A preprocessed TFDS dataset.
+ """
+ ds = ds.map(resizing, num_parallel_calls=tf.data.AUTOTUNE)
+ ds = ds.map(max_box_layer, num_parallel_calls=tf.data.AUTOTUNE)
+ ds = ds.batch(batch_size, drop_remainder=True)
+ return ds
+
+
+"""
+Now concatenate both 2007 and 2012 VOC data
+"""
+train_ds = train_ds_2007.concatenate(train_ds_2012)
+train_ds = train_ds.map(decode_custom_tfds, num_parallel_calls=tf.data.AUTOTUNE)
+train_ds = preprocess_tfds(train_ds)
+
+"""
+Load the eval data
+"""
+eval_ds = eval_ds.map(decode_custom_tfds, num_parallel_calls=tf.data.AUTOTUNE)
+eval_ds = preprocess_tfds(eval_ds)
+
+"""
+### Let's visualize batch of training data
+"""
+record = next(iter(train_ds.shuffle(100).take(1)))
+keras.visualization.plot_bounding_box_gallery(
+ record["images"],
+ bounding_box_format=bbox_format,
+ y_true=record["bounding_boxes"],
+ scale=3,
+ rows=2,
+ cols=2,
+ class_mapping=INDEX_TO_CLASS,
+)
+
+"""
+### Decoded TFDS record to a tuple for keras-hub
+"""
+train_ds = train_ds.map(convert_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
+train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
+
+eval_ds = eval_ds.map(convert_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
+eval_ds = eval_ds.prefetch(tf.data.AUTOTUNE)
+
+"""
+## Configure RetinaNet Model
+
+Configure the model with `backbone`, `num_classes` and `preprocessor`.
+Use callbacks for recording logs and saving checkpoints.
+"""
+
+
+def get_callbacks(experiment_path):
+ """Creates a list of callbacks for model training.
+
+ Args:
+ experiment_path (str): Path to the experiment directory.
+
+ Returns:
+ List of keras callback instances.
+ """
+ tb_logs_path = os.path.join(experiment_path, "logs")
+ ckpt_path = os.path.join(experiment_path, "weights")
+ return [
+ keras.callbacks.BackupAndRestore(ckpt_path, delete_checkpoint=False),
+ keras.callbacks.TensorBoard(
+ tb_logs_path,
+ update_freq=1,
+ ),
+ keras.callbacks.ModelCheckpoint(
+ ckpt_path + "/{epoch:04d}-{val_loss:.2f}.weights.h5",
+ save_best_only=True,
+ save_weights_only=True,
+ verbose=1,
+ ),
+ ]
+
+
+"""
+## Load backbone weights and preprocessor config
+
+Let's use the "retinanet_resnet50_fpn_coco" pretrained weights as the backbone
+model, applying its predefined configuration from the preprocessor of the
+"retinanet_resnet50_fpn_coco" preset.
+Define a RetinaNet object detector model with the backbone and preprocessor
+specified above, and set `num_classes` to 20 to represent the object categories
+from Pascal VOC.
+Finally, compile the model using Mean Absolute Error (MAE) as the box loss.
+"""
+
+backbone = keras_hub.models.Backbone.from_preset("retinanet_resnet50_fpn_coco")
+
+preprocessor = keras_hub.models.RetinaNetObjectDetectorPreprocessor.from_preset(
+ "retinanet_resnet50_fpn_coco"
+)
+model = keras_hub.models.RetinaNetObjectDetector(
+ backbone=backbone, num_classes=len(CLASSES), preprocessor=preprocessor
+)
+model.compile(box_loss=keras.losses.MeanAbsoluteError(reduction="sum"))
+
+"""
+## Train the model
+
+Now that the object detector model is compiled, let's train it using the
+training and validation data we created earlier.
+For demonstration purposes, we have used a small number of epochs. You can
+increase the number of epochs to achieve better results.
+
+**Note:** The model is trained on an L4 GPU. Training for 5 epochs on a T4 GPU
+takes approximately 7 hours.
+"""
+
+model.fit(
+ train_ds,
+ epochs=epochs,
+ validation_data=eval_ds,
+ callbacks=get_callbacks("fine_tuning"),
+)
+
+"""
+### Prediction on evaluation data
+
+Let's predict the model using our evaluation dataset.
+"""
+images, y_true = next(iter(eval_ds.shuffle(50).take(1)))
+y_pred = model.predict(images)
+
+"""
+### Plot the predictions
+"""
+keras.visualization.plot_bounding_box_gallery(
+ images,
+ bounding_box_format=bbox_format,
+ y_true=y_true,
+ y_pred=y_pred,
+ scale=3,
+ rows=2,
+ cols=2,
+ class_mapping=INDEX_TO_CLASS,
+)
+
+"""
+## Custom training object detector
+
+Additionally, you can customize the object detector by modifying the image
+converter, selecting a different image encoder, etc.
+
+### Image Converter
+
+The `RetinaNetImageConverter` class prepares images for use with the `RetinaNet`
+object detection model. Here's what it does:
+
+- Scaling and Offsetting
+- ImageNet Normalization
+- Resizing
+"""
+
+image_converter = keras_hub.layers.RetinaNetImageConverter(scale=1 / 255)
+
+preprocessor = keras_hub.models.RetinaNetObjectDetectorPreprocessor(
+ image_converter=image_converter
+)
+
+"""
+### Image Encoder and RetinaNet Backbone
+
+The image encoder, while typically initialized with pre-trained weights
+(e.g., from ImageNet), can also be instantiated without them. This results in
+the image encoder (and, consequently, the entire object detection network built
+upon it) having randomly initialized weights.
+
+Here we load pre-trained ResNet50 model.
+This will serve as the base for extracting image features.
+
+And then Build the RetinaNet Feature Pyramid Network (FPN) on top of the ResNet50
+backbone. The FPN creates multi-scale feature maps for better object detection
+at different sizes.
+
+**Note:**
+`use_p5`: If True, the output of the last backbone layer (typically `P5` in an
+`FPN`) is used as input to create higher-level feature maps (e.g., `P6`, `P7`)
+through additional convolutional layers. If `False`, the original `P5` feature
+map from the backbone is directly used as input for creating the coarser levels,
+bypassing any further processing of `P5` within the feature pyramid. Defaults to
+`False`.
+"""
+
+image_encoder = keras_hub.models.Backbone.from_preset("resnet_50_imagenet")
+
+backbone = keras_hub.models.RetinaNetBackbone(
+ image_encoder=image_encoder, min_level=3, max_level=5, use_p5=True
+)
+
+"""
+### Train and visualize RetinaNet model
+
+**Note:** Training the model (for demonstration purposes only 5 epochs). In a
+real scenario, you would train for many more epochs (often hundreds) to achieve
+good results.
+"""
+model = keras_hub.models.RetinaNetObjectDetector(
+ backbone=backbone,
+ num_classes=len(CLASSES),
+ preprocessor=preprocessor,
+ use_prediction_head_norm=True,
+)
+model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate=0.001),
+ box_loss=keras.losses.MeanAbsoluteError(reduction="sum"),
+)
+
+model.fit(
+ train_ds,
+ epochs=epochs,
+ validation_data=eval_ds,
+ callbacks=get_callbacks("custom_training"),
+)
+
+images, y_true = next(iter(eval_ds.shuffle(50).take(1)))
+y_pred = model.predict(images)
+
+keras.visualization.plot_bounding_box_gallery(
+ images,
+ bounding_box_format=bbox_format,
+ y_true=y_true,
+ y_pred=y_pred,
+ scale=3,
+ rows=2,
+ cols=2,
+ class_mapping=INDEX_TO_CLASS,
+)
+
+"""
+## Conclusion
+
+In this tutorial, you learned how to custom train and fine-tune the RetinaNet
+object detector.
+
+You can experiment with different existing backbones trained on ImageNet as the
+image encoder, or you can fine-tune your own backbone.
+
+This configuration is equivalent to training the model from scratch, as opposed
+to fine-tuning a pre-trained model.
+
+Training from scratch generally requires significantly more data and
+computational resources to achieve performance comparable to fine-tuning.
+
+To achieve better results when fine-tuning the model, you can increase the
+number of epochs and experiment with different hyperparameter values.
+In addition to the training data used here, you can also use other object
+detection datasets, but keep in mind that custom training these requires
+high GPU memory.
+"""
diff --git a/guides/md/keras_hub/object_detection_retinanet.md b/guides/md/keras_hub/object_detection_retinanet.md
new file mode 100644
index 0000000000..40b392fbf7
--- /dev/null
+++ b/guides/md/keras_hub/object_detection_retinanet.md
@@ -0,0 +1,1118 @@
+# Object Detection with KerasHub
+
+**Authors:** [Siva Sravana Kumar Neeli](https://github.com/sineeli), [Sachin Prasad](https://github.com/sachinprasadhs)
+**Date created:** 2025/04/28
+**Last modified:** 2025/04/28
+**Description:** RetinaNet Object Detection: Training, Fine-tuning, and Inference.
+
+
+ [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/keras_hub/object_detection_retinanet.ipynb) •
[**GitHub source**](https://github.com/keras-team/keras-io/blob/master/guides/keras_hub/object_detection_retinanet.py)
+
+
+
+
+
+---
+## Introduction
+
+Object detection is a crucial computer vision task that goes beyond simple image
+classification. It requires models to not only identify the types of objects
+present in an image but also pinpoint their locations using bounding boxes. This
+dual requirement of classification and localization makes object detection a
+more complex and powerful tool.
+Object detection models are broadly classified into two categories: "two-stage"
+and "single-stage" detectors. Two-stage detectors often achieve higher accuracy
+by first proposing regions of interest and then classifying them. However, this
+approach can be computationally expensive. Single-stage detectors, on the other
+hand, aim for speed by directly predicting object classes and bounding boxes in
+a single pass.
+
+In this tutorial, we'll be diving into `RetinaNet`, a powerful object detection
+model known for its speed and precision. `RetinaNet` is a single-stage detector,
+a design choice that allows it to be remarkably efficient. Its impressive
+performance stems from two key architectural innovations:
+1. **Feature Pyramid Network (FPN):** FPN equips `RetinaNet` with the ability to
+seamlessly detect objects of all scales, from distant, tiny instances to large,
+prominent ones.
+2. **Focal Loss:** This ingenious loss function tackles the common challenge of
+imbalanced data by focusing the model's learning on the most crucial and
+challenging object examples, leading to enhanced accuracy without compromising
+speed.
+
+
+
+### References
+
+- [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)
+- [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144)
+
+---
+## Setup and Imports
+
+Let's install the dependencies and import the necessary modules.
+
+To run this tutorial, you will need to install the following packages:
+
+* `keras-hub`
+* `keras`
+* `opencv-python`
+
+
+```python
+!pip install -q --upgrade keras-hub
+!pip install -q --upgrade keras
+!pip install -q opencv-python
+```
+
+```python
+import os
+
+os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"
+import keras
+import keras_hub
+import tensorflow as tf
+```
+
Preprocessor: "retina_net_object_detector_preprocessor"
+
+
+
+
+
+┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Config ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ retina_net_image_converter (RetinaNetImageConverter) │ Image size: (800, 800) │ +└───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘ ++ + + + +
Model: "retina_net_object_detector"
+
+
+
+
+
+┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ images (InputLayer) │ (None, None, None, 3) │ 0 │ - │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ retina_net_backbone │ [(None, None, None, 256), │ 27,429,824 │ images[0][0] │ +│ (RetinaNetBackbone) │ (None, None, None, 256), │ │ │ +│ │ (None, None, None, 256), │ │ │ +│ │ (None, None, None, 256), │ │ │ +│ │ (None, None, None, 256)] │ │ │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ box_head (PredictionHead) │ (None, None, None, 36) │ 2,443,300 │ retina_net_backbone[0][0], │ +│ │ │ │ retina_net_backbone[0][1], │ +│ │ │ │ retina_net_backbone[0][2], │ +│ │ │ │ retina_net_backbone[0][3], │ +│ │ │ │ retina_net_backbone[0][4] │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ classification_head │ (None, None, None, 819) │ 4,248,115 │ retina_net_backbone[0][0], │ +│ (PredictionHead) │ │ │ retina_net_backbone[0][1], │ +│ │ │ │ retina_net_backbone[0][2], │ +│ │ │ │ retina_net_backbone[0][3], │ +│ │ │ │ retina_net_backbone[0][4] │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ box_pred_P3 (Reshape) │ (None, None, 4) │ 0 │ box_head[0][0] │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ box_pred_P4 (Reshape) │ (None, None, 4) │ 0 │ box_head[1][0] │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ box_pred_P5 (Reshape) │ (None, None, 4) │ 0 │ box_head[2][0] │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ box_pred_P6 (Reshape) │ (None, None, 4) │ 0 │ box_head[3][0] │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ box_pred_P7 (Reshape) │ (None, None, 4) │ 0 │ box_head[4][0] │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ cls_pred_P3 (Reshape) │ (None, None, 91) │ 0 │ classification_head[0][0] │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ cls_pred_P4 (Reshape) │ (None, None, 91) │ 0 │ classification_head[1][0] │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ cls_pred_P5 (Reshape) │ (None, None, 91) │ 0 │ classification_head[2][0] │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ cls_pred_P6 (Reshape) │ (None, None, 91) │ 0 │ classification_head[3][0] │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ cls_pred_P7 (Reshape) │ (None, None, 91) │ 0 │ classification_head[4][0] │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ bbox_regression (Concatenate) │ (None, None, 4) │ 0 │ box_pred_P3[0][0], │ +│ │ │ │ box_pred_P4[0][0], │ +│ │ │ │ box_pred_P5[0][0], │ +│ │ │ │ box_pred_P6[0][0], │ +│ │ │ │ box_pred_P7[0][0] │ +├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ +│ cls_logits (Concatenate) │ (None, None, 91) │ 0 │ cls_pred_P3[0][0], │ +│ │ │ │ cls_pred_P4[0][0], │ +│ │ │ │ cls_pred_P5[0][0], │ +│ │ │ │ cls_pred_P6[0][0], │ +│ │ │ │ cls_pred_P7[0][0] │ +└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘ ++ + + + +
Total params: 34,121,239 (130.16 MB) ++ + + + +
Trainable params: 34,068,119 (129.96 MB) ++ + + + +
Non-trainable params: 53,120 (207.50 KB) ++ + + +--- +## Preprocessing Layers + +Let's define the below preprocessing layers: + +- Resizing Layer: Resizes the image and maintains the aspect ratio by applying +padding when `pad_to_aspect_ratio=True`. Also, sets the default bounding box +format for representing the data. +- Max Bounding Box Layer: Limits the maximum number of bounding boxes per image. + + +```python +image_size = (800, 800) +batch_size = 4 +bbox_format = "yxyx" +epochs = 5 + +resizing = keras.layers.Resizing( + height=image_size[0], + width=image_size[1], + interpolation="bilinear", + pad_to_aspect_ratio=True, + bounding_box_format=bbox_format, +) + +max_box_layer = keras.layers.MaxNumBoundingBoxes( + max_number=100, bounding_box_format=bbox_format +) +``` + +### Predict and Visualize + +Next, let's obtain predictions from our object detector by loading the image and +visualizing them. We'll apply the preprocessing pipeline defined in the +preprocessing layers step. + + +```python +filepath = keras.utils.get_file( + origin="http://farm4.staticflickr.com/3755/10245052896_958cbf4766_z.jpg" +) +image = keras.utils.load_img(filepath) +image = keras.ops.cast(image, "float32") +image = keras.ops.expand_dims(image, axis=0) + +predictions = object_detector.predict(image, batch_size=1) + +keras.visualization.plot_bounding_box_gallery( + resizing(image), # resize image as per prediction preprocessing pipeline + bounding_box_format=bbox_format, + y_pred=predictions, + scale=4, + class_mapping=COCO_90_CLASS_MAPPING, +) +``` + +