diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index cf85488b7aa0..fcdf49156a8f 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -402,15 +402,17 @@ def _get_checkpoint_shard_files( allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns] ignore_patterns = ["*.json", "*.md"] - # `model_info` call must guarded with the above condition. - model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) - for shard_file in original_shard_filenames: - shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) - if not shard_file_present: - raise EnvironmentError( - f"{shards_path} does not appear to have a file named {shard_file} which is " - "required according to the checkpoint index." - ) + + # If the repo doesn't have the required shards, error out early even before downloading anything. + if not local_files_only: + model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) + for shard_file in original_shard_filenames: + shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) + if not shard_file_present: + raise EnvironmentError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) try: # Load from URL @@ -437,6 +439,11 @@ def _get_checkpoint_shard_files( ) from e cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames] + for cached_file in cached_filenames: + if not os.path.isfile(cached_file): + raise EnvironmentError( + f"{cached_folder} does not have a file named {cached_file} which is required according to the checkpoint index." + ) return cached_filenames, sharded_metadata diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 0e16f95a4276..1e08191f56aa 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -36,12 +36,12 @@ import torch import torch.nn as nn from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size -from huggingface_hub import ModelCard, delete_repo, snapshot_download +from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache from huggingface_hub.utils import is_jinja_available from parameterized import parameterized from requests.exceptions import HTTPError -from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel +from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel from diffusers.models.attention_processor import ( AttnProcessor, AttnProcessor2_0, @@ -291,6 +291,54 @@ def test_cached_files_are_used_when_no_internet(self): if p1.data.ne(p2.data).sum() > 0: assert False, "Parameters not the same!" + def test_local_files_only_with_sharded_checkpoint(self): + repo_id = "hf-internal-testing/tiny-flux-sharded" + error_response = mock.Mock( + status_code=500, + headers={}, + raise_for_status=mock.Mock(side_effect=HTTPError), + json=mock.Mock(return_value={}), + ) + + with tempfile.TemporaryDirectory() as tmpdir: + model = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=tmpdir) + + with mock.patch("requests.Session.get", return_value=error_response): + # Should fail with local_files_only=False (network required) + # We would make a network call with model_info + with self.assertRaises(OSError): + FluxTransformer2DModel.from_pretrained( + repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False + ) + + # Should succeed with local_files_only=True (uses cache) + # model_info call skipped + local_model = FluxTransformer2DModel.from_pretrained( + repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True + ) + + assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), ( + "Model parameters don't match!" + ) + + # Remove a shard file + cached_shard_file = try_to_load_from_cache( + repo_id, filename="transformer/diffusion_pytorch_model-00001-of-00002.safetensors", cache_dir=tmpdir + ) + os.remove(cached_shard_file) + + # Attempting to load from cache should raise an error + with self.assertRaises(OSError) as context: + FluxTransformer2DModel.from_pretrained( + repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True + ) + + # Verify error mentions the missing shard + error_msg = str(context.exception) + assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, ( + f"Expected error about missing shard, got: {error_msg}" + ) + @unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners") @unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.") def test_one_request_upon_cached(self):