diff --git a/src/kagglehub/colab_cache_resolver.py b/src/kagglehub/colab_cache_resolver.py index a15d9e3..3e02767 100644 --- a/src/kagglehub/colab_cache_resolver.py +++ b/src/kagglehub/colab_cache_resolver.py @@ -42,13 +42,23 @@ def is_supported(self, handle: ModelHandle, *_, **__) -> bool: # noqa: ANN002, return True def _resolve( - self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + self, + h: ModelHandle, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + target_path: Optional[str] = None, ) -> tuple[str, Optional[int]]: if force_download: logger.info( "Ignoring `force_download` argument when running inside the Colab notebook environment.", extra={**EXTRA_CONSOLE_BLOCK}, ) + if target_path: + logger.info( + "Ignoring `target_path` argument when running inside the Colab notebook environment.", + extra={**EXTRA_CONSOLE_BLOCK}, + ) api_client = ColabClient() data = { @@ -118,13 +128,23 @@ def is_supported(self, handle: DatasetHandle, *_, **__) -> bool: # noqa: ANN002 return True def _resolve( - self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + self, + h: DatasetHandle, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + target_path: Optional[str] = None, ) -> tuple[str, Optional[int]]: if force_download: logger.info( "Ignoring `force_download` argument when running inside the Colab notebook environment.", extra={**EXTRA_CONSOLE_BLOCK}, ) + if target_path: + logger.info( + "Ignoring `target_path` argument when running inside the Colab notebook environment.", + extra={**EXTRA_CONSOLE_BLOCK}, + ) api_client = ColabClient() data = { diff --git a/src/kagglehub/datasets.py b/src/kagglehub/datasets.py index 8a298ae..2d5e350 100755 --- a/src/kagglehub/datasets.py +++ b/src/kagglehub/datasets.py @@ -29,18 +29,19 @@ } -def dataset_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: +def dataset_download(handle: str, path: Optional[str] = None, *, target_path: Optional[str] = None, force_download: Optional[bool] = False) -> str: """Download dataset files Args: handle: (string) the dataset handle path: (string) Optional path to a file within a dataset + target_path: (string) Optional path to a directory where the dataset files will be downloaded. force_download: (bool) Optional flag to force download a dataset, even if it's cached Returns: A string requesting the path to the requested dataset files. """ h = parse_dataset_handle(handle) logger.info(f"Downloading Dataset: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK}) - path, _ = registry.dataset_resolver(h, path, force_download=force_download) + path, _ = registry.dataset_resolver(h, path, force_download=force_download, target_path=target_path) return path diff --git a/src/kagglehub/http_resolver.py b/src/kagglehub/http_resolver.py index f44a2a9..4e532a8 100644 --- a/src/kagglehub/http_resolver.py +++ b/src/kagglehub/http_resolver.py @@ -1,5 +1,6 @@ import logging import os +import shutil import tarfile import zipfile from typing import Optional @@ -17,6 +18,7 @@ from kagglehub.clients import KaggleApiV1Client from kagglehub.exceptions import UnauthenticatedError from kagglehub.handle import CompetitionHandle, DatasetHandle, ModelHandle, NotebookHandle, ResourceHandle +from kagglehub.logger import EXTRA_CONSOLE_BLOCK from kagglehub.packages import PackageScope from kagglehub.resolver import Resolver @@ -35,8 +37,18 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 return True def _resolve( - self, h: CompetitionHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + self, + h: CompetitionHandle, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + target_path: Optional[str] = None, ) -> tuple[str, Optional[int]]: + if target_path: + logger.info( + "Ignoring `target_path` argument for competition downloads.", + extra={**EXTRA_CONSOLE_BLOCK}, + ) api_client = KaggleApiV1Client() cached_path = load_from_cache(h, path) @@ -99,7 +111,12 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 return True def _resolve( - self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + self, + h: DatasetHandle, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + target_path: Optional[str] = None, ) -> tuple[str, Optional[int]]: api_client = KaggleApiV1Client() @@ -108,12 +125,16 @@ def _resolve( dataset_path = load_from_cache(h, path) if dataset_path and not force_download: + if target_path: + # Handle target_path for cached files + return _copy_to_target_path(dataset_path, target_path), h.version return dataset_path, h.version # Already cached elif dataset_path and force_download: delete_from_cache(h, path) url_path = _build_dataset_download_url_path(h, path) out_path = get_cached_path(h, path) + final_path = out_path # Create the intermediary directories if path: @@ -135,7 +156,12 @@ def _resolve( os.remove(archive_path) mark_as_complete(h, path) - return out_path, h.version + + # Handle target_path if specified + if target_path: + final_path = _copy_to_target_path(out_path, target_path) + + return final_path, h.version class ModelHttpResolver(Resolver[ModelHandle]): @@ -144,8 +170,18 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 return True def _resolve( - self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + self, + h: ModelHandle, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + target_path: Optional[str] = None, ) -> tuple[str, Optional[int]]: + if target_path: + logger.info( + "Ignoring `target_path` argument for model downloads.", + extra={**EXTRA_CONSOLE_BLOCK}, + ) api_client = KaggleApiV1Client() if not h.is_versioned(): @@ -206,8 +242,18 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 return True def _resolve( - self, h: NotebookHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + self, + h: NotebookHandle, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + target_path: Optional[str] = None, ) -> tuple[str, Optional[int]]: + if target_path: + logger.info( + "Ignoring `target_path` argument for notebook output downloads.", + extra={**EXTRA_CONSOLE_BLOCK}, + ) api_client = KaggleApiV1Client() if not h.is_versioned(): @@ -385,3 +431,30 @@ def _build_competition_download_all_url_path(h: CompetitionHandle) -> str: def _build_competition_download_file_url_path(h: CompetitionHandle, file: str) -> str: return f"competitions/data/download/{h.competition}/{file}" + + +def _copy_to_target_path(source_path: str, target_path: str) -> str: + """Copy file or directory from source_path to target_path. + + Args: + source_path: Path to the source file or directory + target_path: Path to the target directory + + Returns: + Path to the copied file or directory + """ + os.makedirs(target_path, exist_ok=True) + + # Determine final path + basename = os.path.basename(source_path) + final_path = os.path.join(target_path, basename) + + # Copy file or directory + if os.path.isdir(source_path): + if os.path.exists(final_path): + shutil.rmtree(final_path) + shutil.copytree(source_path, final_path) + else: + shutil.copy2(source_path, final_path) + + return final_path diff --git a/src/kagglehub/kaggle_cache_resolver.py b/src/kagglehub/kaggle_cache_resolver.py index 52df8f1..787018b 100644 --- a/src/kagglehub/kaggle_cache_resolver.py +++ b/src/kagglehub/kaggle_cache_resolver.py @@ -38,7 +38,12 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 return False def _resolve( - self, h: CompetitionHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + self, + h: CompetitionHandle, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + target_path: Optional[str] = None, ) -> tuple[str, Optional[int]]: client = KaggleJwtClient() if force_download: @@ -46,6 +51,11 @@ def _resolve( "Ignoring `force_download` argument when running inside the Kaggle notebook environment.", extra={**EXTRA_CONSOLE_BLOCK}, ) + if target_path: + logger.info( + "Ignoring `target_path` argument when running inside the Kaggle notebook environment.", + extra={**EXTRA_CONSOLE_BLOCK}, + ) competition_ref = { "CompetitionSlug": h.competition, @@ -102,13 +112,23 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 return False def _resolve( - self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + self, + h: DatasetHandle, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + target_path: Optional[str] = None, ) -> tuple[str, Optional[int]]: if force_download: logger.info( "Ignoring `force_download` argument when running inside the Kaggle notebook environment.", extra={**EXTRA_CONSOLE_BLOCK}, ) + if target_path: + logger.info( + "Ignoring `target_path` argument when running inside the Kaggle notebook environment.", + extra={**EXTRA_CONSOLE_BLOCK}, + ) client = KaggleJwtClient() dataset_ref = { "OwnerSlug": h.owner, @@ -177,13 +197,23 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 return False def _resolve( - self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + self, + h: ModelHandle, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + target_path: Optional[str] = None, ) -> tuple[str, Optional[int]]: if force_download: logger.info( "Ignoring `force_download` argument when running inside the Kaggle notebook environment.", extra={**EXTRA_CONSOLE_BLOCK}, ) + if target_path: + logger.info( + "Ignoring `target_path` argument when running inside the Kaggle notebook environment.", + extra={**EXTRA_CONSOLE_BLOCK}, + ) client = KaggleJwtClient() model_ref = { "OwnerSlug": h.owner, @@ -254,13 +284,23 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 return False def _resolve( - self, h: NotebookHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + self, + h: NotebookHandle, + path: Optional[str] = None, + *, + force_download: Optional[bool] = False, + target_path: Optional[str] = None, ) -> tuple[str, Optional[int]]: if force_download: logger.info( "Ignoring `force_download` argument when running inside the Kaggle notebook environment.", extra={**EXTRA_CONSOLE_BLOCK}, ) + if target_path: + logger.info( + "Ignoring `target_path` argument when running inside the Kaggle notebook environment.", + extra={**EXTRA_CONSOLE_BLOCK}, + ) client = KaggleJwtClient() kernel_ref = { "OwnerSlug": h.owner, diff --git a/src/kagglehub/resolver.py b/src/kagglehub/resolver.py index 71824db..dd2e884 100644 --- a/src/kagglehub/resolver.py +++ b/src/kagglehub/resolver.py @@ -13,7 +13,7 @@ class Resolver(Generic[T]): __metaclass__ = abc.ABCMeta def __call__( - self, handle: T, path: Optional[str] = None, *, force_download: Optional[bool] = False + self, handle: T, path: Optional[str] = None, *, force_download: Optional[bool] = False, target_path: Optional[str] = None ) -> tuple[str, Optional[int]]: """Resolves a handle into a path with the requested file(s) and the resource's version number. @@ -21,12 +21,13 @@ def __call__( handle: (T) the ResourceHandle to resolve. path: (string) Optional path to a file within the resource. force_download: (bool) Optional flag to force download, even if it's cached. + target_path: (string) Optional path to a directory where the files will be downloaded. Returns: A tuple of: (string representing the path, version number of resolved datasource if present) Some cases where version number might be missing: Competition datasource, API-based models. """ - path, version = self._resolve(handle, path, force_download=force_download) + path, version = self._resolve(handle, path, force_download=force_download, target_path=target_path) # Note handles are immutable, so _resolve() could not have altered our reference register_datasource_access(handle, version) @@ -35,7 +36,7 @@ def __call__( @abc.abstractmethod def _resolve( - self, handle: T, path: Optional[str] = None, *, force_download: Optional[bool] = False + self, handle: T, path: Optional[str] = None, *, force_download: Optional[bool] = False, target_path: Optional[str] = None ) -> tuple[str, Optional[int]]: """Resolves a handle into a path with the requested file(s) and the resource's version number. @@ -43,6 +44,7 @@ def _resolve( handle: (T) the ResourceHandle to resolve. path: (string) Optional path to a file within the resource. force_download: (bool) Optional flag to force download, even if it's cached. + target_path: (string) Optional path to a directory where the files will be downloaded. Returns: A tuple of: (string representing the path, version number of resolved datasource if present) diff --git a/tests/test_colab_cache_dataset_download.py b/tests/test_colab_cache_dataset_download.py index 25b6177..3b91218 100644 --- a/tests/test_colab_cache_dataset_download.py +++ b/tests/test_colab_cache_dataset_download.py @@ -1,4 +1,5 @@ import os +import shutil from unittest import mock import requests @@ -73,6 +74,21 @@ def test_versioned_dataset_download_bad_handle_raises(self) -> None: with self.assertRaises(ValueError): kagglehub.dataset_download("bad handle") + def test_versioned_dataset_download_with_target_path(self) -> None: + with stub.create_env(): + target_dir = os.path.join(os.getcwd(), "custom_target") + os.makedirs(target_dir, exist_ok=True) + try: + dataset_path = kagglehub.dataset_download(VERSIONED_DATASET_HANDLE, target_path=target_dir) + # Colab cache resolver ignores target_path, so it should return the original path + self.assertNotEqual(target_dir, os.path.dirname(dataset_path)) + # Check that original dataset path has expected ending + self.assertTrue(dataset_path.endswith("/1")) + finally: + # Clean up + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + class TestNoInternetColabCacheModelDownload(BaseTestCase): def test_colab_resolver_skipped_when_dataset_not_present(self) -> None: diff --git a/tests/test_http_dataset_download.py b/tests/test_http_dataset_download.py index e796f26..7c795f6 100644 --- a/tests/test_http_dataset_download.py +++ b/tests/test_http_dataset_download.py @@ -48,26 +48,46 @@ def _download_dataset_and_assert_downloaded( dataset_handle: str, expected_subdir_or_subpath: str, expected_files: Optional[list[str]] = None, + expected_target_path: Optional[str] = None, **kwargs, # noqa: ANN003 ) -> None: # Download the full datasets and ensure all files are there. dataset_path = kagglehub.dataset_download(dataset_handle, **kwargs) - self.assertEqual(os.path.join(d, expected_subdir_or_subpath), dataset_path) + # If target_path was specified, check that the file was copied there + if expected_target_path: + self.assertEqual(expected_target_path, dataset_path) + self.assertTrue(os.path.exists(expected_target_path)) + else: + self.assertEqual(os.path.join(d, expected_subdir_or_subpath), dataset_path) + + path_to_check = dataset_path if not expected_files: expected_files = ["foo.txt"] - self.assertEqual(sorted(expected_files), sorted(os.listdir(dataset_path))) + self.assertEqual(sorted(expected_files), sorted(os.listdir(path_to_check))) # Assert that the archive file has been deleted archive_path = get_cached_archive_path(parse_dataset_handle(dataset_handle)) self.assertFalse(os.path.exists(archive_path)) - def _download_test_file_and_assert_downloaded(self, d: str, dataset_handle: str, **kwargs) -> None: # noqa: ANN003 + def _download_test_file_and_assert_downloaded( + self, + d: str, + dataset_handle: str, + expected_target_path: Optional[str] = None, + **kwargs, # noqa: ANN003 + ) -> None: dataset_path = kagglehub.dataset_download(dataset_handle, path=TEST_FILEPATH, **kwargs) - self.assertEqual(os.path.join(d, EXPECTED_DATASET_SUBPATH, TEST_FILEPATH), dataset_path) - with open(dataset_path) as dataset_file: - self.assertEqual(TEST_CONTENTS, dataset_file.read()) + + if expected_target_path: + self.assertEqual(expected_target_path, dataset_path) + with open(dataset_path) as dataset_file: + self.assertEqual(TEST_CONTENTS, dataset_file.read()) + else: + self.assertEqual(os.path.join(d, EXPECTED_DATASET_SUBPATH, TEST_FILEPATH), dataset_path) + with open(dataset_path) as dataset_file: + self.assertEqual(TEST_CONTENTS, dataset_file.read()) def _download_test_file_and_assert_downloaded_auto_compressed( self, @@ -133,3 +153,26 @@ def test_unversioned_dataset_full_download_with_file_already_cached(self) -> Non # Download a single file first kagglehub.dataset_download(UNVERSIONED_DATASET_HANDLE, path=TEST_FILEPATH) self._download_dataset_and_assert_downloaded(d, UNVERSIONED_DATASET_HANDLE, EXPECTED_DATASET_SUBDIR) + + def test_versioned_dataset_download_with_target_path(self) -> None: + with create_test_cache() as d: + target_dir = os.path.join(d, "custom_target") + os.makedirs(target_dir, exist_ok=True) + self._download_dataset_and_assert_downloaded( + d, + VERSIONED_DATASET_HANDLE, + EXPECTED_DATASET_SUBDIR, + target_path=target_dir, + expected_target_path=os.path.join(target_dir, os.path.basename(EXPECTED_DATASET_SUBPATH)), + ) + + def test_versioned_dataset_download_with_path_and_target_path(self) -> None: + with create_test_cache() as d: + target_dir = os.path.join(d, "custom_target") + os.makedirs(target_dir, exist_ok=True) + self._download_test_file_and_assert_downloaded( + d, + VERSIONED_DATASET_HANDLE, + target_path=target_dir, + expected_target_path=os.path.join(target_dir, TEST_FILEPATH), + ) diff --git a/tests/test_kaggle_cache_dataset_download.py b/tests/test_kaggle_cache_dataset_download.py index c94adbb..7a55586 100644 --- a/tests/test_kaggle_cache_dataset_download.py +++ b/tests/test_kaggle_cache_dataset_download.py @@ -1,4 +1,5 @@ import os +import shutil from unittest import mock import requests @@ -84,3 +85,18 @@ def test_versioned_dataset_download_with_force_download_explicitly_false(self) - with stub.create_env(): dataset_path = kagglehub.dataset_download(VERSIONED_DATASET_HANDLE, force_download=False) self.assertEqual(["foo.txt"], sorted(os.listdir(dataset_path))) + + def test_versioned_dataset_download_with_target_path(self) -> None: + with stub.create_env(): + target_dir = os.path.join(os.getcwd(), "custom_target") + os.makedirs(target_dir, exist_ok=True) + try: + dataset_path = kagglehub.dataset_download(VERSIONED_DATASET_HANDLE, target_path=target_dir) + # Kaggle cache resolver ignores target_path, so it should return the original path + self.assertNotEqual(target_dir, os.path.dirname(dataset_path)) + # Check that original dataset path contains expected files + self.assertEqual(["foo.txt"], sorted(os.listdir(dataset_path))) + finally: + # Clean up + if os.path.exists(target_dir): + shutil.rmtree(target_dir)