Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 3c6091c

Browse files
authored
Expand supported types of recipes (#109) (#110)
* Expand supported types of recipes * fix for tests
1 parent ff43990 commit 3c6091c

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

src/sparsezoo/objects/recipe.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class RecipeTypes(Enum):
5050
"""
5151

5252
ORIGINAL = "original"
53+
SPARSE = "sparse"
54+
TRANSFER = "transfer"
5355
TRANSFER_LEARN = "transfer_learn"
5456

5557

@@ -455,9 +457,6 @@ def search_sparse_recipes(
455457
"""
456458
from sparsezoo.objects.model import Model
457459

458-
if isinstance(recipe_type, str):
459-
recipe_type = RecipeTypes(recipe_type).value
460-
461460
if not isinstance(model, Model):
462461
model = Model.load_model_from_stub(model)
463462

@@ -508,15 +507,21 @@ def recipe_type_original(self) -> bool:
508507
:return: True if this is the original recipe that created the
509508
model, False otherwise
510509
"""
511-
return self.recipe_type == RecipeTypes.ORIGINAL.value
510+
return any(
511+
self.recipe_type.startswith(start)
512+
for start in [RecipeTypes.ORIGINAL.value, RecipeTypes.SPARSE.value]
513+
)
512514

513515
@property
514516
def recipe_type_transfer_learn(self) -> bool:
515517
"""
516518
:return: True if this is a recipe for transfer learning from the
517519
created model, False otherwise
518520
"""
519-
return self.recipe_type == RecipeTypes.TRANSFER_LEARN.value
521+
return any(
522+
self.recipe_type.startswith(start)
523+
for start in [RecipeTypes.TRANSFER.value, RecipeTypes.TRANSFER_LEARN.value]
524+
)
520525

521526
@property
522527
def display_name(self):
@@ -653,15 +658,17 @@ def download_base_framework_files(
653658
return base_framework_files or framework_files
654659

655660

656-
def _get_stub_args_recipe_type(stub_args: Dict[str, str]) -> str:
661+
def _get_stub_args_recipe_type(stub_args: Dict[str, str]) -> Optional[str]:
657662
# check recipe type, default to original, and validate
658663
recipe_type = stub_args.get("recipe_type")
659-
660-
# validate
661664
valid_recipe_types = list(map(lambda typ: typ.value, RecipeTypes))
662-
if recipe_type not in valid_recipe_types and recipe_type is not None:
665+
666+
if recipe_type is not None and not any(
667+
recipe_type.startswith(start) for start in valid_recipe_types
668+
):
663669
raise ValueError(
664670
f"Invalid recipe_type: '{recipe_type}'. "
665-
f"Valid recipe types: {valid_recipe_types}"
671+
f"Valid recipes must start with one of: {valid_recipe_types}"
666672
)
673+
667674
return recipe_type

0 commit comments

Comments
 (0)