diff --git a/src/kagglehub/models.py b/src/kagglehub/models.py index 19bca043..a3a0aec4 100644 --- a/src/kagglehub/models.py +++ b/src/kagglehub/models.py @@ -1,4 +1,5 @@ import logging +from shutil import copytree from typing import Optional, Union from kagglehub import registry @@ -13,22 +14,44 @@ DEFAULT_IGNORE_PATTERNS = [".git/", "*/.git/", ".cache/", ".huggingface/"] -def model_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: +def model_download( + handle: str, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + output_dir: Optional[str] = None) -> str: """Download model files. Args: handle: (string) the model handle. path: (string) Optional path to a file within the model bundle. force_download: (bool) Optional flag to force download a model, even if it's cached. - + output_dir: (string) Optional path to copy model files to after successful download. Returns: A string representing the path to the requested model files. """ h = parse_model_handle(handle) logger.info(f"Downloading Model: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK}) - return registry.model_resolver(h, path, force_download=force_download) + cached_dir = registry.model_resolver(h, path, force_download=force_download) + + if output_dir is None or output_dir == cached_dir: + return cached_dir + try: + # only copying so that we can maintain the cached files + logger.info( + f"Copying model files to requested directory: {output_dir} ...", + extra={**EXTRA_CONSOLE_BLOCK} + ) + true_output_dir = copytree(cached_dir, output_dir, dirs_exist_ok=True) + return true_output_dir + except Exception as e: + logger.warn( + f"Successfully downloaded {handle}, but failed to copy from {cached_dir} " + f"to requested output directory {output_dir}. Encountered error: {e}" + ) + return cached_dir def model_upload( handle: str, diff --git a/tests/test_http_model_download.py b/tests/test_http_model_download.py index 9ba6c392..e2f55d9d 100644 --- a/tests/test_http_model_download.py +++ b/tests/test_http_model_download.py @@ -1,5 +1,7 @@ import os +from tempfile import TemporaryDirectory from typing import Optional +from unittest import mock import requests @@ -147,6 +149,32 @@ def test_versioned_model_download_with_path_with_force_download(self) -> None: with create_test_cache() as d: self._download_test_file_and_assert_downloaded(d, VERSIONED_MODEL_HANDLE, force_download=True) + def test_versioned_model_download_with_output_dir(self) -> None: + with create_test_cache() as d: + with TemporaryDirectory() as expected_output_dir: + self._download_model_and_assert_downloaded( + d, + VERSIONED_MODEL_HANDLE, + expected_output_dir, + output_dir=expected_output_dir + ) + + def test_versioned_model_download_with_bad_output_dir(self) -> None: + with ( + create_test_cache() as d, + TemporaryDirectory() as placeholder_dir, + mock.patch("kagglehub.models.copytree") as mock_copytree + ): + mock_copytree.side_effect = Exception("Mock exception") + expected_output_dir = EXPECTED_MODEL_SUBDIR # falls back to default + self._download_model_and_assert_downloaded( + d, + VERSIONED_MODEL_HANDLE, + expected_output_dir, + # note: placeholder name is irrelevant since copytree is mocked to throw + output_dir=placeholder_dir + ) + def test_unversioned_model_download_with_path_with_force_download(self) -> None: with create_test_cache() as d: self._download_test_file_and_assert_downloaded(d, UNVERSIONED_MODEL_HANDLE, force_download=True) @@ -188,7 +216,6 @@ def test_versioned_model_download_with_path_already_cached_with_force_download_e self.assertEqual(os.path.join(d, EXPECTED_MODEL_SUBPATH), model_path) - class TestHttpNoInternet(BaseTestCase): def test_versioned_model_download_already_cached_with_force_download(self) -> None: with create_test_cache():