@@ -854,6 +854,7 @@ def _to_edge_and_lower_llama_xnnpack(
854
854
xnnpack_extended_ops : bool = False ,
855
855
generate_etrecord : bool = False ,
856
856
verbose : bool = False ,
857
+ gen_tag_fn : Optional [Callable [[torch .fx .Node ], Optional [str ]]] = None ,
857
858
) -> LLMEdgeManager : # noqa: C901
858
859
partitioners = []
859
860
@@ -876,9 +877,25 @@ def _to_edge_and_lower_llama_xnnpack(
876
877
if generate_etrecord :
877
878
builder_exported .generate_etrecord = True
878
879
879
- builder = builder_exported .pt2e_quantize (quantizers ).to_edge_transform_and_lower (
880
- partitioners
881
- )
880
+ builder = builder_exported .pt2e_quantize (quantizers )
881
+ if gen_tag_fn is not None :
882
+ from executorch .exir .passes .external_constants_pass import (
883
+ delegate_external_constants_pass_unlifted ,
884
+ external_constants_pass ,
885
+ )
886
+ assert (
887
+ builder_exported .pre_autograd_graph_module is not None
888
+ ), "pre_autograd_graph_module shouldn't be None here"
889
+ delegate_external_constants_pass_unlifted (
890
+ module = builder_exported .pre_autograd_graph_module ,
891
+ gen_tag_fn = gen_tag_fn ,
892
+ )
893
+ # Also add a pass for 'to_executorch' to tag weights that aren't delegated.
894
+ additional_passes .append (
895
+ partial (external_constants_pass , gen_tag_fn = gen_tag_fn )
896
+ )
897
+
898
+ builder = builder .to_edge_transform_and_lower (partitioners )
882
899
if verbose :
883
900
print_delegation_info (builder .edge_manager .exported_program ().graph_module )
884
901
@@ -1088,31 +1105,14 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
1088
1105
llm_config .backend .xnnpack .enabled = True
1089
1106
1090
1107
if llm_config .backend .xnnpack .enabled :
1108
+ gen_tag_fn = None
1091
1109
if llm_config .export .foundation_weights_file is not None :
1092
1110
gen_tag_fn : Callable [[torch .fx .Node ], Optional [str ]] = lambda x : (
1093
1111
llm_config .export .foundation_weights_file
1094
1112
if "lora" not in x .name
1095
1113
else None
1096
1114
)
1097
1115
1098
- from executorch .exir .passes .external_constants_pass import (
1099
- delegate_external_constants_pass_unlifted ,
1100
- external_constants_pass ,
1101
- )
1102
-
1103
- assert (
1104
- builder_exported .pre_autograd_graph_module is not None
1105
- ), "pre_autograd_graph_module shouldn't be None here"
1106
- delegate_external_constants_pass_unlifted (
1107
- module = builder_exported .pre_autograd_graph_module ,
1108
- gen_tag_fn = gen_tag_fn ,
1109
- )
1110
-
1111
- # Also add a pass for 'to_executorch' to tag weights that aren't delegated.
1112
- additional_passes .append (
1113
- partial (external_constants_pass , gen_tag_fn = gen_tag_fn )
1114
- )
1115
-
1116
1116
builder = _to_edge_and_lower_llama_xnnpack (
1117
1117
builder_exported ,
1118
1118
modelname ,
@@ -1123,6 +1123,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
1123
1123
xnnpack_extended_ops = llm_config .backend .xnnpack .extended_ops ,
1124
1124
generate_etrecord = llm_config .debug .generate_etrecord ,
1125
1125
verbose = llm_config .debug .verbose ,
1126
+ gen_tag_fn = gen_tag_fn ,
1126
1127
)
1127
1128
else :
1128
1129
builder = _to_edge_and_lower_llama (
0 commit comments