Skip to content

Conversation

michel2323
Copy link
Contributor

It passes some tests. GPT 5 helped quite a bit here.

Copy link
Contributor

github-actions bot commented Aug 19, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/mkl/fft.jl b/lib/mkl/fft.jl
index ceefda6..66329a9 100644
--- a/lib/mkl/fft.jl
+++ b/lib/mkl/fft.jl
@@ -20,29 +20,29 @@ export MKLFFTPlan
 # (We can just re-use integer constants; C wrappers return 0 on success.)
 const DFT_PREC_SINGLE = 0
 const DFT_PREC_DOUBLE = 1
-const DFT_DOM_REAL    = 0
+const DFT_DOM_REAL = 0
 const DFT_DOM_COMPLEX = 1
 
 # Configuration parameter indices (must match onemkl_dft.h enum ordering)
-const DFT_PARAM_DIMENSION            = 1
-const DFT_PARAM_LENGTHS              = 2
-const DFT_PARAM_PRECISION            = 3
-const DFT_PARAM_FORWARD_SCALE        = 4
-const DFT_PARAM_BACKWARD_SCALE       = 5
+const DFT_PARAM_DIMENSION = 1
+const DFT_PARAM_LENGTHS = 2
+const DFT_PARAM_PRECISION = 3
+const DFT_PARAM_FORWARD_SCALE = 4
+const DFT_PARAM_BACKWARD_SCALE = 5
 const DFT_PARAM_NUMBER_OF_TRANSFORMS = 6
-const DFT_PARAM_COMPLEX_STORAGE      = 7
-const DFT_PARAM_PLACEMENT            = 8
-const DFT_PARAM_INPUT_STRIDES        = 9
-const DFT_PARAM_OUTPUT_STRIDES       = 10
-const DFT_PARAM_FWD_DISTANCE         = 11
-const DFT_PARAM_BWD_DISTANCE         = 12
-const DFT_PARAM_WORKSPACE            = 13
+const DFT_PARAM_COMPLEX_STORAGE = 7
+const DFT_PARAM_PLACEMENT = 8
+const DFT_PARAM_INPUT_STRIDES = 9
+const DFT_PARAM_OUTPUT_STRIDES = 10
+const DFT_PARAM_FWD_DISTANCE = 11
+const DFT_PARAM_BWD_DISTANCE = 12
+const DFT_PARAM_WORKSPACE = 13
 const DFT_PARAM_WORKSPACE_ESTIMATE_BYTES = 14
-const DFT_PARAM_WORKSPACE_BYTES      = 15
-const DFT_PARAM_FWD_STRIDES          = 16
-const DFT_PARAM_BWD_STRIDES          = 17
+const DFT_PARAM_WORKSPACE_BYTES = 15
+const DFT_PARAM_FWD_STRIDES = 16
+const DFT_PARAM_BWD_STRIDES = 17
 # Config value logical indices (ordering per onemkl_dft.h)
-const DFT_CFG_INPLACE     = 4
+const DFT_CFG_INPLACE = 4
 const DFT_CFG_NOT_INPLACE = 5
 
 # Opaque descriptor type alias to Ptr{Nothing} (generated wrapper not yet exposed)
