diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h index 83c936b7dcada..ae1efb9833649 100644 --- a/flang/include/flang/Optimizer/Support/Utils.h +++ b/flang/include/flang/Optimizer/Support/Utils.h @@ -27,6 +27,8 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" +#include "flang/Optimizer/CodeGen/TypeConverter.h" + namespace fir { /// Return the integer value of a arith::ConstantOp. inline std::int64_t toInt(mlir::arith::ConstantOp cop) { @@ -198,6 +200,42 @@ std::optional> getComponentLowerBoundsIfNonDefault( fir::RecordType recordType, llvm::StringRef component, mlir::ModuleOp module, const mlir::SymbolTable *symbolTable = nullptr); +// Convert FIR type to LLVM without turning fir.box into memory +// reference. +mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter, + mlir::Type firType); + +/// Generate a LLVM constant value of type `ity`, using the provided offset. +mlir::LLVM::ConstantOp +genConstantIndex(mlir::Location loc, mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter, + std::int64_t offset); + +/// Helper function for generating the LLVM IR that computes the distance +/// in bytes between adjacent elements pointed to by a pointer +/// of type \p ptrTy. The result is returned as a value of \p idxTy integer +/// type. +mlir::Value computeElementDistance(mlir::Location loc, + mlir::Type llvmObjectType, mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + const mlir::DataLayout &dataLayout); + +// Compute the alloc scale size (constant factors encoded in the array type). +// We do this for arrays without a constant interior or arrays of character with +// dynamic length arrays, since those are the only ones that get decayed to a +// pointer to the element type. +mlir::Value genAllocationScaleSize(mlir::Location loc, mlir::Type dataTy, + mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter); + +/// Perform an extension or truncation as needed on an integer value. Lowering +/// to the specific target may involve some sign-extending or truncation of +/// values, particularly to fit them from abstract box types to the +/// appropriate reified structures. +mlir::Value integerCast(const fir::LLVMTypeConverter &converter, + mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type ty, mlir::Value val, bool fold = false); } // namespace fir #endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index ecc04a6c9a2be..50c1765b12409 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -85,14 +85,6 @@ static inline mlir::Type getI8Type(mlir::MLIRContext *context) { return mlir::IntegerType::get(context, 8); } -static mlir::LLVM::ConstantOp -genConstantIndex(mlir::Location loc, mlir::Type ity, - mlir::ConversionPatternRewriter &rewriter, - std::int64_t offset) { - auto cattr = rewriter.getI64IntegerAttr(offset); - return rewriter.create(loc, ity, cattr); -} - static mlir::Block *createBlock(mlir::ConversionPatternRewriter &rewriter, mlir::Block *insertBefore) { assert(insertBefore && "expected valid insertion block"); @@ -203,39 +195,6 @@ getDependentTypeMemSizeFn(fir::RecordType recTy, fir::AllocaOp op, TODO(op.getLoc(), "did not find allocation function"); } -// Compute the alloc scale size (constant factors encoded in the array type). -// We do this for arrays without a constant interior or arrays of character with -// dynamic length arrays, since those are the only ones that get decayed to a -// pointer to the element type. -template -static mlir::Value -genAllocationScaleSize(OP op, mlir::Type ity, - mlir::ConversionPatternRewriter &rewriter) { - mlir::Location loc = op.getLoc(); - mlir::Type dataTy = op.getInType(); - auto seqTy = mlir::dyn_cast(dataTy); - fir::SequenceType::Extent constSize = 1; - if (seqTy) { - int constRows = seqTy.getConstantRows(); - const fir::SequenceType::ShapeRef &shape = seqTy.getShape(); - if (constRows != static_cast(shape.size())) { - for (auto extent : shape) { - if (constRows-- > 0) - continue; - if (extent != fir::SequenceType::getUnknownExtent()) - constSize *= extent; - } - } - } - - if (constSize != 1) { - mlir::Value constVal{ - genConstantIndex(loc, ity, rewriter, constSize).getResult()}; - return constVal; - } - return nullptr; -} - namespace { struct DeclareOpConversion : public fir::FIROpConversion { public: @@ -270,7 +229,7 @@ struct AllocaOpConversion : public fir::FIROpConversion { auto loc = alloc.getLoc(); mlir::Type ity = lowerTy().indexType(); unsigned i = 0; - mlir::Value size = genConstantIndex(loc, ity, rewriter, 1).getResult(); + mlir::Value size = fir::genConstantIndex(loc, ity, rewriter, 1).getResult(); mlir::Type firObjType = fir::unwrapRefType(alloc.getType()); mlir::Type llvmObjectType = convertObjectType(firObjType); if (alloc.hasLenParams()) { @@ -302,7 +261,8 @@ struct AllocaOpConversion : public fir::FIROpConversion { << scalarType << " with type parameters"; } } - if (auto scaleSize = genAllocationScaleSize(alloc, ity, rewriter)) + if (auto scaleSize = fir::genAllocationScaleSize( + alloc.getLoc(), alloc.getInType(), ity, rewriter)) size = rewriter.createOrFold(loc, ity, size, scaleSize); if (alloc.hasShapeOperands()) { @@ -479,7 +439,7 @@ struct BoxIsArrayOpConversion : public fir::FIROpConversion { auto loc = boxisarray.getLoc(); TypePair boxTyPair = getBoxTypePair(boxisarray.getVal().getType()); mlir::Value rank = getRankFromBox(loc, boxTyPair, a, rewriter); - mlir::Value c0 = genConstantIndex(loc, rank.getType(), rewriter, 0); + mlir::Value c0 = fir::genConstantIndex(loc, rank.getType(), rewriter, 0); rewriter.replaceOpWithNewOp( boxisarray, mlir::LLVM::ICmpPredicate::ne, rank, c0); return mlir::success(); @@ -815,7 +775,7 @@ struct ConvertOpConversion : public fir::FIROpConversion { // Do folding for constant inputs. if (auto constVal = fir::getIntIfConstant(op0)) { mlir::Value normVal = - genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0); + fir::genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0); rewriter.replaceOp(convert, normVal); return mlir::success(); } @@ -828,7 +788,7 @@ struct ConvertOpConversion : public fir::FIROpConversion { } // Compare the input with zero. - mlir::Value zero = genConstantIndex(loc, fromTy, rewriter, 0); + mlir::Value zero = fir::genConstantIndex(loc, fromTy, rewriter, 0); auto isTrue = rewriter.create( loc, mlir::LLVM::ICmpPredicate::ne, op0, zero); @@ -1075,21 +1035,6 @@ static mlir::SymbolRefAttr getMalloc(fir::AllocMemOp op, return getMallocInModule(mod, op, rewriter, indexType); } -/// Helper function for generating the LLVM IR that computes the distance -/// in bytes between adjacent elements pointed to by a pointer -/// of type \p ptrTy. The result is returned as a value of \p idxTy integer -/// type. -static mlir::Value -computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType, - mlir::Type idxTy, - mlir::ConversionPatternRewriter &rewriter, - const mlir::DataLayout &dataLayout) { - llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType); - unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType); - std::int64_t distance = llvm::alignTo(size, alignment); - return genConstantIndex(loc, idxTy, rewriter, distance); -} - /// Return value of the stride in bytes between adjacent elements /// of LLVM type \p llTy. The result is returned as a value of /// \p idxTy integer type. @@ -1098,7 +1043,7 @@ genTypeStrideInBytes(mlir::Location loc, mlir::Type idxTy, mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy, const mlir::DataLayout &dataLayout) { // Create a pointer type and use computeElementDistance(). - return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout); + return fir::computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout); } namespace { @@ -1117,7 +1062,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion { if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) TODO(loc, "fir.allocmem codegen of derived type with length parameters"); mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy); - if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter)) + if (auto scaleSize = + fir::genAllocationScaleSize(loc, heap.getInType(), ity, rewriter)) size = rewriter.create(loc, ity, size, scaleSize); for (mlir::Value opnd : adaptor.getOperands()) size = rewriter.create( @@ -1140,7 +1086,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion { mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy) const { - return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout()); + return fir::computeElementDistance(loc, llTy, idxTy, rewriter, + getDataLayout()); } }; } // namespace @@ -1324,7 +1271,7 @@ genCUFAllocDescriptor(mlir::Location loc, mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; mlir::Value sizeInBytes = - genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); + fir::genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine}; return rewriter .create(loc, fctTy, RTNAME_STRING(CUFAllocDescriptor), @@ -1580,7 +1527,7 @@ struct EmboxCommonConversion : public fir::FIROpConversion { // representation of derived types with pointer/allocatable components. // This has been seen in hashing algorithms using TRANSFER. mlir::Value zero = - genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0); + fir::genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0); descriptor = insertField(rewriter, loc, descriptor, {getLenParamFieldId(boxTy), 0}, zero); } @@ -1923,8 +1870,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion { bool hasSlice = !xbox.getSlice().empty(); unsigned sliceOffset = xbox.getSliceOperandIndex(); mlir::Location loc = xbox.getLoc(); - mlir::Value zero = genConstantIndex(loc, i64Ty, rewriter, 0); - mlir::Value one = genConstantIndex(loc, i64Ty, rewriter, 1); + mlir::Value zero = fir::genConstantIndex(loc, i64Ty, rewriter, 0); + mlir::Value one = fir::genConstantIndex(loc, i64Ty, rewriter, 1); mlir::Value prevPtrOff = one; mlir::Type eleTy = boxTy.getEleTy(); const unsigned rank = xbox.getRank(); @@ -1973,7 +1920,7 @@ struct XEmboxOpConversion : public EmboxCommonConversion { prevDimByteStride = getCharacterByteSize(loc, rewriter, charTy, adaptor.getLenParams()); } else { - prevDimByteStride = genConstantIndex( + prevDimByteStride = fir::genConstantIndex( loc, i64Ty, rewriter, charTy.getLen() * lowerTy().characterBitsize(charTy) / 8); } @@ -2131,7 +2078,7 @@ struct XReboxOpConversion : public EmboxCommonConversion { if (auto charTy = mlir::dyn_cast(inputEleTy)) { if (charTy.hasConstantLen()) { mlir::Value len = - genConstantIndex(loc, idxTy, rewriter, charTy.getLen()); + fir::genConstantIndex(loc, idxTy, rewriter, charTy.getLen()); lenParams.emplace_back(len); } else { mlir::Value len = getElementSizeFromBox(loc, idxTy, inputBoxTyPair, @@ -2140,7 +2087,7 @@ struct XReboxOpConversion : public EmboxCommonConversion { assert(!isInGlobalOp(rewriter) && "character target in global op must have constant length"); mlir::Value width = - genConstantIndex(loc, idxTy, rewriter, charTy.getFKind()); + fir::genConstantIndex(loc, idxTy, rewriter, charTy.getFKind()); len = rewriter.create(loc, idxTy, len, width); } lenParams.emplace_back(len); @@ -2194,8 +2141,9 @@ struct XReboxOpConversion : public EmboxCommonConversion { mlir::ConversionPatternRewriter &rewriter) const { mlir::Location loc = rebox.getLoc(); mlir::Value zero = - genConstantIndex(loc, lowerTy().indexType(), rewriter, 0); - mlir::Value one = genConstantIndex(loc, lowerTy().indexType(), rewriter, 1); + fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 0); + mlir::Value one = + fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 1); for (auto iter : llvm::enumerate(llvm::zip(extents, strides))) { mlir::Value extent = std::get<0>(iter.value()); unsigned dim = iter.index(); @@ -2227,7 +2175,7 @@ struct XReboxOpConversion : public EmboxCommonConversion { mlir::Location loc = rebox.getLoc(); mlir::Type byteTy = ::getI8Type(rebox.getContext()); mlir::Type idxTy = lowerTy().indexType(); - mlir::Value zero = genConstantIndex(loc, idxTy, rewriter, 0); + mlir::Value zero = fir::genConstantIndex(loc, idxTy, rewriter, 0); // Apply subcomponent and substring shift on base address. if (!rebox.getSubcomponent().empty() || !rebox.getSubstr().empty()) { // Cast to inputEleTy* so that a GEP can be used. @@ -2255,7 +2203,7 @@ struct XReboxOpConversion : public EmboxCommonConversion { // and strides. llvm::SmallVector slicedExtents; llvm::SmallVector slicedStrides; - mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1); + mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1); const bool sliceHasOrigins = !rebox.getShift().empty(); unsigned sliceOps = rebox.getSliceOperandIndex(); unsigned shiftOps = rebox.getShiftOperandIndex(); @@ -2328,7 +2276,7 @@ struct XReboxOpConversion : public EmboxCommonConversion { // which may be OK if all new extents are ones, the stride does not // matter, use one. mlir::Value stride = inputStrides.empty() - ? genConstantIndex(loc, idxTy, rewriter, 1) + ? fir::genConstantIndex(loc, idxTy, rewriter, 1) : inputStrides[0]; for (unsigned i = 0; i < rebox.getShape().size(); ++i) { mlir::Value rawExtent = operands[rebox.getShapeOperandIndex() + i]; @@ -2563,9 +2511,9 @@ struct XArrayCoorOpConversion unsigned shiftOffset = coor.getShiftOperandIndex(); unsigned sliceOffset = coor.getSliceOperandIndex(); auto sliceOps = coor.getSlice().begin(); - mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1); + mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1); mlir::Value prevExt = one; - mlir::Value offset = genConstantIndex(loc, idxTy, rewriter, 0); + mlir::Value offset = fir::genConstantIndex(loc, idxTy, rewriter, 0); const bool isShifted = !coor.getShift().empty(); const bool isSliced = !coor.getSlice().empty(); const bool baseIsBoxed = @@ -2895,7 +2843,7 @@ struct CoordinateOpConversion // of lower bound aspects. This both accounts for dynamically sized // types and non contiguous arrays. auto idxTy = lowerTy().indexType(); - mlir::Value off = genConstantIndex(loc, idxTy, rewriter, 0); + mlir::Value off = fir::genConstantIndex(loc, idxTy, rewriter, 0); unsigned arrayDim = arrTy.getDimension(); for (unsigned dim = 0; dim < arrayDim && it != end; ++dim, ++it) { mlir::Value stride = @@ -3808,7 +3756,7 @@ struct IsPresentOpConversion : public fir::FIROpConversion { ptr = rewriter.create(loc, ptr, 0); } mlir::LLVM::ConstantOp c0 = - genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0); + fir::genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0); auto addr = rewriter.create(loc, idxTy, ptr); rewriter.replaceOpWithNewOp( isPresent, mlir::LLVM::ICmpPredicate::ne, addr, c0); diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp index 37f1c9f97e1ce..b2c6b880c6f52 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp @@ -21,6 +21,7 @@ #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/Support/FatalError.h" #include "flang/Optimizer/Support/InternalNames.h" +#include "flang/Optimizer/Support/Utils.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -125,10 +126,49 @@ struct PrivateClauseOpConversion return mlir::success(); } }; + +// FIR Op specific conversion for TargetAllocMemOp +struct TargetAllocMemOpConversion + : public OpenMPFIROpConversion { + using OpenMPFIROpConversion::OpenMPFIROpConversion; + + llvm::LogicalResult + matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Type heapTy = allocmemOp.getAllocatedType(); + mlir::Location loc = allocmemOp.getLoc(); + auto ity = lowerTy().indexType(); + mlir::Type dataTy = fir::unwrapRefType(heapTy); + mlir::Type llvmObjectTy = fir::convertObjectType(lowerTy(), dataTy); + if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) + TODO(loc, "omp.target_allocmem codegen of derived type with length " + "parameters"); + mlir::Value size = fir::computeElementDistance( + loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout()); + if (auto scaleSize = fir::genAllocationScaleSize( + loc, allocmemOp.getInType(), ity, rewriter)) + size = rewriter.create(loc, ity, size, scaleSize); + for (mlir::Value opnd : adaptor.getOperands().drop_front()) + size = rewriter.create( + loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd)); + auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); + auto mallocTy = + mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); + if (mallocTyWidth != ity.getIntOrFloatBitWidth()) + size = integerCast(lowerTy(), loc, rewriter, mallocTy, size); + rewriter.modifyOpInPlace(allocmemOp, [&]() { + allocmemOp.setInType(rewriter.getI8Type()); + allocmemOp.getTypeparamsMutable().clear(); + allocmemOp.getTypeparamsMutable().append(size); + }); + return mlir::success(); + } +}; } // namespace void fir::populateOpenMPFIRToLLVMConversionPatterns( const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) { patterns.add(converter); patterns.add(converter); + patterns.add(converter); } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index b6bf2753b80ce..958fc46c9e41c 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -107,7 +107,6 @@ static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { } /// Parser shared by Alloca and Allocmem -/// /// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type /// ( `(` $typeparams `)` )? ( `,` $shape )? /// attr-dict-without-keyword diff --git a/flang/lib/Optimizer/Support/Utils.cpp b/flang/lib/Optimizer/Support/Utils.cpp index 5d663e28336c0..6dc80ff8d18a6 100644 --- a/flang/lib/Optimizer/Support/Utils.cpp +++ b/flang/lib/Optimizer/Support/Utils.cpp @@ -50,3 +50,81 @@ std::optional> fir::getComponentLowerBoundsIfNonDefault( return componentInfo.getLowerBounds(); return std::nullopt; } + +mlir::Type fir::convertObjectType(const fir::LLVMTypeConverter &converter, + mlir::Type firType) { + if (auto boxTy = mlir::dyn_cast(firType)) + return converter.convertBoxTypeAsStruct(boxTy); + return converter.convertType(firType); +} + +mlir::LLVM::ConstantOp +fir::genConstantIndex(mlir::Location loc, mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter, + std::int64_t offset) { + auto cattr = rewriter.getI64IntegerAttr(offset); + return rewriter.create(loc, ity, cattr); +} + +mlir::Value +fir::computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType, + mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + const mlir::DataLayout &dataLayout) { + llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType); + unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType); + std::int64_t distance = llvm::alignTo(size, alignment); + return fir::genConstantIndex(loc, idxTy, rewriter, distance); +} + +mlir::Value +fir::genAllocationScaleSize(mlir::Location loc, mlir::Type dataTy, + mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter) { + auto seqTy = mlir::dyn_cast(dataTy); + fir::SequenceType::Extent constSize = 1; + if (seqTy) { + int constRows = seqTy.getConstantRows(); + const fir::SequenceType::ShapeRef &shape = seqTy.getShape(); + if (constRows != static_cast(shape.size())) { + for (auto extent : shape) { + if (constRows-- > 0) + continue; + if (extent != fir::SequenceType::getUnknownExtent()) + constSize *= extent; + } + } + } + + if (constSize != 1) { + mlir::Value constVal{ + fir::genConstantIndex(loc, ity, rewriter, constSize).getResult()}; + return constVal; + } + return nullptr; +} + +mlir::Value fir::integerCast(const fir::LLVMTypeConverter &converter, + mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type ty, mlir::Value val, bool fold) { + auto valTy = val.getType(); + // If the value was not yet lowered, lower its type so that it can + // be used in getPrimitiveTypeSizeInBits. + if (!mlir::isa(valTy)) + valTy = converter.convertType(valTy); + auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); + auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy); + if (fold) { + if (toSize < fromSize) + return rewriter.createOrFold(loc, ty, val); + if (toSize > fromSize) + return rewriter.createOrFold(loc, ty, val); + } else { + if (toSize < fromSize) + return rewriter.create(loc, ty, val); + if (toSize > fromSize) + return rewriter.create(loc, ty, val); + } + return val; +} diff --git a/flang/test/Fir/omp_target_allocmem_freemem.fir b/flang/test/Fir/omp_target_allocmem_freemem.fir new file mode 100644 index 0000000000000..03eb94acb1ac7 --- /dev/null +++ b/flang/test/Fir/omp_target_allocmem_freemem.fir @@ -0,0 +1,294 @@ +// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s + +// UNSUPPORTED: system-windows +// Disabled on 32-bit targets due to the additional `trunc` opcodes required +// UNSUPPORTED: target-x86 +// UNSUPPORTED: target=sparc-{{.*}} +// UNSUPPORTED: target=sparcel-{{.*}} + +// CHECK-LABEL: define void @omp_target_allocmem_scalar_nonchar() { +// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 4, i32 0) +// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_scalar_nonchar() -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, i32 + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_scalars_nonchar() { +// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 400, i32 0) +// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_scalars_nonchar() -> () { + %device = arith.constant 0 : i32 + %0 = arith.constant 100 : index + %1 = omp.target_allocmem %device : i32, i32, %0 + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_scalar_char() { +// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 10, i32 0) +// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_scalar_char() -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.char<1,10> + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_scalar_char_kind() { +// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 20, i32 0) +// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_scalar_char_kind() -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.char<2,10> + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_scalar_dynchar( +// CHECK-SAME: i32 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 1, [[TMP2]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0) +// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_scalar_dynchar(%l : i32) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.char<1,?>(%l : i32) + omp.target_freemem %device, %1 : i32, i64 + return +} + + +// CHECK-LABEL: define void @omp_target_allocmem_scalar_dynchar_kind( +// CHECK-SAME: i32 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 2, [[TMP2]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0) +// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_scalar_dynchar_kind(%l : i32) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.char<2,?>(%l : i32) + omp.target_freemem %device, %1 : i32, i64 + return +} + + +// CHECK-LABEL: define void @omp_target_allocmem_array_of_nonchar() { +// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 36, i32 0) +// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_array_of_nonchar() -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x3xi32> + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_of_char() { +// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 90, i32 0) +// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_array_of_char() -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>> + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_of_dynchar( +// CHECK-SAME: i32 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 9, [[TMP2]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0) +// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32) + omp.target_freemem %device, %1 : i32, i64 + return +} + + +// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_nonchar( +// CHECK-SAME: i64 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = mul i64 12, [[TMP0]] +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 1, [[TMP2]] +// CHECK-NEXT: [[TMP4:%.*]] = call ptr @omp_target_alloc(i64 [[TMP3]], i32 0) +// CHECK-NEXT: [[TMP5:%.*]] = ptrtoint ptr [[TMP4]] to i64 +// CHECK-NEXT: [[TMP6:%.*]] = inttoptr i64 [[TMP5]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP6]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_dynarray_of_nonchar(%e: index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x?xi32>, %e + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_nonchar2( +// CHECK-SAME: i64 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = mul i64 4, [[TMP0]] +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], [[TMP0]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0) +// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_dynarray_of_nonchar2(%e: index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array, %e, %e + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_char( +// CHECK-SAME: i64 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = mul i64 60, [[TMP0]] +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 1, [[TMP2]] +// CHECK-NEXT: [[TMP4:%.*]] = call ptr @omp_target_alloc(i64 [[TMP3]], i32 0) +// CHECK-NEXT: [[TMP5:%.*]] = ptrtoint ptr [[TMP4]] to i64 +// CHECK-NEXT: [[TMP6:%.*]] = inttoptr i64 [[TMP5]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP6]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_dynarray_of_char(%e : index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x?x!fir.char<2,10>>, %e + omp.target_freemem %device, %1 : i32, i64 + return +} + + +// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_char2( +// CHECK-SAME: i64 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = mul i64 20, [[TMP0]] +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], [[TMP0]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0) +// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_dynarray_of_char2(%e : index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array>, %e, %e + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_dynchar( +// CHECK-SAME: i32 [[TMP0:%.*]], i64 [[TMP1:%.*]]) { +// CHECK-NEXT: [[TMP3:%.*]] = sext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 6, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], [[TMP1]] +// CHECK-NEXT: [[TMP6:%.*]] = mul i64 1, [[TMP5]] +// CHECK-NEXT: [[TMP7:%.*]] = call ptr @omp_target_alloc(i64 [[TMP6]], i32 0) +// CHECK-NEXT: [[TMP8:%.*]] = ptrtoint ptr [[TMP7]] to i64 +// CHECK-NEXT: [[TMP9:%.*]] = inttoptr i64 [[TMP8]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP9]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_dynarray_of_dynchar(%l: i32, %e : index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x?x!fir.char<2,?>>(%l : i32), %e + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_dynchar2( +// CHECK-SAME: i32 [[TMP0:%.*]], i64 [[TMP1:%.*]]) { +// CHECK-NEXT: [[TMP3:%.*]] = sext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 2, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], [[TMP1]] +// CHECK-NEXT: [[TMP6:%.*]] = mul i64 [[TMP5]], [[TMP1]] +// CHECK-NEXT: [[TMP7:%.*]] = mul i64 1, [[TMP6]] +// CHECK-NEXT: [[TMP8:%.*]] = call ptr @omp_target_alloc(i64 [[TMP7]], i32 0) +// CHECK-NEXT: [[TMP9:%.*]] = ptrtoint ptr [[TMP8]] to i64 +// CHECK-NEXT: [[TMP10:%.*]] = inttoptr i64 [[TMP9]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP10]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_dynarray_of_dynchar2(%l: i32, %e : index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array>(%l : i32), %e, %e + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_with_holes_nonchar( +// CHECK-SAME: i64 [[TMP0:%.*]], i64 [[TMP1:%.*]]) { +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 240, [[TMP0]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP3]], [[TMP1]] +// CHECK-NEXT: [[TMP5:%.*]] = mul i64 1, [[TMP4]] +// CHECK-NEXT: [[TMP6:%.*]] = call ptr @omp_target_alloc(i64 [[TMP5]], i32 0) +// CHECK-NEXT: [[TMP7:%.*]] = ptrtoint ptr [[TMP6]] to i64 +// CHECK-NEXT: [[TMP8:%.*]] = inttoptr i64 [[TMP7]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP8]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_array_with_holes_nonchar(%0 : index, %1 : index) -> () { + %device = arith.constant 0 : i32 + %2 = omp.target_allocmem %device : i32, !fir.array<4x?x3x?x5xi32>, %0, %1 + omp.target_freemem %device, %2 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_with_holes_char( +// CHECK-SAME: i64 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = mul i64 240, [[TMP0]] +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 1, [[TMP2]] +// CHECK-NEXT: [[TMP4:%.*]] = call ptr @omp_target_alloc(i64 [[TMP3]], i32 0) +// CHECK-NEXT: [[TMP5:%.*]] = ptrtoint ptr [[TMP4]] to i64 +// CHECK-NEXT: [[TMP6:%.*]] = inttoptr i64 [[TMP5]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP6]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_array_with_holes_char(%e: index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x?x4x!fir.char<2,10>>, %e + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_with_holes_dynchar( +// CHECK-SAME: i64 [[TMP0:%.*]], i64 [[TMP1:%.*]]) { +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 24, [[TMP0]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP3]], [[TMP1]] +// CHECK-NEXT: [[TMP5:%.*]] = mul i64 1, [[TMP4]] +// CHECK-NEXT: [[TMP6:%.*]] = call ptr @omp_target_alloc(i64 [[TMP5]], i32 0) +// CHECK-NEXT: [[TMP7:%.*]] = ptrtoint ptr [[TMP6]] to i64 +// CHECK-NEXT: [[TMP8:%.*]] = inttoptr i64 [[TMP7]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP8]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_array_with_holes_dynchar(%arg0: index, %arg1: index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x?x4x!fir.char<2,?>>(%arg0 : index), %arg1 + omp.target_freemem %device, %1 : i32, i64 + return +} diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 8cf18b43450ab..57ddc41e4ed9b 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2113,4 +2113,98 @@ def AllocateDirOp : OpenMP_Op<"allocate_dir", clauses = [ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// TargetAllocMemOp +//===----------------------------------------------------------------------===// + +def TargetAllocMemOp : OpenMP_Op<"target_allocmem", + [MemoryEffects<[MemAlloc]>, AttrSizedOperandSegments]> { + let summary = "allocate storage on an openmp device for an object of a given type"; + + let description = [{ + Allocates memory on the specified OpenMP device for an object of the given type. + Returns an integer value representing the device pointer to the allocated memory. + The memory is uninitialized after allocation. Operations must be paired with + `omp.target_freemem` to avoid memory leaks. + + * `$device`: The integer ID of the OpenMP device where the memory will be allocated. + * `$in_type`: The type of the object for which memory is being allocated. + For arrays, this can be a static or dynamic array type. + * `$uniq_name`: An optional unique name for the allocated memory. + * `$bindc_name`: An optional name used for C interoperability. + * `$typeparams`: Runtime type parameters for polymorphic or parameterized types. + These are typically integer values that define aspects of a type not fixed at compile time. + * `$shape`: Runtime shape operands for dynamic arrays. + Each operand is an integer value representing the extent of a specific dimension. + + ```mlir + // Allocate a static 3x3 integer vector on device 0 + %device_0 = arith.constant 0 : i32 + %ptr_static = omp.target_allocmem %device_0 : i32, vector<3x3xi32> + // ... use %ptr_static ... + omp.target_freemem %device_0, %ptr_static : i32, i64 + + // Allocate a dynamic 2D Fortran array (fir.array) on device 1 + %device_1 = arith.constant 1 : i32 + %rows = arith.constant 10 : index + %cols = arith.constant 20 : index + %ptr_dynamic = omp.target_allocmem %device_1 : i32, !fir.array, %rows, %cols : index, index + // ... use %ptr_dynamic ... + omp.target_freemem %device_1, %ptr_dynamic : i32, i64 + ``` + }]; + + let arguments = (ins + Arg:$device, + TypeAttr:$in_type, + OptionalAttr:$uniq_name, + OptionalAttr:$bindc_name, + Variadic:$typeparams, + Variadic:$shape + ); + let results = (outs I64); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + mlir::Type getAllocatedType(); + }]; +} + +//===----------------------------------------------------------------------===// +// TargetFreeMemOp +//===----------------------------------------------------------------------===// + +def TargetFreeMemOp : OpenMP_Op<"target_freemem", + [MemoryEffects<[MemFree]>]> { + let summary = "free memory on an openmp device"; + + let description = [{ + Deallocates memory on the specified OpenMP device that was previously + allocated by an `omp.target_allocmem` operation. After this operation, the + deallocated memory is in an undefined state and should not be accessed. + It is crucial to ensure that all accesses to the memory region are completed + before `omp.target_freemem` is called to avoid undefined behavior. + + * `$device`: The integer ID of the OpenMP device from which the memory will be freed. + * `$heapref`: The integer value representing the device pointer to the memory + to be deallocated, which was previously returned by `omp.target_allocmem`. + + ```mlir + // Example of allocating and freeing memory on an OpenMP device + %device_id = arith.constant 0 : i32 + %allocated_ptr = omp.target_allocmem %device_id : i32, vector<3x3xi32> + // ... operations using %allocated_ptr on the device ... + omp.target_freemem %device_id, %allocated_ptr : i32, i64 + ``` + }]; + + let arguments = (ins + Arg:$device, + Arg:$heapref + ); + let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))"; +} + #endif // OPENMP_OPS diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 769aee64e1695..49a26d8cd156f 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -3878,6 +3878,107 @@ LogicalResult AllocateDirOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// TargetAllocMemOp +//===----------------------------------------------------------------------===// + +mlir::Type omp::TargetAllocMemOp::getAllocatedType() { + return getInTypeAttr().getValue(); +} + +/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype, +/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )? +/// attr-dict-without-keyword +static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto &builder = parser.getBuilder(); + bool hasOperands = false; + std::int32_t typeparamsSize = 0; + + // Parse device number as a new operand + mlir::OpAsmParser::UnresolvedOperand deviceOperand; + mlir::Type deviceType; + if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType)) + return mlir::failure(); + if (parser.resolveOperand(deviceOperand, deviceType, result.operands)) + return mlir::failure(); + if (parser.parseComma()) + return mlir::failure(); + + mlir::Type intype; + if (parser.parseType(intype)) + return mlir::failure(); + result.addAttribute("in_type", mlir::TypeAttr::get(intype)); + llvm::SmallVector operands; + llvm::SmallVector typeVec; + if (!parser.parseOptionalLParen()) { + // parse the LEN params of the derived type. ( : ) + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || + parser.parseColonTypeList(typeVec) || parser.parseRParen()) + return mlir::failure(); + typeparamsSize = operands.size(); + hasOperands = true; + } + std::int32_t shapeSize = 0; + if (!parser.parseOptionalComma()) { + // parse size to scale by, vector of n dimensions of type index + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None)) + return mlir::failure(); + shapeSize = operands.size() - typeparamsSize; + auto idxTy = builder.getIndexType(); + for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) + typeVec.push_back(idxTy); + hasOperands = true; + } + if (hasOperands && + parser.resolveOperands(operands, typeVec, parser.getNameLoc(), + result.operands)) + return mlir::failure(); + + mlir::Type restype = builder.getIntegerType(64); + if (!restype) { + parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; + return mlir::failure(); + } + llvm::SmallVector segmentSizes{1, typeparamsSize, shapeSize}; + result.addAttribute("operandSegmentSizes", + builder.getDenseI32ArrayAttr(segmentSizes)); + if (parser.parseOptionalAttrDict(result.attributes) || + parser.addTypeToList(restype, result.types)) + return mlir::failure(); + return mlir::success(); +} + +mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseTargetAllocMemOp(parser, result); +} + +void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) { + p << " "; + p.printOperand(getDevice()); + p << " : "; + p << getDevice().getType(); + p << ", "; + p << getInType(); + if (!getTypeparams().empty()) { + p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')'; + } + for (auto sh : getShape()) { + p << ", "; + p.printOperand(sh); + } + p.printOptionalAttrDict((*this)->getAttrs(), + {"in_type", "operandSegmentSizes"}); +} + +llvm::LogicalResult omp::TargetAllocMemOp::verify() { + mlir::Type outType = getType(); + if (!mlir::dyn_cast(outType)) + return emitOpError("must be a integer type"); + return mlir::success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 3185f28fe6681..145433faaf5c9 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5858,6 +5858,85 @@ static bool isTargetDeviceOp(Operation *op) { return false; } +static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder, + llvm::Module *llvmModule) { + llvm::Type *i64Ty = builder.getInt64Ty(); + llvm::Type *i32Ty = builder.getInt32Ty(); + llvm::Type *returnType = builder.getPtrTy(0); + llvm::FunctionType *fnType = + llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false); + llvm::Function *func = cast( + llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee()); + return func; +} + +static LogicalResult +convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto allocMemOp = cast(opInst); + if (!allocMemOp) + return failure(); + + // Get "omp_target_alloc" function + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule); + // Get the corresponding device value in llvm + mlir::Value deviceNum = allocMemOp.getDevice(); + llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); + // Get the allocation size. + llvm::DataLayout dataLayout = llvmModule->getDataLayout(); + mlir::Type heapTy = allocMemOp.getAllocatedType(); + llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy); + llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy); + llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue()); + for (auto typeParam : allocMemOp.getTypeparams()) + allocSize = + builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam)); + // Create call to "omp_target_alloc" with the args as translated llvm values. + llvm::CallInst *call = + builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum}); + llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty()); + + // Map the result + moduleTranslation.mapValue(allocMemOp.getResult(), resultI64); + return success(); +} + +static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder, + llvm::Module *llvmModule) { + llvm::Type *ptrTy = builder.getPtrTy(0); + llvm::Type *i32Ty = builder.getInt32Ty(); + llvm::Type *voidTy = builder.getVoidTy(); + llvm::FunctionType *fnType = + llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false); + llvm::Function *func = dyn_cast( + llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee()); + return func; +} + +static LogicalResult +convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto freeMemOp = cast(opInst); + if (!freeMemOp) + return failure(); + + // Get "omp_target_free" function + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule); + // Get the corresponding device value in llvm + mlir::Value deviceNum = freeMemOp.getDevice(); + llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); + // Get the corresponding heapref value in llvm + mlir::Value heapref = freeMemOp.getHeapref(); + llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref); + // Convert heapref int to ptr and call "omp_target_free" + llvm::Value *intToPtr = + builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0)); + builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum}); + return success(); +} + /// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including /// OpenMP runtime calls). static LogicalResult @@ -6032,6 +6111,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, // the omp.canonical_loop. return applyUnrollHeuristic(op, builder, moduleTranslation); }) + .Case([&](omp::TargetAllocMemOp) { + return convertTargetAllocMemOp(*op, builder, moduleTranslation); + }) + .Case([&](omp::TargetFreeMemOp) { + return convertTargetFreeMemOp(*op, builder, moduleTranslation); + }) .Default([&](Operation *inst) { return inst->emitError() << "not yet implemented: " << inst->getName(); diff --git a/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir b/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir new file mode 100644 index 0000000000000..1bc97609ccff4 --- /dev/null +++ b/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s -convert-openmp-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s + +// This file contains MLIR test cases for omp.target_allocmem and omp.target_freemem + +// CHECK-LABEL: test_alloc_free_i64 +// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 8, i32 0) +// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64 +// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr +// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0) +// CHECK: ret void +llvm.func @test_alloc_free_i64() -> () { + %device = llvm.mlir.constant(0 : i32) : i32 + %1 = omp.target_allocmem %device : i32, i64 + omp.target_freemem %device, %1 : i32, i64 + llvm.return +} + +// CHECK-LABEL: test_alloc_free_vector_1d_f32 +// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 64, i32 0) +// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64 +// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr +// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0) +// CHECK: ret void +llvm.func @test_alloc_free_vector_1d_f32() -> () { + %device = llvm.mlir.constant(0 : i32) : i32 + %1 = omp.target_allocmem %device : i32, vector<16xf32> + omp.target_freemem %device, %1 : i32, i64 + llvm.return +} + +// CHECK-LABEL: test_alloc_free_vector_2d_f32 +// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 1024, i32 0) +// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64 +// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr +// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0) +// CHECK: ret void +llvm.func @test_alloc_free_vector_2d_f32() -> () { + %device = llvm.mlir.constant(0 : i32) : i32 + %1 = omp.target_allocmem %device : i32, vector<16x16xf32> + omp.target_freemem %device, %1 : i32, i64 + llvm.return +}