Skip to content

Create the convenience methods i() and o() on Node #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 50 additions & 14 deletions src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,56 @@
"Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
)

@property
def outputs(self) -> Sequence[Value]:
"""The output values of the node.

The outputs are immutable. To change the outputs, create a new node and
replace the inputs of the using nodes of this node's outputs by calling
:meth:`replace_input_with` on the using nodes of this node's outputs.
"""
return self._outputs

@outputs.setter
def outputs(self, _: Sequence[Value]) -> None:
raise AttributeError("outputs is immutable. Please create a new node instead.")

Check warning on line 1561 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1561

Added line #L1561 was not covered by tests

def i(self, index: int = 0) -> Value | None:
"""Get the input value at the given index.

This is a convenience method that is equivalent to ``node.inputs[index]``.

The following is equivalent::

node.inputs[0] == node.i(0) == node.i() # Default index is 0
node.inputs[index] == node.i(index)

Returns:
The input value at the given index.

Raises:
IndexError: If the index is out of range.
"""
return self.inputs[index]

Check warning on line 1579 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1579

Added line #L1579 was not covered by tests

def o(self, index: int = 0) -> Value:
"""Get the output value at the given index.

This is a convenience method that is equivalent to ``node.outputs[index]``.

The following is equivalent::

node.outputs[0] == node.o(0) == node.o() # Default index is 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's duplicated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I will release 0.1.1 without this for now

node.outputs[index] == node.o(index)

Returns:
The output value at the given index.

Raises:
IndexError: If the index is out of range.
"""
return self.outputs[index]

def predecessors(self) -> Sequence[Node]:
"""Return the predecessor nodes of the node, deduplicated, in a deterministic order."""
# Use the ordered nature of a dictionary to deduplicate the nodes
Expand Down Expand Up @@ -1616,20 +1666,6 @@
raise ValueError("The node to append to does not belong to any graph.")
self._graph.insert_after(self, nodes)

@property
def outputs(self) -> Sequence[Value]:
"""The output values of the node.

The outputs are immutable. To change the outputs, create a new node and
replace the inputs of the using nodes of this node's outputs by calling
:meth:`replace_input_with` on the using nodes of this node's outputs.
"""
return self._outputs

@outputs.setter
def outputs(self, _: Sequence[Value]) -> None:
raise AttributeError("outputs is immutable. Please create a new node instead.")

@property
def attributes(self) -> _graph_containers.Attributes:
"""The attributes of the node."""
Expand Down
4 changes: 2 additions & 2 deletions src/onnx_ir/external_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,13 @@ def _simple_model(self) -> ir.Model:
node_1 = ir.Node(
"",
"Op_1",
inputs=[node_0.outputs[0]],
inputs=[node_0.o()],
num_outputs=1,
name="node_1",
)
graph = ir.Graph(
inputs=node_0.inputs, # type: ignore
outputs=[node_1.outputs[0]],
outputs=[node_1.o()],
initializers=[
ir.Value(name="tensor1", const_value=tensor1),
ir.Value(name="tensor2", const_value=tensor2),
Expand Down
50 changes: 37 additions & 13 deletions src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,41 @@ def test_pass_with_clear_metadata_and_docstring(self):
)
mul_node = ir.node(
"Mul",
inputs=[add_node.outputs[0], inputs[1]],
inputs=[add_node.o(), inputs[1]],
num_outputs=1,
metadata_props={"mul_key": "mul_value"},
doc_string="This is a Mul node",
)
func_inputs = [
ir.Value(
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
),
ir.Value(
name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
),
]
function = ir.Function(
graph=ir.Graph(
name="my_function",
inputs=func_inputs,
outputs=mul_node.outputs,
nodes=[add_node, mul_node],
inputs=[
input_a := ir.Value(
name="input_a",
type=ir.TensorType(ir.DataType.FLOAT),
shape=ir.Shape((2, 3)),
),
input_b := ir.Value(
name="input_b",
type=ir.TensorType(ir.DataType.FLOAT),
shape=ir.Shape((2, 3)),
),
],
nodes=[
add_node_func := ir.node(
"Add",
inputs=[input_a, input_b],
metadata_props={"add_key": "add_value"},
doc_string="This is an Add node",
),
mul_node_func := ir.node(
"Mul",
inputs=[add_node_func.o(), input_b],
metadata_props={"mul_key": "mul_value"},
doc_string="This is a Mul node",
),
],
outputs=mul_node_func.outputs,
opset_imports={"": 20},
doc_string="This is a function docstring",
metadata_props={"function_key": "function_value"},
Expand All @@ -57,6 +73,14 @@ def test_pass_with_clear_metadata_and_docstring(self):
domain="my_domain",
attributes=[],
)
func_node = ir.node(
"my_function",
inputs=[inputs[0], mul_node.o()],
domain="my_domain",
metadata_props={"mul_key": "mul_value"},
doc_string="This is a Mul node",
)
# TODO(justinchuby): This graph is broken. The output of the function cannot be a input to a node
# Create a model with the graph and function
constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir.DataType.FLOAT.numpy()))
const_node = ir.node(
Expand All @@ -69,7 +93,7 @@ def test_pass_with_clear_metadata_and_docstring(self):
)
sub_node = ir.node(
"Sub",
inputs=[function.outputs[0], const_node.outputs[0]],
inputs=[func_node.o(), const_node.o()],
num_outputs=1,
metadata_props={"sub_key": "sub_value"},
doc_string="This is a Sub node",
Expand Down
6 changes: 3 additions & 3 deletions src/onnx_ir/passes/common/constant_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
assert node.graph is not None
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
continue
if node.outputs[0].is_graph_output():
if node.o().is_graph_output():
logger.debug(
"Constant node '%s' is used as output, so it can't be lifted.", node.name
)
Expand All @@ -54,7 +54,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
continue

attr_name, attr_value = next(iter(node.attributes.items()))
initializer_name = node.outputs[0].name
initializer_name = node.o().name
assert initializer_name is not None
assert isinstance(attr_value, ir.Attr)
tensor = self._constant_node_attribute_to_tensor(
Expand All @@ -73,7 +73,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
assert node.graph is not None
node.graph.register_initializer(initializer)
# Replace the constant node with the initializer
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
ir.convenience.replace_all_uses_with(node.o(), initializer)
node.graph.remove(node, safe=True)
count += 1
logger.debug(
Expand Down
28 changes: 13 additions & 15 deletions src/onnx_ir/passes/common/constant_manipulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(
const_node = ir.node(
"Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1
)
add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]])
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]])
add_node = ir.node("Add", inputs=[inputs[0], const_node.o()])
mul_node = ir.node("Mul", inputs=[add_node.o(), inputs[1]])