@@ -68,34 +68,34 @@ ccall_set_int(desc, param::Int32, value::Int64) = ccall((:onemklDftSetValueInt64
 ccall_set_int64_array(desc, param::Int32, values::Vector{Int64}) = ccall((:onemklDftSetValueInt64Array, lib), Cint, (Ptr{Cvoid}, Cint, Ptr{Int64}, Int64), desc, param, pointer(values), length(values))
 ccall_set_cfg(desc, param::Int32, value::Int32) = ccall((:onemklDftSetValueConfigValue, lib), Cint, (Ptr{Cvoid}, Cint, Cint), desc, param, value)
 
-abstract type MKLFFTPlan{T,K,inplace} <: AbstractFFTs.Plan{T} end
+abstract type MKLFFTPlan{T, K, inplace} <: AbstractFFTs.Plan{T} end
 
-Base.eltype(::MKLFFTPlan{T}) where T = T
-is_inplace(::MKLFFTPlan{<:Any,<:Any,inplace}) where inplace = inplace
+Base.eltype(::MKLFFTPlan{T}) where {T} = T
+is_inplace(::MKLFFTPlan{<:Any, <:Any, inplace}) where {inplace} = inplace
 
 # Forward / inverse flags
 const MKLFFT_FORWARD = true
 const MKLFFT_INVERSE = false
 
-mutable struct cMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
+mutable struct cMKLFFTPlan{T, K, inplace, N, R, B} <: MKLFFTPlan{T, K, inplace}
     handle::Ptr{Cvoid}
     queue::syclQueue_t
-    sz::NTuple{N,Int}
-    osz::NTuple{N,Int}
+    sz::NTuple{N, Int}
+    osz::NTuple{N, Int}
     realdomain::Bool
-    region::NTuple{R,Int}
+    region::NTuple{R, Int}
     buffer::B
     pinv::Any
 end
 
 # Real transforms use separate struct (mirroring AMDGPU style) for buffer staging
-mutable struct rMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
+mutable struct rMKLFFTPlan{T, K, inplace, N, R, B} <: MKLFFTPlan{T, K, inplace}
     handle::Ptr{Cvoid}
     queue::syclQueue_t
-    sz::NTuple{N,Int}
-    osz::NTuple{N,Int}
+    sz::NTuple{N, Int}
+    osz::NTuple{N, Int}
     xtype::Symbol
-    region::NTuple{R,Int}
+    region::NTuple{R, Int}
     buffer::B
     pinv::Any
 end
@@ -103,40 +103,44 @@ end
 # Inverse plan constructors (derive from existing plan)
 function normalization_factor(sz, region)
     # AbstractFFTs expects inverse to scale by 1/prod(lengths along region)
-    prod(ntuple(i-> sz[region[i]], length(region)))
+    return prod(ntuple(i -> sz[region[i]], length(region)))
 end
 
-function plan_inv(p::cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
-    q = cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
+function plan_inv(p::cMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}) where {T, inplace, N, R, B}
+    q = cMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, p.realdomain, p.region, p.buffer, p)
     p.pinv = q
-    ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+    return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
 end
-function plan_inv(p::cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
-    q = cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
+function plan_inv(p::cMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}) where {T, inplace, N, R, B}
+    q = cMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, p.realdomain, p.region, p.buffer, p)
     p.pinv = q
-    ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+    return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
 end
 
-function plan_inv(p::rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
-    q = rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:brfft,p.region,p.buffer,p)
+function plan_inv(p::rMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}) where {T, inplace, N, R, B}
+    q = rMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, :brfft, p.region, p.buffer, p)
     p.pinv = q
-    ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+    return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
 end
-function plan_inv(p::rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
-    q = rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:rfft,p.region,p.buffer,p)
+function plan_inv(p::rMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}) where {T, inplace, N, R, B}
+    q = rMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, :rfft, p.region, p.buffer, p)
     p.pinv = q
-    ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+    return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
 end
 
-function Base.show(io::IO, p::MKLFFTPlan{T,K,inplace}) where {T,K,inplace}
+function Base.show(io::IO, p::MKLFFTPlan{T, K, inplace}) where {T, K, inplace}
     print(io, inplace ? "oneMKL FFT in-place " : "oneMKL FFT ", K ? "forward" : "inverse", " plan for ")
-    if isempty(p.sz); print(io, "0-dimensional") else print(io, join(p.sz, "×")) end
-    print(io, " oneArray of ", T)
+    if isempty(p.sz)
+        print(io, "0-dimensional")
+    else
+        print(io, join(p.sz, "×"))
+    end
+    return print(io, " oneArray of ", T)
 end
 
 # Plan constructors
-function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool) where {N}
-    prec = T<:Float64 || T<:ComplexF64 ? DFT_PREC_DOUBLE : DFT_PREC_SINGLE
+function _create_descriptor(sz::NTuple{N, Int}, T::Type, complex::Bool) where {N}
+    prec = T <: Float64 || T <: ComplexF64 ? DFT_PREC_DOUBLE : DFT_PREC_SINGLE
     dom = complex ? DFT_DOM_COMPLEX : DFT_DOM_REAL
     desc_ref = Ref{Ptr{Cvoid}}()
     # Create descriptor for the full array dimensions
@@ -156,8 +160,8 @@ function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool) where {N}
 end
 
 # Complex plans
-function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
-    R = length(region); reg = NTuple{R,Int}(region)
+function plan_fft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+    R = length(region); reg = NTuple{R, Int}(region)
     # For now, only support full transforms (all dimensions)
     if reg != ntuple(identity, N)
         error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
