-
Notifications
You must be signed in to change notification settings - Fork 71
[IR] Introduce node convenience functions #2303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
db475d9
d240255
142e47f
e89df33
9d61a04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
"replace_all_uses_with", | ||
"create_value_mapping", | ||
"replace_nodes_and_values", | ||
"insert_nodes_in_value", | ||
"remove_nodes", | ||
] | ||
|
||
from typing import Mapping, Sequence, Union | ||
|
@@ -335,6 +337,18 @@ | |
return values | ||
|
||
|
||
def _update_graph_or_function_outputs( | ||
graph_or_function: _core.Graph | _core.Function, | ||
old_values: Sequence[_core.Value], | ||
new_values: Sequence[_core.Value], | ||
): | ||
"""Update graph/function outputs""" | ||
replacement_mapping = dict(zip(old_values, new_values)) | ||
for idx, graph_or_function_output in enumerate(graph_or_function.outputs): | ||
if graph_or_function_output in replacement_mapping: | ||
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] | ||
|
||
|
||
def replace_nodes_and_values( | ||
graph_or_function: _core.Graph | _core.Function, | ||
/, | ||
|
@@ -367,11 +381,171 @@ | |
# Reconnect the users of the deleted values to use the new values | ||
replace_all_uses_with(old_values, new_values) | ||
# Update graph/function outputs if the node generates output | ||
replacement_mapping = dict(zip(old_values, new_values)) | ||
for idx, graph_or_function_output in enumerate(graph_or_function.outputs): | ||
if graph_or_function_output in replacement_mapping: | ||
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] | ||
_update_graph_or_function_outputs(graph_or_function, old_values, new_values) | ||
|
||
# insert new nodes after the index node | ||
graph_or_function.insert_after(insertion_point, new_nodes) | ||
graph_or_function.remove(old_nodes, safe=True) | ||
|
||
|
||
def _find_inputs_outputs( | ||
nodes: Sequence[_core.Node], | ||
) -> tuple[Sequence[_core.Value], Sequence[_core.Value]]: | ||
"""Find the values that are considered as inputs and outputs in a sequence of nodes""" | ||
# Search the unique inputs/outputs in new_nodes, keeping the order. | ||
all_inputs = dict.fromkeys(sum([node.inputs for node in nodes], ())) | ||
all_outputs = dict.fromkeys(sum([node.outputs for node in nodes], ())) | ||
# A value is considered as input if it is not any output. | ||
Check noticeCode scanning / lintrunner RUFF/C419 Note
Unnecessary list comprehension.
See https://docs.astral.sh/ruff/rules/unnecessary-comprehension-in-call |
||
inputs = tuple(val for val in all_inputs if val not in all_outputs) | ||
# A value is considered as output if it is not any input. | ||
outputs = tuple(val for val in all_outputs if val not in all_inputs) | ||
return inputs, outputs | ||
|
||
|
||
def insert_nodes_in_value( | ||
values: _core.Value | Sequence[_core.Value], new_nodes: Sequence[_core.Node] | ||
) -> None: | ||
"""Inserts a sequence of nodes into the provided value(s). | ||
|
||
This allows to insert a list of LINKED nodes (over the same context) at | ||
a specific point in the graph. | ||
|
||
For example, suppose we have the following graph:: | ||
|
||
input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output | ||
|
||
We want to insert [node_M, node_N] at B value:: | ||
|
||
>>> from onnxscript import ir | ||
>>> input = ir.Input("input") | ||
>>> node_A = ir.node("op_A", [input]) | ||
>>> B = ir.Value(name="B") | ||
>>> node_B = ir.node("op_B", node_A.outputs, outputs=[B]) | ||
>>> node_C = ir.node("op_C", node_B.outputs) | ||
>>> # Create a new sequence to insert | ||
>>> input_2 = ir.Input("input_2") | ||
>>> node_M = ir.node("op_M", [input_2]) | ||
>>> node_N = ir.node("op_N", node_M.outputs) | ||
>>> # Insert nodes in B | ||
>>> insert_nodes_before_value(node_B.outputs, [node_M, node_N]) | ||
>>> len(node_B.outputs) | ||
1 | ||
>>> node_B.outputs[0].consumers()[0].op_type | ||
'op_M' | ||
>>> len(node_C.inputs) | ||
1 | ||
>>> node_C.inputs[0].producer().op_type | ||
'op_N' | ||
>>> node_C.inputs[0].name | ||
'B' | ||
|
||
When values is a sequence, the set of nodes must have the same number | ||
of inputs and outputs, then they are zipped into pairs: first value is | ||
replaced with the first input/output, and so on. | ||
|
||
Args: | ||
values: The value(s) where to insert the nodes. | ||
new_nodes: The nodes to insert in the graph. | ||
""" | ||
if not isinstance(values, Sequence): | ||
values = (values,) | ||
|
||
# Search the unique inputs/outputs in new_nodes, keeping the order. | ||
inputs, outputs = _find_inputs_outputs(new_nodes) | ||
|
||
# Sanity check. | ||
if len(values) != len(inputs): | ||
raise ValueError(f"The number of values and inputs ({inputs}) in new_nodes must match.") | ||
if len(values) != len(outputs): | ||
raise ValueError(f"The number of values and outputs ({outputs}) in new_nodes must match.") | ||
|
||
# Propagate relevant info. | ||
for val, in_val, out_val in zip(values, inputs, outputs): | ||
# Propagate relevant info from value to out_value. | ||
# TODO(Rama): Perhaps this should be a separate utility function. | ||
out_val.type = val.type | ||
out_val.shape = val.shape | ||
out_val.name = val.name | ||
# Propagate relevant info from value to in_value. | ||
# TODO(Rama): Perhaps this should be a separate utility function. | ||
in_val.type = val.type | ||
in_val.shape = val.shape | ||
# Rename each value, following each input. | ||
val.name = in_val.name | ||
|
||
# Insert the new nodes in two steps: | ||
# 1. Reconnect the users of values to the outputs | ||
replace_all_uses_with(values, outputs) | ||
# 2. Reconnect the users of inputs to values | ||
replace_all_uses_with(inputs, values) | ||
|
||
# Update graph if there is one: | ||
if (graph := values[-1].graph) is not None: | ||
# Update graph/function outputs if the node generates output | ||
_update_graph_or_function_outputs(graph, values, outputs) | ||
|
||
# Insert new nodes if there is a graph | ||
graph.extend(new_nodes) | ||
graph.sort() | ||
|
||
|
||
def remove_nodes(nodes: Sequence[_core.Node]) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very useful, thanks There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe call it remove_connected_nodes ? Also could you move this PR to https://github.com/onnx/ir-py now that we finished migration? (sorry about the extra effort) I recommend creating two PRs for the two functions so they can be reviewed individually |
||
"""Remove a sequence of nodes. | ||
|
||
This allows to delete a list of LINKED nodes (over the same context). | ||
|
||
For example, suppose we have the following graph:: | ||
|
||
input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output | ||
|
||
We want to prune [node_B]:: | ||
|
||
>>> from onnxscript import ir | ||
>>> input = ir.Input("input") | ||
>>> node_A = ir.node("op_A", [input]) | ||
>>> node_B = ir.node("op_B", node_A.outputs) | ||
>>> node_C = ir.node("op_C", node_B.outputs) | ||
>>> # Delete node_B | ||
>>> remove_nodes([node_B]) | ||
>>> len(node_A.outputs[0].consumers()) | ||
1 | ||
>>> node_A.outputs[0].consumers()[0].op_type | ||
'op_C' | ||
>>> len(node_C.inputs) | ||
1 | ||
>>> node_C.inputs[0].producer().op_type | ||
'op_A' | ||
>>> node_B.inputs | ||
(None,) | ||
>>> len(node_B.outputs) | ||
1 | ||
>>> len(node_B.outputs[0].consumers()) | ||
0 | ||
|
||
Args: | ||
nodes: The nodes to remove. | ||
""" | ||
# Search the unique inputs/outputs in new_nodes, keeping the order. | ||
inputs, outputs = _find_inputs_outputs(nodes) | ||
|
||
# Sanity check. | ||
if len(inputs) != len(outputs): | ||
raise ValueError( | ||
f"The number of inputs ({inputs}) and outputs ({outputs}) in nodes must match." | ||
) | ||
|
||
# Remove nodes, in several steps: | ||
# 1. Reconnect the users of outputs with inputs | ||
replace_all_uses_with(outputs, inputs) | ||
# 2. Detach nodes for their inputs | ||
for node in nodes: | ||
for i in range(len(node.inputs)): | ||
node.replace_input_with(i, None) | ||
|
||
# Update graph if there is one: | ||
if (graph := inputs[-1].graph) is not None: | ||
# Update graph/function outputs if the node generates output | ||
_update_graph_or_function_outputs(graph, outputs, inputs) | ||
|
||
# Drop nodes from graph | ||
graph.remove(nodes, safe=True) |
Check notice
Code scanning / lintrunner
RUFF/C419 Note