model = ir.Model(
graph=ir.Graph(
Expand Down Expand Up @@ -92,21 +92,21 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(
)
# then branch adds the constant to the input
# else branch multiplies the input by the constant
add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]])
add_node = ir.node("Add", inputs=[input_value, then_const_node.o()])
then_graph = ir.Graph(
inputs=[],
outputs=[add_node.outputs[0]],
outputs=[add_node.o()],
nodes=[then_const_node, add_node],
opset_imports={"": 20},
)
else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
else_const_node = ir.node(
"Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1
)
mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]])
mul_node = ir.node("Mul", inputs=[input_value, else_const_node.o()])
else_graph = ir.Graph(
inputs=[],
outputs=[mul_node.outputs[0]],
outputs=[mul_node.o()],
nodes=[else_const_node, mul_node],
opset_imports={"": 20},
)
Expand Down Expand Up @@ -178,15 +178,13 @@ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
attributes={constant_attribute: constant_value},
num_outputs=1,
)
identity_node_constant = ir.node(
"Identity", inputs=[const_node.outputs[0]], num_outputs=1
)
identity_node_constant = ir.node("Identity", inputs=[const_node.o()], num_outputs=1)
identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1)

model = ir.Model(
graph=ir.Graph(
inputs=[input_value],
outputs=[identity_node_input.outputs[0], identity_node_constant.outputs[0]],
outputs=[identity_node_input.o(), identity_node_constant.o()],
nodes=[identity_node_input, const_node, identity_node_constant],
opset_imports={"": 20},
),
Expand Down Expand Up @@ -232,7 +230,7 @@ def test_not_lifting_constants_to_initializers_when_it_is_output(self):
model = ir.Model(
graph=ir.Graph(
inputs=[input_value],
outputs=[identity_node_input.outputs[0], const_node.outputs[0]],
outputs=[identity_node_input.o(), const_node.o()],
nodes=[identity_node_input, const_node],
opset_imports={"": 20},
),
Expand Down Expand Up @@ -272,7 +270,7 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(
add_node = ir.node("Add", inputs=[input_value, then_initializer_value])
then_graph = ir.Graph(
inputs=[],
outputs=[add_node.outputs[0]],
outputs=[add_node.o()],
nodes=[add_node],
opset_imports={"": 20},
initializers=[then_initializer_value],
Expand All @@ -287,7 +285,7 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(
mul_node = ir.node("Mul", inputs=[input_value, else_initializer_value])
else_graph = ir.Graph(
inputs=[],
outputs=[mul_node.outputs[0]],
outputs=[mul_node.o()],
nodes=[mul_node],
opset_imports={"": 20},
initializers=[else_initializer_value],
Expand Down Expand Up @@ -351,7 +349,7 @@ def test_pass_does_not_lift_initialized_inputs_in_subgraph(
# The initializer is also an input. We don't lift it to the main graph
# to preserve the graph signature
inputs=[then_initializer_value],
outputs=[add_node.outputs[0]],
outputs=[add_node.o()],
nodes=[add_node],
opset_imports={"": 20},
initializers=[then_initializer_value],
Expand All @@ -366,7 +364,7 @@ def test_pass_does_not_lift_initialized_inputs_in_subgraph(
mul_node = ir.node("Mul", inputs=[input_value, else_initializer_value])
else_graph = ir.Graph(
inputs=[],
outputs=[mul_node.outputs[0]],
outputs=[mul_node.o()],
nodes=[mul_node],
opset_imports={"": 20},
initializers=[else_initializer_value],
Expand Down