@@ -166,20 +170,20 @@ function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,Co
     ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_NOT_INPLACE))
     if N > 1
         # Column-major strides: stride along dimension i is product of sizes of previous dims
-        strides = Vector{Int64}(undef, N+1); strides[1]=0
+        strides = Vector{Int64}(undef, N + 1); strides[1] = 0
         prod = 1
         @inbounds for i in 1:N
-            strides[i+1] = prod
-            prod *= size(X,i)
+            strides[i + 1] = prod
+            prod *= size(X, i)
         end
         ccall_set_int64_array(desc, Int32(DFT_PARAM_FWD_STRIDES), strides)
         ccall_set_int64_array(desc, Int32(DFT_PARAM_BWD_STRIDES), strides)
     end
     stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
-    return cMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+    return cMKLFFTPlan{T, MKLFFT_FORWARD, false, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
 end
-function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
-    R = length(region); reg = NTuple{R,Int}(region)
+function plan_bfft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+    R = length(region); reg = NTuple{R, Int}(region)
     # For now, only support full transforms (all dimensions)
     if reg != ntuple(identity, N)
         error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
@@ -187,59 +191,59 @@ function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,C
     desc, q = _create_descriptor(size(X), T, true)
     ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_NOT_INPLACE))
     if N > 1
-        strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+        strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
         @inbounds for i in 1:N
-            strides[i+1]=prod; prod*=size(X,i)
+            strides[i + 1] = prod; prod *= size(X, i)
         end
         ccall_set_int64_array(desc, Int32(DFT_PARAM_FWD_STRIDES), strides)
         ccall_set_int64_array(desc, Int32(DFT_PARAM_BWD_STRIDES), strides)
     end
     stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
-    return cMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+    return cMKLFFTPlan{T, MKLFFT_INVERSE, false, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
 end
 
 # In-place (provide separate methods)
-function plan_fft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
-    R = length(region); reg = NTuple{R,Int}(region)
+function plan_fft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+    R = length(region); reg = NTuple{R, Int}(region)
     # For now, only support full transforms (all dimensions)
     if reg != ntuple(identity, N)
         error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
     end
-    desc,q = _create_descriptor(size(X),T,true)
+    desc, q = _create_descriptor(size(X), T, true)
     ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_INPLACE))
     if N > 1
-        strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+        strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
         @inbounds for i in 1:N
-            strides[i+1]=prod; prod*=size(X,i)
+            strides[i + 1] = prod; prod *= size(X, i)
         end
         ccall_set_int64_array(desc, Int32(DFT_PARAM_FWD_STRIDES), strides)
         ccall_set_int64_array(desc, Int32(DFT_PARAM_BWD_STRIDES), strides)
     end
     stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
-    cMKLFFTPlan{T,MKLFFT_FORWARD,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+    return cMKLFFTPlan{T, MKLFFT_FORWARD, true, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
 end
-function plan_bfft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
-    R = length(region); reg = NTuple{R,Int}(region)
+function plan_bfft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+    R = length(region); reg = NTuple{R, Int}(region)
     # For now, only support full transforms (all dimensions)
     if reg != ntuple(identity, N)
         error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
     end
-    desc,q = _create_descriptor(size(X),T,true)
+    desc, q = _create_descriptor(size(X), T, true)
     ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_INPLACE))
     if N > 1
-        strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+        strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
         @inbounds for i in 1:N
-            strides[i+1]=prod; prod*=size(X,i)
+            strides[i + 1] = prod; prod *= size(X, i)
         end
         ccall_set_int64_array(desc, Int32(DFT_PARAM_FWD_STRIDES), strides)
         ccall_set_int64_array(desc, Int32(DFT_PARAM_BWD_STRIDES), strides)
     end
     stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
-    cMKLFFTPlan{T,MKLFFT_INVERSE,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+    return cMKLFFTPlan{T, MKLFFT_INVERSE, true, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
 end
 
 # Real forward (out-of-place) - only support 1D transforms for now
-function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_rfft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
     # Convert region to tuple if it's a range
     if isa(region, AbstractUnitRange)
         # For real FFTs, if region is 1:ndims(X), treat it as (1,) like FFTW
@@ -249,7 +253,7 @@ function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Floa
             region = tuple(region...)
         end
     end
-    R = length(region); reg = NTuple{R,Int}(region)
+    R = length(region); reg = NTuple{R, Int}(region)
     # Only support single dimension transforms for now
     if R != 1
         error("Multi-dimensional real FFT not yet supported")
@@ -260,10 +264,10 @@ function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Floa
     end
 
     # Create 1D descriptor for the transform dimension
-    desc,q = _create_descriptor((size(X, reg[1]),), T, false)
+    desc, q = _create_descriptor((size(X, reg[1]),), T, false)
     xdims = size(X)
     # output along first dim becomes N/2+1
-    ydims = Base.setindex(xdims, div(xdims[1],2)+1, 1)
+    ydims = Base.setindex(xdims, div(xdims[1], 2) + 1, 1)
     buffer = oneAPI.oneArray{Complex{T}}(undef, ydims)
     ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_NOT_INPLACE))
 
