@@ -50,6 +50,8 @@ class RecipeTypes(Enum):
50
50
"""
51
51
52
52
ORIGINAL = "original"
53
+ SPARSE = "sparse"
54
+ TRANSFER = "transfer"
53
55
TRANSFER_LEARN = "transfer_learn"
54
56
55
57
@@ -455,9 +457,6 @@ def search_sparse_recipes(
455
457
"""
456
458
from sparsezoo .objects .model import Model
457
459
458
- if isinstance (recipe_type , str ):
459
- recipe_type = RecipeTypes (recipe_type ).value
460
-
461
460
if not isinstance (model , Model ):
462
461
model = Model .load_model_from_stub (model )
463
462
@@ -508,15 +507,21 @@ def recipe_type_original(self) -> bool:
508
507
:return: True if this is the original recipe that created the
509
508
model, False otherwise
510
509
"""
511
- return self .recipe_type == RecipeTypes .ORIGINAL .value
510
+ return any (
511
+ self .recipe_type .startswith (start )
512
+ for start in [RecipeTypes .ORIGINAL .value , RecipeTypes .SPARSE .value ]
513
+ )
512
514
513
515
@property
514
516
def recipe_type_transfer_learn (self ) -> bool :
515
517
"""
516
518
:return: True if this is a recipe for transfer learning from the
517
519
created model, False otherwise
518
520
"""
519
- return self .recipe_type == RecipeTypes .TRANSFER_LEARN .value
521
+ return any (
522
+ self .recipe_type .startswith (start )
523
+ for start in [RecipeTypes .TRANSFER .value , RecipeTypes .TRANSFER_LEARN .value ]
524
+ )
520
525
521
526
@property
522
527
def display_name (self ):
@@ -653,15 +658,17 @@ def download_base_framework_files(
653
658
return base_framework_files or framework_files
654
659
655
660
656
- def _get_stub_args_recipe_type (stub_args : Dict [str , str ]) -> str :
661
+ def _get_stub_args_recipe_type (stub_args : Dict [str , str ]) -> Optional [ str ] :
657
662
# check recipe type, default to original, and validate
658
663
recipe_type = stub_args .get ("recipe_type" )
659
-
660
- # validate
661
664
valid_recipe_types = list (map (lambda typ : typ .value , RecipeTypes ))
662
- if recipe_type not in valid_recipe_types and recipe_type is not None :
665
+
666
+ if recipe_type is not None and not any (
667
+ recipe_type .startswith (start ) for start in valid_recipe_types
668
+ ):
663
669
raise ValueError (
664
670
f"Invalid recipe_type: '{ recipe_type } '. "
665
- f"Valid recipe types : { valid_recipe_types } "
671
+ f"Valid recipes must start with one of : { valid_recipe_types } "
666
672
)
673
+
667
674
return recipe_type
0 commit comments