diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index 79582390d1294..2736da374a687 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -35,6 +35,10 @@ static llvm::cl::opt forceMatmulAsElemental( llvm::cl::desc("Expand hlfir.matmul as elemental operation"), llvm::cl::init(false)); +static llvm::cl::opt forceComplexDivAsArithmetic( + "flang-complex-div-converter", llvm::cl::init(false), llvm::cl::Hidden, + llvm::cl::desc("Force complex div as arithmetic calculation.")); + namespace { // Helper class to generate operations related to computing @@ -2320,6 +2324,98 @@ class ReshapeAsElementalConversion } }; +/// This rewrite pattern class performs a custom transformation on FIR +/// 'fir.call' operation that invoke the '__divdc3' runtime function, which is +/// typically used to perform double-precision complex division. +/// +/// If the 'forceComplexDivAsArithmetic' flag option is true, this pattern +/// matches call to '__divdc3', extracts the real and imaginary components of +/// the numerator and denominator, and replaces the function call with an +/// explicit computation using MLIR's arithmetic operations. +/// Specifically, it replaces the call to '__divdc3(x0, y0, x1, y1)' —where +/// (x0 + y0i) / (x1 + y1i) is the intended operation—with the mathematically +/// equivalent expression: +/// real_part = (x0*x1 + y0*y1) / (x1^2 + y1^2) +/// imag_part = (y0*x1 - x0*y1) / (x1^2 + y1^2) +/// The result is then reassembled into a 'complex' value using FIR's +/// 'InsertValueOp' instructions. +class ComplexDivisionConversion : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + llvm::LogicalResult + matchAndRewrite(fir::CallOp callOp, + mlir::PatternRewriter &rewriter) const override { + if (!forceComplexDivAsArithmetic) { + LLVM_DEBUG(llvm::dbgs() + << "Complex division with arithmetic calculation support is " + "currently disabled \n"); + return mlir::failure(); + } + fir::FirOpBuilder builder{rewriter, callOp.getOperation()}; + const mlir::Location &loc = callOp.getLoc(); + if (!callOp.getCallee()) { + LLVM_DEBUG(llvm::dbgs() + << "No callee found for CallOp at " << loc << "\n"); + return mlir::failure(); + } + + const mlir::SymbolRefAttr &callee = *callOp.getCallee(); + const auto &fctName = callee.getRootReference().getValue(); + if (fctName != "__divdc3") + return mlir::failure(); + + const mlir::Type &eleTy = callOp.getOperands()[0].getType(); + const mlir::Type &resTy = callOp.getResult(0).getType(); + + auto x0 = callOp.getOperands()[0]; // real part of numerator + auto y0 = callOp.getOperands()[1]; // imaginary part of numerator + auto x1 = callOp.getOperands()[2]; // real part of denominator + auto y1 = callOp.getOperands()[3]; // imaginary part of denominator + + // standard complex division formula: + // (x0 + y0i)/(x1 + y1i) = ((x0*x1 + y0*y1)/(x1^2 + y1^2)) + ((y0*x1 - + // x0*y1)/(x1^2 + y1^2))i + auto x0x1 = + rewriter.create(loc, eleTy, x0, x1); // x0 * x1 + auto x1Squared = + rewriter.create(loc, eleTy, x1, x1); // x1^2 + auto y0x1 = + rewriter.create(loc, eleTy, y0, x1); // y0 * x1 + auto x0y1 = + rewriter.create(loc, eleTy, x0, y1); // x0 * y1 + auto y0y1 = + rewriter.create(loc, eleTy, y0, y1); // y0 * y1 + auto y1Squared = + rewriter.create(loc, eleTy, y1, y1); // y1^2 + + auto denom = rewriter.create(loc, eleTy, x1Squared, + y1Squared); // x1^2 + y1^2 + auto realNumerator = rewriter.create( + loc, eleTy, x0x1, y0y1); // x0*x1 + y0*y1 + auto imagNumerator = rewriter.create( + loc, eleTy, y0x1, x0y1); // y0*x1 - x0*y1 + + // compute final real and imaginary parts + auto realResult = + rewriter.create(loc, eleTy, realNumerator, denom); + auto imagResult = + rewriter.create(loc, eleTy, imagNumerator, denom); + + // construct the result complex number + auto undefComplex = rewriter.create(loc, resTy); + auto index0 = builder.getArrayAttr( + {builder.getI32IntegerAttr(0)}); // index for real part + auto index1 = builder.getArrayAttr( + {builder.getI32IntegerAttr(1)}); // index for imag part + auto complexWithReal = rewriter.create( + loc, resTy, undefComplex, realResult, index0); // Insert real part + auto resComplex = rewriter.create( + loc, resTy, complexWithReal, imagResult, + index1); // Insert imaginary part + rewriter.replaceOp(callOp, resComplex.getResult()); + return mlir::success(); + } +}; + class SimplifyHLFIRIntrinsics : public hlfir::impl::SimplifyHLFIRIntrinsicsBase { public: @@ -2366,6 +2462,13 @@ class SimplifyHLFIRIntrinsics patterns.insert(context); patterns.insert(context); + /// If the 'forceComplexDivAsArithmetic' flag option is true, this pattern + /// matches call to '__divdc3', extracts the real and imaginary components + /// of the numerator and denominator, and replaces the function call with an + /// explicit computation using MLIR's arithmetic operations. + if (forceComplexDivAsArithmetic) + patterns.insert(context); + if (mlir::failed(mlir::applyPatternsGreedily( getOperation(), std::move(patterns), config))) { mlir::emitError(getOperation()->getLoc(), diff --git a/flang/test/Fir/target-rewrite-complex-division.fir b/flang/test/Fir/target-rewrite-complex-division.fir new file mode 100644 index 0000000000000..24744a37acb3d --- /dev/null +++ b/flang/test/Fir/target-rewrite-complex-division.fir @@ -0,0 +1,31 @@ +// RUN: fir-opt %s --simplify-hlfir-intrinsics --flang-complex-div-converter | FileCheck %s + +func.func @test_double_complex_div(%arg0: !fir.ref>, %arg1: !fir.ref>, %arg2: !fir.ref>) { + %0 = fir.load %arg1 : !fir.ref> + %1 = fir.load %arg2 : !fir.ref> + %2 = fir.extract_value %0, [0 : index] : (complex) -> f64 + %3 = fir.extract_value %0, [1 : index] : (complex) -> f64 + %4 = fir.extract_value %1, [0 : index] : (complex) -> f64 + %5 = fir.extract_value %1, [1 : index] : (complex) -> f64 + %6 = fir.call @__divdc3(%2, %3, %4, %5) fastmath : (f64, f64, f64, f64) -> complex + fir.store %6 to %arg0 : !fir.ref> + return +} + +// CHECK-LABEL: func.func @test_double_complex_div +// CHECK-NOT: fir.call @__divdc3 +// CHECK: %[[R1:.*]] = arith.mulf %2, %4 : f64 +// CHECK: %[[R2:.*]] = arith.mulf %4, %4 : f64 +// CHECK: %[[R3:.*]] = arith.mulf %3, %4 : f64 +// CHECK: %[[R4:.*]] = arith.mulf %2, %5 : f64 +// CHECK: %[[R5:.*]] = arith.mulf %3, %5 : f64 +// CHECK: %[[R6:.*]] = arith.mulf %5, %5 : f64 +// CHECK: %[[DENOM:.*]] = arith.addf %[[R2]], %[[R6]] : f64 +// CHECK: %[[NUM_RE:.*]] = arith.addf %[[R1]], %[[R5]] : f64 +// CHECK: %[[NUM_IM:.*]] = arith.subf %[[R3]], %[[R4]] : f64 +// CHECK: %[[RES_RE:.*]] = arith.divf %[[NUM_RE]], %[[DENOM]] : f64 +// CHECK: %[[RES_IM:.*]] = arith.divf %[[NUM_IM]], %[[DENOM]] : f64 +// CHECK: %[[U:.*]] = fir.undefined complex +// CHECK: %[[C0:.*]] = fir.insert_value %[[U]], %[[RES_RE]], [0 : i32] : (complex, f64) -> complex +// CHECK: %[[C1:.*]] = fir.insert_value %[[C0]], %[[RES_IM]], [1 : i32] : (complex, f64) -> complex +// CHECK: fir.store %[[C1]] to %arg0 : !fir.ref> \ No newline at end of file