Skip to content

[IR] Introduce node convenience functions #2303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/ir/ir_api/ir_convenience.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@
.. autofunction:: replace_all_uses_with
.. autofunction:: replace_nodes_and_values
.. autofunction:: create_value_mapping
.. autofunction:: insert_nodes_in_value
.. autofunction:: remove_nodes
```
182 changes: 178 additions & 4 deletions onnxscript/ir/_convenience/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"replace_all_uses_with",
"create_value_mapping",
"replace_nodes_and_values",
"insert_nodes_in_value",
"remove_nodes",
]

from typing import Mapping, Sequence, Union
Expand Down Expand Up @@ -335,6 +337,18 @@
return values


def _update_graph_or_function_outputs(
graph_or_function: _core.Graph | _core.Function,
old_values: Sequence[_core.Value],
new_values: Sequence[_core.Value],
):
"""Update graph/function outputs"""
replacement_mapping = dict(zip(old_values, new_values))
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
if graph_or_function_output in replacement_mapping:
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]


def replace_nodes_and_values(
graph_or_function: _core.Graph | _core.Function,
/,
Expand Down Expand Up @@ -367,11 +381,171 @@
# Reconnect the users of the deleted values to use the new values
replace_all_uses_with(old_values, new_values)
# Update graph/function outputs if the node generates output
replacement_mapping = dict(zip(old_values, new_values))
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
if graph_or_function_output in replacement_mapping:
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
_update_graph_or_function_outputs(graph_or_function, old_values, new_values)

# insert new nodes after the index node
graph_or_function.insert_after(insertion_point, new_nodes)
graph_or_function.remove(old_nodes, safe=True)


def _find_inputs_outputs(
nodes: Sequence[_core.Node],
) -> tuple[Sequence[_core.Value], Sequence[_core.Value]]:
"""Find the values that are considered as inputs and outputs in a sequence of nodes"""
# Search the unique inputs/outputs in new_nodes, keeping the order.
all_inputs = dict.fromkeys(sum([node.inputs for node in nodes], ()))
all_outputs = dict.fromkeys(sum([node.outputs for node in nodes], ()))

Check notice

Code scanning / lintrunner

RUFF/C419 Note

# A value is considered as input if it is not any output.

Check notice

Code scanning / lintrunner

RUFF/C419 Note

inputs = tuple(val for val in all_inputs if val not in all_outputs)
# A value is considered as output if it is not any input.
outputs = tuple(val for val in all_outputs if val not in all_inputs)
return inputs, outputs


def insert_nodes_in_value(
values: _core.Value | Sequence[_core.Value], new_nodes: Sequence[_core.Node]
) -> None:
"""Inserts a sequence of nodes into the provided value(s).

This allows to insert a list of LINKED nodes (over the same context) at
a specific point in the graph.

For example, suppose we have the following graph::

input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output

We want to insert [node_M, node_N] at B value::

>>> from onnxscript import ir
>>> input = ir.Input("input")
>>> node_A = ir.node("op_A", [input])
>>> B = ir.Value(name="B")
>>> node_B = ir.node("op_B", node_A.outputs, outputs=[B])
>>> node_C = ir.node("op_C", node_B.outputs)
>>> # Create a new sequence to insert
>>> input_2 = ir.Input("input_2")
>>> node_M = ir.node("op_M", [input_2])
>>> node_N = ir.node("op_N", node_M.outputs)
>>> # Insert nodes in B
>>> insert_nodes_before_value(node_B.outputs, [node_M, node_N])
>>> len(node_B.outputs)
1
>>> node_B.outputs[0].consumers()[0].op_type
'op_M'
>>> len(node_C.inputs)
1
>>> node_C.inputs[0].producer().op_type
'op_N'
>>> node_C.inputs[0].name
'B'

When values is a sequence, the set of nodes must have the same number
of inputs and outputs, then they are zipped into pairs: first value is
replaced with the first input/output, and so on.

Args:
values: The value(s) where to insert the nodes.
new_nodes: The nodes to insert in the graph.
"""
if not isinstance(values, Sequence):
values = (values,)

# Search the unique inputs/outputs in new_nodes, keeping the order.
inputs, outputs = _find_inputs_outputs(new_nodes)

# Sanity check.
if len(values) != len(inputs):
raise ValueError(f"The number of values and inputs ({inputs}) in new_nodes must match.")
if len(values) != len(outputs):
raise ValueError(f"The number of values and outputs ({outputs}) in new_nodes must match.")

# Propagate relevant info.
for val, in_val, out_val in zip(values, inputs, outputs):
# Propagate relevant info from value to out_value.
# TODO(Rama): Perhaps this should be a separate utility function.
out_val.type = val.type
out_val.shape = val.shape
out_val.name = val.name
# Propagate relevant info from value to in_value.
# TODO(Rama): Perhaps this should be a separate utility function.
in_val.type = val.type
in_val.shape = val.shape
# Rename each value, following each input.
val.name = in_val.name

# Insert the new nodes in two steps:
# 1. Reconnect the users of values to the outputs
replace_all_uses_with(values, outputs)
# 2. Reconnect the users of inputs to values
replace_all_uses_with(inputs, values)

# Update graph if there is one:
if (graph := values[-1].graph) is not None:
# Update graph/function outputs if the node generates output
_update_graph_or_function_outputs(graph, values, outputs)

# Insert new nodes if there is a graph
graph.extend(new_nodes)
graph.sort()


def remove_nodes(nodes: Sequence[_core.Node]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very useful, thanks

Copy link
Collaborator

@justinchuby justinchuby May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call it remove_connected_nodes ? Also could you move this PR to https://github.com/onnx/ir-py now that we finished migration? (sorry about the extra effort) I recommend creating two PRs for the two functions so they can be reviewed individually

"""Remove a sequence of nodes.

This allows to delete a list of LINKED nodes (over the same context).

For example, suppose we have the following graph::

input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output

We want to prune [node_B]::

>>> from onnxscript import ir
>>> input = ir.Input("input")
>>> node_A = ir.node("op_A", [input])
>>> node_B = ir.node("op_B", node_A.outputs)
>>> node_C = ir.node("op_C", node_B.outputs)
>>> # Delete node_B
>>> remove_nodes([node_B])
>>> len(node_A.outputs[0].consumers())
1
>>> node_A.outputs[0].consumers()[0].op_type
'op_C'
>>> len(node_C.inputs)
1
>>> node_C.inputs[0].producer().op_type
'op_A'
>>> node_B.inputs
(None,)
>>> len(node_B.outputs)
1
>>> len(node_B.outputs[0].consumers())
0

Args:
nodes: The nodes to remove.
"""
# Search the unique inputs/outputs in new_nodes, keeping the order.
inputs, outputs = _find_inputs_outputs(nodes)

# Sanity check.
if len(inputs) != len(outputs):
raise ValueError(
f"The number of inputs ({inputs}) and outputs ({outputs}) in nodes must match."
)

# Remove nodes, in several steps:
# 1. Reconnect the users of outputs with inputs
replace_all_uses_with(outputs, inputs)
# 2. Detach nodes for their inputs
for node in nodes:
for i in range(len(node.inputs)):
node.replace_input_with(i, None)

# Update graph if there is one:
if (graph := inputs[-1].graph) is not None:
# Update graph/function outputs if the node generates output
_update_graph_or_function_outputs(graph, outputs, inputs)

# Drop nodes from graph
graph.remove(nodes, safe=True)
Loading
Loading