diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 87e218a0c..8a4f20820 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -133,19 +133,7 @@ def color(image, factor): def contrast(image, factor): """Equivalent of PIL Contrast.""" - degenerate = tf.image.rgb_to_grayscale(image) - # Cast before calling tf.histogram. - degenerate = tf.cast(degenerate, tf.int32) - - # Compute the grayscale histogram, then compute the mean pixel value, - # and create a constant image size of that value. Use that as the - # blending degenerate target of the original image. - hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) - mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 - degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean - degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) - degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8)) - return blend(degenerate, image, factor) + return tf.image.adjust_contrast(image, factor) def brightness(image, factor):