Skip to content

[Flang] Canonicalize divdc3 calls into arithmetic-based complex division #146017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ static llvm::cl::opt<bool> forceMatmulAsElemental(
llvm::cl::desc("Expand hlfir.matmul as elemental operation"),
llvm::cl::init(false));

static llvm::cl::opt<bool> forceComplexDivAsArithmetic(
"flang-complex-div-converter", llvm::cl::init(false), llvm::cl::Hidden,
llvm::cl::desc("Force complex div as arithmetic calculation."));

namespace {

// Helper class to generate operations related to computing
Expand Down Expand Up @@ -2320,6 +2324,98 @@ class ReshapeAsElementalConversion
}
};

/// This rewrite pattern class performs a custom transformation on FIR
/// 'fir.call' operation that invoke the '__divdc3' runtime function, which is
/// typically used to perform double-precision complex division.
///
/// If the 'forceComplexDivAsArithmetic' flag option is true, this pattern
/// matches call to '__divdc3', extracts the real and imaginary components of
/// the numerator and denominator, and replaces the function call with an
/// explicit computation using MLIR's arithmetic operations.
/// Specifically, it replaces the call to '__divdc3(x0, y0, x1, y1)' —where
/// (x0 + y0i) / (x1 + y1i) is the intended operation—with the mathematically
/// equivalent expression:
/// real_part = (x0*x1 + y0*y1) / (x1^2 + y1^2)
/// imag_part = (y0*x1 - x0*y1) / (x1^2 + y1^2)
/// The result is then reassembled into a 'complex<f64>' value using FIR's
/// 'InsertValueOp' instructions.
class ComplexDivisionConversion : public mlir::OpRewritePattern<fir::CallOp> {
using OpRewritePattern::OpRewritePattern;
llvm::LogicalResult
matchAndRewrite(fir::CallOp callOp,
mlir::PatternRewriter &rewriter) const override {
if (!forceComplexDivAsArithmetic) {
LLVM_DEBUG(llvm::dbgs()
<< "Complex division with arithmetic calculation support is "
"currently disabled \n");
return mlir::failure();
}
fir::FirOpBuilder builder{rewriter, callOp.getOperation()};
const mlir::Location &loc = callOp.getLoc();
if (!callOp.getCallee()) {
LLVM_DEBUG(llvm::dbgs()
<< "No callee found for CallOp at " << loc << "\n");
return mlir::failure();
}

const mlir::SymbolRefAttr &callee = *callOp.getCallee();
const auto &fctName = callee.getRootReference().getValue();
if (fctName != "__divdc3")
return mlir::failure();

const mlir::Type &eleTy = callOp.getOperands()[0].getType();
const mlir::Type &resTy = callOp.getResult(0).getType();

auto x0 = callOp.getOperands()[0]; // real part of numerator
auto y0 = callOp.getOperands()[1]; // imaginary part of numerator
auto x1 = callOp.getOperands()[2]; // real part of denominator
auto y1 = callOp.getOperands()[3]; // imaginary part of denominator

// standard complex division formula:
// (x0 + y0i)/(x1 + y1i) = ((x0*x1 + y0*y1)/(x1^2 + y1^2)) + ((y0*x1 -
// x0*y1)/(x1^2 + y1^2))i
auto x0x1 =
rewriter.create<mlir::arith::MulFOp>(loc, eleTy, x0, x1); // x0 * x1
auto x1Squared =
rewriter.create<mlir::arith::MulFOp>(loc, eleTy, x1, x1); // x1^2
auto y0x1 =
rewriter.create<mlir::arith::MulFOp>(loc, eleTy, y0, x1); // y0 * x1
auto x0y1 =
rewriter.create<mlir::arith::MulFOp>(loc, eleTy, x0, y1); // x0 * y1
auto y0y1 =
rewriter.create<mlir::arith::MulFOp>(loc, eleTy, y0, y1); // y0 * y1
auto y1Squared =
rewriter.create<mlir::arith::MulFOp>(loc, eleTy, y1, y1); // y1^2

auto denom = rewriter.create<mlir::arith::AddFOp>(loc, eleTy, x1Squared,
y1Squared); // x1^2 + y1^2
auto realNumerator = rewriter.create<mlir::arith::AddFOp>(
loc, eleTy, x0x1, y0y1); // x0*x1 + y0*y1
auto imagNumerator = rewriter.create<mlir::arith::SubFOp>(
loc, eleTy, y0x1, x0y1); // y0*x1 - x0*y1

// compute final real and imaginary parts
auto realResult =
rewriter.create<mlir::arith::DivFOp>(loc, eleTy, realNumerator, denom);
auto imagResult =
rewriter.create<mlir::arith::DivFOp>(loc, eleTy, imagNumerator, denom);

// construct the result complex number
auto undefComplex = rewriter.create<fir::UndefOp>(loc, resTy);
auto index0 = builder.getArrayAttr(
{builder.getI32IntegerAttr(0)}); // index for real part
auto index1 = builder.getArrayAttr(
{builder.getI32IntegerAttr(1)}); // index for imag part
auto complexWithReal = rewriter.create<fir::InsertValueOp>(
loc, resTy, undefComplex, realResult, index0); // Insert real part
auto resComplex = rewriter.create<fir::InsertValueOp>(
loc, resTy, complexWithReal, imagResult,
index1); // Insert imaginary part
rewriter.replaceOp(callOp, resComplex.getResult());
return mlir::success();
}
};

