-
Notifications
You must be signed in to change notification settings - Fork 28
oneMKL DFT support #515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
oneMKL DFT support #515
Conversation
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/mkl/fft.jl b/lib/mkl/fft.jl
index ceefda6..66329a9 100644
--- a/lib/mkl/fft.jl
+++ b/lib/mkl/fft.jl
@@ -20,29 +20,29 @@ export MKLFFTPlan
# (We can just re-use integer constants; C wrappers return 0 on success.)
const DFT_PREC_SINGLE = 0
const DFT_PREC_DOUBLE = 1
-const DFT_DOM_REAL = 0
+const DFT_DOM_REAL = 0
const DFT_DOM_COMPLEX = 1
# Configuration parameter indices (must match onemkl_dft.h enum ordering)
-const DFT_PARAM_DIMENSION = 1
-const DFT_PARAM_LENGTHS = 2
-const DFT_PARAM_PRECISION = 3
-const DFT_PARAM_FORWARD_SCALE = 4
-const DFT_PARAM_BACKWARD_SCALE = 5
+const DFT_PARAM_DIMENSION = 1
+const DFT_PARAM_LENGTHS = 2
+const DFT_PARAM_PRECISION = 3
+const DFT_PARAM_FORWARD_SCALE = 4
+const DFT_PARAM_BACKWARD_SCALE = 5
const DFT_PARAM_NUMBER_OF_TRANSFORMS = 6
-const DFT_PARAM_COMPLEX_STORAGE = 7
-const DFT_PARAM_PLACEMENT = 8
-const DFT_PARAM_INPUT_STRIDES = 9
-const DFT_PARAM_OUTPUT_STRIDES = 10
-const DFT_PARAM_FWD_DISTANCE = 11
-const DFT_PARAM_BWD_DISTANCE = 12
-const DFT_PARAM_WORKSPACE = 13
+const DFT_PARAM_COMPLEX_STORAGE = 7
+const DFT_PARAM_PLACEMENT = 8
+const DFT_PARAM_INPUT_STRIDES = 9
+const DFT_PARAM_OUTPUT_STRIDES = 10
+const DFT_PARAM_FWD_DISTANCE = 11
+const DFT_PARAM_BWD_DISTANCE = 12
+const DFT_PARAM_WORKSPACE = 13
const DFT_PARAM_WORKSPACE_ESTIMATE_BYTES = 14
-const DFT_PARAM_WORKSPACE_BYTES = 15
-const DFT_PARAM_FWD_STRIDES = 16
-const DFT_PARAM_BWD_STRIDES = 17
+const DFT_PARAM_WORKSPACE_BYTES = 15
+const DFT_PARAM_FWD_STRIDES = 16
+const DFT_PARAM_BWD_STRIDES = 17
# Config value logical indices (ordering per onemkl_dft.h)
-const DFT_CFG_INPLACE = 4
+const DFT_CFG_INPLACE = 4
const DFT_CFG_NOT_INPLACE = 5
# Opaque descriptor type alias to Ptr{Nothing} (generated wrapper not yet exposed)
@@ -68,34 +68,34 @@ ccall_set_int(desc, param::Int32, value::Int64) = ccall((:onemklDftSetValueInt64
ccall_set_int64_array(desc, param::Int32, values::Vector{Int64}) = ccall((:onemklDftSetValueInt64Array, lib), Cint, (Ptr{Cvoid}, Cint, Ptr{Int64}, Int64), desc, param, pointer(values), length(values))
ccall_set_cfg(desc, param::Int32, value::Int32) = ccall((:onemklDftSetValueConfigValue, lib), Cint, (Ptr{Cvoid}, Cint, Cint), desc, param, value)
-abstract type MKLFFTPlan{T,K,inplace} <: AbstractFFTs.Plan{T} end
+abstract type MKLFFTPlan{T, K, inplace} <: AbstractFFTs.Plan{T} end
-Base.eltype(::MKLFFTPlan{T}) where T = T
-is_inplace(::MKLFFTPlan{<:Any,<:Any,inplace}) where inplace = inplace
+Base.eltype(::MKLFFTPlan{T}) where {T} = T
+is_inplace(::MKLFFTPlan{<:Any, <:Any, inplace}) where {inplace} = inplace
# Forward / inverse flags
const MKLFFT_FORWARD = true
const MKLFFT_INVERSE = false
-mutable struct cMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
+mutable struct cMKLFFTPlan{T, K, inplace, N, R, B} <: MKLFFTPlan{T, K, inplace}
handle::Ptr{Cvoid}
queue::syclQueue_t
- sz::NTuple{N,Int}
- osz::NTuple{N,Int}
+ sz::NTuple{N, Int}
+ osz::NTuple{N, Int}
realdomain::Bool
- region::NTuple{R,Int}
+ region::NTuple{R, Int}
buffer::B
pinv::Any
end
# Real transforms use separate struct (mirroring AMDGPU style) for buffer staging
-mutable struct rMKLFFTPlan{T,K,inplace,N,R,B} <: MKLFFTPlan{T,K,inplace}
+mutable struct rMKLFFTPlan{T, K, inplace, N, R, B} <: MKLFFTPlan{T, K, inplace}
handle::Ptr{Cvoid}
queue::syclQueue_t
- sz::NTuple{N,Int}
- osz::NTuple{N,Int}
+ sz::NTuple{N, Int}
+ osz::NTuple{N, Int}
xtype::Symbol
- region::NTuple{R,Int}
+ region::NTuple{R, Int}
buffer::B
pinv::Any
end
@@ -103,40 +103,44 @@ end
# Inverse plan constructors (derive from existing plan)
function normalization_factor(sz, region)
# AbstractFFTs expects inverse to scale by 1/prod(lengths along region)
- prod(ntuple(i-> sz[region[i]], length(region)))
+ return prod(ntuple(i -> sz[region[i]], length(region)))
end
-function plan_inv(p::cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
- q = cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
+function plan_inv(p::cMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}) where {T, inplace, N, R, B}
+ q = cMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, p.realdomain, p.region, p.buffer, p)
p.pinv = q
- ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
end
-function plan_inv(p::cMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
- q = cMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,p.realdomain,p.region,p.buffer,p)
+function plan_inv(p::cMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}) where {T, inplace, N, R, B}
+ q = cMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, p.realdomain, p.region, p.buffer, p)
p.pinv = q
- ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
end
-function plan_inv(p::rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}) where {T,inplace,N,R,B}
- q = rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:brfft,p.region,p.buffer,p)
+function plan_inv(p::rMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}) where {T, inplace, N, R, B}
+ q = rMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, :brfft, p.region, p.buffer, p)
p.pinv = q
- ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
end
-function plan_inv(p::rMKLFFTPlan{T,MKLFFT_INVERSE,inplace,N,R,B}) where {T,inplace,N,R,B}
- q = rMKLFFTPlan{T,MKLFFT_FORWARD,inplace,N,R,B}(p.handle,p.queue,p.sz,p.osz,:rfft,p.region,p.buffer,p)
+function plan_inv(p::rMKLFFTPlan{T, MKLFFT_INVERSE, inplace, N, R, B}) where {T, inplace, N, R, B}
+ q = rMKLFFTPlan{T, MKLFFT_FORWARD, inplace, N, R, B}(p.handle, p.queue, p.sz, p.osz, :rfft, p.region, p.buffer, p)
p.pinv = q
- ScaledPlan(q, 1/normalization_factor(p.sz, p.region))
+ return ScaledPlan(q, 1 / normalization_factor(p.sz, p.region))
end
-function Base.show(io::IO, p::MKLFFTPlan{T,K,inplace}) where {T,K,inplace}
+function Base.show(io::IO, p::MKLFFTPlan{T, K, inplace}) where {T, K, inplace}
print(io, inplace ? "oneMKL FFT in-place " : "oneMKL FFT ", K ? "forward" : "inverse", " plan for ")
- if isempty(p.sz); print(io, "0-dimensional") else print(io, join(p.sz, "×")) end
- print(io, " oneArray of ", T)
+ if isempty(p.sz)
+ print(io, "0-dimensional")
+ else
+ print(io, join(p.sz, "×"))
+ end
+ return print(io, " oneArray of ", T)
end
# Plan constructors
-function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool) where {N}
- prec = T<:Float64 || T<:ComplexF64 ? DFT_PREC_DOUBLE : DFT_PREC_SINGLE
+function _create_descriptor(sz::NTuple{N, Int}, T::Type, complex::Bool) where {N}
+ prec = T <: Float64 || T <: ComplexF64 ? DFT_PREC_DOUBLE : DFT_PREC_SINGLE
dom = complex ? DFT_DOM_COMPLEX : DFT_DOM_REAL
desc_ref = Ref{Ptr{Cvoid}}()
# Create descriptor for the full array dimensions
@@ -156,8 +160,8 @@ function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool) where {N}
end
# Complex plans
-function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
- R = length(region); reg = NTuple{R,Int}(region)
+function plan_fft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+ R = length(region); reg = NTuple{R, Int}(region)
# For now, only support full transforms (all dimensions)
if reg != ntuple(identity, N)
error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
@@ -166,20 +170,20 @@ function plan_fft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,Co
ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_NOT_INPLACE))
if N > 1
# Column-major strides: stride along dimension i is product of sizes of previous dims
- strides = Vector{Int64}(undef, N+1); strides[1]=0
+ strides = Vector{Int64}(undef, N + 1); strides[1] = 0
prod = 1
@inbounds for i in 1:N
- strides[i+1] = prod
- prod *= size(X,i)
+ strides[i + 1] = prod
+ prod *= size(X, i)
end
ccall_set_int64_array(desc, Int32(DFT_PARAM_FWD_STRIDES), strides)
ccall_set_int64_array(desc, Int32(DFT_PARAM_BWD_STRIDES), strides)
end
stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
- return cMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+ return cMKLFFTPlan{T, MKLFFT_FORWARD, false, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
end
-function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
- R = length(region); reg = NTuple{R,Int}(region)
+function plan_bfft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+ R = length(region); reg = NTuple{R, Int}(region)
# For now, only support full transforms (all dimensions)
if reg != ntuple(identity, N)
error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
@@ -187,59 +191,59 @@ function plan_bfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,C
desc, q = _create_descriptor(size(X), T, true)
ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_NOT_INPLACE))
if N > 1
- strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+ strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
@inbounds for i in 1:N
- strides[i+1]=prod; prod*=size(X,i)
+ strides[i + 1] = prod; prod *= size(X, i)
end
ccall_set_int64_array(desc, Int32(DFT_PARAM_FWD_STRIDES), strides)
ccall_set_int64_array(desc, Int32(DFT_PARAM_BWD_STRIDES), strides)
end
stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
- return cMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+ return cMKLFFTPlan{T, MKLFFT_INVERSE, false, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
end
# In-place (provide separate methods)
-function plan_fft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
- R = length(region); reg = NTuple{R,Int}(region)
+function plan_fft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+ R = length(region); reg = NTuple{R, Int}(region)
# For now, only support full transforms (all dimensions)
if reg != ntuple(identity, N)
error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
end
- desc,q = _create_descriptor(size(X),T,true)
+ desc, q = _create_descriptor(size(X), T, true)
ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_INPLACE))
if N > 1
- strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+ strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
@inbounds for i in 1:N
- strides[i+1]=prod; prod*=size(X,i)
+ strides[i + 1] = prod; prod *= size(X, i)
end
ccall_set_int64_array(desc, Int32(DFT_PARAM_FWD_STRIDES), strides)
ccall_set_int64_array(desc, Int32(DFT_PARAM_BWD_STRIDES), strides)
end
stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
- cMKLFFTPlan{T,MKLFFT_FORWARD,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+ return cMKLFFTPlan{T, MKLFFT_FORWARD, true, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
end
-function plan_bfft!(X::oneAPI.oneArray{T,N}, region) where {T<:Union{ComplexF32,ComplexF64},N}
- R = length(region); reg = NTuple{R,Int}(region)
+function plan_bfft!(X::oneAPI.oneArray{T, N}, region) where {T <: Union{ComplexF32, ComplexF64}, N}
+ R = length(region); reg = NTuple{R, Int}(region)
# For now, only support full transforms (all dimensions)
if reg != ntuple(identity, N)
error("Partial dimension FFT not yet supported. Region $reg must be $(ntuple(identity, N))")
end
- desc,q = _create_descriptor(size(X),T,true)
+ desc, q = _create_descriptor(size(X), T, true)
ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_INPLACE))
if N > 1
- strides = Vector{Int64}(undef, N+1); strides[1]=0; prod=1
+ strides = Vector{Int64}(undef, N + 1); strides[1] = 0; prod = 1
@inbounds for i in 1:N
- strides[i+1]=prod; prod*=size(X,i)
+ strides[i + 1] = prod; prod *= size(X, i)
end
ccall_set_int64_array(desc, Int32(DFT_PARAM_FWD_STRIDES), strides)
ccall_set_int64_array(desc, Int32(DFT_PARAM_BWD_STRIDES), strides)
end
stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
- cMKLFFTPlan{T,MKLFFT_INVERSE,true,N,R,Nothing}(desc,q,size(X),size(X),false,reg,nothing,nothing)
+ return cMKLFFTPlan{T, MKLFFT_INVERSE, true, N, R, Nothing}(desc, q, size(X), size(X), false, reg, nothing, nothing)
end
# Real forward (out-of-place) - only support 1D transforms for now
-function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Float64},N}
+function plan_rfft(X::oneAPI.oneArray{T, N}, region) where {T <: Union{Float32, Float64}, N}
# Convert region to tuple if it's a range
if isa(region, AbstractUnitRange)
# For real FFTs, if region is 1:ndims(X), treat it as (1,) like FFTW
@@ -249,7 +253,7 @@ function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Floa
region = tuple(region...)
end
end
- R = length(region); reg = NTuple{R,Int}(region)
+ R = length(region); reg = NTuple{R, Int}(region)
# Only support single dimension transforms for now
if R != 1
error("Multi-dimensional real FFT not yet supported")
@@ -260,10 +264,10 @@ function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Floa
end
# Create 1D descriptor for the transform dimension
- desc,q = _create_descriptor((size(X, reg[1]),), T, false)
+ desc, q = _create_descriptor((size(X, reg[1]),), T, false)
xdims = size(X)
# output along first dim becomes N/2+1
- ydims = Base.setindex(xdims, div(xdims[1],2)+1, 1)
+ ydims = Base.setindex(xdims, div(xdims[1], 2) + 1, 1)
buffer = oneAPI.oneArray{Complex{T}}(undef, ydims)
ccall_set_cfg(desc, Int32(DFT_PARAM_PLACEMENT), Int32(DFT_CFG_NOT_INPLACE))
@@ -278,11 +282,11 @@ function plan_rfft(X::oneAPI.oneArray{T,N}, region) where {T<:Union{Float32,Floa
end
stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
- rMKLFFTPlan{T,MKLFFT_FORWARD,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:rfft,reg,buffer,nothing)
+ return rMKLFFTPlan{T, MKLFFT_FORWARD, false, N, R, typeof(buffer)}(desc, q, xdims, ydims, :rfft, reg, buffer, nothing)
end
# Real inverse (complex->real) requires complex input shape
-function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union{ComplexF32,ComplexF64},N}
+function plan_brfft(X::oneAPI.oneArray{T, N}, d::Integer, region) where {T <: Union{ComplexF32, ComplexF64}, N}
# Convert region to tuple if it's a range
if isa(region, AbstractUnitRange)
# For real FFTs, if region is 1:ndims(X), treat it as (1,) like FFTW
@@ -294,7 +298,7 @@ function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union
end
# Debug: print what we received
# @show region, typeof(region), length(region)
- R = length(region); reg = NTuple{R,Int}(region)
+ R = length(region); reg = NTuple{R, Int}(region)
# Only support single dimension transforms for now
if R != 1
error("Multi-dimensional real FFT not yet supported. Region: $region, R: $R")
@@ -309,7 +313,7 @@ function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union
RT = T.parameters[1]
# Create 1D descriptor for the transform dimension
- desc,q = _create_descriptor((d,), RT, false)
+ desc, q = _create_descriptor((d,), RT, false)
xdims = size(X)
ydims = Base.setindex(xdims, d, 1)
buffer = oneAPI.oneArray{T}(undef, xdims) # copy for safety
@@ -322,7 +326,7 @@ function plan_brfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T<:Union
end
stc = ccall_commit(desc, q); stc == 0 || error("commit failed ($stc)")
- rMKLFFTPlan{T,MKLFFT_INVERSE,false,N,R,typeof(buffer)}(desc,q,xdims,ydims,:brfft,reg,buffer,nothing)
+ return rMKLFFTPlan{T, MKLFFT_INVERSE, false, N, R, typeof(buffer)}(desc, q, xdims, ydims, :brfft, reg, buffer, nothing)
end
# Convenience no-region methods use all dimensions in order
@@ -337,90 +341,98 @@ plan_brfft(X::oneAPI.oneArray, d::Integer) = plan_brfft(X, d, (1,))
const plan_ifft = plan_bfft
const plan_ifft! = plan_bfft!
# plan_irfft should be normalized, unlike plan_brfft
-plan_irfft(X::oneAPI.oneArray{T,N}, d::Integer, region) where {T,N} = begin
+plan_irfft(X::oneAPI.oneArray{T, N}, d::Integer, region) where {T, N} = begin
p = plan_brfft(X, d, region)
- ScaledPlan(p, 1/normalization_factor(p.sz, p.region))
+ ScaledPlan(p, 1 / normalization_factor(p.sz, p.region))
end
-plan_irfft(X::oneAPI.oneArray{T,N}, d::Integer) where {T,N} = plan_irfft(X, d, (1,))
+plan_irfft(X::oneAPI.oneArray{T, N}, d::Integer) where {T, N} = plan_irfft(X, d, (1,))
# Inversion
Base.inv(p::MKLFFTPlan) = plan_inv(p)
# High-level wrappers operating like CPU FFTW versions.
-function fft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
- (plan_fft(X) * X)
+function fft(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
+ return (plan_fft(X) * X)
end
-function ifft(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function ifft(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
p = plan_bfft(X)
# Apply normalization for ifft (unlike bfft which is unnormalized)
scaling = 1.0 / normalization_factor(size(X), ntuple(identity, ndims(X)))
- scaling * (p * X)
+ return scaling * (p * X)
end
-function fft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function fft!(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
(plan_fft!(X) * X; X)
end
-function ifft!(X::oneAPI.oneArray{T}) where {T<:Union{ComplexF32,ComplexF64}}
+function ifft!(X::oneAPI.oneArray{T}) where {T <: Union{ComplexF32, ComplexF64}}
p = plan_bfft!(X)
# Apply normalization for ifft! (unlike bfft! which is unnormalized)
scaling = 1.0 / normalization_factor(size(X), ntuple(identity, ndims(X)))
p * X
X .*= scaling
- X
+ return X
end
-function rfft(X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
- (plan_rfft(X) * X)
+function rfft(X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+ return (plan_rfft(X) * X)
end
-function irfft(X::oneAPI.oneArray{T}, d::Integer) where {T<:Union{ComplexF32,ComplexF64}}
+function irfft(X::oneAPI.oneArray{T}, d::Integer) where {T <: Union{ComplexF32, ComplexF64}}
# Use the normalized plan_irfft instead of unnormalized plan_brfft
- (plan_irfft(X, d) * X)
+ return (plan_irfft(X, d) * X)
end
# Execution helpers
-_rawptr(a::oneAPI.oneArray{T}) where T = reinterpret(Ptr{Cvoid}, pointer(a))
+_rawptr(a::oneAPI.oneArray{T}) where {T} = reinterpret(Ptr{Cvoid}, pointer(a))
-function _exec!(p::cMKLFFTPlan{T,MKLFFT_FORWARD,true}, X::oneAPI.oneArray{T}) where T
- st = ccall_fwd(p.handle, _rawptr(X)); st==0 || error("forward FFT failed ($st)"); X
+function _exec!(p::cMKLFFTPlan{T, MKLFFT_FORWARD, true}, X::oneAPI.oneArray{T}) where {T}
+ st = ccall_fwd(p.handle, _rawptr(X)); st == 0 || error("forward FFT failed ($st)")
+ return X
end
-function _exec!(p::cMKLFFTPlan{T,MKLFFT_INVERSE,true}, X::oneAPI.oneArray{T}) where T
- st = ccall_bwd(p.handle, _rawptr(X)); st==0 || error("inverse FFT failed ($st)"); X
+function _exec!(p::cMKLFFTPlan{T, MKLFFT_INVERSE, true}, X::oneAPI.oneArray{T}) where {T}
+ st = ccall_bwd(p.handle, _rawptr(X)); st == 0 || error("inverse FFT failed ($st)")
+ return X
end
-function _exec!(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{T}) where {T,K}
- st = (K==MKLFFT_FORWARD ? ccall_fwd_oop : ccall_bwd_oop)(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("FFT failed ($st)"); Y
+function _exec!(p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{T}) where {T, K}
+ st = (K == MKLFFT_FORWARD ? ccall_fwd_oop : ccall_bwd_oop)(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("FFT failed ($st)")
+ return Y
end
# Real forward
-function _exec!(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{Complex{T}}) where T
- st = ccall_fwd_oop(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("rfft failed ($st)"); Y
+function _exec!(p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{Complex{T}}) where {T}
+ st = ccall_fwd_oop(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("rfft failed ($st)")
+ return Y
end
# Real inverse (complex -> real)
-function _exec!(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{R}) where {R,T<:Complex{R}}
- st = ccall_bwd_oop(p.handle, _rawptr(X), _rawptr(Y)); st==0 || error("brfft failed ($st)"); Y
+function _exec!(p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}, Y::oneAPI.oneArray{R}) where {R, T <: Complex{R}}
+ st = ccall_bwd_oop(p.handle, _rawptr(X), _rawptr(Y)); st == 0 || error("brfft failed ($st)")
+ return Y
end
# Public API similar to AMDGPU
-function Base.:*(p::cMKLFFTPlan{T,K,true}, X::oneAPI.oneArray{T}) where {T,K}
- _exec!(p,X)
+function Base.:*(p::cMKLFFTPlan{T, K, true}, X::oneAPI.oneArray{T}) where {T, K}
+ return _exec!(p, X)
end
-function Base.:*(p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K}
- Y = oneAPI.oneArray{T}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}) where {T, K}
+ Y = oneAPI.oneArray{T}(undef, p.osz)
+ return _exec!(p, X, Y)
end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T,K,false}, X::oneAPI.oneArray{T}) where {T,K}
- _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{T}, p::cMKLFFTPlan{T, K, false}, X::oneAPI.oneArray{T}) where {T, K}
+ return _exec!(p, X, Y)
end
# Real forward
-function Base.:*(p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
- Y = oneAPI.oneArray{Complex{T}}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+ Y = oneAPI.oneArray{Complex{T}}(undef, p.osz)
+ return _exec!(p, X, Y)
end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{Complex{T}}, p::rMKLFFTPlan{T,MKLFFT_FORWARD,false}, X::oneAPI.oneArray{T}) where {T<:Union{Float32,Float64}}
- _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{Complex{T}}, p::rMKLFFTPlan{T, MKLFFT_FORWARD, false}, X::oneAPI.oneArray{T}) where {T <: Union{Float32, Float64}}
+ return _exec!(p, X, Y)
end
# Real inverse
-function Base.:*(p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}}
- Y = oneAPI.oneArray{R}(undef, p.osz); _exec!(p,X,Y)
+function Base.:*(p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}) where {R, T <: Complex{R}}
+ Y = oneAPI.oneArray{R}(undef, p.osz)
+ return _exec!(p, X, Y)
end
-function LinearAlgebra.mul!(Y::oneAPI.oneArray{R}, p::rMKLFFTPlan{T,MKLFFT_INVERSE,false}, X::oneAPI.oneArray{T}) where {R,T<:Complex{R}}
- _exec!(p,X,Y)
+function LinearAlgebra.mul!(Y::oneAPI.oneArray{R}, p::rMKLFFTPlan{T, MKLFFT_INVERSE, false}, X::oneAPI.oneArray{T}) where {R, T <: Complex{R}}
+ return _exec!(p, X, Y)
end
end # module FFT
diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl
index 06d8bee..0ea694b 100644
--- a/lib/support/liboneapi_support.jl
+++ b/lib/support/liboneapi_support.jl
@@ -7111,122 +7111,160 @@ mutable struct onemklDftDescriptor_st end
const onemklDftDescriptor_t = Ptr{onemklDftDescriptor_st}
function onemklDftCreate1D(desc, precision, domain, length)
- @ccall liboneapi_support.onemklDftCreate1D(desc::Ptr{onemklDftDescriptor_t},
- precision::onemklDftPrecision,
- domain::onemklDftDomain, length::Int64)::Cint
+ return @ccall liboneapi_support.onemklDftCreate1D(
+ desc::Ptr{onemklDftDescriptor_t},
+ precision::onemklDftPrecision,
+ domain::onemklDftDomain, length::Int64
+ )::Cint
end
function onemklDftCreateND(desc, precision, domain, dim, lengths)
- @ccall liboneapi_support.onemklDftCreateND(desc::Ptr{onemklDftDescriptor_t},
- precision::onemklDftPrecision,
- domain::onemklDftDomain, dim::Int64,
- lengths::Ptr{Int64})::Cint
+ return @ccall liboneapi_support.onemklDftCreateND(
+ desc::Ptr{onemklDftDescriptor_t},
+ precision::onemklDftPrecision,
+ domain::onemklDftDomain, dim::Int64,
+ lengths::Ptr{Int64}
+ )::Cint
end
function onemklDftDestroy(desc)
- @ccall liboneapi_support.onemklDftDestroy(desc::onemklDftDescriptor_t)::Cint
+ return @ccall liboneapi_support.onemklDftDestroy(desc::onemklDftDescriptor_t)::Cint
end
function onemklDftCommit(desc, queue)
- @ccall liboneapi_support.onemklDftCommit(desc::onemklDftDescriptor_t,
- queue::syclQueue_t)::Cint
+ return @ccall liboneapi_support.onemklDftCommit(
+ desc::onemklDftDescriptor_t,
+ queue::syclQueue_t
+ )::Cint
end
function onemklDftSetValueInt64(desc, param, value)
- @ccall liboneapi_support.onemklDftSetValueInt64(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Int64)::Cint
+ return @ccall liboneapi_support.onemklDftSetValueInt64(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Int64
+ )::Cint
end
function onemklDftSetValueDouble(desc, param, value)
- @ccall liboneapi_support.onemklDftSetValueDouble(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Cdouble)::Cint
+ return @ccall liboneapi_support.onemklDftSetValueDouble(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Cdouble
+ )::Cint
end
function onemklDftSetValueInt64Array(desc, param, values, n)
- @ccall liboneapi_support.onemklDftSetValueInt64Array(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- values::Ptr{Int64}, n::Int64)::Cint
+ return @ccall liboneapi_support.onemklDftSetValueInt64Array(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ values::Ptr{Int64}, n::Int64
+ )::Cint
end
function onemklDftSetValueConfigValue(desc, param, value)
- @ccall liboneapi_support.onemklDftSetValueConfigValue(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::onemklDftConfigValue)::Cint
+ return @ccall liboneapi_support.onemklDftSetValueConfigValue(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::onemklDftConfigValue
+ )::Cint
end
function onemklDftGetValueInt64(desc, param, value)
- @ccall liboneapi_support.onemklDftGetValueInt64(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Ptr{Int64})::Cint
+ return @ccall liboneapi_support.onemklDftGetValueInt64(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Ptr{Int64}
+ )::Cint
end
function onemklDftGetValueDouble(desc, param, value)
- @ccall liboneapi_support.onemklDftGetValueDouble(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Ptr{Cdouble})::Cint
+ return @ccall liboneapi_support.onemklDftGetValueDouble(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Ptr{Cdouble}
+ )::Cint
end
function onemklDftGetValueInt64Array(desc, param, values, n)
- @ccall liboneapi_support.onemklDftGetValueInt64Array(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- values::Ptr{Int64},
- n::Ptr{Int64})::Cint
+ return @ccall liboneapi_support.onemklDftGetValueInt64Array(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ values::Ptr{Int64},
+ n::Ptr{Int64}
+ )::Cint
end
function onemklDftGetValueConfigValue(desc, param, value)
- @ccall liboneapi_support.onemklDftGetValueConfigValue(desc::onemklDftDescriptor_t,
- param::onemklDftConfigParam,
- value::Ptr{onemklDftConfigValue})::Cint
+ return @ccall liboneapi_support.onemklDftGetValueConfigValue(
+ desc::onemklDftDescriptor_t,
+ param::onemklDftConfigParam,
+ value::Ptr{onemklDftConfigValue}
+ )::Cint
end
function onemklDftComputeForward(desc, inout)
- @ccall liboneapi_support.onemklDftComputeForward(desc::onemklDftDescriptor_t,
- inout::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeForward(
+ desc::onemklDftDescriptor_t,
+ inout::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeForwardOutOfPlace(desc, in, out)
- @ccall liboneapi_support.onemklDftComputeForwardOutOfPlace(desc::onemklDftDescriptor_t,
- in::Ptr{Cvoid},
- out::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeForwardOutOfPlace(
+ desc::onemklDftDescriptor_t,
+ in::Ptr{Cvoid},
+ out::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeBackward(desc, inout)
- @ccall liboneapi_support.onemklDftComputeBackward(desc::onemklDftDescriptor_t,
- inout::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeBackward(
+ desc::onemklDftDescriptor_t,
+ inout::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeBackwardOutOfPlace(desc, in, out)
- @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlace(desc::onemklDftDescriptor_t,
- in::Ptr{Cvoid},
- out::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlace(
+ desc::onemklDftDescriptor_t,
+ in::Ptr{Cvoid},
+ out::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeForwardBuffer(desc, inout)
- @ccall liboneapi_support.onemklDftComputeForwardBuffer(desc::onemklDftDescriptor_t,
- inout::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeForwardBuffer(
+ desc::onemklDftDescriptor_t,
+ inout::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeForwardOutOfPlaceBuffer(desc, in, out)
- @ccall liboneapi_support.onemklDftComputeForwardOutOfPlaceBuffer(desc::onemklDftDescriptor_t,
- in::Ptr{Cvoid},
- out::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeForwardOutOfPlaceBuffer(
+ desc::onemklDftDescriptor_t,
+ in::Ptr{Cvoid},
+ out::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeBackwardBuffer(desc, inout)
- @ccall liboneapi_support.onemklDftComputeBackwardBuffer(desc::onemklDftDescriptor_t,
- inout::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeBackwardBuffer(
+ desc::onemklDftDescriptor_t,
+ inout::Ptr{Cvoid}
+ )::Cint
end
function onemklDftComputeBackwardOutOfPlaceBuffer(desc, in, out)
- @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlaceBuffer(desc::onemklDftDescriptor_t,
- in::Ptr{Cvoid},
- out::Ptr{Cvoid})::Cint
+ return @ccall liboneapi_support.onemklDftComputeBackwardOutOfPlaceBuffer(
+ desc::onemklDftDescriptor_t,
+ in::Ptr{Cvoid},
+ out::Ptr{Cvoid}
+ )::Cint
end
function onemklDftQueryParamIndices(out, n)
- @ccall liboneapi_support.onemklDftQueryParamIndices(out::Ptr{Int64}, n::Int64)::Cint
+ return @ccall liboneapi_support.onemklDftQueryParamIndices(out::Ptr{Int64}, n::Int64)::Cint
end
const ONEMKL_DFT_STATUS_SUCCESS = 0
diff --git a/res/wrap.jl b/res/wrap.jl
index 1d48315..2e9b29f 100644
--- a/res/wrap.jl
+++ b/res/wrap.jl
@@ -112,14 +112,14 @@ using oneAPI_Level_Zero_Headers_jll
function main()
wrap("ze", oneAPI_Level_Zero_Headers_jll.ze_api)
- wrap(
- "support",
- joinpath(dirname(@__DIR__), "deps", "src", "sycl.h"),
- joinpath(dirname(@__DIR__), "deps", "src", "onemkl.h"),
- joinpath(dirname(@__DIR__), "deps", "src", "onemkl_dft.h");
- dependents=false,
- include_dirs=[dirname(dirname(oneAPI_Level_Zero_Headers_jll.ze_api))]
- )
+ return wrap(
+ "support",
+ joinpath(dirname(@__DIR__), "deps", "src", "sycl.h"),
+ joinpath(dirname(@__DIR__), "deps", "src", "onemkl.h"),
+ joinpath(dirname(@__DIR__), "deps", "src", "onemkl_dft.h");
+ dependents = false,
+ include_dirs = [dirname(dirname(oneAPI_Level_Zero_Headers_jll.ze_api))]
+ )
end
isinteractive() || main()
diff --git a/test/fft.jl b/test/fft.jl
index 321ea9c..8a0cb2b 100644
--- a/test/fft.jl
+++ b/test/fft.jl
@@ -5,21 +5,21 @@ using AbstractFFTs
using FFTW
# Helper to move data to GPU
-gpu(A::AbstractArray{T}) where T = oneAPI.oneArray{T}(A)
+gpu(A::AbstractArray{T}) where {T} = oneAPI.oneArray{T}(A)
-const MYRTOL = 1e-5
-const MYATOL = 1e-8
+const MYRTOL = 1.0e-5
+const MYATOL = 1.0e-8
-function cmp(a,b; rtol=MYRTOL, atol=MYATOL)
- @test isapprox(Array(a), Array(b); rtol=rtol, atol=atol)
+function cmp(a, b; rtol = MYRTOL, atol = MYATOL)
+ return @test isapprox(Array(a), Array(b); rtol = rtol, atol = atol)
end
-function cmp_broken(a,b; rtol=MYRTOL, atol=MYATOL)
- @test_broken isapprox(Array(a), Array(b); rtol=rtol, atol=atol)
+function cmp_broken(a, b; rtol = MYRTOL, atol = MYATOL)
+ return @test_broken isapprox(Array(a), Array(b); rtol = rtol, atol = atol)
end
@testset "FFT" begin
- Ns = (8,32,64,8)
+ Ns = (8, 32, 64, 8)
# Complex tests
for T in (ComplexF32, ComplexF64)
@@ -61,8 +61,8 @@ end
dX = gpu(X)
p = plan_fft!(dX, 1)
p * dX
- cmp_broken(dX, fft(X,1))
- pinv = plan_ifft!(dX,1)
+ cmp_broken(dX, fft(X, 1))
+ pinv = plan_ifft!(dX, 1)
pinv * dX
cmp_broken(dX, X)
end
@@ -76,7 +76,7 @@ end
p = plan_rfft(dX)
dY = p * dX
cmp(dY, rfft(X))
- pinv = plan_irfft(dY, size(X,1))
+ pinv = plan_irfft(dY, size(X, 1))
dZ = pinv * dY
cmp(dZ, X)
@@ -86,7 +86,7 @@ end
p = plan_rfft(dX)
dY = p * dX
cmp(dY, rfft(X, (1,))) # Compare with 1D FFT along first dim, not multi-dimensional FFT
- pinv = plan_irfft(dY, size(X,1))
+ pinv = plan_irfft(dY, size(X, 1))
dZ = pinv * dY
cmp_broken(dZ, X)
end
@@ -105,7 +105,7 @@ end
X = gpu(rand(T, Ns[1], Ns[2]))
Y = rfft(X)
cmp(Y, rfft(Array(X), (1,))) # Compare with 1D FFT along first dim, not multi-dimensional FFT
- Z = irfft(Y, size(X,1))
+ Z = irfft(Y, size(X, 1))
cmp_broken(Z, Array(X))
end
end |
} | ||
*out = desc; | ||
return 0; | ||
} catch (...) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a catch-all.
It passes some tests. GPT 5 helped quite a bit here.