@@ -278,11 +282,11 @@ function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Floa
     end
 
     stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
-    rMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:rfft,reg,buffer,nothing)
+    return rMKLFFTPlan{T, MKLFFT_FORWARD, false, N, R, typeof(buffer)}(desc, q, xdims, ydims, :rfft, reg, buffer, nothing)
 end
 
 # Real inverse (complex->real) requires complex input shape
-function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union{ComplexF32,ComplexF64},N}
+function plan_brfft(X::oneAPI.oneArray{T, N}, d::Integer, region) where {T <: Union{ComplexF32, ComplexF64}, N}
     # Convert region to tuple if it's a range
     if isa(region, AbstractUnitRange)
         # For real FFTs, if region is 1:ndims(X), treat it as (1,) like FFTW
@@ -294,7 +298,7 @@ function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union
     end
     # Debug: print what we received
     # @show region, typeof(region), length(region)
-    R = length(region); reg = NTuple{R,Int}(region)
+    R = length(region); reg = NTuple{R, Int}(region)
     # Only support single dimension transforms for now
     if R != 1
         error("Multi-dimensional real FFT not yet supported. Region: $region, R: $R")
@@ -309,7 +313,7 @@ function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union
     RT = T.parameters[1]
 
     # Create 1D descriptor for the transform dimension
-    desc,q = _create_descriptor((d,), RT, false)
+    desc, q = _create_descriptor((d,), RT, false)
     xdims = size(X)
     ydims = Base.setindex(xdims, d, 1)
     buffer = oneAPI.oneArray{T}(undef, xdims) # copy for safety
@@ -322,7 +326,7 @@ function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union
     end
 
     stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
-    rMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:brfft,reg,buffer,nothing)
+    return rMKLFFTPlan{T, MKLFFT_INVERSE, false, N, R, typeof(buffer)}(desc, q, xdims, ydims, :brfft, reg, buffer, nothing)
 end
 
 # Convenience no-region methods use all dimensions in order
@@ -337,90 +341,98 @@ plan_brfft(X::oneAPI.oneArray, d::Integer) = plan_brfft(X, d, (1,))
 const plan_ifft = plan_bfft
 const plan_ifft! = plan_bfft!
 # plan_irfft should be normalized, unlike plan_brfft
-plan_irfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T,N} = begin
+plan_irfft(X::oneAPI.oneArray{T, N}, d::Integer, region) where {T, N} = begin
     p = plan_brfft(X, d, region)
-    ScaledPlan(p, 1/normalization_factor(p.sz, p.region))
+    ScaledPlan(p, 1 / normalization_factor(p.sz, p.region))
 end
-plan_irfft(X::oneAPI.oneArray{T,N}, d::Integer) where {T,N} = plan_irfft(X, d, (1,))
+plan_irfft(X::oneAPI.oneArray{T, N}, d::Integer) where {T, N} = plan_irfft(X, d, (1,))
 
 # Inversion
 Base.inv(p::MKLFFTPlan) = plan_inv(p)
 
 # High-level wrappers operating like CPU FFTW versions.
