Skip to content

[CIR][Lowering] Fix Vector Comparison Lowering with -fno-signed-char/unsigned operand #1770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,9 @@ mlir::Value CirAttrToValue::visitCirAttr(cir::GlobalViewAttr globalAttr) {
}
auto resTy = addrOp.getType();
auto eltTy = converter->convertType(sourceType);
addrOp = rewriter.create<mlir::LLVM::GEPOp>(loc, resTy, eltTy, addrOp,
indices, mlir::LLVM::GEPNoWrapFlags::inbounds);
addrOp = rewriter.create<mlir::LLVM::GEPOp>(
loc, resTy, eltTy, addrOp, indices,
mlir::LLVM::GEPNoWrapFlags::inbounds);
}

if (auto intTy = mlir::dyn_cast<cir::IntType>(globalAttr.getType())) {
Expand Down Expand Up @@ -1205,8 +1206,9 @@ mlir::LogicalResult CIRToLLVMVTTAddrPointOpLowering::matchAndRewrite(
offsets.push_back(0);
offsets.push_back(adaptor.getOffset());
}
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, resultType, eltType,
llvmAddr, offsets, mlir::LLVM::GEPNoWrapFlags::inbounds);
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
op, resultType, eltType, llvmAddr, offsets,
mlir::LLVM::GEPNoWrapFlags::inbounds);
return mlir::success();
}

Expand Down Expand Up @@ -2052,9 +2054,24 @@ mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite(
auto elementType = elementTypeIfVector(op.getLhs().getType());
mlir::Value bitResult;
if (auto intType = mlir::dyn_cast<cir::IntType>(elementType)) {

auto isCIRZeroVector = [](mlir::Value value) {
if (auto constantOp = value.getDefiningOp<cir::ConstantOp>())
if (auto zeroAttr =
mlir::dyn_cast<cir::ZeroAttr>(constantOp.getValue()))
return true;
return false;
};

bool shouldUseSigned = intType.isSigned();
// Special treatment for sign-bit extraction patterns (lt comparison with
// zero), always use signed comparison to preserve the semantic intent
if (op.getKind() == cir::CmpOpKind::lt && isCIRZeroVector(op.getRhs()))
shouldUseSigned = true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be forcing signedness for no good reason, like you previously stated in the description, we know this comes from unsigned. This is a fair IR difference to live with, the question is weather this is too aggressive for the canonicalizer to be doing or if we want to move this into CIR simplify. I think the current behavior is good enough. Can you instead add a C source test for both unsigned and signed versions and capture that the canonicalizer kicks for one and not for the other?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the detailed response. Just to confirm — in the context of this PR, we want to preserve the current lowering logic, and my next step would be to add tests documenting the current behavior. Is that correct?


bitResult = rewriter.create<mlir::LLVM::ICmpOp>(
op.getLoc(),
convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()),
convertCmpKindToICmpPredicate(op.getKind(), shouldUseSigned),
adaptor.getLhs(), adaptor.getRhs());
} else if (mlir::isa<cir::FPTypeInterface>(elementType)) {
bitResult = rewriter.create<mlir::LLVM::FCmpOp>(
Expand Down Expand Up @@ -3881,8 +3898,9 @@ mlir::LogicalResult CIRToLLVMVTableAddrPointOpLowering::matchAndRewrite(
op.getAddressPointAttr().getOffset()};

assert(eltType && "Shouldn't ever be missing an eltType here");
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, targetType, eltType,
symAddr, offsets, mlir::LLVM::GEPNoWrapFlags::inbounds);
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
op, targetType, eltType, symAddr, offsets,
mlir::LLVM::GEPNoWrapFlags::inbounds);

return mlir::success();
}
Expand All @@ -3908,7 +3926,8 @@ mlir::LogicalResult CIRToLLVMVTableGetVirtualFnAddrOpLowering::matchAndRewrite(
llvm::SmallVector<mlir::LLVM::GEPArg> offsets =
llvm::SmallVector<mlir::LLVM::GEPArg>{op.getIndex()};
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
op, targetType, eltType, adaptor.getVptr(), offsets, mlir::LLVM::GEPNoWrapFlags::inbounds);
op, targetType, eltType, adaptor.getVptr(), offsets,
mlir::LLVM::GEPNoWrapFlags::inbounds);
return mlir::success();
}

Expand Down Expand Up @@ -4000,7 +4019,9 @@ mlir::LogicalResult CIRToLLVMInlineAsmOpLowering::matchAndRewrite(
op, llResTy, llvmOperands, op.getAsmStringAttr(), op.getConstraintsAttr(),
op.getSideEffectsAttr(),
/*is_align_stack*/ mlir::UnitAttr(),
/*tail_call_kind*/ mlir::LLVM::TailCallKindAttr::get(getContext(), mlir::LLVM::tailcallkind::TailCallKind::None),
/*tail_call_kind*/
mlir::LLVM::TailCallKindAttr::get(
getContext(), mlir::LLVM::tailcallkind::TailCallKind::None),
mlir::LLVM::AsmDialectAttr::get(getContext(), llDialect),
rewriter.getArrayAttr(opAttrs));

Expand Down
12 changes: 12 additions & 0 deletions clang/test/CIR/Lowering/vec-cmp.cir
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,15 @@ cir.func @vec_cmp(%0: !cir.vector<!s16i x 16>, %1: !cir.vector<!s16i x 16>) -> (
// MLIR-NEXT: %{{[0-9]+}} = llvm.icmp "slt" %arg0, %arg1 : vector<16xi16>
// MLIR-NEXT: %{{[0-9]+}} = llvm.bitcast %{{[0-9]+}} : vector<16xi1> to i16
// MLIR-NEXT: llvm.return

cir.func @vec_cmp_zero(%0: !cir.vector<!u8i x 16>) -> () {
%1 = cir.const #cir.zero : !cir.vector<!u8i x 16>
%2 = cir.vec.cmp(lt, %0, %1) : !cir.vector<!u8i x 16>, !cir.vector<!cir.int<u, 1> x 16>
%3 = cir.cast(bitcast, %2 : !cir.vector<!cir.int<u, 1> x 16>), !cir.int<u, 16>

cir.return
}

// MLIR: llvm.func @vec_cmp_zero
// MLIR: %{{[0-9]+}} = llvm.icmp "slt" %arg0, %{{[0-9]+}} : vector<16xi8>
// MLIR-NEXT: %{{[0-9]+}} = llvm.bitcast %{{[0-9]+}} : vector<16xi1> to i16
Loading