Skip to content

[WIP] Create symbolic type/shape inference logic #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 31 commits into
base: main
Choose a base branch
from

Conversation

justinchuby
Copy link
Member

No description provided.

Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Copy link

codecov bot commented Jun 30, 2025

Codecov Report

Attention: Patch coverage is 26.68760% with 467 lines in your changes missing coverage. Please review.

Project coverage is 68.80%. Comparing base (676cda1) to head (9256233).

✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
src/onnx_ir/_shape_type_inference/_engine.py 18.18% 108 Missing ⚠️
src/onnx_ir/_shape_type_inference/ops/unsqueeze.py 22.22% 63 Missing ⚠️
src/onnx_ir/_shape_type_inference/ops/squeeze.py 25.33% 56 Missing ⚠️
src/onnx_ir/_shape_type_inference/ops/reshape.py 13.33% 39 Missing ⚠️
.../onnx_ir/_shape_type_inference/ops/standard_ops.py 30.18% 37 Missing ⚠️
src/onnx_ir/_shape_type_inference/ops/concat.py 12.19% 36 Missing ⚠️
src/onnx_ir/_shape_type_inference/ops/matmul.py 19.51% 33 Missing ⚠️
src/onnx_ir/_shape_type_inference/_common.py 53.62% 32 Missing ⚠️
src/onnx_ir/_shape_type_inference/ops/transpose.py 25.00% 21 Missing ⚠️
src/onnx_ir/_shape_type_inference/factory.py 38.70% 19 Missing ⚠️
... and 2 more

❗ There is a different number of reports uploaded between BASE (676cda1) and HEAD (9256233). Click for more details.

HEAD has 9 uploads less than BASE
Flag BASE (676cda1) HEAD (9256233)
18 9
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #117      +/-   ##
==========================================
- Coverage   74.51%   68.80%   -5.71%     
==========================================
  Files          38       50      +12     
  Lines        4693     5325     +632     
  Branches      958     1085     +127     
==========================================
+ Hits         3497     3664     +167     
- Misses        843     1307     +464     
- Partials      353      354       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>

# Reconcile based on policy
if self.reconciliation_policy == ReconciliationPolicy.OVERWRITE:
node.outputs[i] = inferred_value

Check failure

Code scanning / lintrunner

MYPY/index Error

Unsupported target for indexed assignment ("Sequence[Value]") To disable, use # type: ignore[index]
elif self.reconciliation_policy == ReconciliationPolicy.IGNORE:
# Keep existing output if it has shape/type info
if existing_output.shape is None and existing_output.type is None:
node.outputs[i] = inferred_value

Check failure

Code scanning / lintrunner

MYPY/index Error

Unsupported target for indexed assignment ("Sequence[Value]") To disable, use # type: ignore[index]

elif self.reconciliation_policy == ReconciliationPolicy.RECONCILE:
reconciled_output = self._reconcile_value(existing_output, inferred_value)
node.outputs[i] = reconciled_output

Check failure

Code scanning / lintrunner

MYPY/index Error

Unsupported target for indexed assignment ("Sequence[Value]") To disable, use # type: ignore[index]
elif isinstance(dim2, int) and dim2 > 0:
reconciled_dims.append(dim2)
elif dim1 is not None:
reconciled_dims.append(dim1)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "append" of "list" has incompatible type "int | SymbolicDim"; expected "int" To disable, use # type: ignore[arg-type]
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
  I have successfully updated the entire
  InferenceResult system across all files in
   the src/onnx_ir/_shape_type_inference/ops
   directory:

  ✅ What Was Accomplished:

  1. Converted InferenceResult from
  dataclass to normal class with
  string-based status initialization
  2. Updated all validation decorators in
  _common.py to use string status
  3. Updated the engine in _engine.py to
  handle different status types
  appropriately
  4. Updated all 8 operation files in the
  ops directory:
    - standard_ops.py (BinaryInferrer)
    - matmul.py (MatMulInferrer)
    - concat.py (ConcatInferrer)
    - reshape.py (ReshapeInferrer)
    - constant.py (ConstantInferrer)
    - squeeze.py (Squeeze12Inferrer,
  Squeeze13Inferrer)
    - transpose.py (TransposeInferrer)
    - unsqueeze.py (Unsqueeze12Inferrer,
  Unsqueeze13Inferrer)
  5. Updated exports in __init__.py to
  include InferenceStatus
  6. Updated documentation in README.md with
   examples

  ✅ Key Benefits:

  - More convenient API:
  status="missing_info" instead of
  status=InferenceStatus.MISSING_INFO
  - Type safety: Automatic enum conversion
  with clear error messages for invalid
  strings
  - Better categorization: Proper error
  classification (missing_info,
  invalid_node, partial, success)
  - Cleaner code: Less imports needed, more
  readable error handling
  - Graceful degradation: Engine can handle
  partial inference and missing information

  The refactoring is now complete and all
  files consistently use the improved
  InferenceResult class with string-based
  status initialization!

Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
)

# Get first input shape as base
first_shape = node.inputs[0].shape

Check failure

Code scanning / lintrunner

MYPY/union-attr Error

