diff --git a/classy_vision/dataset/transforms/__init__.py b/classy_vision/dataset/transforms/__init__.py index 2c2d0d34c7..5b6e50380e 100644 --- a/classy_vision/dataset/transforms/__init__.py +++ b/classy_vision/dataset/transforms/__init__.py @@ -12,6 +12,7 @@ import torchvision.transforms as transforms import torchvision.transforms._transforms_video as transforms_video from classy_vision.generic.registry_utils import import_all_modules +from classy_vision.generic.util import log_class_usage from .classy_transform import ClassyTransform @@ -49,20 +50,23 @@ def build_transform(transform_config: Dict[str, Any]) -> Callable: transform_args = copy.deepcopy(transform_config) del transform_args["name"] if name in TRANSFORM_REGISTRY: - return TRANSFORM_REGISTRY[name].from_config(transform_args) - # the name should be available in torchvision.transforms - # if users specify the torchvision transform name in snake case, - # we need to convert it to title case. - if not (hasattr(transforms, name) or hasattr(transforms_video, name)): - name = name.title().replace("_", "") - assert hasattr(transforms, name) or hasattr(transforms_video, name), ( - f"{name} isn't a registered tranform" - ", nor is it available in torchvision.transforms" - ) - if hasattr(transforms, name): - return getattr(transforms, name)(**transform_args) + transform = TRANSFORM_REGISTRY[name].from_config(transform_args) else: - return getattr(transforms_video, name)(**transform_args) + # the name should be available in torchvision.transforms + # if users specify the torchvision transform name in snake case, + # we need to convert it to title case. + if not (hasattr(transforms, name) or hasattr(transforms_video, name)): + name = name.title().replace("_", "") + assert hasattr(transforms, name) or hasattr(transforms_video, name), ( + f"{name} isn't a registered tranform" + ", nor is it available in torchvision.transforms" + ) + if hasattr(transforms, name): + transform = getattr(transforms, name)(**transform_args) + else: + transform = getattr(transforms_video, name)(**transform_args) + log_class_usage("Transform", transform.__class__) + return transform def build_transforms(transforms_config: List[Dict[str, Any]]) -> Callable: