Skip to content

[core] respect local_files_only=True when using sharded checkpoints #12005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
33 changes: 31 additions & 2 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
from packaging import version
from requests import HTTPError
from requests.exceptions import ConnectionError

from .. import __version__
from .constants import (
Expand Down Expand Up @@ -403,9 +404,27 @@ def _get_checkpoint_shard_files(

ignore_patterns = ["*.json", "*.md"]
# `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the purpose of this check is to verify if the necessary sharded files are present in the model repo before attempting a download, presumably to avoid a large download if all files aren't present. If we cannot connect to the hub, we just have to assume the necessary shard files are already present locally.

I think we can just skip this check if local_files_only=True and then check if all the shard filenames are present in the cached_folder

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about now?

local = False
if local_files_only:
temp_dir = snapshot_download(
repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, local_files_only=local_files_only
)
model_files_info = _get_filepaths_for_folder(temp_dir)
local = True
else:
try:
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
except ConnectionError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e

for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if local:
shard_file_present = any(shard_file in k for k in model_files_info)
else:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
Expand Down Expand Up @@ -441,6 +460,16 @@ def _get_checkpoint_shard_files(
return cached_filenames, sharded_metadata


def _get_filepaths_for_folder(folder):
relative_paths = []
for root, dirs, files in os.walk(folder):
for fname in files:
abs_path = os.path.join(root, fname)
rel_path = os.path.relpath(abs_path, start=folder)
relative_paths.append(rel_path)
return relative_paths


def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
if filenames and folder:
raise ValueError("Both `filenames` and `folder` cannot be provided.")
Expand Down
Loading