Skip to content

Create a name fix pass to ensure unique names for all values and nodes #124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/onnx_ir/passes/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"InlinePass",
"LiftConstantsToInitializersPass",
"LiftSubgraphInitializersToMainGraphPass",
"NameFixPass",
"RemoveInitializersFromInputsPass",
"RemoveUnusedFunctionsPass",
"RemoveUnusedNodesPass",
Expand Down Expand Up @@ -38,6 +39,7 @@
DeduplicateInitializersPass,
)
from onnx_ir.passes.common.inliner import InlinePass
from onnx_ir.passes.common.naming import NameFixPass
from onnx_ir.passes.common.onnx_checker import CheckerPass
from onnx_ir.passes.common.shape_inference import ShapeInferencePass
from onnx_ir.passes.common.topological_sort import TopologicalSortPass
Expand Down
264 changes: 264 additions & 0 deletions src/onnx_ir/passes/common/naming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Name fix pass for ensuring unique names for all values and nodes."""

from __future__ import annotations

__all__ = [
"NameFixPass",
]

import logging

import onnx_ir as ir

logger = logging.getLogger(__name__)


class NameFixPass(ir.passes.InPlacePass):
"""Pass for fixing names to ensure all values and nodes have unique names.

This pass ensures that:
1. Graph inputs and outputs have unique names (take precedence)
2. All intermediate values have unique names (assign names to unnamed values)
3. All values in subgraphs have unique names
4. All nodes have unique names (assign names to unnamed nodes)

The pass maintains global uniqueness across the entire model.
"""

def call(self, model: ir.Model) -> ir.passes.PassResult:
modified = False

# Use sets to track seen names globally
seen_value_names: set[str] = set()
seen_node_names: set[str] = set()

# Dictionary to track which values have been assigned names
value_to_name: dict[ir.Value, str] = {}

# Counters for generating unique names (using list to pass by reference)
value_counter = [0]
node_counter = [0]

# Process the main graph
if _fix_graph_names(
model.graph,
seen_value_names,
seen_node_names,
value_to_name,
value_counter,
node_counter,
):
modified = True

# Process functions
for function in model.functions.values():
if _fix_function_names(
function,
seen_value_names,
seen_node_names,
value_to_name,
value_counter,
node_counter,
):
modified = True

Check warning on line 65 in src/onnx_ir/passes/common/naming.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/passes/common/naming.py#L65

Added line #L65 was not covered by tests

if modified:
logger.info("Name fix pass modified the model")

return ir.passes.PassResult(model, modified=modified)


def _fix_graph_names(
graph: ir.Graph,
seen_value_names: set[str],
seen_node_names: set[str],
value_to_name: dict[ir.Value, str],
value_counter: list[int],
node_counter: list[int],
) -> bool:
"""Fix names in a graph and return whether modifications were made."""
modified = False

# Step 1: Fix graph input names first (they have precedence)
for input_value in graph.inputs:
if _process_value(input_value, seen_value_names, value_to_name, value_counter):
modified = True

# Step 2: Fix graph output names (they have precedence)
for output_value in graph.outputs:
if _process_value(output_value, seen_value_names, value_to_name, value_counter):
modified = True

# Step 3: Fix initializer names
for initializer in graph.initializers.values():
if _process_value(initializer, seen_value_names, value_to_name, value_counter):
modified = True

Check warning on line 97 in src/onnx_ir/passes/common/naming.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/passes/common/naming.py#L97

Added line #L97 was not covered by tests

# Step 4: Process all nodes and their values
for node in ir.traversal.RecursiveGraphIterator(graph):
# Fix node name
if node.name is None or node.name == "":
if _assign_node_name(node, seen_node_names, node_counter):
modified = True
else:
if _fix_duplicate_node_name(node, seen_node_names):
modified = True

# Fix input value names (only if not already processed)
for input_value in node.inputs:

Check failure

Code scanning / lintrunner

MYPY/assignment Error

Incompatible types in assignment (expression has type "Value | None", variable has type "Value") To disable, use # type: ignore[assignment]
if input_value is not None:
if _process_value(input_value, seen_value_names, value_to_name, value_counter):
modified = True

# Fix output value names (only if not already processed)
for output_value in node.outputs:
if _process_value(output_value, seen_value_names, value_to_name, value_counter):
modified = True

return modified


def _fix_function_names(
function: ir.Function,
seen_value_names: set[str],
seen_node_names: set[str],
value_to_name: dict[ir.Value, str],
value_counter: list[int],
node_counter: list[int],
) -> bool:
"""Fix names in a function and return whether modifications were made."""
modified = False

Check warning on line 132 in src/onnx_ir/passes/common/naming.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/passes/common/naming.py#L132

Added line #L132 was not covered by tests

# Process function inputs first (they have precedence)
for input_value in function.inputs:
if _process_value(input_value, seen_value_names, value_to_name, value_counter):
modified = True

Check warning on line 137 in src/onnx_ir/passes/common/naming.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/passes/common/naming.py#L137

Added line #L137 was not covered by tests

# Process function outputs (they have precedence)
for output_value in function.outputs:
if _process_value(output_value, seen_value_names, value_to_name, value_counter):
modified = True

Check warning on line 142 in src/onnx_ir/passes/common/naming.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/passes/common/naming.py#L142

Added line #L142 was not covered by tests

# Process all nodes and their values
for node in ir.traversal.RecursiveGraphIterator(function):
# Fix node name
if node.name is None or node.name == "":
if _assign_node_name(node, seen_node_names, node_counter):
modified = True

Check warning on line 149 in src/onnx_ir/passes/common/naming.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/passes/common/naming.py#L149

Added line #L149 was not covered by tests
else:
if _fix_duplicate_node_name(node, seen_node_names):
modified = True

Check warning on line 152 in src/onnx_ir/passes/common/naming.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/passes/common/naming.py#L152

Added line #L152 was not covered by tests

# Fix input value names (only if not already processed)
for input_value in node.inputs:

Check failure

Code scanning / lintrunner

MYPY/assignment Error

Incompatible types in assignment (expression has type "Value | None", variable has type "Value") To disable, use # type: ignore[assignment]
if input_value is not None:
if _process_value(input_value, seen_value_names, value_to_name, value_counter):
modified = True

Check warning on line 158 in src/onnx_ir/passes/common/naming.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/passes/common/naming.py#L158

Added line #L158 was not covered by tests

# Fix output value names (only if not already processed)
for output_value in node.outputs:
if _process_value(output_value, seen_value_names, value_to_name, value_counter):
modified = True

Check warning on line 163 in src/onnx_ir/passes/common/naming.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/passes/common/naming.py#L163

Added line #L163 was not covered by tests

return modified

Check warning on line 165 in src/onnx_ir/passes/common/naming.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/passes/common/naming.py#L165

Added line #L165 was not covered by tests


def _process_value(
value: ir.Value,
seen_value_names: set[str],
value_to_name: dict[ir.Value, str],
value_counter: list[int],
) -> bool:
"""Process a value only if it hasn't been processed before."""
if value in value_to_name:
return False

