diff --git a/README.md b/README.md index e35ed717d..d305ae9a0 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,7 @@ Following models are supported: | vgg_unet | VGG 16 | U-Net | | resnet50_unet | Resnet-50 | U-Net | | mobilenet_unet | MobileNet | U-Net | +| unet3_plus | Vanilla CNN | U-Net 3+ | | segnet | Vanilla CNN | Segnet | | vgg_segnet | VGG 16 | Segnet | | resnet50_segnet | Resnet-50 | Segnet | diff --git a/keras_segmentation/models/all_models.py b/keras_segmentation/models/all_models.py index 4300c839c..2083638c9 100644 --- a/keras_segmentation/models/all_models.py +++ b/keras_segmentation/models/all_models.py @@ -1,5 +1,6 @@ from . import pspnet from . import unet +from . import unet3_plus from . import segnet from . import fcn model_from_name = {} @@ -35,6 +36,8 @@ model_from_name["resnet50_unet"] = unet.resnet50_unet model_from_name["mobilenet_unet"] = unet.mobilenet_unet +model_from_name["unet3_plus"] = unet3_plus.unet3_plus + model_from_name["segnet"] = segnet.segnet model_from_name["vgg_segnet"] = segnet.vgg_segnet diff --git a/keras_segmentation/models/unet3_plus.py b/keras_segmentation/models/unet3_plus.py new file mode 100644 index 000000000..fcae16d4e --- /dev/null +++ b/keras_segmentation/models/unet3_plus.py @@ -0,0 +1,203 @@ +""" + @Author: Hamid Ali + @Date: 10/18/2022 + @GitHub: https://github.com/hamidriasat + @Gmail: hamidriasat@gmail.com +""" +import tensorflow as tf +import keras as k +from keras.layers import * +from .model_utils import get_segmentation_model +from .config import IMAGE_ORDERING + +if IMAGE_ORDERING == 'channels_first': + MERGE_AXIS = 1 +elif IMAGE_ORDERING == 'channels_last': + MERGE_AXIS = -1 + +""" # Model Architecture """ + + +def __conv_block(x, kernels, kernel_size=(3, 3), strides=(1, 1), padding='same', + is_bn=True, is_relu=True, n=2): + """ Custom function for conv2d: + Apply 3*3 convolutions with BN and relu. + """ + for i in range(1, n + 1): + x = k.layers.Conv2D(filters=kernels, kernel_size=kernel_size, + padding=padding, strides=strides, + kernel_regularizer=tf.keras.regularizers.l2(1e-4), + kernel_initializer=k.initializers.he_normal(seed=5))(x) + if is_bn: + x = k.layers.BatchNormalization()(x) + if is_relu: + x = k.activations.relu(x) + + return x + + +def __dotProduct(seg, cls): + B, H, W, N = k.backend.int_shape(seg) + seg = tf.reshape(seg, [-1, H * W, N]) + final = tf.einsum("ijk,ik->ijk", seg, cls) + final = tf.reshape(final, [-1, H, W, N]) + return final + + +""" UNet_3Plus """ + + +def __unet3_plus(n_classes, input_height=416, input_width=608, channels=3): + """ + Create model and pass it to segmentation head. + :param n_classes: number of output classes + :param input_height: input image height + :param input_width: input image width + :param channels: number of input channels + :return: image-segmentation-keras library compatible model + """ + assert input_height % 32 == 0 + assert input_width % 32 == 0 + + if IMAGE_ORDERING == 'channels_first': + img_input = Input(shape=(channels, input_height, input_width), name="img_input") + elif IMAGE_ORDERING == 'channels_last': + img_input = Input(shape=(input_height, input_width, channels), name="img_input") + + filters = [64, 128, 256, 512, 1024] + + """ Encoder""" + # block 1 + e1 = __conv_block(img_input, filters[0]) # 320*320*64 + + # block 2 + e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64 + e2 = __conv_block(e2, filters[1]) # 160*160*128 + + # block 3 + e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128 + e3 = __conv_block(e3, filters[2]) # 80*80*256 + + # block 4 + e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256 + e4 = __conv_block(e4, filters[3]) # 40*40*512 + + # block 5 + # bottleneck layer + e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512 + e5 = __conv_block(e5, filters[4]) # 20*20*1024 + + """ Decoder """ + cat_channels = filters[0] + cat_blocks = len(filters) + upsample_channels = cat_blocks * cat_channels + + """ d4 """ + e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64 + e1_d4 = __conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64 + + e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128 + e2_d4 = __conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64 + + e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256 + e3_d4 = __conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64 + + e4_d4 = __conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64 + + e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256 + e5_d4 = __conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64 + + d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4]) + d4 = __conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320 + + """ d3 """ + e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64 + e1_d3 = __conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64 + + e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256 + e2_d3 = __conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64 + + e3_d3 = __conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64 + + e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320 + e4_d3 = __conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 + + e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320 + e5_d3 = __conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 + + d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3]) + d3 = __conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320 + + """ d2 """ + e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64 + e1_d2 = __conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64 + + e2_d2 = __conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64 + + d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320 + d3_d2 = __conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 + + d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320 + d4_d2 = __conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 + + e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320 + e5_d2 = __conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 + + d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2]) + d2 = __conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320 + + """ d1 """ + e1_d1 = __conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64 + + d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320 + d2_d1 = __conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64 + + d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320 + d3_d1 = __conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 + + d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320 + d4_d1 = __conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 + + e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320 + e5_d1 = __conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 + + d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ]) + d1 = __conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320 + + # last layer does not have batchnorm and relu + d = __conv_block(d1, n_classes, n=1, is_bn=False, is_relu=False) + + model = get_segmentation_model(img_input, d) + + return model + + +def unet3_plus(n_classes: int, input_height: int = 416, input_width: int = 608, channels: int = 3): + """ + Create UNet3+ model based on image-segmentation-keras requirements + :param n_classes: number of output classes + :param input_height: input image height + :param input_width: input image width + :param channels: number of input channels + :return: image-segmentation-keras library compatible model + """ + model = __unet3_plus( + n_classes, + input_height=input_height, + input_width=input_width, + channels=channels + ) + model.model_name = "unet3_plus" + return model + + +if __name__ == "__main__": + """## Model Compilation""" + OUTPUT_CHANNELS = 50 + + __unet_3P = unet3_plus(OUTPUT_CHANNELS) + __unet_3P.summary() + + # tf.keras.utils.plot_model(__unet_3P, show_layer_names=True, show_shapes=True) + + # __unet_3P.save("unet_3P.hdf5") diff --git a/test/test_models.py b/test/test_models.py index d55d5518f..cafe305dd 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,7 +1,7 @@ import numpy as np import tempfile -import sys +import sys import keras @@ -10,9 +10,8 @@ verify_segmentation_dataset, image_segmentation_generator from keras_segmentation.predict import predict_multiple, predict, evaluate - from keras_segmentation.model_compression import perform_distilation -from keras_segmentation.pretrained import pspnet_50_ADE_20K +from keras_segmentation.pretrained import pspnet_50_ADE_20K tr_im = "test/example_dataset/images_prepped_train" tr_an = "test/example_dataset/annotations_prepped_train" @@ -20,35 +19,32 @@ te_an = "test/example_dataset/annotations_prepped_test" - def test_models(): - n_c = 100 - models = [ ( "unet_mini" , 124 , 156 ) , ( "vgg_unet" , 224 , 224*2 ) , - ( 'resnet50_pspnet', 192*2 , 192*3 ) , ( 'mobilenet_unet', 224 , 224 ), ( 'mobilenet_unet', 224+32 , 224+32 ),( 'segnet', 224 , 224*2 ),( 'vgg_segnet', 224 , 224*2 ) ,( 'fcn_32', 224 , 224*2 ) ,( 'fcn_8_vgg', 224 , 224*2 ) ] + models = [("unet_mini", 124, 156), ("unet3_plus", 224, 224), ("vgg_unet", 224, 224 * 2), + ('resnet50_pspnet', 192 * 2, 192 * 3), ('mobilenet_unet', 224, 224), + ('mobilenet_unet', 224 + 32, 224 + 32), ('segnet', 224, 224 * 2), ('vgg_segnet', 224, 224 * 2), + ('fcn_32', 224, 224 * 2), ('fcn_8_vgg', 224, 224 * 2)] - for model_name , h , w in models: - m = all_models.model_from_name[model_name]( n_c, input_height=h, input_width=w) + for model_name, h, w in models: + m = all_models.model_from_name[model_name](n_c, input_height=h, input_width=w) m.train(train_images=tr_im, - train_annotations=tr_an, - steps_per_epoch=2, - epochs=2 ) + train_annotations=tr_an, + steps_per_epoch=2, + epochs=2) keras.backend.clear_session() - - - def test_verify(): verify_segmentation_dataset(tr_im, tr_an, 50) def test_datag(): g = image_segmentation_generator(images_path=tr_im, segs_path=tr_an, - batch_size=3, n_classes=50, + batch_size=3, n_classes=50, input_height=224, input_width=324, output_height=114, output_width=134, do_augment=False) @@ -62,7 +58,7 @@ def test_datag(): # with augmentation def test_datag2(): g = image_segmentation_generator(images_path=tr_im, segs_path=tr_an, - batch_size=3, n_classes=50, + batch_size=3, n_classes=50, input_height=224, input_width=324, output_height=114, output_width=134, do_augment=True) @@ -81,7 +77,7 @@ def test_model(): check_path = tempfile.mktemp() m = all_models.model_from_name[model_name]( - n_c, input_height=h, input_width=w) + n_c, input_height=h, input_width=w) m.train(train_images=tr_im, train_annotations=tr_an, @@ -102,7 +98,7 @@ def test_model(): predict_multiple( inp_dir=te_im, checkpoints_path=check_path, out_dir="/tmp") - predict_multiple(inps=[np.zeros((h, w, 3))]*3, + predict_multiple(inps=[np.zeros((h, w, 3))] * 3, checkpoints_path=check_path, out_dir="/tmp") ev = m.evaluate_segmentation(inp_images_dir=te_im, annotations_dir=te_an) @@ -111,7 +107,7 @@ def test_model(): o = predict(inp=np.zeros((h, w, 3)), checkpoints_path=check_path) o = predict(inp=np.zeros((h, w, 3)), checkpoints_path=check_path, - overlay_img=True, class_names=['nn']*n_c, show_legends=True) + overlay_img=True, class_names=['nn'] * n_c, show_legends=True) print("pr") o.shape @@ -123,10 +119,9 @@ def test_model(): def test_kd(): - if sys.version_info.major < 3: # KD wont work with python 2 - return + return model_name = "fcn_8" h = 224 @@ -135,9 +130,7 @@ def test_kd(): check_path1 = tempfile.mktemp() m1 = all_models.model_from_name[model_name]( - n_c, input_height=h, input_width=w) - - + n_c, input_height=h, input_width=w) model_name = "unet_mini" h = 124 @@ -146,53 +139,34 @@ def test_kd(): check_path2 = tempfile.mktemp() m2 = all_models.model_from_name[model_name]( - n_c, input_height=h, input_width=w) - + n_c, input_height=h, input_width=w) m1.train(train_images=tr_im, - train_annotations=tr_an, - steps_per_epoch=2, - epochs=2, - checkpoints_path=check_path1 - ) - - perform_distilation(m1 ,m2, tr_im , distilation_loss='kl' , - batch_size =2 ,checkpoints_path=check_path2 , epochs = 2 , steps_per_epoch=2, ) - - - perform_distilation(m1 ,m2, tr_im , distilation_loss='l2' , - batch_size =2 ,checkpoints_path=check_path2 , epochs = 2 , steps_per_epoch=2, ) - - - perform_distilation(m1 ,m2, tr_im , distilation_loss='l2' , - batch_size =2 ,checkpoints_path=check_path2 , epochs = 2 , steps_per_epoch=2, feats_distilation_loss='pa' ) - - - - - + train_annotations=tr_an, + steps_per_epoch=2, + epochs=2, + checkpoints_path=check_path1 + ) + perform_distilation(m1, m2, tr_im, distilation_loss='kl', + batch_size=2, checkpoints_path=check_path2, epochs=2, steps_per_epoch=2, ) + perform_distilation(m1, m2, tr_im, distilation_loss='l2', + batch_size=2, checkpoints_path=check_path2, epochs=2, steps_per_epoch=2, ) + perform_distilation(m1, m2, tr_im, distilation_loss='l2', + batch_size=2, checkpoints_path=check_path2, epochs=2, steps_per_epoch=2, + feats_distilation_loss='pa') def test_pretrained(): - - - model = pspnet_50_ADE_20K() + model = pspnet_50_ADE_20K() out = model.predict_segmentation( - inp=te_im+"/0016E5_07959.png", + inp=te_im + "/0016E5_07959.png", out_fname="/tmp/out.png" ) - - - - - - - # def test_models():