-
Notifications
You must be signed in to change notification settings - Fork 9
[WIP] Create symbolic type/shape inference logic #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Codecov ReportAttention: Patch coverage is
✅ All tests successful. No failed tests found.
Additional details and impacted files@@ Coverage Diff @@
## main #117 +/- ##
==========================================
- Coverage 74.51% 68.80% -5.71%
==========================================
Files 38 50 +12
Lines 4693 5325 +632
Branches 958 1085 +127
==========================================
+ Hits 3497 3664 +167
- Misses 843 1307 +464
- Partials 353 354 +1 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
|
||
# Reconcile based on policy | ||
if self.reconciliation_policy == ReconciliationPolicy.OVERWRITE: | ||
node.outputs[i] = inferred_value |
Check failure
Code scanning / lintrunner
MYPY/index Error
elif self.reconciliation_policy == ReconciliationPolicy.IGNORE: | ||
# Keep existing output if it has shape/type info | ||
if existing_output.shape is None and existing_output.type is None: | ||
node.outputs[i] = inferred_value |
Check failure
Code scanning / lintrunner
MYPY/index Error
|
||
elif self.reconciliation_policy == ReconciliationPolicy.RECONCILE: | ||
reconciled_output = self._reconcile_value(existing_output, inferred_value) | ||
node.outputs[i] = reconciled_output |
Check failure
Code scanning / lintrunner
MYPY/index Error
elif isinstance(dim2, int) and dim2 > 0: | ||
reconciled_dims.append(dim2) | ||
elif dim1 is not None: | ||
reconciled_dims.append(dim1) |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
I have successfully updated the entire InferenceResult system across all files in the src/onnx_ir/_shape_type_inference/ops directory: ✅ What Was Accomplished: 1. Converted InferenceResult from dataclass to normal class with string-based status initialization 2. Updated all validation decorators in _common.py to use string status 3. Updated the engine in _engine.py to handle different status types appropriately 4. Updated all 8 operation files in the ops directory: - standard_ops.py (BinaryInferrer) - matmul.py (MatMulInferrer) - concat.py (ConcatInferrer) - reshape.py (ReshapeInferrer) - constant.py (ConstantInferrer) - squeeze.py (Squeeze12Inferrer, Squeeze13Inferrer) - transpose.py (TransposeInferrer) - unsqueeze.py (Unsqueeze12Inferrer, Unsqueeze13Inferrer) 5. Updated exports in __init__.py to include InferenceStatus 6. Updated documentation in README.md with examples ✅ Key Benefits: - More convenient API: status="missing_info" instead of status=InferenceStatus.MISSING_INFO - Type safety: Automatic enum conversion with clear error messages for invalid strings - Better categorization: Proper error classification (missing_info, invalid_node, partial, success) - Cleaner code: Less imports needed, more readable error handling - Graceful degradation: Engine can handle partial inference and missing information The refactoring is now complete and all files consistently use the improved InferenceResult class with string-based status initialization! Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
) | ||
|
||
# Get first input shape as base | ||
first_shape = node.inputs[0].shape |
Check failure
Code scanning / lintrunner
MYPY/union-attr Error
return _common.InferenceResult( | ||
status="missing_info", msg="Concat input shapes cannot be None." | ||
) | ||
first_type = node.inputs[0].type |
Check failure
Code scanning / lintrunner
MYPY/union-attr Error
) | ||
|
||
# Create shape from the tensor dimensions | ||
output_shape = ir.Shape(tensor.shape) |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
logger.warning( | ||
"Squeeze operation has symbolic dimension %s, assuming it is not 1.", dim | ||
) | ||
output_dims.append(dim) |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
output_shape = _compute_output_shape_no_axes(input_shape) | ||
else: | ||
try: | ||
axes = _normalize_axes(axes, rank) |
Check failure
Code scanning / lintrunner
MYPY/assignment Error
axes = _normalize_axes(axes, rank) | ||
except ValueError as e: | ||
return _common.InferenceResult(status="invalid_node", msg=str(e)) | ||
output_shape = _compute_output_shape_with_axes(input_shape, axes) |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
rank1 = len(shape1) | ||
rank2 = len(shape2) | ||
new_rank = max(rank1, rank2) | ||
new_dims = [] |
Check failure
Code scanning / lintrunner
MYPY/var-annotated Error
I've successfully updated all operation files to use direct slicing on ir.Shape objects instead of accessing .dims and then taking slices. Here are the specific changes made: 1. matmul.py - Before: output_dims = [*rhs_shape.dims[:-2], rhs_shape.dims[-1]] - After: output_dims = [*rhs_shape[:-2], rhs_shape[-1]] - Before: output_dims = list(lhs_shape.dims[:-1]) - After: output_shape = ir.Shape(lhs_shape[:-1]) - Before: lhs_batch = lhs_shape.dims[:-2] - After: lhs_batch = lhs_shape[:-2] - Before: output_dims = [*batch_shape.dims, lhs_shape.dims[-2], rhs_shape.dims[-1]] - After: output_dims = [*batch_shape, lhs_shape[-2], rhs_shape[-1]] 2. transpose.py - Before: output_dims.append(input_shape.dims[axis]) - After: output_dims.append(input_shape[axis]) 3. concat.py - Before: output_dims = list(first_shape.dims) - After: output_dims = list(first_shape) 4. unsqueeze.py - Before: output_dims.append(input_shape.d ims[input_axis]) - After: output_dims.append(input_shape[in put_axis]) 5. squeeze.py - Before: for dim in input_shape.dims: - After: for dim in input_shape: - Before: output_dims = [dim for i, dim in enumerate(input_shape.dims) if i not in axes] - After: output_dims = [dim for i, dim in enumerate(input_shape) if i not in axes] Key Benefits 1. Cleaner Code: Direct slicing is more readable and pythonic 2. Less Verbose: Removes the need to access .dims explicitly 3. Consistent API: Uses the shape object's built-in slicing support 4. Type Safety: Shape slicing returns the correct types directly The code is now significantly cleaner and more intuitive, taking advantage of the ir.Shape object's built-in slicing capabilities! Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
for op_type in binary_ops: | ||
inferrers.append(BinaryInferrer(op_type)) | ||
|
||
return SymbolicInferenceEngine(inferrers, reconciliation_policy) |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
BinaryInferrer("Mul"), | ||
] | ||
|
||
return SymbolicInferenceEngine(inferrers, reconciliation_policy) |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
No description provided.