Skip to content

Implement graph composition with callable Graph interface #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2781,6 +2781,172 @@
self._metadata_props = {}
return self._metadata_props

def __call__(self, *args: Value) -> tuple[Value, ...]:
"""Create a copy of this graph and connect it with the provided input values.

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
This enables graph composition by creating a copy of the graph with new
values connected as inputs. All nodes from this graph are cloned and
added to the graph that owns the input values.

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
Args:
*args: Input values to connect to the graph inputs. The number of
arguments must match the number of graph inputs.

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
Returns:
A tuple of output values from the cloned graph.

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
Raises:
ValueError: If the number of input arguments doesn't match the graph inputs.
ValueError: If the input values don't all belong to the same graph.
ValueError: If any input value doesn't belong to a graph.
"""
# Validate inputs
if len(args) != len(self.inputs):
raise ValueError(
f"Expected {len(self.inputs)} input arguments, got {len(args)}"
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
if not args:
# Handle the case of a graph with no inputs
target_graph = None
else:
# Validate that all input values belong to a graph and the same graph
target_graph = args[0].graph
if target_graph is None:
raise ValueError(f"Input value {args[0]} does not belong to any graph")

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
for i, arg in enumerate(args[1:], 1):
if arg.graph is None:
raise ValueError(f"Input value {arg} does not belong to any graph")

Check warning on line 2820 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L2820

Added line #L2820 was not covered by tests
if arg.graph is not target_graph:
raise ValueError(

Check warning on line 2822 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L2822

Added line #L2822 was not covered by tests
f"All input values must belong to the same graph. "
f"Value at index {i} belongs to a different graph."
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Create value mapping from original inputs to provided inputs
value_map: dict[Value, Value] = {}
for original_input, new_input in zip(self.inputs, args):
value_map[original_input] = new_input

Check warning

Code scanning / lintrunner

RUFF/PERF403 Warning

Use a dictionary comprehension instead of a for-loop.
See https://docs.astral.sh/ruff/rules/manual-dict-comprehension

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Clone all nodes, building the value map as we go
cloned_nodes = []
for node in self:
cloned_node = self._clone_node_for_composition(node, value_map, target_graph)
cloned_nodes.append(cloned_node)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Clone initializers and add them to the target graph
if target_graph is not None:
for init in self.initializers.values():
cloned_init = self._clone_value_for_composition(init, value_map)
target_graph.register_initializer(cloned_init)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Add all cloned nodes to the target graph
if target_graph is not None and cloned_nodes:
target_graph.extend(cloned_nodes)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Return the cloned output values
cloned_outputs = []
for output in self.outputs:
if output in value_map:
cloned_outputs.append(value_map[output])
else:
# This should not happen if the graph is well-formed
raise RuntimeError(f"Output value {output} was not found in value mapping")

Check warning on line 2855 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L2855

Added line #L2855 was not covered by tests

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
return tuple(cloned_outputs)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
def _clone_node_for_composition(
self, node: Node, value_map: dict[Value, Value], target_graph: Graph | None
) -> Node:
"""Clone a node for graph composition, updating the value map."""
# Clone input values (or use existing mappings)
cloned_inputs = []

Check failure

Code scanning / lintrunner

MYPY/var-annotated Error

Need type annotation for "cloned_inputs" (hint: "cloned_inputs: list[] = ...") To disable, use # type: ignore[var-annotated]
for input_val in node.inputs:
if input_val is None:
cloned_inputs.append(None)

Check warning on line 2867 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L2867

Added line #L2867 was not covered by tests

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "append" of "list" has incompatible type "None"; expected "Value" To disable, use # type: ignore[arg-type]
elif input_val in value_map:
cloned_inputs.append(value_map[input_val])
else:
# This input is not yet mapped, clone it
cloned_input = self._clone_value_for_composition(input_val, value_map)
value_map[input_val] = cloned_input
cloned_inputs.append(cloned_input)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Clone attributes
cloned_attributes = []
for attr in node.attributes.values():
if isinstance(attr, Attr):
cloned_attr = self._clone_attr_for_composition(attr, value_map, target_graph)

Check warning on line 2880 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L2880

Added line #L2880 was not covered by tests
if cloned_attr is not None:
cloned_attributes.append(cloned_attr)

Check warning on line 2882 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L2882

Added line #L2882 was not covered by tests

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Create the new node
cloned_node = Node(
domain=node.domain,
op_type=node.op_type,
inputs=cloned_inputs,
attributes=cloned_attributes,
overload=node.overload,
num_outputs=len(node.outputs),
graph=target_graph,
name=node.name, # Note: Graph.extend will assign unique names if needed
doc_string=node.doc_string,
metadata_props=node.metadata_props,
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Map the output values
for original_output, cloned_output in zip(node.outputs, cloned_node.outputs):
value_map[original_output] = cloned_output
# Copy relevant properties
cloned_output.name = original_output.name
cloned_output.type = original_output.type
cloned_output.shape = original_output.shape
cloned_output.const_value = original_output.const_value

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
return cloned_node

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
def _clone_value_for_composition(
self, value: Value, value_map: dict[Value, Value]
) -> Value:
"""Clone a value for graph composition."""
if value in value_map:
return value_map[value]

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Create a new value
cloned_value = Value(
name=value.name,
type=value.type,
shape=value.shape,
doc_string=value.doc_string,
const_value=value.const_value,
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
value_map[value] = cloned_value
return cloned_value

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
def _clone_attr_for_composition(
self, attr: Attr, value_map: dict[Value, Value], target_graph: Graph | None
) -> Attr | None:
"""Clone an attribute for graph composition."""
if not attr.is_ref():
if attr.type == _enums.AttributeType.GRAPH:
# Recursively clone subgraphs
subgraph = attr.as_graph()

Check warning on line 2935 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L2935

Added line #L2935 was not covered by tests

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable subgraph is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
# For subgraphs, we need to handle them specially
# For now, we'll just return the attribute as-is
# TODO: Implement proper subgraph composition if needed
return attr

Check warning on line 2939 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L2939

Added line #L2939 was not covered by tests
elif attr.type == _enums.AttributeType.GRAPHS:
# Handle multiple subgraphs
# For now, we'll just return the attribute as-is
# TODO: Implement proper subgraph composition if needed
return attr
return attr

Check warning on line 2945 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L2944-L2945

Added lines #L2944 - L2945 were not covered by tests

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Handle reference attributes - for now just return as-is
return attr

Check warning on line 2948 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L2948

Added line #L2948 was not covered by tests

def __str__(self) -> str:
return _graph_str(self)

Expand Down
Loading
Loading