Item "None" of "Value | None" has no attribute "shape" To disable, use # type: ignore[union-attr]
return _common.InferenceResult(
status="missing_info", msg="Concat input shapes cannot be None."
)
first_type = node.inputs[0].type

Check failure

Code scanning / lintrunner

MYPY/union-attr Error

Item "None" of "Value | None" has no attribute "type" To disable, use # type: ignore[union-attr]
)

# Create shape from the tensor dimensions
output_shape = ir.Shape(tensor.shape)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "Shape" has incompatible type "ShapeProtocol"; expected "Iterable[int | SupportsInt | SymbolicDim | str | None]" To disable, use # type: ignore[arg-type]
logger.warning(
"Squeeze operation has symbolic dimension %s, assuming it is not 1.", dim
)
output_dims.append(dim)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "append" of "list" has incompatible type "SymbolicDim"; expected "int" To disable, use # type: ignore[arg-type]
output_shape = _compute_output_shape_no_axes(input_shape)
else:
try:
axes = _normalize_axes(axes, rank)

Check failure

Code scanning / lintrunner

MYPY/assignment Error

Incompatible types in assignment (expression has type "set[int]", variable has type "Sequence[int]") To disable, use # type: ignore[assignment]
axes = _normalize_axes(axes, rank)
except ValueError as e:
return _common.InferenceResult(status="invalid_node", msg=str(e))
output_shape = _compute_output_shape_with_axes(input_shape, axes)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 2 to "_compute_output_shape_with_axes" has incompatible type "Sequence[int]"; expected "set[int]" To disable, use # type: ignore[arg-type]
rank1 = len(shape1)
rank2 = len(shape2)
new_rank = max(rank1, rank2)
new_dims = []

Check failure

Code scanning / lintrunner

MYPY/var-annotated Error

Need type annotation for "new_dims" (hint: "new_dims: list[] = ...") To disable, use # type: ignore[var-annotated]
  I've successfully updated all operation
  files to use direct slicing on ir.Shape
  objects instead of accessing .dims and
  then taking slices. Here are the specific
  changes made:

  1. matmul.py

  - Before: output_dims =
  [*rhs_shape.dims[:-2], rhs_shape.dims[-1]]
  - After: output_dims = [*rhs_shape[:-2],
  rhs_shape[-1]]
  - Before: output_dims =
  list(lhs_shape.dims[:-1])
  - After: output_shape =
  ir.Shape(lhs_shape[:-1])
  - Before: lhs_batch = lhs_shape.dims[:-2]
  - After: lhs_batch = lhs_shape[:-2]
  - Before: output_dims =
  [*batch_shape.dims, lhs_shape.dims[-2],
  rhs_shape.dims[-1]]
  - After: output_dims = [*batch_shape,
  lhs_shape[-2], rhs_shape[-1]]

  2. transpose.py

  - Before:
  output_dims.append(input_shape.dims[axis])
  - After:
  output_dims.append(input_shape[axis])

  3. concat.py

  - Before: output_dims =
  list(first_shape.dims)
  - After: output_dims = list(first_shape)

  4. unsqueeze.py

  - Before: output_dims.append(input_shape.d
  ims[input_axis])
  - After: output_dims.append(input_shape[in
  put_axis])

  5. squeeze.py

  - Before: for dim in input_shape.dims:
  - After: for dim in input_shape:
  - Before: output_dims = [dim for i, dim in
   enumerate(input_shape.dims) if i not in
  axes]
  - After: output_dims = [dim for i, dim in
  enumerate(input_shape) if i not in axes]

  Key Benefits

  1. Cleaner Code: Direct slicing is more
  readable and pythonic
  2. Less Verbose: Removes the need to
  access .dims explicitly
  3. Consistent API: Uses the shape object's
   built-in slicing support
  4. Type Safety: Shape slicing returns the
  correct types directly

  The code is now significantly cleaner and
  more intuitive, taking advantage of the
  ir.Shape object's built-in slicing
  capabilities!

Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
output_dims.append(1)
else:
# Copy dimension from input
output_dims.append(input_shape[input_axis])

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "append" of "list" has incompatible type "int | SymbolicDim"; expected "int" To disable, use # type: ignore[arg-type]
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
for op_type in binary_ops:
inferrers.append(BinaryInferrer(op_type))

return SymbolicInferenceEngine(inferrers, reconciliation_policy)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 2 to "SymbolicInferenceEngine" has incompatible type "ReconciliationPolicy"; expected "str" To disable, use # type: ignore[arg-type]
BinaryInferrer("Mul"),
]

return SymbolicInferenceEngine(inferrers, reconciliation_policy)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 2 to "SymbolicInferenceEngine" has incompatible type "ReconciliationPolicy"; expected "str" To disable, use # type: ignore[arg-type]
Signed-off-by: Justin Chu <[email protected]>
Dictionary mapping operation types to inferrer counts.
"""
info = {}
for (op_type, domain), inferrers in self._inferrer_registry.items():

Check failure

Code scanning / lintrunner

MYPY/misc Error

Too many values to unpack (2 expected, 3 provided) To disable, use # type: ignore[misc]
Signed-off-by: Justin Chu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant