Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions classy_vision/dataset/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torchvision.transforms as transforms
import torchvision.transforms._transforms_video as transforms_video
from classy_vision.generic.registry_utils import import_all_modules
from classy_vision.generic.util import log_class_usage

from .classy_transform import ClassyTransform

Expand Down Expand Up @@ -49,20 +50,23 @@ def build_transform(transform_config: Dict[str, Any]) -> Callable:
transform_args = copy.deepcopy(transform_config)
del transform_args["name"]
if name in TRANSFORM_REGISTRY:
return TRANSFORM_REGISTRY[name].from_config(transform_args)
# the name should be available in torchvision.transforms
# if users specify the torchvision transform name in snake case,
# we need to convert it to title case.
if not (hasattr(transforms, name) or hasattr(transforms_video, name)):
name = name.title().replace("_", "")
assert hasattr(transforms, name) or hasattr(transforms_video, name), (
f"{name} isn't a registered tranform"
", nor is it available in torchvision.transforms"
)
if hasattr(transforms, name):
return getattr(transforms, name)(**transform_args)
transform = TRANSFORM_REGISTRY[name].from_config(transform_args)
else:
return getattr(transforms_video, name)(**transform_args)
# the name should be available in torchvision.transforms
# if users specify the torchvision transform name in snake case,
# we need to convert it to title case.
if not (hasattr(transforms, name) or hasattr(transforms_video, name)):
name = name.title().replace("_", "")
assert hasattr(transforms, name) or hasattr(transforms_video, name), (
f"{name} isn't a registered tranform"
", nor is it available in torchvision.transforms"
)
if hasattr(transforms, name):
transform = getattr(transforms, name)(**transform_args)
else:
transform = getattr(transforms_video, name)(**transform_args)
log_class_usage("Transform", transform.__class__)
return transform


def build_transforms(transforms_config: List[Dict[str, Any]]) -> Callable:
Expand Down