Skip to content

Commit fc63fc3

Browse files
committed
Tag constants after quantization
1 parent d4129b7 commit fc63fc3

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,7 @@ def _to_edge_and_lower_llama_xnnpack(
854854
xnnpack_extended_ops: bool = False,
855855
generate_etrecord: bool = False,
856856
verbose: bool = False,
857+
gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None,
857858
) -> LLMEdgeManager: # noqa: C901
858859
partitioners = []
859860

@@ -876,9 +877,25 @@ def _to_edge_and_lower_llama_xnnpack(
876877
if generate_etrecord:
877878
builder_exported.generate_etrecord = True
878879

879-
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
880-
partitioners
881-
)
880+
builder = builder_exported.pt2e_quantize(quantizers)
881+
if gen_tag_fn is not None:
882+
from executorch.exir.passes.external_constants_pass import (
883+
delegate_external_constants_pass_unlifted,
884+
external_constants_pass,
885+
)
886+
assert (
887+
builder_exported.pre_autograd_graph_module is not None
888+
), "pre_autograd_graph_module shouldn't be None here"
889+
delegate_external_constants_pass_unlifted(
890+
module=builder_exported.pre_autograd_graph_module,
891+
gen_tag_fn=gen_tag_fn,
892+
)
893+
# Also add a pass for 'to_executorch' to tag weights that aren't delegated.
894+
additional_passes.append(
895+
partial(external_constants_pass, gen_tag_fn=gen_tag_fn)
896+
)
897+
898+
builder = builder.to_edge_transform_and_lower(partitioners)
882899
if verbose:
883900
print_delegation_info(builder.edge_manager.exported_program().graph_module)
884901

@@ -1088,31 +1105,14 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10881105
llm_config.backend.xnnpack.enabled = True
10891106

10901107
if llm_config.backend.xnnpack.enabled:
1108+
gen_tag_fn = None
10911109
if llm_config.export.foundation_weights_file is not None:
10921110
gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: (
10931111
llm_config.export.foundation_weights_file
10941112
if "lora" not in x.name
10951113
else None
10961114
)
10971115

1098-
from executorch.exir.passes.external_constants_pass import (
1099-
delegate_external_constants_pass_unlifted,
1100-
external_constants_pass,
1101-
)
1102-
1103-
assert (
1104-
builder_exported.pre_autograd_graph_module is not None
1105-
), "pre_autograd_graph_module shouldn't be None here"
1106-
delegate_external_constants_pass_unlifted(
1107-
module=builder_exported.pre_autograd_graph_module,
1108-
gen_tag_fn=gen_tag_fn,
1109-
)
1110-
1111-
# Also add a pass for 'to_executorch' to tag weights that aren't delegated.
1112-
additional_passes.append(
1113-
partial(external_constants_pass, gen_tag_fn=gen_tag_fn)
1114-
)
1115-
11161116
builder = _to_edge_and_lower_llama_xnnpack(
11171117
builder_exported,
11181118
modelname,
@@ -1123,6 +1123,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11231123
xnnpack_extended_ops=llm_config.backend.xnnpack.extended_ops,
11241124
generate_etrecord=llm_config.debug.generate_etrecord,
11251125
verbose=llm_config.debug.verbose,
1126+
gen_tag_fn=gen_tag_fn,
11261127
)
11271128
else:
11281129
builder = _to_edge_and_lower_llama(

0 commit comments

Comments
 (0)