@@ -65,33 +65,18 @@ def get_args_and_kwargs_add(
65
65
dequants_inputs : List [fx .Node ],
66
66
quant_node : fx .Node ,
67
67
) -> Tuple [Tuple [ArgsType , ...], Dict [str , ArgsType ]]:
68
- X_scale_ = graph_module .graph .call_function (
69
- torch .ops .aten .full .default ,
70
- ([1 ], dequants_inputs [0 ].args [1 ]),
71
- {"dtype" : torch .float },
72
- )
73
- X_zero_point_ = graph_module .graph .call_function (
74
- torch .ops .aten .full .default ,
75
- ([1 ], dequants_inputs [0 ].args [2 ]),
76
- {"dtype" : torch .int32 },
77
- )
78
- Y_scale_ = graph_module .graph .call_function (
79
- torch .ops .aten .full .default ,
80
- ([1 ], dequants_inputs [1 ].args [1 ]),
81
- {"dtype" : torch .float },
82
- )
83
- Y_zero_point_ = graph_module .graph .call_function (
84
- torch .ops .aten .full .default ,
85
- ([1 ], dequants_inputs [1 ].args [2 ]),
86
- {"dtype" : torch .int32 },
87
- )
68
+ X_scale = dequants_inputs [0 ].args [1 ]
69
+
70
+ X_zero_point = dequants_inputs [0 ].args [2 ]
71
+ Y_scale = dequants_inputs [1 ].args [1 ]
72
+ Y_zero_point = dequants_inputs [1 ].args [2 ]
88
73
args = (
89
74
inputs_inputs [0 ],
90
- X_scale_ ,
91
- X_zero_point_ ,
75
+ X_scale ,
76
+ X_zero_point ,
92
77
inputs_inputs [1 ],
93
- Y_scale_ ,
94
- Y_zero_point_ ,
78
+ Y_scale ,
79
+ Y_zero_point ,
95
80
quant_node .args [1 ],
96
81
quant_node .args [2 ],
97
82
)
@@ -129,31 +114,12 @@ def get_args_and_kwargs_linear(
129
114
else :
130
115
bias = bias_inputs [0 ]
131
116
132
- # Create single element tensors for weight_zero_point, out_multiplier, out_shift.
133
- # Note that the function expects int32_t, when it would default to int64_t, so
134
- # we explicitly require that type.
135
- weight_zero_point_ = graph_module .graph .call_function (
136
- torch .ops .aten .full .default ,
137
- ([1 ], dequants_weights [0 ].args [2 ]),
138
- {"dtype" : torch .int32 },
139
- )
140
- out_multiplier_ = graph_module .graph .call_function (
141
- torch .ops .aten .full .default ,
142
- ([1 ], out_multiplier [0 ].item ()),
143
- {"dtype" : torch .int32 },
144
- )
145
- out_shift_ = graph_module .graph .call_function (
146
- torch .ops .aten .full .default ,
147
- ([1 ], out_shift [0 ].item ()),
148
- {"dtype" : torch .int32 },
149
- )
150
-
151
117
args = tuple (inputs_inputs + weights_inputs + [bias ])
152
118
kwargs = {
153
119
"src_zero_point" : dequants_inputs [0 ].args [2 ],
154
- "weight_zero_point" : weight_zero_point_ ,
155
- "out_multiplier" : out_multiplier_ ,
156
- "out_shift" : out_shift_ ,
120
+ "weight_zero_point" : dequants_weights [ 0 ]. args [ 2 ] ,
121
+ "out_multiplier" : out_multiplier [ 0 ]. item () ,
122
+ "out_shift" : out_shift [ 0 ]. item () ,
157
123
"out_zero_point" : quant_node .args [2 ],
158
124
"offset" : None ,
159
125
}
@@ -178,22 +144,8 @@ def get_args_and_kwargs_layer_norm(
178
144
), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars"
179
145
180
146
# Make the scale and zero_point tensors
181
- scale_tensor = graph_module .graph .call_function (
182
- torch .ops .aten .full .default ,
183
- (
184
- [1 ],
185
- dequants_inputs [0 ].args [1 ],
186
- ),
187
- {"dtype" : torch .float32 },
188
- )
189
- zero_point_tensor = graph_module .graph .call_function (
190
- torch .ops .aten .full .default ,
191
- (
192
- [1 ],
193
- dequants_inputs [0 ].args [2 ],
194
- ),
195
- {"dtype" : torch .int32 },
196
- )
147
+ scale = dequants_inputs [0 ].args [1 ]
148
+ zero_point = dequants_inputs [0 ].args [2 ]
197
149
198
150
weight = other_inputs [1 ] if len (other_inputs ) > 1 else None
199
151
@@ -220,7 +172,7 @@ def get_args_and_kwargs_layer_norm(
220
172
)
221
173
222
174
# Make the args and kwargs for the replacement op
223
- args = tuple (inputs_inputs + [scale_tensor ] + [ zero_point_tensor ])
175
+ args = tuple (inputs_inputs + [scale , zero_point ])
224
176
kwargs = {
225
177
"normalized_shape" : other_inputs [0 ],
226
178
"weight" : weight ,
@@ -308,31 +260,6 @@ def get_args_and_kwargs_conv(
308
260
309
261
(out_multiplier , out_shift ) = quantize_tensor_multiplier (requantize_scale_t )
310
262
311
- out_multiplier_ = graph_module .graph .call_function (
312
- torch .ops .aten .full .default ,
313
- ([1 ], out_multiplier [0 ].item ()),
314
- {"dtype" : torch .int32 },
315
- )
316
- out_shift_ = graph_module .graph .call_function (
317
- torch .ops .aten .full .default ,
318
- ([1 ], out_shift [0 ].item ()),
319
- {"dtype" : torch .int32 },
320
- )
321
-
322
- # Create a single element tensor for the weight zero point
323
- weight_zero_point_tensor = graph_module .graph .call_function (
324
- torch .ops .aten .full .default ,
325
- ([1 ], weight_zero_point ),
326
- {"dtype" : torch .int32 },
327
- )
328
-
329
- # Create a single element tensor for the bias scale
330
- bias_scale_tensor = graph_module .graph .call_function (
331
- torch .ops .aten .full .default ,
332
- ([1 ], bias_scale ),
333
- {"dtype" : torch .float32 },
334
- )
335
-
336
263
# Make the args and kwargs for the replacement op
337
264
args = tuple (inputs_inputs + weights_inputs + [bias ])
338
265
kwargs = {
@@ -341,12 +268,12 @@ def get_args_and_kwargs_conv(
341
268
"dilation" : dilation ,
342
269
"groups" : groups ,
343
270
"input_zero_point" : dequants_inputs [0 ].args [2 ],
344
- "weight_zero_point" : weight_zero_point_tensor ,
345
- "bias_scale" : bias_scale_tensor ,
271
+ "weight_zero_point" : weight_zero_point ,
272
+ "bias_scale" : bias_scale ,
346
273
"out_scale" : quant_node .args [1 ],
347
274
"out_zero_point" : quant_node .args [2 ],
348
- "out_multiplier" : out_multiplier_ ,
349
- "out_shift" : out_shift_ ,
275
+ "out_multiplier" : out_multiplier [ 0 ]. item () ,
276
+ "out_shift" : out_shift [ 0 ]. item () ,
350
277
}
351
278
return args , kwargs
352
279
@@ -367,27 +294,11 @@ def get_args_and_kwargs_relu(
367
294
# Make the args and kwargs for the replacement op
368
295
args = tuple (inputs_inputs )
369
296
370
- X_zero_point = graph_module .graph .call_function (
371
- torch .ops .aten .full .default ,
372
- ([1 ], dequants_inputs [0 ].args [2 ]),
373
- {"dtype" : torch .int32 },
374
- )
375
- out_multiplier_ = graph_module .graph .call_function (
376
- torch .ops .aten .full .default ,
377
- ([1 ], out_multiplier [0 ].item ()),
378
- {"dtype" : torch .int32 },
379
- )
380
- out_shift_ = graph_module .graph .call_function (
381
- torch .ops .aten .full .default ,
382
- ([1 ], out_shift [0 ].item ()),
383
- {"dtype" : torch .int32 },
384
- )
385
-
386
297
kwargs = {
387
- "X_zero_point" : X_zero_point ,
298
+ "X_zero_point" : dequants_inputs [ 0 ]. args [ 2 ] ,
388
299
"out_zero_point" : quant_node .args [2 ],
389
- "out_multiplier" : out_multiplier_ ,
390
- "out_shift" : out_shift_ ,
300
+ "out_multiplier" : out_multiplier [ 0 ]. item () ,
301
+ "out_shift" : out_shift [ 0 ]. item () ,
391
302
}
392
303
return args , kwargs
393
304
@@ -435,48 +346,20 @@ def get_args_and_kwargs_softmax(
435
346
{"dtype" : torch .int32 },
436
347
)
437
348
# Make the scale and zero_point tensors
438
- in_scale_tensor = graph_module .graph .call_function (
439
- torch .ops .aten .full .default ,
440
- (
441
- [1 ],
442
- dequants_inputs [0 ].args [1 ],
443
- ),
444
- {"dtype" : torch .float32 },
445
- )
446
- in_zero_point_tensor = graph_module .graph .call_function (
447
- torch .ops .aten .full .default ,
448
- (
449
- [1 ],
450
- dequants_inputs [0 ].args [2 ],
451
- ),
452
- {"dtype" : torch .int32 },
453
- )
454
- out_scale_tensor = graph_module .graph .call_function (
455
- torch .ops .aten .full .default ,
456
- (
457
- [1 ],
458
- quant_node .args [1 ],
459
- ),
460
- {"dtype" : torch .float32 },
461
- )
462
- out_zero_point_tensor = graph_module .graph .call_function (
463
- torch .ops .aten .full .default ,
464
- (
465
- [1 ],
466
- quant_node .args [2 ],
467
- ),
468
- {"dtype" : torch .int32 },
469
- )
349
+ in_scale = dequants_inputs [0 ].args [1 ]
350
+ in_zero_point = dequants_inputs [0 ].args [2 ]
351
+ out_scale = quant_node .args [1 ]
352
+ out_zero_point = quant_node .args [2 ]
470
353
471
354
# Make the args and kwargs for the replacement op
472
355
args = (
473
356
inputs_inputs [0 ],
474
357
mask_tensor ,
475
358
op_node .args [1 ],
476
- in_scale_tensor ,
477
- in_zero_point_tensor ,
478
- out_scale_tensor ,
479
- out_zero_point_tensor ,
359
+ in_scale ,
360
+ in_zero_point ,
361
+ out_scale ,
362
+ out_zero_point ,
480
363
)
481
364
kwargs = {}
482
365
0 commit comments