-function fft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
-    (plan_fft(X) * X)
+function fft(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
+    return (plan_fft(X) * X)
 end
-function ifft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function ifft(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
     p = plan_bfft(X)
     # Apply normalization for ifft (unlike bfft which is unnormalized)
     scaling = 1.0 / normalization_factor(size(X), ntuple(identity, ndims(X)))
-    scaling * (p * X)
+    return scaling * (p * X)
 end
-function fft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function fft!(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
     (plan_fft!(X) * X; X)
 end
-function ifft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function ifft!(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
     p = plan_bfft!(X)
     # Apply normalization for ifft! (unlike bfft! which is unnormalized)
     scaling = 1.0 / normalization_factor(size(X), ntuple(identity, ndims(X)))
     p * X
     X .*= scaling
-    X
+    return X
 end
-function rfft(X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
-    (plan_rfft(X) * X)
+function rfft(X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+    return (plan_rfft(X) * X)
 end
-function irfft(X::oneAPI.oneArray{T}, d::Integer) where {T<:Union{ComplexF32,ComplexF64}}
+function irfft(X::oneAPI.oneArray{T}, d::Integer) where {T <: Union{ComplexF32, ComplexF64}}
     # Use the normalized plan_irfft instead of unnormalized plan_brfft
-    (plan_irfft(X, d) * X)
+    return (plan_irfft(X, d) * X)
 end
 
 # Execution helpers
-_rawptr(a::oneAPI.oneArray{T}) where T = reinterpret(Ptr{Cvoid}, pointer(a))
+_rawptr(a::oneAPI.oneArray{T}) where {T} = reinterpret(Ptr{Cvoid}, pointer(a))
 
-function _exec!(p::cMKLFFTPlan{T,MKLFFT_FORWARD,true}, X::oneAPI.oneArray{T}) where T
-    st = ccall_fwd(p.handle, _rawptr(X)); st==0 || error("forward FFT failed ($st)"); X
+function _exec!(p::cMKLFFTPlan{T, MKLFFT_FORWARD, true}, X::oneAPI.oneArray{T}) where {T}
+    st = ccall_fwd(p.handle, _rawptr(X)); st == 0 || error("forward FFT failed ($st)")
+    return X
 end
-function _exec!(p::cMKLFFTPlan{T,MKLFFT_INVERSE,true}, X::oneAPI.oneArray{T}) where T
-    st = ccall_bwd(p.handle, _rawptr(X)); st==0 || error("inverse FFT failed ($st)"); X
+function _exec!(p::cMKLFFTPlan{T, MKLFFT_INVERSE, true}, X::oneAPI.oneArray{T}) where {T}
+    st = ccall_bwd(p.handle, _rawptr(X)); st == 0 || error("inverse FFT failed ($st)")
+    return X
 end
-function _exec!(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{T}) where {T,K}
-    st = (K==MKLFFT_FORWARD ? ccall_fwd_oop : ccall_bwd_oop)(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("FFT failed ($st)"); Y
+function _exec!(p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{T}) where {T, K}
+    st = (K == MKLFFT_FORWARD ? ccall_fwd_oop : ccall_bwd_oop)(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("FFT failed ($st)")
+    return Y
 end
 
 # Real forward
-function _exec!(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{Complex{T}}) where T
-    st = ccall_fwd_oop(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("rfft failed ($st)"); Y
+function _exec!(p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{Complex{T}}) where {T}
+    st = ccall_fwd_oop(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("rfft failed ($st)")
+    return Y
 end
 # Real inverse (complex -> real)
-function _exec!(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{R}) where {R,T<:Complex{R}}
-    st = ccall_bwd_oop(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("brfft failed ($st)"); Y
+function _exec!(p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{R}) where {R, T <: Complex{R}}
+    st = ccall_bwd_oop(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("brfft failed ($st)")
+    return Y
 end
 
 # Public API similar to AMDGPU
-function Base.:*(p::cMKLFFTPlan{T,K,true}, X::oneAPI.oneArray{T}) where {T,K}
-    _exec!(p,X)
+function Base.:*(p::cMKLFFTPlan{T, K, true}, X::oneAPI.oneArray{T}) where {T, K}
+    return _exec!(p, X)
 end
-function Base.:*(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K}
-    Y = oneAPI.oneArray{T}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}) where {T, K}
+    Y = oneAPI.oneArray{T}(undef, p.osz)
+    return _exec!(p, X, Y)
 end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K}
-    _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}) where {T, K}
+    return _exec!(p, X, Y)
 end
 
 # Real forward
-function Base.:*(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
-    Y = oneAPI.oneArray{Complex{T}}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+    Y = oneAPI.oneArray{Complex{T}}(undef, p.osz)
+    return _exec!(p, X, Y)
 end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{Complex{T}}, p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
-    _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{Complex{T}}, p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+    return _exec!(p, X, Y)
 end
 # Real inverse
-function Base.:*(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}}
-    Y = oneAPI.oneArray{R}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}) where {R, T <: Complex{R}}
+    Y = oneAPI.oneArray{R}(undef, p.osz)
+    return _exec!(p, X, Y)
 end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{R}, p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}}
-    _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{R}, p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}) where {R, T <: Complex{R}}
+    return _exec!(p, X, Y)
 end
 
 end # module FFT
diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl
index 06d8bee..0ea694b 100644
--- a/lib/support/liboneapi_support.jl
+++ b/lib/support/liboneapi_support.jl
@@ -7111,122 +7111,160 @@ mutable struct onemklDftDescriptor_st end
 const onemklDftDescriptor_t = Ptr{onemklDftDescriptor_st}
 
 function onemklDftCreate1D(desc, precision, domain, length)
-    @ccall liboneapi_support.onemklDftCreate1D(desc::Ptr{onemklDftDescriptor_t},
-                                               precision::onemklDftPrecision,
-                                               domain::onemklDftDomain, length::Int64)::Cint
+    return @ccall liboneapi_support.onemklDftCreate1D(
+        desc::Ptr{onemklDftDescriptor_t},
+        precision::onemklDftPrecision,
+        domain::onemklDftDomain, length::Int64
+    )::Cint
 end
 
 function onemklDftCreateND(desc, precision, domain, dim, lengths)
-    @ccall liboneapi_support.onemklDftCreateND(desc::Ptr{onemklDftDescriptor_t},
-                                               precision::onemklDftPrecision,
-                                               domain::onemklDftDomain, dim::Int64,
-                                               lengths::Ptr{Int64})::Cint
+    return @ccall liboneapi_support.onemklDftCreateND(
+        desc::Ptr{onemklDftDescriptor_t},
+        precision::onemklDftPrecision,
+        domain::onemklDftDomain, dim::Int64,
+        lengths::Ptr{Int64}
+    )::Cint
 end
 
 function onemklDftDestroy(desc)
-    @ccall liboneapi_support.onemklDftDestroy(desc::onemklDftDescriptor_t)::Cint
+    return @ccall liboneapi_support.onemklDftDestroy(desc::onemklDftDescriptor_t)::Cint
 end
 
 function onemklDftCommit(desc, queue)
-    @ccall liboneapi_support.onemklDftCommit(desc::onemklDftDescriptor_t,
-                                             queue::syclQueue_t)::Cint
+    return @ccall liboneapi_support.onemklDftCommit(
+        desc::onemklDftDescriptor_t,
+        queue::syclQueue_t
+    )::Cint
 end
 
 function onemklDftSetValueInt64(desc, param, value)
-    @ccall liboneapi_support.onemklDftSetValueInt64(desc::onemklDftDescriptor_t,
-                                                    param::onemklDftConfigParam,
-                                                    value::Int64)::Cint
+    return @ccall liboneapi_support.onemklDftSetValueInt64(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        value::Int64
+    )::Cint
 end
 
 function onemklDftSetValueDouble(desc, param, value)
-    @ccall liboneapi_support.onemklDftSetValueDouble(desc::onemklDftDescriptor_t,
-                                                     param::onemklDftConfigParam,
-                                                     value::Cdouble)::Cint
+    return @ccall liboneapi_support.onemklDftSetValueDouble(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        value::Cdouble
+    )::Cint
 end
 
 function onemklDftSetValueInt64Array(desc, param, values, n)
-    @ccall liboneapi_support.onemklDftSetValueInt64Array(desc::onemklDftDescriptor_t,
-                                                         param::onemklDftConfigParam,
-                                                         values::Ptr{Int64}, n::Int64)::Cint
+    return @ccall liboneapi_support.onemklDftSetValueInt64Array(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        values::Ptr{Int64}, n::Int64
+    )::Cint
 end
 
 function onemklDftSetValueConfigValue(desc, param, value)
-    @ccall liboneapi_support.onemklDftSetValueConfigValue(desc::onemklDftDescriptor_t,
-                                                          param::onemklDftConfigParam,
-                                                          value::onemklDftConfigValue)::Cint
+    return @ccall liboneapi_support.onemklDftSetValueConfigValue(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        value::onemklDftConfigValue
+    )::Cint
 end
 
 function onemklDftGetValueInt64(desc, param, value)
-    @ccall liboneapi_support.onemklDftGetValueInt64(desc::onemklDftDescriptor_t,
-                                                    param::onemklDftConfigParam,
-                                                    value::Ptr{Int64})::Cint
+    return @ccall liboneapi_support.onemklDftGetValueInt64(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        value::Ptr{Int64}
+    )::Cint
 end
 
 function onemklDftGetValueDouble(desc, param, value)
-    @ccall liboneapi_support.onemklDftGetValueDouble(desc::onemklDftDescriptor_t,
-                                                     param::onemklDftConfigParam,
-                                                     value::Ptr{Cdouble})::Cint
+    return @ccall liboneapi_support.onemklDftGetValueDouble(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        value::Ptr{Cdouble}
+    )::Cint
 end
 
 function onemklDftGetValueInt64Array(desc, param, values, n)
-    @ccall liboneapi_support.onemklDftGetValueInt64Array(desc::onemklDftDescriptor_t,
-                                                         param::onemklDftConfigParam,
-                                                         values::Ptr{Int64},
-                                                         n::Ptr{Int64})::Cint
+    return @ccall liboneapi_support.onemklDftGetValueInt64Array(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        values::Ptr{Int64},
+        n::Ptr{Int64}
+    )::Cint
 end
 
 function onemklDftGetValueConfigValue(desc, param, value)
-    @ccall liboneapi_support.onemklDftGetValueConfigValue(desc::onemklDftDescriptor_t,
-                                                          param::onemklDftConfigParam,
-                                                          value::Ptr{onemklDftConfigValue})::Cint
+    return @ccall liboneapi_support.onemklDftGetValueConfigValue(
+        desc::onemklDftDescriptor_t,
+        param::onemklDftConfigParam,
+        value::Ptr{onemklDftConfigValue}
+    )::Cint
 end
 
 function onemklDftComputeForward(desc, inout)
-    @ccall liboneapi_support.onemklDftComputeForward(desc::onemklDftDescriptor_t,
-                                                     inout::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeForward(
+        desc::onemklDftDescriptor_t,
+        inout::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeForwardOutOfPlace(desc, in, out)
-    @ccall liboneapi_support.onemklDftComputeForwardOutOfPlace(desc::onemklDftDescriptor_t,
-                                                               in::Ptr{Cvoid},
-                                                               out::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeForwardOutOfPlace(
+        desc::onemklDftDescriptor_t,
+        in::Ptr{Cvoid},
+        out::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeBackward(desc, inout)
-    @ccall liboneapi_support.onemklDftComputeBackward(desc::onemklDftDescriptor_t,
-                                                      inout::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeBackward(
+        desc::onemklDftDescriptor_t,
+        inout::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeBackwardOutOfPlace(desc, in, out)
-    @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlace(desc::onemklDftDescriptor_t,
-                                                                in::Ptr{Cvoid},
-                                                                out::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlace(
+        desc::onemklDftDescriptor_t,
+        in::Ptr{Cvoid},
+        out::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeForwardBuffer(desc, inout)
-    @ccall liboneapi_support.onemklDftComputeForwardBuffer(desc::onemklDftDescriptor_t,
-                                                           inout::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeForwardBuffer(
+        desc::onemklDftDescriptor_t,
+        inout::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeForwardOutOfPlaceBuffer(desc, in, out)
-    @ccall liboneapi_support.onemklDftComputeForwardOutOfPlaceBuffer(desc::onemklDftDescriptor_t,
-                                                                     in::Ptr{Cvoid},
-                                                                     out::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeForwardOutOfPlaceBuffer(
+        desc::onemklDftDescriptor_t,
+        in::Ptr{Cvoid},
+        out::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeBackwardBuffer(desc, inout)
-    @ccall liboneapi_support.onemklDftComputeBackwardBuffer(desc::onemklDftDescriptor_t,
-                                                            inout::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeBackwardBuffer(
+        desc::onemklDftDescriptor_t,
+        inout::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftComputeBackwardOutOfPlaceBuffer(desc, in, out)
-    @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlaceBuffer(desc::onemklDftDescriptor_t,
-                                                                      in::Ptr{Cvoid},
-                                                                      out::Ptr{Cvoid})::Cint
+    return @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlaceBuffer(
+        desc::onemklDftDescriptor_t,
+        in::Ptr{Cvoid},
+        out::Ptr{Cvoid}
+    )::Cint
 end
 
 function onemklDftQueryParamIndices(out, n)
-    @ccall liboneapi_support.onemklDftQueryParamIndices(out::Ptr{Int64}, n::Int64)::Cint
+    return @ccall liboneapi_support.onemklDftQueryParamIndices(out::Ptr{Int64}, n::Int64)::Cint
 end
 
 const ONEMKL_DFT_STATUS_SUCCESS = 0
diff --git a/res/wrap.jl b/res/wrap.jl
index 1d48315..2e9b29f 100644
--- a/res/wrap.jl
+++ b/res/wrap.jl
@@ -112,14 +112,14 @@ using oneAPI_Level_Zero_Headers_jll
 
 function main()
     wrap("ze", oneAPI_Level_Zero_Headers_jll.ze_api)
-    wrap(
-            "support",
-            joinpath(dirname(@__DIR__), "deps", "src", "sycl.h"),
-            joinpath(dirname(@__DIR__), "deps", "src", "onemkl.h"),
-            joinpath(dirname(@__DIR__), "deps", "src", "onemkl_dft.h");
-            dependents=false,
-            include_dirs=[dirname(dirname(oneAPI_Level_Zero_Headers_jll.ze_api))]
-        )
+    return wrap(
+        "support",
+        joinpath(dirname(@__DIR__), "deps", "src", "sycl.h"),
+        joinpath(dirname(@__DIR__), "deps", "src", "onemkl.h"),
+        joinpath(dirname(@__DIR__), "deps", "src", "onemkl_dft.h");
+        dependents = false,
+        include_dirs = [dirname(dirname(oneAPI_Level_Zero_Headers_jll.ze_api))]
+    )
 end
 
 isinteractive() || main()
diff --git a/test/fft.jl b/test/fft.jl
index 321ea9c..8a0cb2b 100644
--- a/test/fft.jl
+++ b/test/fft.jl
@@ -5,21 +5,21 @@ using AbstractFFTs
 using FFTW
 
 # Helper to move data to GPU
-gpu(A::AbstractArray{T}) where T = oneAPI.oneArray{T}(A)
+gpu(A::AbstractArray{T}) where {T} = oneAPI.oneArray{T}(A)
 
-const MYRTOL = 1e-5
-const MYATOL = 1e-8
+const MYRTOL = 1.0e-5
+const MYATOL = 1.0e-8
 
-function cmp(a,b; rtol=MYRTOL, atol=MYATOL)
-    @test isapprox(Array(a), Array(b); rtol=rtol, atol=atol)
+function cmp(a, b; rtol = MYRTOL, atol = MYATOL)
+    return @test isapprox(Array(a), Array(b); rtol = rtol, atol = atol)
 end
 
-function cmp_broken(a,b; rtol=MYRTOL, atol=MYATOL)
-    @test_broken isapprox(Array(a), Array(b); rtol=rtol, atol=atol)
+function cmp_broken(a, b; rtol = MYRTOL, atol = MYATOL)
+    return @test_broken isapprox(Array(a), Array(b); rtol = rtol, atol = atol)
 end
 
 @testset "FFT" begin
-    Ns = (8,32,64,8)
+    Ns = (8, 32, 64, 8)
 
     # Complex tests
     for T in (ComplexF32, ComplexF64)
@@ -61,8 +61,8 @@ end
             dX = gpu(X)
             p = plan_fft!(dX, 1)
             p * dX
-            cmp_broken(dX, fft(X,1))
-            pinv = plan_ifft!(dX,1)
+            cmp_broken(dX, fft(X, 1))
+            pinv = plan_ifft!(dX, 1)
             pinv * dX
             cmp_broken(dX, X)
         end
@@ -76,7 +76,7 @@ end
             p = plan_rfft(dX)
             dY = p * dX
             cmp(dY, rfft(X))
-            pinv = plan_irfft(dY, size(X,1))
+            pinv = plan_irfft(dY, size(X, 1))
             dZ = pinv * dY
             cmp(dZ, X)
 
@@ -86,7 +86,7 @@ end
             p = plan_rfft(dX)
             dY = p * dX
             cmp(dY, rfft(X, (1,)))  # Compare with 1D FFT along first dim, not multi-dimensional FFT
-            pinv = plan_irfft(dY, size(X,1))
+            pinv = plan_irfft(dY, size(X, 1))
             dZ = pinv * dY
             cmp_broken(dZ, X)
         end
@@ -105,7 +105,7 @@ end
         X = gpu(rand(T, Ns[1], Ns[2]))
         Y = rfft(X)
         cmp(Y, rfft(Array(X), (1,)))  # Compare with 1D FFT along first dim, not multi-dimensional FFT
-        Z = irfft(Y, size(X,1))
+        Z = irfft(Y, size(X, 1))
         cmp_broken(Z, Array(X))
     end
 end

}
*out = desc;
return 0;
} catch (...) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a catch-all.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants