From 59bc1c921519967d71837a0833023a6dbccf9045 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Fri, 11 Apr 2025 13:30:38 +0100 Subject: [PATCH 1/2] [Flang][MLIR][OpenMP] Improve use_device_* handling This patch updates MLIR op verifiers for operations taking arguments that must always be defined by an `omp.map.info` operation to check this requirement. It also modifies Flang lowering for `use_device_{addr, ptr}`, as well as the custom MLIR printer and parser for these clauses, to support initializing it to `OMP_MAP_RETURN_PARAM` and represent this in the MLIR representation as `return_param`. This internal mapping flag is what eventually is used for variables passed via these clauses into the target region when translating to LLVM IR, so making it explicit in Flang and MLIR removes an inconsistency in the current representation. --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 6 +-- flang/lib/Lower/OpenMP/Utils.cpp | 8 ++-- .../Fir/convert-to-llvm-openmp-and-fir.fir | 5 +- flang/test/Lower/OpenMP/target.f90 | 2 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 47 +++++++++++++++---- mlir/test/Dialect/OpenMP/ops.mlir | 10 ++-- 6 files changed, 57 insertions(+), 21 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index f4876256a378f..02454543d0a60 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1407,8 +1407,7 @@ bool ClauseProcessor::processUseDeviceAddr( const parser::CharBlock &source) { mlir::Location location = converter.genLocation(source); llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDeviceAddrVars, useDeviceSyms); @@ -1429,8 +1428,7 @@ bool ClauseProcessor::processUseDevicePtr( const parser::CharBlock &source) { mlir::Location location = converter.genLocation(source); llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDevicePtrVars, useDeviceSyms); diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index 3f4cfb8c11a9d..173dceb07b193 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -398,14 +398,16 @@ mlir::Value createParentSymAndGenIntermediateMaps( interimBounds, treatIndexAsSection); } - // Remove all map TO, FROM and TOFROM bits, from the intermediate - // allocatable maps, we simply wish to alloc or release them. It may be - // safer to just pass OMP_MAP_NONE as the map type, but we may still + // Remove all map-type bits (e.g. TO, FROM, etc.) from the intermediate + // allocatable maps, as we simply wish to alloc or release them. It may + // be safer to just pass OMP_MAP_NONE as the map type, but we may still // need some of the other map types the mapped member utilises, so for // now it's good to keep an eye on this. llvm::omp::OpenMPOffloadMappingFlags interimMapType = mapTypeBits; interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + interimMapType &= + ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; // Create a map for the intermediate member and insert it and it's // indices into the parentMemberIndices list to track it. diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir index 8019ecf7f6a05..b13921f822b4d 100644 --- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir +++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir @@ -423,14 +423,15 @@ func.func @_QPopenmp_target_data_region() { func.func @_QPomp_target_data_empty() { %0 = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_data_emptyEa"} - omp.target_data use_device_addr(%0 -> %arg0 : !fir.ref>) { + %1 = omp.map.info var_ptr(%0 : !fir.ref>, !fir.ref>) map_clauses(return_param) capture(ByRef) -> !fir.ref> {name = ""} + omp.target_data use_device_addr(%1 -> %arg0 : !fir.ref>) { omp.terminator } return } // CHECK-LABEL: llvm.func @_QPomp_target_data_empty -// CHECK: omp.target_data use_device_addr(%1 -> %{{.*}} : !llvm.ptr) { +// CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}} : !llvm.ptr) { // CHECK: } // ----- diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90 index 4815e6564fc7e..f04aacc63fc2b 100644 --- a/flang/test/Lower/OpenMP/target.f90 +++ b/flang/test/Lower/OpenMP/target.f90 @@ -544,7 +544,7 @@ subroutine omp_target_device_addr !CHECK: %[[VAL_0_DECL:.*]]:2 = hlfir.declare %[[VAL_0]] {fortran_attrs = #fir.var_attrs, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref>>) -> (!fir.ref>>, !fir.ref>>) !CHECK: %[[MAP_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr>) -> !fir.llvm_ptr> {name = ""} !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, !fir.box>) map_clauses(to) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr>) -> !fir.ref>> {name = "a"} - !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr>) -> !fir.llvm_ptr> {name = ""} + !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, i32) map_clauses(return_param) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr>) -> !fir.llvm_ptr> {name = ""} !CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, !fir.box>) map_clauses(to) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr>) -> !fir.ref>> {name = "a"} !CHECK: omp.target_data map_entries(%[[MAP]], %[[MAP_MEMBERS]] : {{.*}}) use_device_addr(%[[DEV_ADDR]] -> %[[ARG_0:.*]], %[[DEV_ADDR_MEMBERS]] -> %[[ARG_1:.*]] : !fir.ref>>, !fir.llvm_ptr>) { !$omp target data map(tofrom: a) use_device_addr(a) diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 2bf7aaa46db11..deff86d5c5ecb 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1521,6 +1521,9 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { if (mapTypeMod == "delete") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; + if (mapTypeMod == "return_param") + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + return success(); }; @@ -1583,6 +1586,12 @@ static void printMapClause(OpAsmPrinter &p, Operation *op, emitAllocRelease = false; mapTypeStrs.push_back("delete"); } + if (mapTypeToBitFlag( + mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) { + emitAllocRelease = false; + mapTypeStrs.push_back("return_param"); + } if (emitAllocRelease) mapTypeStrs.push_back("exit_release_or_enter_alloc"); @@ -1777,6 +1786,17 @@ static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) { // MapInfoOp //===----------------------------------------------------------------------===// +static LogicalResult verifyMapInfoDefinedArgs(Operation *op, + StringRef clauseName, + OperandRange vars) { + for (Value var : vars) + if (!llvm::isa_and_present(var.getDefiningOp())) + return op->emitOpError() + << "'" << clauseName + << "' arguments must be defined by 'omp.map.info' ops"; + return success(); +} + LogicalResult MapInfoOp::verify() { if (getMapperId() && !SymbolTable::lookupNearestSymbolFrom( @@ -1784,6 +1804,9 @@ LogicalResult MapInfoOp::verify() { return emitError("invalid mapper id"); } + if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers()))) + return failure(); + return success(); } @@ -1805,6 +1828,15 @@ LogicalResult TargetDataOp::verify() { "At least one of map, use_device_ptr_vars, or " "use_device_addr_vars operand must be present"); } + + if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr", + getUseDevicePtrVars()))) + return failure(); + + if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr", + getUseDeviceAddrVars()))) + return failure(); + return verifyMapClause(*this, getMapVars()); } @@ -1889,16 +1921,15 @@ void TargetOp::build(OpBuilder &builder, OperationState &state, } LogicalResult TargetOp::verify() { - LogicalResult verifyDependVars = - verifyDependVarList(*this, getDependKinds(), getDependVars()); - - if (failed(verifyDependVars)) - return verifyDependVars; + if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars()))) + return failure(); - LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars()); + if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr", + getHasDeviceAddrVars()))) + return failure(); - if (failed(verifyMapVars)) - return verifyMapVars; + if (failed(verifyMapClause(*this, getMapVars()))) + return failure(); return verifyPrivateVarsMapping(*this); } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index b7e16b7ec35e2..a9e4af035dbd7 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -802,10 +802,14 @@ func.func @omp_target_data (%if_cond : i1, %device : si32, %device_ptr: memref, tensor) map_clauses(always, from) capture(ByRef) -> memref {name = ""} omp.target_data if(%if_cond) device(%device : si32) map_entries(%mapv1 : memref){} - // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref, tensor) map_clauses(close, present, to) capture(ByRef) -> memref {name = ""} - // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref) use_device_addr(%[[VAL_3:.*]] -> %{{.*}} : memref) use_device_ptr(%[[VAL_4:.*]] -> %{{.*}} : memref) + // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%{{.*}} : memref, tensor) map_clauses(close, present, to) capture(ByRef) -> memref {name = ""} + // CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr(%{{.*}} : memref, tensor) map_clauses(return_param) capture(ByRef) -> memref {name = ""} + // CHECK: %[[DEV_PTR:.*]] = omp.map.info var_ptr(%{{.*}} : memref, tensor) map_clauses(return_param) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref) use_device_addr(%[[DEV_ADDR]] -> %{{.*}} : memref) use_device_ptr(%[[DEV_PTR]] -> %{{.*}} : memref) %mapv2 = omp.map.info var_ptr(%map1 : memref, tensor) map_clauses(close, present, to) capture(ByRef) -> memref {name = ""} - omp.target_data map_entries(%mapv2 : memref) use_device_addr(%device_addr -> %arg0 : memref) use_device_ptr(%device_ptr -> %arg1 : memref) { + %device_addrv1 = omp.map.info var_ptr(%device_addr : memref, tensor) map_clauses(return_param) capture(ByRef) -> memref {name = ""} + %device_ptrv1 = omp.map.info var_ptr(%device_ptr : memref, tensor) map_clauses(return_param) capture(ByRef) -> memref {name = ""} + omp.target_data map_entries(%mapv2 : memref) use_device_addr(%device_addrv1 -> %arg0 : memref) use_device_ptr(%device_ptrv1 -> %arg1 : memref) { omp.terminator } From c6954b3120a87eef7d9cf86f18d4ef342b2e7b25 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Fri, 11 Apr 2025 13:40:14 +0100 Subject: [PATCH 2/2] [MLIR][OpenMP] Assert on map translation functions, NFC This patch adds assertions to map-related MLIR to LLVM IR translation functions and utils to explicitly document whether they are intended for host or device compilation only. Over time, map-related handling has increased in complexity. This is compounded by the fact that some handling is device-specific and some is host-specific. By explicitly asserting on these functions on the expected compilation pass, the flow should become slighlty easier to follow. --- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 9f7b5605556e6..010c46358f7df 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3720,6 +3720,9 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams) { + assert(!ompBuilder.Config.isTargetDevice() && + "function only supported for host device codegen"); + // Map the first segment of our structure combinedInfo.Types.emplace_back( isTargetParams @@ -3828,6 +3831,8 @@ static void processMapMembersWithParent( llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) { + assert(!ompBuilder.Config.isTargetDevice() && + "function only supported for host device codegen"); auto parentClause = llvm::cast(mapData.MapClause[mapDataIndex]); @@ -3941,6 +3946,9 @@ static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams) { + assert(!ompBuilder.Config.isTargetDevice() && + "function only supported for host device codegen"); + auto parentClause = llvm::cast(mapData.MapClause[mapDataIndex]); @@ -3982,6 +3990,8 @@ static void createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder) { + assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && + "function only supported for host device codegen"); for (size_t i = 0; i < mapData.MapClause.size(); ++i) { // if it's declare target, skip it, it's handled separately. if (!mapData.IsDeclareTarget[i]) { @@ -4046,6 +4056,9 @@ static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, bool isTargetParams = false) { + assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && + "function only supported for host device codegen"); + // We wish to modify some of the methods in which arguments are // passed based on their capture type by the target region, this can // involve generating new loads and stores, which changes the @@ -4057,8 +4070,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder, // kernel arg structure. It primarily becomes relevant in cases like // bycopy, or byref range'd arrays. In the default case, we simply // pass thee pointer byref as both basePointer and pointer. - if (!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice()) - createAlteredByCaptureMap(mapData, moduleTranslation, builder); + createAlteredByCaptureMap(mapData, moduleTranslation, builder); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); @@ -4092,6 +4104,8 @@ emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, static llvm::Expected getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { + assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && + "function only supported for host device codegen"); auto declMapperOp = cast(op); std::string mapperFuncName = moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName( @@ -4108,6 +4122,8 @@ static llvm::Expected emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::StringRef mapperFuncName) { + assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && + "function only supported for host device codegen"); auto declMapperOp = cast(op); auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo(); DataLayout dl = DataLayout(declMapperOp->getParentOfType()); @@ -4597,6 +4613,8 @@ static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::Function *func) { + assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && + "function only supported for target device codegen"); for (size_t i = 0; i < mapData.MapClause.size(); ++i) { // In the case of declare target mapped variables, the basePointer is // the reference pointer generated by the convertDeclareTargetAttr @@ -4689,6 +4707,8 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase::InsertPoint allocaIP, llvm::IRBuilderBase::InsertPoint codeGenIP) { + assert(ompBuilder.Config.isTargetDevice() && + "function only supported for target device codegen"); builder.restoreIP(allocaIP); omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;