class SimplifyHLFIRIntrinsics
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
public:
Expand Down Expand Up @@ -2366,6 +2462,13 @@ class SimplifyHLFIRIntrinsics
patterns.insert<DotProductConversion>(context);
patterns.insert<ReshapeAsElementalConversion>(context);

/// If the 'forceComplexDivAsArithmetic' flag option is true, this pattern
/// matches call to '__divdc3', extracts the real and imaginary components
/// of the numerator and denominator, and replaces the function call with an
/// explicit computation using MLIR's arithmetic operations.
if (forceComplexDivAsArithmetic)
patterns.insert<ComplexDivisionConversion>(context);

if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
Expand Down
31 changes: 31 additions & 0 deletions flang/test/Fir/target-rewrite-complex-division.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: fir-opt %s --simplify-hlfir-intrinsics --flang-complex-div-converter | FileCheck %s

func.func @test_double_complex_div(%arg0: !fir.ref<complex<f64>>, %arg1: !fir.ref<complex<f64>>, %arg2: !fir.ref<complex<f64>>) {
%0 = fir.load %arg1 : !fir.ref<complex<f64>>
%1 = fir.load %arg2 : !fir.ref<complex<f64>>
%2 = fir.extract_value %0, [0 : index] : (complex<f64>) -> f64
%3 = fir.extract_value %0, [1 : index] : (complex<f64>) -> f64
%4 = fir.extract_value %1, [0 : index] : (complex<f64>) -> f64
%5 = fir.extract_value %1, [1 : index] : (complex<f64>) -> f64
%6 = fir.call @__divdc3(%2, %3, %4, %5) fastmath<contract> : (f64, f64, f64, f64) -> complex<f64>
fir.store %6 to %arg0 : !fir.ref<complex<f64>>
return
}

// CHECK-LABEL: func.func @test_double_complex_div
// CHECK-NOT: fir.call @__divdc3
// CHECK: %[[R1:.*]] = arith.mulf %2, %4 : f64
// CHECK: %[[R2:.*]] = arith.mulf %4, %4 : f64
// CHECK: %[[R3:.*]] = arith.mulf %3, %4 : f64
// CHECK: %[[R4:.*]] = arith.mulf %2, %5 : f64
// CHECK: %[[R5:.*]] = arith.mulf %3, %5 : f64
// CHECK: %[[R6:.*]] = arith.mulf %5, %5 : f64
// CHECK: %[[DENOM:.*]] = arith.addf %[[R2]], %[[R6]] : f64
// CHECK: %[[NUM_RE:.*]] = arith.addf %[[R1]], %[[R5]] : f64
// CHECK: %[[NUM_IM:.*]] = arith.subf %[[R3]], %[[R4]] : f64
// CHECK: %[[RES_RE:.*]] = arith.divf %[[NUM_RE]], %[[DENOM]] : f64
// CHECK: %[[RES_IM:.*]] = arith.divf %[[NUM_IM]], %[[DENOM]] : f64
// CHECK: %[[U:.*]] = fir.undefined complex<f64>
// CHECK: %[[C0:.*]] = fir.insert_value %[[U]], %[[RES_RE]], [0 : i32] : (complex<f64>, f64) -> complex<f64>
// CHECK: %[[C1:.*]] = fir.insert_value %[[C0]], %[[RES_IM]], [1 : i32] : (complex<f64>, f64) -> complex<f64>
// CHECK: fir.store %[[C1]] to %arg0 : !fir.ref<complex<f64>>