diff --git a/ext/ReactantNNlibExt/Implementations.jl b/ext/ReactantNNlibExt/Implementations.jl index e076681979..cbee050d3f 100644 --- a/ext/ReactantNNlibExt/Implementations.jl +++ b/ext/ReactantNNlibExt/Implementations.jl @@ -6,6 +6,21 @@ for (jlop, hloop) in ( @eval $(jlop)(x::TracedRNumber) = Ops.$(hloop)(x) end +# See https://github.com/EnzymeAD/Reactant.jl/issues/1420 +# Without this we will never fuse the gelu into gemm +if isdefined(NNlib, :gelu_tanh) + function NNlib.gelu_tanh(x::TracedRNumber) + return Reactant.Ops.gelu(x, Reactant.NNLIB_GELU_APPROXIMATION[]) + end + + NNlib.gelu_erf(x::TracedRNumber) = Reactant.Ops.gelu(x, "NONE") +else + # Older versions of NNlib do not have gelu_tanh (gelu refers to the tanh version) + function NNlib.gelu(x::TracedRNumber) + return Reactant.Ops.gelu(x, Reactant.NNLIB_GELU_APPROXIMATION[]) + end +end + function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N} x = T.(Reactant.materialize_traced_array(x)) max_ = maximum(x; dims) diff --git a/src/Compiler.jl b/src/Compiler.jl index 5cff4174aa..bd836648d4 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1623,6 +1623,7 @@ function compile_mlir!( blas_int_width = sizeof(BLAS.BlasInt) * 8 lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \ blas_int_width=$blas_int_width}" + lower_enzymexla_ml_pass = "lower-enzymexla-ml" if compile_options.optimization_passes === :all run_pass_pipeline!( @@ -1650,6 +1651,7 @@ function compile_mlir!( )..., opt_passes2, lower_enzymexla_linalg_pass, + lower_enzymexla_ml_pass, jit, ] else @@ -1674,6 +1676,7 @@ function compile_mlir!( kern, raise_passes, lower_enzymexla_linalg_pass, + lower_enzymexla_ml_pass, jit, ] end, @@ -1863,6 +1866,7 @@ function compile_mlir!( )..., opt_passes2, lower_enzymexla_linalg_pass, + lower_enzymexla_ml_pass, jit, ] else @@ -1884,6 +1888,7 @@ function compile_mlir!( kern, raise_passes, lower_enzymexla_linalg_pass, + lower_enzymexla_ml_pass, jit, ] end, @@ -1906,6 +1911,7 @@ function compile_mlir!( enzyme_pass, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math", lower_enzymexla_linalg_pass, + lower_enzymexla_ml_pass, jit, ] else @@ -1919,6 +1925,7 @@ function compile_mlir!( kern, raise_passes, lower_enzymexla_linalg_pass, + lower_enzymexla_ml_pass, jit, ] end, diff --git a/src/Configuration.jl b/src/Configuration.jl index 5b0eaa00af..991b35d5da 100644 --- a/src/Configuration.jl +++ b/src/Configuration.jl @@ -20,6 +20,8 @@ scope will use the provided values. `ApproxTopK` for TPUs unless `fallback_approx_top_k_lowering` is set to `true`. - `fallback_approx_top_k_lowering`: Whether to lower `Ops.approx_top_k` to `stablehlo.top_k` if the XLA backend doesn't support `ApproxTopK`. Defaults to `true`. + - `nnlib_gelu_approximation`: Controls the approximation used for `NNlib.gelu_tanh`. Can + be `"TANH"` or `"SIGMOID"`. Defaults to `"SIGMOID"`. ### DotGeneral @@ -38,6 +40,7 @@ function with_config( convolution_precision=missing, lower_partialsort_to_approx_top_k=missing, fallback_approx_top_k_lowering=missing, + nnlib_gelu_approximation=missing, ) config_vars = () dot_general_algorithm !== missing && @@ -58,6 +61,10 @@ function with_config( FALLBACK_APPROX_TOP_K_LOWERING => fallback_approx_top_k_lowering, ) ) + if nnlib_gelu_approximation !== missing + @assert nnlib_gelu_approximation in ("TANH", "SIGMOID") "Invalid nnlib_gelu_approximation: $nnlib_gelu_approximation. Expected \"TANH\" or \"SIGMOID\"." + config_vars = (config_vars..., NNLIB_GELU_APPROXIMATION => nnlib_gelu_approximation) + end return ScopedValues.with(f, config_vars...) end @@ -65,6 +72,7 @@ end # Lower to ApproxTopK const LOWER_PARTIALSORT_TO_APPROX_TOP_K = ScopedValue(false) const FALLBACK_APPROX_TOP_K_LOWERING = ScopedValue(true) +const NNLIB_GELU_APPROXIMATION = ScopedValue("SIGMOID") # DotGeneral Attributes Configuration """ diff --git a/src/Ops.jl b/src/Ops.jl index 75cd1d9fc8..d2cf2d80c3 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3,7 +3,7 @@ # Julia and Reactant semantics should be considered on the higher abstractions that use these ops. module Ops using ..MLIR: MLIR -using ..MLIR.Dialects: stablehlo, chlo, enzyme +using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla using ..Reactant: Reactant, TracedRArray, @@ -3003,7 +3003,7 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors ` permutation_shape = vcat(batch_shape, size(x, ndims(x) - 1)) info_shape = batch_shape - op = MLIR.Dialects.enzymexla.linalg_lu( + op = enzymexla.linalg_lu( x.mlir_data; output=MLIR.IR.TensorType(output_shape, MLIR.IR.Type(unwrapped_eltype(T))), pivots=MLIR.IR.TensorType(pivots_shape, MLIR.IR.Type(pT)), @@ -3210,4 +3210,22 @@ end end end +@noinline function gelu( + x::Union{TracedRArray{T,N},TracedRNumber{T}}, + approximation::String; + location=mlir_stacktrace("gelu", @__FILE__, @__LINE__), +) where {T,N} + @assert approximation in ("NONE", "TANH", "SIGMOID") + + res = MLIR.IR.result( + enzymexla.ml_gelu(x.mlir_data; gelu_approximation=approximation, location), 1 + ) + + if x isa TracedRArray + return TracedRArray{T,N}((), res, size(x)) + else + return TracedRNumber{T}((), res) + end +end + end # module Ops