modified = False
if value.name is None or value.name == "":
modified = _assign_value_name(value, seen_value_names, value_counter)
else:
modified = _fix_duplicate_value_name(value, seen_value_names)

# Record the final name for this value
value_to_name[value] = value.name

Check failure

Code scanning / lintrunner

MYPY/assignment Error

Incompatible types in assignment (expression has type "str | None", target has type "str") To disable, use # type: ignore[assignment]
return modified


def _assign_value_name(value: ir.Value, seen_names: set[str], counter: list[int]) -> bool:
"""Assign a name to an unnamed value. Returns True if modified."""
while True:
new_name = f"val_{counter[0]}"
counter[0] += 1
if new_name not in seen_names:
value.name = new_name
seen_names.add(new_name)
logger.debug("Assigned name %s to unnamed value", new_name)
return True


def _assign_node_name(node: ir.Node, seen_names: set[str], counter: list[int]) -> bool:
"""Assign a name to an unnamed node. Returns True if modified."""
while True:
new_name = f"node_{counter[0]}"
counter[0] += 1
if new_name not in seen_names:
node.name = new_name
seen_names.add(new_name)
logger.debug("Assigned name %s to unnamed node", new_name)
return True


def _fix_duplicate_value_name(value: ir.Value, seen_names: set[str]) -> bool:
"""Fix a value's name if it conflicts with existing names. Returns True if modified."""
original_name = value.name

if original_name is None or original_name == "":
return False # Should not happen if called correctly

Check warning on line 218 in src/onnx_ir/passes/common/naming.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/passes/common/naming.py#L218

Added line #L218 was not covered by tests

# If name is already seen, make it unique
if original_name in seen_names:
base_name = original_name
suffix = 1
while True:
new_name = f"{base_name}_{suffix}"
if new_name not in seen_names:
value.name = new_name
seen_names.add(new_name)
logger.debug(
"Renamed value from %s to %s for uniqueness", original_name, new_name
)
return True
suffix += 1
else:
# Name is unique, just record it
seen_names.add(original_name)
return False


def _fix_duplicate_node_name(node: ir.Node, seen_names: set[str]) -> bool:
"""Fix a node's name if it conflicts with existing names. Returns True if modified."""
original_name = node.name

if original_name is None or original_name == "":
return False # Should not happen if called correctly

Check warning on line 245 in src/onnx_ir/passes/common/naming.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/passes/common/naming.py#L245

Added line #L245 was not covered by tests

# If name is already seen, make it unique
if original_name in seen_names:
base_name = original_name
suffix = 1
while True:
new_name = f"{base_name}_{suffix}"
if new_name not in seen_names:
node.name = new_name
seen_names.add(new_name)
logger.debug(
"Renamed node from %s to %s for uniqueness", original_name, new_name
)
return True
suffix += 1

Check warning on line 260 in src/onnx_ir/passes/common/naming.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/passes/common/naming.py#L260

Added line #L260 was not covered by tests
else:
# Name is unique, just record it
seen_names.add(original_name)
return False
Loading
Loading