From 63e08cc6cbcd1b822f293b4eaf767c28ef5d0cb2 Mon Sep 17 00:00:00 2001 From: bjeffrey92 <36240394+bjeffrey92@users.noreply.github.com> Date: Tue, 17 Jun 2025 12:23:45 +0100 Subject: [PATCH 1/3] ported module for converting onnx model to mixed pres from onnxconverter-common Signed-off-by: bjeffrey92 <36240394+bjeffrey92@users.noreply.github.com> --- src/onnx_ir/passes/common/onnx_float_16.py | 431 +++++++++++++++++++++ tests/onnx_float_16_test.py | 62 +++ 2 files changed, 493 insertions(+) create mode 100644 src/onnx_ir/passes/common/onnx_float_16.py create mode 100644 tests/onnx_float_16_test.py diff --git a/src/onnx_ir/passes/common/onnx_float_16.py b/src/onnx_ir/passes/common/onnx_float_16.py new file mode 100644 index 00000000..7f9a30d1 --- /dev/null +++ b/src/onnx_ir/passes/common/onnx_float_16.py @@ -0,0 +1,431 @@ +import itertools +import warnings + +import numpy as np +import numpy.typing as npt +import onnx +import packaging.version as pv +from onnx import helper, numpy_helper +from onnx import onnx_pb as onnx_proto + +DEFAULT_OP_BLOCK_LIST = [ + "ArrayFeatureExtractor", + "Binarizer", + "CastMap", + "CategoryMapper", + "DictVectorizer", + "FeatureVectorizer", + "Imputer", + "LabelEncoder", + "LinearClassifier", + "LinearRegressor", + "Normalizer", + "OneHotEncoder", + "RandomUniformLike", + "SVMClassifier", + "SVMRegressor", + "Scaler", + "TreeEnsembleClassifier", + "TreeEnsembleRegressor", + "ZipMap", + "NonMaxSuppression", + "TopK", + "RoiAlign", + "Resize", + "Range", + "CumSum", + "Min", + "Max", + "Upsample", +] + + +def _npfloat16_to_int(np_list: list[np.float16] | npt.NDArray[np.float16]) -> list[int]: + """Convert numpy float16 to python int. + + :param np_list: numpy float16 list + :return int_list: python int list + """ + return [int(bin(_.view("H"))[2:].zfill(16), 2) for _ in np_list] + + +def convert_np_to_float16( + np_array: npt.NDArray[np.float32], + min_positive_val: float = 1e-7, + max_finite_val: float = 1e4, +) -> npt.NDArray[np.float16]: + """Convert float32 numpy array to float16 without changing sign or finiteness. + + Positive values less than min_positive_val are mapped to min_positive_val. + Positive finite values greater than max_finite_val are mapped to max_finite_val. + Similar for negative values. NaN, 0, inf, and -inf are unchanged. + """ + + def between( + a: npt.NDArray | int | float, + b: npt.NDArray | int | float, + c: npt.NDArray | int | float, + ) -> npt.NDArray[np.bool_]: + return np.logical_and(a < b, b < c) + + if np_array[np.where(np_array > 0)].shape[0] > 0: + pos_max = np_array[np.where(np_array > 0)].max() + pos_min = np_array[np.where(np_array > 0)].min() + + if pos_max >= max_finite_val: + warnings.warn( + f"the float32 number {pos_max} will be truncated to {max_finite_val}", + stacklevel=2, + ) + + if pos_min <= min_positive_val: + warnings.warn( + f"the float32 number {pos_min} will be truncated to {min_positive_val}", + stacklevel=2, + ) + + if np_array[np.where(np_array < 0)].shape[0] > 0: + neg_max = np_array[np.where(np_array < 0)].max() + neg_min = np_array[np.where(np_array < 0)].min() + + if neg_min <= -max_finite_val: + warnings.warn( + f"the float32 number {neg_min} will be truncated to {-max_finite_val}", + stacklevel=2, + ) + + if neg_max >= -min_positive_val: + warnings.warn( + f"the float32 number {neg_max} will be truncated to {-min_positive_val}", + stacklevel=2, + ) + + np_array = np.where(between(0, np_array, min_positive_val), min_positive_val, np_array) + np_array = np.where(between(-min_positive_val, np_array, 0), -min_positive_val, np_array) + np_array = np.where( + between(max_finite_val, np_array, float("inf")), max_finite_val, np_array + ) + np_array = np.where( + between(float("-inf"), np_array, -max_finite_val), -max_finite_val, np_array + ) + return np.float16(np_array) # pyright: ignore[reportReturnType] + + +def convert_tensor_float_to_float16( + tensor: onnx_proto.TensorProto, + min_positive_val: float = 1e-7, + max_finite_val: float = 1e4, +) -> onnx_proto.TensorProto: + """Convert tensor float to float16. + + :param tensor: TensorProto object + :return tensor_float16: converted TensorProto object + + Example: + :: + from onnxmltools.utils.float16_converter import convert_tensor_float_to_float16 + new_tensor = convert_tensor_float_to_float16(tensor) + """ + if tensor.data_type == onnx_proto.TensorProto.FLOAT: + tensor.data_type = onnx_proto.TensorProto.FLOAT16 + # convert float_data (float type) to float16 and write to int32_data + if tensor.float_data: + float16_data = convert_np_to_float16( + np.array(tensor.float_data), min_positive_val, max_finite_val + ) + int_list = _npfloat16_to_int(float16_data) + tensor.int32_data[:] = int_list + tensor.float_data[:] = [] + # convert raw_data (bytes type) + if tensor.raw_data: + # convert n.raw_data to float + float32_list = np.frombuffer(tensor.raw_data, dtype="float32") + # convert float to float16 + float16_list = convert_np_to_float16( + float32_list, min_positive_val, max_finite_val + ) + # convert float16 to bytes and write back to raw_data + tensor.raw_data = float16_list.tobytes() + return tensor + + +def make_value_info_from_tensor( + tensor: onnx_proto.TensorProto, +) -> onnx_proto.ValueInfoProto: + shape = numpy_helper.to_array(tensor).shape + return helper.make_tensor_value_info(tensor.name, tensor.data_type, shape) + + +def convert_float_to_float16( + model: onnx_proto.ModelProto, + min_positive_val: float = 1e-7, + max_finite_val: float = 1e4, + keep_io_types: bool = False, + disable_shape_infer: bool = False, + op_blocks: list[str] | None = None, + node_blocks: list[str] | None = None, +) -> onnx_proto.ModelProto: + """Convert tensor float type in the ONNX ModelProto input to tensor float16. + + :param model: ONNX ModelProto object + :param disable_shape_infer: Type/shape information is needed for conversion to work. + Set to True only if the model already has type/shape information for all tensors. + :return: converted ONNX ModelProto object + + Examples: + :: + + Example 1: Convert ONNX ModelProto object: + from onnxmltools.utils.float16_converter import convert_float_to_float16 + new_onnx_model = convert_float_to_float16(onnx_model) + + Example 2: Convert ONNX model binary file: + from onnxmltools.utils.float16_converter import convert_float_to_float16 + from onnxmltools.utils import load_model, save_model + onnx_model = load_model('model.onnx') + new_onnx_model = convert_float_to_float16(onnx_model) + save_model(new_onnx_model, 'new_model.onnx') + """ + func_infer_shape = None + if not disable_shape_infer and pv.Version(onnx.__version__) >= pv.Version("1.2"): # pyright: ignore[reportPrivateImportUsage] + try: + from onnx.shape_inference import infer_shapes + + func_infer_shape = infer_shapes + finally: + pass + + # create blocklists + if op_blocks is None: + op_blocks = DEFAULT_OP_BLOCK_LIST + if node_blocks is None: + node_blocks = [] + op_block_list = set(op_blocks) + node_block_list = set(node_blocks) + # create a queue for BFS + queue = [] + value_info_list = [] + node_list = [] + # key = node, value = graph, used to distinguish global with sub-graph + node_dict = {} + # type inference on input model + if func_infer_shape is not None: + model = func_infer_shape(model) + queue.append(model) + name_mapping = {} + graph_io_to_skip = set() + io_casts = set() + if keep_io_types: + for i, n in enumerate(model.graph.input): + if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: + output_name = "graph_input_cast_" + str(i) + name_mapping[n.name] = output_name + graph_io_to_skip.add(n.name) + + node_name = "graph_input_cast" + str(i) + new_value_info = model.graph.value_info.add() + new_value_info.CopyFrom(n) + new_value_info.name = output_name + new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 + # add Cast node (from tensor(float) to tensor(float16) after graph input + new_node = [ + helper.make_node("Cast", [n.name], [output_name], to=10, name=node_name) + ] + model.graph.node.extend(new_node) + value_info_list.append(new_value_info) + io_casts.add(node_name) + + for i, n in enumerate(model.graph.output): + if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: + input_name = "graph_output_cast_" + str(i) + name_mapping[n.name] = input_name + graph_io_to_skip.add(n.name) + + node_name = "graph_output_cast" + str(i) + # add Cast node (from tensor(float16) to tensor(float) before + # graph output + new_value_info = model.graph.value_info.add() + new_value_info.CopyFrom(n) + new_value_info.name = input_name + new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 + new_node = [ + helper.make_node("Cast", [input_name], [n.name], to=1, name=node_name) + ] + model.graph.node.extend(new_node) + value_info_list.append(new_value_info) + io_casts.add(node_name) + + while queue: + next_level = [] + for q in queue: + # if q is model, push q.graph (GraphProto) + if isinstance(q, onnx_proto.ModelProto): + next_level.append(q.graph) + # if q is model.graph, push q.node.attribute (AttributeProto) + if isinstance(q, onnx_proto.GraphProto): + for n in q.node: + # if n is in the block list (doesn't support float16), no + # conversion for the node, + # and save the node for further processing + if n.name in io_casts: + continue + for i in range(len(n.input)): + if n.input[i] in name_mapping: + n.input[i] = name_mapping[n.input[i]] + for i in range(len(n.output)): + if n.output[i] in name_mapping: + n.output[i] = name_mapping[n.output[i]] + # don't add the attr into next_level for the node + # in node_keep_data_type_list + # so it will not be converted to float16 + if n.op_type in op_block_list or n.name in node_block_list: + node_list.append(n) + node_dict[n.name] = q + else: + if n.op_type == "Cast": + for attr in n.attribute: + if attr.name == "to" and attr.i == 1: + attr.i = 10 + break + for attr in n.attribute: + next_level.append(attr) + # if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto) + # and process node.attribute.t and node.attribute.tensors (TensorProto) + if isinstance(q, onnx_proto.AttributeProto): + next_level.append(q.g) + for n in q.graphs: + next_level.append(n) + q.t.CopyFrom( + convert_tensor_float_to_float16(q.t, min_positive_val, max_finite_val) + ) + for n in q.tensors: + n = convert_tensor_float_to_float16(n, min_positive_val, max_finite_val) + # if q is graph, process graph.initializer(TensorProto), input, output + # and value_info (ValueInfoProto) + if isinstance(q, onnx_proto.GraphProto): + for n in q.initializer: # TensorProto type + if n.data_type == onnx_proto.TensorProto.FLOAT: + n = convert_tensor_float_to_float16( + n, min_positive_val, max_finite_val + ) + value_info_list.append(make_value_info_from_tensor(n)) + # for all ValueInfoProto with tensor(float) type in input, output + # and value_info, convert them to + # tensor(float16) except map and seq(map). And save them in + # value_info_list for further processing + for n in itertools.chain(q.input, q.output, q.value_info): + if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: + if n.name not in graph_io_to_skip: + n.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 + value_info_list.append(n) + queue = next_level + + # process the nodes in block list that doesn't support tensor(float16) + for node in node_list: + # if input's name is in the value_info_list meaning input is tensor(float16) + # type, insert a float16 to float Cast node before the node, + # change current node's input name and create new value_info for the new name + for i in range(len(node.input)): + input = node.input[i] + for value_info in value_info_list: + if input == value_info.name: + # create new value_info for current node's new input name + graph = node_dict[ + node.name + ] # get the correct graph instead of the global graph + new_value_info = graph.value_info.add() + new_value_info.CopyFrom(value_info) + output_name = node.name + "_input_cast_" + str(i) + new_value_info.name = output_name + new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT + # add Cast node (from tensor(float16) to tensor(float) + # before current node + node_name = node.name + "_input_cast" + str(i) + new_node = [ + helper.make_node("Cast", [input], [output_name], to=1, name=node_name) + ] + graph.node.extend(new_node) + # change current node's input name + node.input[i] = output_name + break + # if output's name is in the value_info_list meaning output is + # tensor(float16) type, insert a float to + # float16 Cast node after the node, change current node's output name and + # create new value_info for the new name + for i in range(len(node.output)): + output = node.output[i] + for value_info in value_info_list: + if output == value_info.name: + # create new value_info for current node's new output + graph = node_dict[ + node.name + ] # get the correct graph instead of the global graph + new_value_info = graph.value_info.add() + new_value_info.CopyFrom(value_info) + input_name = node.name + "_output_cast_" + str(i) + new_value_info.name = input_name + new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT + # add Cast node (from tensor(float) to tensor(float16) after + # current node + node_name = node.name + "_output_cast" + str(i) + new_node = [ + helper.make_node("Cast", [input_name], [output], to=10, name=node_name) + ] + graph.node.extend(new_node) + # change current node's input name + node.output[i] = input_name + break + + sort_topology(model.graph) + return model + + +def sort_graph_node(graph_proto: onnx_proto.GraphProto) -> None: + # find the "first" node in Nodes that its input is not any node's output + def find_first_node( + output2node_dict: dict[str, onnx_proto.NodeProto], + ) -> onnx_proto.NodeProto | None: + for node in org_nodes: + is_not_first_node = any(item in output2node_dict for item in node.input) + if not is_not_first_node: + return node + return None + + # remove the node from output2node_dict using output as key + def remove_first_node_from_dict2(first_node: onnx_proto.NodeProto) -> None: + for output in first_node.output: + output2node_dict.pop(output, None) + + org_nodes = graph_proto.node + # create a dict to store output as key and node as value + output2node_dict = {} + for node in org_nodes: + for output in node.output: + output2node_dict[output] = node + + # save the final node after sorted + sorted_node = [] + # traverse the Nodes to find the first node + while len(output2node_dict) > 0: + first_node = find_first_node(output2node_dict) + sorted_node.append(first_node) + if first_node is not None: + remove_first_node_from_dict2(first_node) + # del node from original nodes list to avoid duplicate traverse + org_nodes.remove(first_node) + + for new_node in sorted_node: + graph_proto.node.extend([new_node]) + + +# The input graph should be mode.graph +# Recursevly sort the topology for each sub-graph +def sort_topology(graph_proto: onnx_proto.GraphProto) -> None: + sort_graph_node(graph_proto) # sort global graph + for node in graph_proto.node: + for attr in node.attribute: + if isinstance(attr.g, onnx_proto.GraphProto) and len(attr.g.node) > 0: + sort_topology(attr.g) # sort sub-graph + for g in attr.graphs: + if isinstance(g, onnx_proto.GraphProto): + sort_topology(g) # sort sub-graph diff --git a/tests/onnx_float_16_test.py b/tests/onnx_float_16_test.py new file mode 100644 index 00000000..b3d05107 --- /dev/null +++ b/tests/onnx_float_16_test.py @@ -0,0 +1,62 @@ +import os +import typing + +import numpy as np +import onnx +import onnxruntime as _ort +from onnx import onnx_pb as onnx_proto + +from onnx_ir.passes.common import onnx_float_16 + + +def _ort_inference( + mdl: onnx_proto.ModelProto, inputs: dict[str, typing.Any] +) -> typing.Sequence[typing.Any]: + sess = _ort.InferenceSession(mdl.SerializeToString()) + return sess.run(None, inputs) + + +def test_convert_to_float16() -> None: + model32_name = "image_classifier32.onnx" + working_path = os.path.abspath(os.path.dirname(__file__)) + data_path = os.path.join(working_path, "data") + model_path = os.path.join(data_path, model32_name) + onnx_model32 = onnx.load(model_path) + input_x = np.random.rand(1, 3, 32, 32).astype(np.float32) + output_32 = _ort_inference(onnx_model32, {"modelInput": input_x}) + + onnx_model16 = onnx_float_16.convert_float_to_float16(onnx_model32, keep_io_types=False) + output_16 = _ort_inference(onnx_model16, {"modelInput": input_x.astype(np.float16)}) + assert np.allclose(output_16, output_32, atol=1e-2) + + onnx_model16 = onnx_float_16.convert_float_to_float16(onnx_model32, keep_io_types=True) + output_16 = _ort_inference(onnx_model16, {"modelInput": input_x}) + assert np.allclose(output_16, output_32, atol=1e-2) + + +def test_convert_to_float16_with_truncated() -> None: + np_array = np.array([1e-10, -2.0, 15, -1e-9, 65536.1, -100000]) + onnx_float_16.convert_np_to_float16(np_array) + + +def test_convert_to_float16_with_subgraph() -> None: + model32_name = "test_subgraph.onnx" + working_path = os.path.abspath(os.path.dirname(__file__)) + data_path = os.path.join(working_path, "data") + model_path = os.path.join(data_path, model32_name) + onnx_model32 = onnx.load(model_path) + x = np.array([1.0], dtype=np.float32) + y = np.array([2.0], dtype=np.float32) + output_32 = _ort_inference(onnx_model32, {"x": x, "y": y}) + + onnx_model16 = onnx_float_16.convert_float_to_float16(onnx_model32, keep_io_types=True) + actual = _ort_inference(onnx_model16, {"x": x, "y": y}) + assert np.allclose(actual, output_32, atol=1e-2) + assert actual[0].dtype == np.float32 + + onnx_model16 = onnx_float_16.convert_float_to_float16(onnx_model32, keep_io_types=False) + actual = _ort_inference( + onnx_model16, {"x": x.astype(np.float16), "y": y.astype(np.float16)} + ) + assert np.allclose(actual, output_32, atol=1e-2) + assert actual[0].dtype == np.float16 From 95678659ae965efc150668c9e5e6fc24cf67a968 Mon Sep 17 00:00:00 2001 From: bjeffrey92 <36240394+bjeffrey92@users.noreply.github.com> Date: Tue, 17 Jun 2025 12:25:33 +0100 Subject: [PATCH 2/3] abide by all linting rules Signed-off-by: bjeffrey92 <36240394+bjeffrey92@users.noreply.github.com> --- src/onnx_ir/passes/common/onnx_float_16.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/onnx_ir/passes/common/onnx_float_16.py b/src/onnx_ir/passes/common/onnx_float_16.py index 7f9a30d1..88757ba8 100644 --- a/src/onnx_ir/passes/common/onnx_float_16.py +++ b/src/onnx_ir/passes/common/onnx_float_16.py @@ -287,14 +287,12 @@ def convert_float_to_float16( if attr.name == "to" and attr.i == 1: attr.i = 10 break - for attr in n.attribute: - next_level.append(attr) + next_level.extend(list(n.attribute)) # if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto) # and process node.attribute.t and node.attribute.tensors (TensorProto) if isinstance(q, onnx_proto.AttributeProto): next_level.append(q.g) - for n in q.graphs: - next_level.append(n) + next_level.extend(list(q.graphs)) q.t.CopyFrom( convert_tensor_float_to_float16(q.t, min_positive_val, max_finite_val) ) From 365554f37c33f47c1c8c2fc91182fe8182657c44 Mon Sep 17 00:00:00 2001 From: bjeffrey92 <36240394+bjeffrey92@users.noreply.github.com> Date: Tue, 17 Jun 2025 13:13:18 +0100 Subject: [PATCH 3/3] add MIT license information to source code Signed-off-by: bjeffrey92 <36240394+bjeffrey92@users.noreply.github.com> --- src/onnx_ir/passes/common/onnx_float_16.py | 8 ++++++++ tests/onnx_float_16_test.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/src/onnx_ir/passes/common/onnx_float_16.py b/src/onnx_ir/passes/common/onnx_float_16.py index 88757ba8..2a42de42 100644 --- a/src/onnx_ir/passes/common/onnx_float_16.py +++ b/src/onnx_ir/passes/common/onnx_float_16.py @@ -1,3 +1,11 @@ +# Portions of this file are derived from work by Microsoft Corporation under the MIT License. This was modified by bjeffrey92 for use in ir-py. +# See below for original license and copyright. + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSES/MIT.txt in the project root for +# license information. +# Original source: https://github.com/microsoft/onnxconverter-common/blob/209a47e18e6a4c3474273a0b2a5e8f1fda481643/onnxconverter_common/float16.py + import itertools import warnings diff --git a/tests/onnx_float_16_test.py b/tests/onnx_float_16_test.py index b3d05107..768613f0 100644 --- a/tests/onnx_float_16_test.py +++ b/tests/onnx_float_16_test.py @@ -1,3 +1,11 @@ +# Portions of this file are derived from work by Microsoft Corporation under the MIT License. This was modified by bjeffrey92 for use in ir-py. +# See below for original license and copyright. + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSES/MIT.txt in the project root for +# license information. +# Original source: https://github.com/microsoft/onnxconverter-common/blob/209a47e18e6a4c3474273a0b2a5e8f1fda481643/tests/test_float16.py + import os import typing