-
Notifications
You must be signed in to change notification settings - Fork 9
Port onnx float16 from onnxconverter-common #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
6b1c5e3
to
5d05b60
Compare
Thank you - for this pass we would want to replace onnx.helper usages with relavant onnx_ir apis. Could you do that? |
You can start by removing the imports onnx, helper and onnx_proto. |
You may also refer to the rest of the files in the passes directory for examples |
@justinchuby thanks for looking at this so quickly, but can you be a bit more specific please? I don't see any functions in the other modules of the passes directory which look like they replicate the |
@justinchuby and when you say that you want me to remove the Like how it is referenced here: https://github.com/onnx/ir-py/blob/main/src/onnx_ir/passes/common/shape_inference.py#L22 |
Sorry I forgot to mention: the shape inference and checker passes are the only two not to reference because they need to directly use onnx apis. You many see https://github.com/onnx/ir-py/blob/main/src/onnx_ir/passes/common/constant_manipulation.py as an example. The essential idea is this: onnx_ir provides a complete set of apis to manipulate an onnx graph so you don’t need to work directly with protobuf. Therefore any import from the onnx package in a graph transformation pass is not needed. I.e. onnx.numpy_helper, onnx.helper etc. that will directly generate protobuf objects should not be used. |
You may find the api documentation here: https://onnx.ai/ir-py/api/index.html as well as an ai generated version: https://deepwiki.com/onnx/ir-py |
Specifically you may use: onnx.helper.make_node: ir.node() |
…ter-common Signed-off-by: bjeffrey92 <[email protected]>
Signed-off-by: bjeffrey92 <[email protected]>
Signed-off-by: bjeffrey92 <[email protected]>
5d05b60
to
365554f
Compare
|
||
import numpy as np | ||
import numpy.typing as npt | ||
import onnx |
Check warning
Code scanning / lintrunner
RUFF/TID251
import numpy.typing as npt | ||
import onnx | ||
import packaging.version as pv | ||
from onnx import helper, numpy_helper |
Check warning
Code scanning / lintrunner
RUFF/TID251
import numpy.typing as npt | ||
import onnx | ||
import packaging.version as pv | ||
from onnx import helper, numpy_helper |
Check warning
Code scanning / lintrunner
RUFF/TID251
import numpy.typing as npt | ||
import onnx | ||
import packaging.version as pv | ||
from onnx import helper, numpy_helper |
Check warning
Code scanning / lintrunner
RUFF/TID251
import onnx | ||
import packaging.version as pv | ||
from onnx import helper, numpy_helper | ||
from onnx import onnx_pb as onnx_proto |
Check warning
Code scanning / lintrunner
RUFF/TID251
np_array = np.where( | ||
between(float("-inf"), np_array, -max_finite_val), -max_finite_val, np_array | ||
) | ||
return np.float16(np_array) # pyright: ignore[reportReturnType] |
Check failure
Code scanning / lintrunner
MYPY/return-value
func_infer_shape = None | ||
if not disable_shape_infer and pv.Version(onnx.__version__) >= pv.Version("1.2"): # pyright: ignore[reportPrivateImportUsage] | ||
try: | ||
from onnx.shape_inference import infer_shapes |
Check warning
Code scanning / lintrunner
RUFF/TID251
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #86 +/- ##
==========================================
- Coverage 74.52% 71.66% -2.87%
==========================================
Files 38 39 +1
Lines 4687 4895 +208
Branches 957 1017 +60
==========================================
+ Hits 3493 3508 +15
- Misses 841 1033 +192
- Partials 353 354 +1 ☔ View full report in Codecov by Sentry. |
|
||
import numpy as np | ||
import numpy.typing as npt | ||
import onnx |
Check warning
Code scanning / lintrunner
RUFF/TID251 Warning
See https://docs.astral.sh/ruff/rules/banned-api
import numpy.typing as npt | ||
import onnx | ||
import packaging.version as pv | ||
from onnx import helper, numpy_helper |
Check warning
Code scanning / lintrunner
RUFF/TID251 Warning
See https://docs.astral.sh/ruff/rules/banned-api
import numpy.typing as npt | ||
import onnx | ||
import packaging.version as pv | ||
from onnx import helper, numpy_helper |
Check warning
Code scanning / lintrunner
RUFF/TID251 Warning
See https://docs.astral.sh/ruff/rules/banned-api
import numpy.typing as npt | ||
import onnx | ||
import packaging.version as pv | ||
from onnx import helper, numpy_helper |
Check warning
Code scanning / lintrunner
RUFF/TID251 Warning
See https://docs.astral.sh/ruff/rules/banned-api
import onnx | ||
import packaging.version as pv | ||
from onnx import helper, numpy_helper | ||
from onnx import onnx_pb as onnx_proto |
Check warning
Code scanning / lintrunner
RUFF/TID251 Warning
See https://docs.astral.sh/ruff/rules/banned-api
np_array = np.where( | ||
between(float("-inf"), np_array, -max_finite_val), -max_finite_val, np_array | ||
) | ||
return np.float16(np_array) # pyright: ignore[reportReturnType] |
Check failure
Code scanning / lintrunner
MYPY/return-value Error
func_infer_shape = None | ||
if not disable_shape_infer and pv.Version(onnx.__version__) >= pv.Version("1.2"): # pyright: ignore[reportPrivateImportUsage] | ||
try: | ||
from onnx.shape_inference import infer_shapes |
Check warning
Code scanning / lintrunner
RUFF/TID251 Warning
See https://docs.astral.sh/ruff/rules/banned-api
A few weeks ago I created an issue on the
onnx
library repo enquiring about the state of theonnxconverter-common
library, which appears to no longer be actively maintained. @justinchuby responded to me and suggested that the module ofonnxconverter-common
which I require, for converting a single precision onnx model to mixed precision, could be ported into this library.Here, I have ported the relevant functions from the
float16
module fromonnxconverter-common
(https://github.com/microsoft/onnxconverter-common/blob/master/onnxconverter_common/float16.py) to this repo at the location suggested by @justinchuby. I also copied the relevant tests across.Note, that this is not originally my code, hence I have retained the attribution and the original MIT license information with the source code. The modifications I made were just to:
I have not made any material change to the logic of any of the functions in the original module in
onnxconverter-common
.