Skip to content

Commit 33a47e1

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Refactor quantizer: Only replace with per-tensor variants (#14974)
Summary: In our previous flow, we would replace ops with default variants, have a special fusion pass which constructs singleton tensors for a variety of fused quantized ops, and then we would call a replace ops to turn them into per-tensor-variants. I confirmed this was for legacy reasons, so a cleanup was much due. This diff also fixes any ref implementations during the refactor. Reviewed By: zonglinpeng Differential Revision: D83873738
1 parent d2672a6 commit 33a47e1

File tree

5 files changed

+122
-404
lines changed

5 files changed

+122
-404
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ python_unittest(
425425
"//executorch/exir:pass_base",
426426
"//executorch/exir/dialects:lib",
427427
"//executorch/exir/passes:lib",
428+
":ref_implementations",
428429
],
429430
)
430431

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 30 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -65,33 +65,18 @@ def get_args_and_kwargs_add(
6565
dequants_inputs: List[fx.Node],
6666
quant_node: fx.Node,
6767
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
68-
X_scale_ = graph_module.graph.call_function(
69-
torch.ops.aten.full.default,
70-
([1], dequants_inputs[0].args[1]),
71-
{"dtype": torch.float},
72-
)
73-
X_zero_point_ = graph_module.graph.call_function(
74-
torch.ops.aten.full.default,
75-
([1], dequants_inputs[0].args[2]),
76-
{"dtype": torch.int32},
77-
)
78-
Y_scale_ = graph_module.graph.call_function(
79-
torch.ops.aten.full.default,
80-
([1], dequants_inputs[1].args[1]),
81-
{"dtype": torch.float},
82-
)
83-
Y_zero_point_ = graph_module.graph.call_function(
84-
torch.ops.aten.full.default,
85-
([1], dequants_inputs[1].args[2]),
86-
{"dtype": torch.int32},
87-
)
68+
X_scale = dequants_inputs[0].args[1]
69+
70+
X_zero_point = dequants_inputs[0].args[2]
71+
Y_scale = dequants_inputs[1].args[1]
72+
Y_zero_point = dequants_inputs[1].args[2]
8873
args = (
8974
inputs_inputs[0],
90-
X_scale_,
91-
X_zero_point_,
75+
X_scale,
76+
X_zero_point,
9277
inputs_inputs[1],
93-
Y_scale_,
94-
Y_zero_point_,
78+
Y_scale,
79+
Y_zero_point,
9580
quant_node.args[1],
9681
quant_node.args[2],
9782
)
@@ -129,31 +114,12 @@ def get_args_and_kwargs_linear(
129114
else:
130115
bias = bias_inputs[0]
131116

132-
# Create single element tensors for weight_zero_point, out_multiplier, out_shift.
133-
# Note that the function expects int32_t, when it would default to int64_t, so
134-
# we explicitly require that type.
135-
weight_zero_point_ = graph_module.graph.call_function(
136-
torch.ops.aten.full.default,
137-
([1], dequants_weights[0].args[2]),
138-
{"dtype": torch.int32},
139-
)
140-
out_multiplier_ = graph_module.graph.call_function(
141-
torch.ops.aten.full.default,
142-
([1], out_multiplier[0].item()),
143-
{"dtype": torch.int32},
144-
)
145-
out_shift_ = graph_module.graph.call_function(
146-
torch.ops.aten.full.default,
147-
([1], out_shift[0].item()),
148-
{"dtype": torch.int32},
149-
)
150-
151117
args = tuple(inputs_inputs + weights_inputs + [bias])
152118
kwargs = {
153119
"src_zero_point": dequants_inputs[0].args[2],
154-
"weight_zero_point": weight_zero_point_,
155-
"out_multiplier": out_multiplier_,
156-
"out_shift": out_shift_,
120+
"weight_zero_point": dequants_weights[0].args[2],
121+
"out_multiplier": out_multiplier[0].item(),
122+
"out_shift": out_shift[0].item(),
157123
"out_zero_point": quant_node.args[2],
158124
"offset": None,
159125
}
@@ -178,22 +144,8 @@ def get_args_and_kwargs_layer_norm(
178144
), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars"
179145

180146
# Make the scale and zero_point tensors
181-
scale_tensor = graph_module.graph.call_function(
182-
torch.ops.aten.full.default,
183-
(
184-
[1],
185-
dequants_inputs[0].args[1],
186-
),
187-
{"dtype": torch.float32},
188-
)
189-
zero_point_tensor = graph_module.graph.call_function(
190-
torch.ops.aten.full.default,
191-
(
192-
[1],
193-
dequants_inputs[0].args[2],
194-
),
195-
{"dtype": torch.int32},
196-
)
147+
scale = dequants_inputs[0].args[1]
148+
zero_point = dequants_inputs[0].args[2]
197149

198150
weight = other_inputs[1] if len(other_inputs) > 1 else None
199151

@@ -220,7 +172,7 @@ def get_args_and_kwargs_layer_norm(
220172
)
221173

222174
# Make the args and kwargs for the replacement op
223-
args = tuple(inputs_inputs + [scale_tensor] + [zero_point_tensor])
175+
args = tuple(inputs_inputs + [scale, zero_point])
224176
kwargs = {
225177
"normalized_shape": other_inputs[0],
226178
"weight": weight,
@@ -308,31 +260,6 @@ def get_args_and_kwargs_conv(
308260

309261
(out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)
310262

311-
out_multiplier_ = graph_module.graph.call_function(
312-
torch.ops.aten.full.default,
313-
([1], out_multiplier[0].item()),
314-
{"dtype": torch.int32},
315-
)
316-
out_shift_ = graph_module.graph.call_function(
317-
torch.ops.aten.full.default,
318-
([1], out_shift[0].item()),
319-
{"dtype": torch.int32},
320-
)
321-
322-
# Create a single element tensor for the weight zero point
323-
weight_zero_point_tensor = graph_module.graph.call_function(
324-
torch.ops.aten.full.default,
325-
([1], weight_zero_point),
326-
{"dtype": torch.int32},
327-
)
328-
329-
# Create a single element tensor for the bias scale
330-
bias_scale_tensor = graph_module.graph.call_function(
331-
torch.ops.aten.full.default,
332-
([1], bias_scale),
333-
{"dtype": torch.float32},
334-
)
335-
336263
# Make the args and kwargs for the replacement op
337264
args = tuple(inputs_inputs + weights_inputs + [bias])
338265
kwargs = {
@@ -341,12 +268,12 @@ def get_args_and_kwargs_conv(
341268
"dilation": dilation,
342269
"groups": groups,
343270
"input_zero_point": dequants_inputs[0].args[2],
344-
"weight_zero_point": weight_zero_point_tensor,
345-
"bias_scale": bias_scale_tensor,
271+
"weight_zero_point": weight_zero_point,
272+
"bias_scale": bias_scale,
346273
"out_scale": quant_node.args[1],
347274
"out_zero_point": quant_node.args[2],
348-
"out_multiplier": out_multiplier_,
349-
"out_shift": out_shift_,
275+
"out_multiplier": out_multiplier[0].item(),
276+
"out_shift": out_shift[0].item(),
350277
}
351278
return args, kwargs
352279

@@ -367,27 +294,11 @@ def get_args_and_kwargs_relu(
367294
# Make the args and kwargs for the replacement op
368295
args = tuple(inputs_inputs)
369296

370-
X_zero_point = graph_module.graph.call_function(
371-
torch.ops.aten.full.default,
372-
([1], dequants_inputs[0].args[2]),
373-
{"dtype": torch.int32},
374-
)
375-
out_multiplier_ = graph_module.graph.call_function(
376-
torch.ops.aten.full.default,
377-
([1], out_multiplier[0].item()),
378-
{"dtype": torch.int32},
379-
)
380-
out_shift_ = graph_module.graph.call_function(
381-
torch.ops.aten.full.default,
382-
([1], out_shift[0].item()),
383-
{"dtype": torch.int32},
384-
)
385-
386297
kwargs = {
387-
"X_zero_point": X_zero_point,
298+
"X_zero_point": dequants_inputs[0].args[2],
388299
"out_zero_point": quant_node.args[2],
389-
"out_multiplier": out_multiplier_,
390-
"out_shift": out_shift_,
300+
"out_multiplier": out_multiplier[0].item(),
301+
"out_shift": out_shift[0].item(),
391302
}
392303
return args, kwargs
393304

@@ -435,48 +346,20 @@ def get_args_and_kwargs_softmax(
435346
{"dtype": torch.int32},
436347
)
437348
# Make the scale and zero_point tensors
438-
in_scale_tensor = graph_module.graph.call_function(
439-
torch.ops.aten.full.default,
440-
(
441-
[1],
442-
dequants_inputs[0].args[1],
443-
),
444-
{"dtype": torch.float32},
445-
)
446-
in_zero_point_tensor = graph_module.graph.call_function(
447-
torch.ops.aten.full.default,
448-
(
449-
[1],
450-
dequants_inputs[0].args[2],
451-
),
452-
{"dtype": torch.int32},
453-
)
454-
out_scale_tensor = graph_module.graph.call_function(
455-
torch.ops.aten.full.default,
456-
(
457-
[1],
458-
quant_node.args[1],
459-
),
460-
{"dtype": torch.float32},
461-
)
462-
out_zero_point_tensor = graph_module.graph.call_function(
463-
torch.ops.aten.full.default,
464-
(
465-
[1],
466-
quant_node.args[2],
467-
),
468-
{"dtype": torch.int32},
469-
)
349+
in_scale = dequants_inputs[0].args[1]
350+
in_zero_point = dequants_inputs[0].args[2]
351+
out_scale = quant_node.args[1]
352+
out_zero_point = quant_node.args[2]
470353

471354
# Make the args and kwargs for the replacement op
472355
args = (
473356
inputs_inputs[0],
474357
mask_tensor,
475358
op_node.args[1],
476-
in_scale_tensor,
477-
in_zero_point_tensor,
478-
out_scale_tensor,
479-
out_zero_point_tensor,
359+
in_scale,
360+
in_zero_point,
361+
out_scale,
362+
out_zero_point,
480363
)
481364
kwargs = {}
482365

backends/cadence/aot/quantizer/patterns.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def get_anchors(
112112
)
113113

114114
def replacement_op(self) -> OpOverload:
115-
return torch.ops.cadence.quantized_linear.default
115+
return torch.ops.cadence.quantized_linear.per_tensor
116116

117117

118118
class AddPattern(QuantizationPattern):
@@ -150,7 +150,7 @@ def get_anchors(
150150
)
151151

152152
def replacement_op(self) -> OpOverload:
153-
return torch.ops.cadence.quantized_add.default
153+
return torch.ops.cadence.quantized_add.per_tensor
154154

155155

156156
class BmmPattern(QuantizationPattern):
@@ -265,7 +265,7 @@ def get_anchors(
265265
)
266266

267267
def replacement_op(self) -> OpOverload:
268-
return torch.ops.cadence.quantized_conv2d_nchw.default
268+
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
269269

270270

271271
class Conv2dPattern(QuantizationPattern):
@@ -307,7 +307,7 @@ def get_anchors(
307307
)
308308

309309
def replacement_op(self) -> OpOverload:
310-
return torch.ops.cadence.quantized_conv2d_nchw.default
310+
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
311311

312312

313313
class LayerNormPattern(QuantizationPattern):
@@ -345,7 +345,7 @@ def get_anchors(
345345
)
346346

347347
def replacement_op(self) -> OpOverload:
348-
return torch.ops.cadence.quantized_layer_norm.default
348+
return torch.ops.cadence.quantized_layer_norm.per_tensor
349349

350350

351351
class LinearPattern(QuantizationPattern):
@@ -387,7 +387,7 @@ def get_anchors(
387387
)
388388

389389
def replacement_op(self) -> OpOverload:
390-
return torch.ops.cadence.quantized_linear.default
390+
return torch.ops.cadence.quantized_linear.per_tensor
391391

392392

393393
class MatmulPattern(QuantizationPattern):
@@ -411,6 +411,7 @@ def get_anchors(
411411
)
412412

413413
def replacement_op(self) -> OpOverload:
414+
# TODO: T240804887 This is actually a per-tensor variant, we just need to change the name of the op
414415
return torch.ops.cadence.quantized_matmul.default
415416

416417

@@ -437,7 +438,7 @@ def get_anchors(
437438
)
438439

439440
def replacement_op(self) -> OpOverload:
440-
return torch.ops.cadence.quantized_relu.default
441+
return torch.ops.cadence.quantized_relu.per_tensor
441442

442443

443444
# Regular relu op
@@ -496,7 +497,7 @@ def get_anchors(
496497
)
497498

498499
def replacement_op(self) -> OpOverload:
499-
return torch.ops.cadence.quantized_conv2d_nchw.default
500+
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
500501

501502

502503
# Conv1d + regular relu op fusion
@@ -544,7 +545,7 @@ def get_anchors(
544545
)
545546

546547
def replacement_op(self) -> OpOverload:
547-
return torch.ops.cadence.quantized_softmax.default
548+
return torch.ops.cadence.quantized_softmax.per_tensor
548549

549550

550551
class MixedW8A32LinearPattern(QuantizationPattern):
@@ -598,7 +599,7 @@ def get_anchors(
598599
)
599600

600601
def replacement_op(self) -> OpOverload:
601-
return torch.ops.cadence.quantized_w8a32_linear.default
602+
return torch.ops.cadence.quantized_w8a32_linear.per_tensor
602603

603604

604605
class MixedW8A32ConvPattern(QuantizationPattern):
@@ -660,4 +661,4 @@ def get_anchors(
660661
)
661662

662663
def replacement_op(self) -> OpOverload:
663-
return torch.ops.cadence.quantized_w8a32_conv.default
664+
return torch.ops.cadence.quantized_w8a32_conv.per_tensor

0 commit comments

Comments
 (0)