diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index f93dd901535c3..20889558be314 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1490,9 +1490,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> { }]; let arguments = (ins - Tosa_I1Tensor:$input1, - Tosa_Tensor:$input2, - Tosa_Tensor:$input3 + Tosa_I1Tensor:$input1, // pred + Tosa_Tensor:$input2, // on true + Tosa_Tensor:$input3 // on false ); let results = (outs @@ -1512,6 +1512,13 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> { operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3) `)` `->` type($output) }]; + + let extraClassDeclaration = [{ + // Custom getters for readability + ::mlir::TypedValue<::mlir::TensorType> getPred() { return getInput1(); } + ::mlir::TypedValue<::mlir::TensorType> getOnTrue() { return getInput2(); } + ::mlir::TypedValue<::mlir::TensorType> getOnFalse() { return getInput3(); } + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index c19d3733769b7..1d21096e8920b 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -344,7 +344,7 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { return failure(); rewriter.modifyOpInPlace(op, [&]() { op.getOperation()->setOperands( - {notOp.getInput1(), op.getInput3(), op.getInput2()}); + {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()}); }); return success(); } @@ -1510,8 +1510,8 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { } OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { - if (getInput2() == getInput3()) - return getInput2(); + if (getOnTrue() == getOnFalse()) + return getOnTrue(); auto predicate = llvm::dyn_cast_if_present(adaptor.getInput1()); @@ -1520,8 +1520,8 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { if (!predicate.isSplat()) return {}; - return predicate.getSplatValue().getBoolValue() ? getInput2() - : getInput3(); + return predicate.getSplatValue().getBoolValue() ? getOnTrue() + : getOnFalse(); } OpFoldResult TileOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index a32e4ccbed594..1dd392a9b8099 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -3819,16 +3819,16 @@ LogicalResult ReverseOp::verify() { LogicalResult tosa::SelectOp::verify() { // verify input2 and input3 have same element type as output - if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(), + if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(), /* outType = */ getOutput().getType()) .failed() || - verifySameElementTypes(*this, /* inType = */ getInput3().getType(), + verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(), /* outType = */ getOutput().getType()) .failed()) { return failure(); } // verify input1 has element type of bool - auto predicateType = llvm::dyn_cast(getInput1().getType()); + auto predicateType = llvm::dyn_cast(getPred().getType()); if (!predicateType) { return emitOpError("expect shaped tensor for input1, got ") << getInput1().getType(); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp index 02a3ad83bdefa..7997753469527 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -169,9 +169,9 @@ struct ConvertTosaOp : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::SelectOp tosaOp, PatternRewriter &rewriter) const override { - Value input1 = tosaOp.getInput1(); - Value input2 = tosaOp.getInput2(); - Value input3 = tosaOp.getInput3(); + Value input1 = tosaOp.getPred(); + Value input2 = tosaOp.getOnTrue(); + Value input3 = tosaOp.getOnFalse(); Value output = tosaOp.getResult(); auto outputType = dyn_cast(output.getType()); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index de08e7e9a4394..a4edccfd4c9c7 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -188,8 +188,8 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) { template <> LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) { - addValue(op.getInput2()); - addValue(op.getInput3()); + addValue(op.getOnTrue()); + addValue(op.getOnFalse()); addValue(op.getOutput()); return success(); }