diff --git a/src/sagemaker_training/_entry_point_type.py b/src/sagemaker_training/_entry_point_type.py index 790a292b..634c25d0 100644 --- a/src/sagemaker_training/_entry_point_type.py +++ b/src/sagemaker_training/_entry_point_type.py @@ -21,11 +21,13 @@ class _EntryPointType(enum.Enum): """Enumerated type consisting of valid types of training entry points.""" + PYTHON_MODULE = "PYTHON_MODULE" PYTHON_PACKAGE = "PYTHON_PACKAGE" PYTHON_PROGRAM = "PYTHON_PROGRAM" COMMAND = "COMMAND" +PYTHON_MODULE = _EntryPointType.PYTHON_MODULE PYTHON_PACKAGE = _EntryPointType.PYTHON_PACKAGE PYTHON_PROGRAM = _EntryPointType.PYTHON_PROGRAM COMMAND = _EntryPointType.COMMAND @@ -46,5 +48,7 @@ def get(path, name): # type: (str, str) -> _EntryPointType return _EntryPointType.PYTHON_PACKAGE elif name.endswith(".py"): return _EntryPointType.PYTHON_PROGRAM + elif name.startswith("-m "): + return _EntryPointType.PYTHON_MODULE else: return _EntryPointType.COMMAND diff --git a/src/sagemaker_training/torch_distributed.py b/src/sagemaker_training/torch_distributed.py index ead9cf84..3d3c4a1c 100644 --- a/src/sagemaker_training/torch_distributed.py +++ b/src/sagemaker_training/torch_distributed.py @@ -93,7 +93,7 @@ def _create_command(self): "Please use a python script as the entry-point" ) - if entrypoint_type is _entry_point_type.PYTHON_PROGRAM: + if entrypoint_type is _entry_point_type.PYTHON_PROGRAM or entrypoint_type is _entry_point_type.PYTHON_MODULE: num_hosts = len(self._hosts) torchrun_cmd = [] @@ -135,7 +135,7 @@ def _create_command(self): torchrun_cmd += self._args return torchrun_cmd else: - raise errors.ClientError("Unsupported entry point type for torch_distributed") + raise errors.ClientError(f"Unsupported entry point type for torch_distributed: {entrypoint_type}") def run(self, capture_error=True, wait=True): """ diff --git a/test/unit/test_entry_point_type.py b/test/unit/test_entry_point_type.py index adb6c538..32b81a96 100644 --- a/test/unit/test_entry_point_type.py +++ b/test/unit/test_entry_point_type.py @@ -36,6 +36,10 @@ def has_requirements(): yield +def test_get_module(): + assert _entry_point_type.get("bla", "-m program") == _entry_point_type.PYTHON_MODULE + + def test_get_package(entry_point_type_module): assert _entry_point_type.get("bla", "program.py") == _entry_point_type.PYTHON_PACKAGE