Skip to content

Add support for custom target download directory in dataset resolvers #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions src/kagglehub/colab_cache_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,23 @@ def is_supported(self, handle: ModelHandle, *_, **__) -> bool: # noqa: ANN002,
return True

def _resolve(
self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: ModelHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Colab notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
if target_path:
logger.info(
"Ignoring `target_path` argument when running inside the Colab notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)

api_client = ColabClient()
data = {
Expand Down Expand Up @@ -118,13 +128,23 @@ def is_supported(self, handle: DatasetHandle, *_, **__) -> bool: # noqa: ANN002
return True

def _resolve(
self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: DatasetHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Colab notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
if target_path:
logger.info(
"Ignoring `target_path` argument when running inside the Colab notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)

api_client = ColabClient()
data = {
Expand Down
5 changes: 3 additions & 2 deletions src/kagglehub/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,19 @@
}


def dataset_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
def dataset_download(handle: str, path: Optional[str] = None, *, target_path: Optional[str] = None, force_download: Optional[bool] = False) -> str:
"""Download dataset files
Args:
handle: (string) the dataset handle
path: (string) Optional path to a file within a dataset
target_path: (string) Optional path to a directory where the dataset files will be downloaded.
force_download: (bool) Optional flag to force download a dataset, even if it's cached
Returns:
A string requesting the path to the requested dataset files.
"""
h = parse_dataset_handle(handle)
logger.info(f"Downloading Dataset: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK})
path, _ = registry.dataset_resolver(h, path, force_download=force_download)
path, _ = registry.dataset_resolver(h, path, force_download=force_download, target_path=target_path)
return path


Expand Down
83 changes: 78 additions & 5 deletions src/kagglehub/http_resolver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import shutil
import tarfile
import zipfile
from typing import Optional
Expand All @@ -17,6 +18,7 @@
from kagglehub.clients import KaggleApiV1Client
from kagglehub.exceptions import UnauthenticatedError
from kagglehub.handle import CompetitionHandle, DatasetHandle, ModelHandle, NotebookHandle, ResourceHandle
from kagglehub.logger import EXTRA_CONSOLE_BLOCK
from kagglehub.packages import PackageScope
from kagglehub.resolver import Resolver

Expand All @@ -35,8 +37,18 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return True

def _resolve(
self, h: CompetitionHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: CompetitionHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if target_path:
logger.info(
"Ignoring `target_path` argument for competition downloads.",
extra={**EXTRA_CONSOLE_BLOCK},
)
api_client = KaggleApiV1Client()

cached_path = load_from_cache(h, path)
Expand Down Expand Up @@ -99,7 +111,12 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return True

def _resolve(
self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: DatasetHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
api_client = KaggleApiV1Client()

Expand All @@ -108,12 +125,16 @@ def _resolve(

dataset_path = load_from_cache(h, path)
if dataset_path and not force_download:
if target_path:
# Handle target_path for cached files
return _copy_to_target_path(dataset_path, target_path), h.version
return dataset_path, h.version # Already cached
elif dataset_path and force_download:
delete_from_cache(h, path)

url_path = _build_dataset_download_url_path(h, path)
out_path = get_cached_path(h, path)
final_path = out_path

# Create the intermediary directories
if path:
Expand All @@ -135,7 +156,12 @@ def _resolve(
os.remove(archive_path)

mark_as_complete(h, path)
return out_path, h.version

# Handle target_path if specified
if target_path:
final_path = _copy_to_target_path(out_path, target_path)

return final_path, h.version


class ModelHttpResolver(Resolver[ModelHandle]):
Expand All @@ -144,8 +170,18 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return True

def _resolve(
self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: ModelHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if target_path:
logger.info(
"Ignoring `target_path` argument for model downloads.",
extra={**EXTRA_CONSOLE_BLOCK},
)
api_client = KaggleApiV1Client()

if not h.is_versioned():
Expand Down Expand Up @@ -206,8 +242,18 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return True

def _resolve(
self, h: NotebookHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: NotebookHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if target_path:
logger.info(
"Ignoring `target_path` argument for notebook output downloads.",
extra={**EXTRA_CONSOLE_BLOCK},
)
api_client = KaggleApiV1Client()

if not h.is_versioned():
Expand Down Expand Up @@ -385,3 +431,30 @@ def _build_competition_download_all_url_path(h: CompetitionHandle) -> str:

def _build_competition_download_file_url_path(h: CompetitionHandle, file: str) -> str:
return f"competitions/data/download/{h.competition}/{file}"


def _copy_to_target_path(source_path: str, target_path: str) -> str:
"""Copy file or directory from source_path to target_path.

Args:
source_path: Path to the source file or directory
target_path: Path to the target directory

Returns:
Path to the copied file or directory
"""
os.makedirs(target_path, exist_ok=True)

# Determine final path
basename = os.path.basename(source_path)
final_path = os.path.join(target_path, basename)

# Copy file or directory
if os.path.isdir(source_path):
if os.path.exists(final_path):
shutil.rmtree(final_path)
shutil.copytree(source_path, final_path)
else:
shutil.copy2(source_path, final_path)

return final_path
48 changes: 44 additions & 4 deletions src/kagglehub/kaggle_cache_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,24 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return False

def _resolve(
self, h: CompetitionHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: CompetitionHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
client = KaggleJwtClient()
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
if target_path:
logger.info(
"Ignoring `target_path` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)

competition_ref = {
"CompetitionSlug": h.competition,
Expand Down Expand Up @@ -102,13 +112,23 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return False

def _resolve(
self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: DatasetHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
if target_path:
logger.info(
"Ignoring `target_path` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
client = KaggleJwtClient()
dataset_ref = {
"OwnerSlug": h.owner,
Expand Down Expand Up @@ -177,13 +197,23 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return False

def _resolve(
self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: ModelHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
if target_path:
logger.info(
"Ignoring `target_path` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
client = KaggleJwtClient()
model_ref = {
"OwnerSlug": h.owner,
Expand Down Expand Up @@ -254,13 +284,23 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return False

def _resolve(
self, h: NotebookHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: NotebookHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
if target_path:
logger.info(
"Ignoring `target_path` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
client = KaggleJwtClient()
kernel_ref = {
"OwnerSlug": h.owner,
Expand Down
8 changes: 5 additions & 3 deletions src/kagglehub/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@ class Resolver(Generic[T]):
__metaclass__ = abc.ABCMeta

def __call__(
self, handle: T, path: Optional[str] = None, *, force_download: Optional[bool] = False
self, handle: T, path: Optional[str] = None, *, force_download: Optional[bool] = False, target_path: Optional[str] = None
) -> tuple[str, Optional[int]]:
"""Resolves a handle into a path with the requested file(s) and the resource's version number.

Args:
handle: (T) the ResourceHandle to resolve.
path: (string) Optional path to a file within the resource.
force_download: (bool) Optional flag to force download, even if it's cached.
target_path: (string) Optional path to a directory where the files will be downloaded.

Returns:
A tuple of: (string representing the path, version number of resolved datasource if present)
Some cases where version number might be missing: Competition datasource, API-based models.
"""
path, version = self._resolve(handle, path, force_download=force_download)
path, version = self._resolve(handle, path, force_download=force_download, target_path=target_path)

# Note handles are immutable, so _resolve() could not have altered our reference
register_datasource_access(handle, version)
Expand All @@ -35,14 +36,15 @@ def __call__(

@abc.abstractmethod
def _resolve(
self, handle: T, path: Optional[str] = None, *, force_download: Optional[bool] = False
self, handle: T, path: Optional[str] = None, *, force_download: Optional[bool] = False, target_path: Optional[str] = None
) -> tuple[str, Optional[int]]:
"""Resolves a handle into a path with the requested file(s) and the resource's version number.

Args:
handle: (T) the ResourceHandle to resolve.
path: (string) Optional path to a file within the resource.
force_download: (bool) Optional flag to force download, even if it's cached.
target_path: (string) Optional path to a directory where the files will be downloaded.

Returns:
A tuple of: (string representing the path, version number of resolved datasource if present)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_colab_cache_dataset_download.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
from unittest import mock

import requests
Expand Down Expand Up @@ -73,6 +74,21 @@ def test_versioned_dataset_download_bad_handle_raises(self) -> None:
with self.assertRaises(ValueError):
kagglehub.dataset_download("bad handle")

def test_versioned_dataset_download_with_target_path(self) -> None:
with stub.create_env():
target_dir = os.path.join(os.getcwd(), "custom_target")
os.makedirs(target_dir, exist_ok=True)
try:
dataset_path = kagglehub.dataset_download(VERSIONED_DATASET_HANDLE, target_path=target_dir)
# Colab cache resolver ignores target_path, so it should return the original path
self.assertNotEqual(target_dir, os.path.dirname(dataset_path))
# Check that original dataset path has expected ending
self.assertTrue(dataset_path.endswith("/1"))
finally:
# Clean up
if os.path.exists(target_dir):
shutil.rmtree(target_dir)


class TestNoInternetColabCacheModelDownload(BaseTestCase):
def test_colab_resolver_skipped_when_dataset_not_present(self) -> None:
Expand Down
Loading