Skip to content

feat: intel xpu register #1260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
steps:
- group: ":test_tube: Tests"
- group: ":test_tube: CUDA Tests"
steps:
- label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}} -- {{matrix.runtime}}"
- label: ":julia: :linux: Julia v{{matrix.version}} -- CUDA -- {{matrix.group}} -- {{matrix.runtime}}"
matrix:
setup:
version:
Expand Down Expand Up @@ -48,6 +48,54 @@ steps:
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 120

- group: ":test_tube: Intel GPU (SYCL) Tests"
steps:
- label: ":julia: :linux: Julia v{{matrix.version}} -- Intel GPU -- {{matrix.group}} -- {{matrix.runtime}}"
matrix:
setup:
version:
- "1.10"
group:
- core
runtime:
- "PJRT"
- "IFRT"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.version}}"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
- lib/ReactantCore/src
commands: |
touch LocalPreferences.toml

echo "[Reactant]" >> LocalPreferences.toml
echo "xla_runtime = \"{{matrix.runtime}}\"" >> LocalPreferences.toml

cat LocalPreferences.toml

julia --project=. -e 'println("--- :julia: Instantiating project")
using Pkg
Pkg.develop([PackageSpec(path="lib/ReactantCore")])'

julia --project=. -e 'println("--- :julia: Run Tests")
using Pkg
Pkg.test(; coverage="user")'
env:
REACTANT_TEST_GROUP: "{{matrix.group}}"
JULIA_DEBUG: "Reactant,Reactant_jll"
REACTANT_BACKEND_GROUP: "GPU"
REACTANT_TEST_ONLY_PLUGIN: "true"
agents:
queue: "juliagpu"
intel: "*"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60


# - group: ":racehorse: Benchmarks"
# steps:
# - label: "CPU: Run Benchmarks"
Expand Down
1 change: 1 addition & 0 deletions src/accelerators/Accelerators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ module Accelerators

include("TPU.jl")
include("Metal.jl")
include("IntelXPU.jl")

end
80 changes: 80 additions & 0 deletions src/accelerators/IntelXPU.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
module IntelXPU

using Reactant: Reactant
using Scratch: @get_scratch!
using Downloads
using Libdl

const intel_xpu_pjrt_plugin_dir = Ref{Union{Nothing,String}}(nothing)

function __init__()
@static if Sys.ARCH === :x86_64
if !Reactant.precompiling()
setup_intel_xpu_pjrt_plugin!()

try
Libdl.dlopen(joinpath(get_intel_xpu_pjrt_plugin_dir(), "sycl_onednn.so");)
catch e
@debug "Failed to load sycl_onednn.so: $e"
end
end
end
end

function setup_intel_xpu_pjrt_plugin!()
path_from_env = get(ENV, "INTEL_XPU_LIBRARY_PATH", nothing)
if path_from_env !== nothing && ispath(path_from_env)
intel_xpu_pjrt_plugin_dir[] = path_from_env
else
intel_xpu_pjrt_plugin_dir[] = @get_scratch!("pjrt_intel_xpu_plugin")
end
download_intel_xpu_pjrt_plugin_if_needed(intel_xpu_pjrt_plugin_dir[])
return nothing
end

get_intel_xpu_pjrt_plugin_dir() = intel_xpu_pjrt_plugin_dir[]

function get_intel_xpu_pjrt_plugin_path()
return joinpath(get_intel_xpu_pjrt_plugin_dir(), "pjrt_plugin_xpu.so")
end

function download_intel_xpu_pjrt_plugin_if_needed(path=nothing)
path === nothing && (path = get_intel_xpu_pjrt_plugin_dir())
@assert path !== nothing "intel_xpu_pjrt_plugin_dir is not set!"

intel_xpu_pjrt_plugin_path = joinpath(path, "pjrt_plugin_xpu.so")
if !isfile(intel_xpu_pjrt_plugin_path)
zip_file_path = joinpath(path, "pjrt-plugin-intel-xpu.zip")
tmp_dir = joinpath(path, "tmp")
Downloads.download(
if Sys.ARCH === :x86_64
"https://files.pythonhosted.org/packages/42/28/26564ea0937ec11755e63ab3c85d6d4b96201131a69c6fddf8b985e7f9ae/intel_extension_for_openxla-0.6.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
else
error("Unsupported architecture: $(Sys.ARCH)")
end,
zip_file_path,
)
run(`unzip -qq $(zip_file_path) -d $(tmp_dir)`)
mv(
joinpath(
tmp_dir, "jax_plugins", "intel_extension_for_openxla", "pjrt_plugin_xpu.so"
),
intel_xpu_pjrt_plugin_path;
)
mv(
joinpath(
tmp_dir,
"jax_plugins",
"intel_extension_for_openxla",
"service",
"gpu",
"sycl_onednn.so",
),
joinpath(path, "sycl_onednn.so");
)
rm(tmp_dir; recursive=true)
rm(zip_file_path; recursive=true)
end
end

end
18 changes: 18 additions & 0 deletions src/xla/IFRT/Client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,14 @@ const cpu_client_count = Ref(0)
const cuda_client_count = Ref(0)
const tpu_client_count = Ref(0)
const metal_client_count = Ref(0)
const sycl_client_count = Ref(0)

for (backend, counter) in (
(:CPUClient, :cpu_client_count),
(:CUDAClient, :cuda_client_count),
(:TPUClient, :tpu_client_count),
(:MetalClient, :metal_client_count),
(:SYCLClient, :sycl_client_count),
)
main_fn = Symbol(:MakeIFRTPJRT, backend)
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
Expand Down Expand Up @@ -219,6 +221,22 @@ function MakeIFRTPJRTMetalClient(;
)
end

function MakeIFRTPJRTSYCLClient(;
sycl_pjrt_plugin_path::String,
node_id::Integer=0,
num_nodes::Integer=1,
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
)
return MakeIFRTPJRTClientViaPluginAPI(
sycl_pjrt_plugin_path,
"sycl",
"SYCL";
node_id,
num_nodes,
distributed_runtime_client,
)
end

function MakeIFRTPJRTClientViaPluginAPI(
library_path::String,
device_type::String,
Expand Down
16 changes: 16 additions & 0 deletions src/xla/PJRT/Client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,14 @@ const cpu_client_count = Ref(0)
const cuda_client_count = Ref(0)
const tpu_client_count = Ref(0)
const metal_client_count = Ref(0)
const sycl_client_count = Ref(0)

for (backend, counter) in (
(:CPUClient, :cpu_client_count),
(:CUDAClient, :cuda_client_count),
(:TPUClient, :tpu_client_count),
(:MetalClient, :metal_client_count),
(:SYCLClient, :sycl_client_count),
)
main_fn = Symbol(:Make, backend)
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
Expand Down Expand Up @@ -207,6 +209,20 @@ function MakeMetalClient(;
return MakeClientUsingPluginAPI(metal_pjrt_plugin_path, "metal", "METAL")
end

function MakeSYCLClient(;
sycl_pjrt_plugin_path::String,
node_id::Integer=0,
num_nodes::Integer=1,
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
)
@assert node_id == 0 "`PJRT.MakeSYCLClient` does not support node_id"
@assert num_nodes == 1 "`PJRT.MakeSYCLClient` does not support num_nodes > 1"
@assert distributed_runtime_client === nothing "`PJRT.MakeSYCLClient` does not support \
distributed_runtime_client"

return MakeClientUsingPluginAPI(sycl_pjrt_plugin_path, "sycl", "SYCL")
end

function MakeClientUsingPluginAPI(
library_path::String, device_type::String, client_name::String=uppercase(device_type)
)
Expand Down
35 changes: 24 additions & 11 deletions src/xla/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,10 @@ const global_state = State()

function client(backend::String)
if backend == "gpu"
if haskey(global_backend_state.clients, "cuda")
backend = "cuda"
elseif haskey(global_backend_state.clients, "metal")
backend = "metal"
else
error("No GPU client found")
end
backend = findfirst(
Base.Fix1(haskey, global_backend_state.clients), ("cuda", "metal", "sycl")
)
@assert backend !== nothing "No GPU client found"
end
return global_backend_state.clients[backend]
end
Expand Down Expand Up @@ -216,7 +213,7 @@ for runtime in (:PJRT, :IFRT)
state.clients["tpu"] = tpu
state.default_client = tpu
catch e
println(stdout, e)
@debug "Failed to load TPU client: $e"
end
else
try
Expand All @@ -231,11 +228,28 @@ for runtime in (:PJRT, :IFRT)
state.clients["cuda"] = gpu
state.default_client = gpu
catch e
println(stdout, e)
@debug "Failed to load CUDA client: $e"
end

if Sys.ARCH == :x86_64
try
if was_initialized && haskey(state.clients, "sycl")
XLA.free_client(state.clients["sycl"])
XLA.$(runtime).sycl_client_count[] -= 1
end
gpu = $(runtime).SYCLClient(;
sycl_pjrt_plugin_path=Accelerators.IntelXPU.get_intel_xpu_pjrt_plugin_path(),
common_kwargs...,
)
state.clients["sycl"] = gpu
catch e
@debug "Failed to load SYCL client: $e"
end
end
end
else
try
# XXX: Metal PJRT plugin is not yet compatible with latest OpenXLA
#=
if was_initialized && haskey(state.clients, "metal")
XLA.free_client(state.clients["metal"])
Expand All @@ -249,9 +263,8 @@ for runtime in (:PJRT, :IFRT)
# Don't put this in the default_client since metal support is fairly
# limited
=#
# Metal PJRT plugin is not yet compatible with latest OpenXLA
catch e
println(stdout, e)
@debug "Failed to load Metal client: $e"
end
end
end
Expand Down
46 changes: 46 additions & 0 deletions test/plugins/common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using Reactant, Enzyme, Lux, Random, Test

sincos(x) = sin(cos(x))
sumsincos(x) = sum(sincos, x)

@testset "Simple Function" begin
x = reshape(collect(Float32, 1:40), 10, 4)
x_ra = Reactant.to_rarray(x)

@test @jit(sincos(x_ra)) ≈ sincos(x)
end

@testset "Autodiff" begin
x = reshape(collect(Float32, 1:40), 10, 4)
x_ra = Reactant.to_rarray(x)

@test @jit(sumsincos(x_ra)) ≈ sum(sincos, x)

@test @jit(Enzyme.gradient(Enzyme.Reverse, sumsincos, x_ra))[2] ≈
Enzyme.gradient(Enzyme.Reverse, sumsincos, x)[2]
@test @jit(Enzyme.gradient(Enzyme.Forward, sumsincos, x_ra))[2] ≈
Enzyme.gradient(Enzyme.Forward, sumsincos, x)[2]
end

@testset "CNN" begin
model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)),
)
ps, st = Lux.setup(Random.default_rng(), model)
x = rand(Float32, 28, 28, 1, 4)

st_test = Lux.testmode(st)

ps_ra = Reactant.to_rarray(ps)
st_ra = Reactant.to_rarray(st)
x_ra = Reactant.to_rarray(x)

st_ra_test = Lux.testmode(st_ra)

@test @jit(model(x_ra, ps_ra, st_ra_test))[1] ≈ model(x, ps, st_test)[1]
end
47 changes: 2 additions & 45 deletions test/plugins/metal.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,8 @@
using Reactant, Enzyme, Lux, Random, Test
using Reactant

original_backend = Reactant.XLA.default_backend()
Reactant.set_default_backend("metal")

sincos(x) = sin(cos(x))
sumsincos(x) = sum(sincos, x)

@testset "Simple Function" begin
x = reshape(collect(Float32, 1:40), 10, 4)
x_ra = Reactant.to_rarray(x)

@test @jit(sincos(x_ra)) ≈ sincos(x)
end

@testset "Autodiff" begin
x = reshape(collect(Float32, 1:40), 10, 4)
x_ra = Reactant.to_rarray(x)

@test @jit(sumsincos(x_ra)) ≈ sum(sincos, x)

@test @jit(Enzyme.gradient(Enzyme.Reverse, sumsincos, x_ra))[2] ≈
Enzyme.gradient(Enzyme.Reverse, sumsincos, x)[2]
@test @jit(Enzyme.gradient(Enzyme.Forward, sumsincos, x_ra))[2] ≈
Enzyme.gradient(Enzyme.Forward, sumsincos, x)[2]
end

@testset "CNN" begin
model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)),
)
ps, st = Lux.setup(Random.default_rng(), model)
x = rand(Float32, 28, 28, 1, 4)

st_test = Lux.testmode(st)

ps_ra = Reactant.to_rarray(ps)
st_ra = Reactant.to_rarray(st)
x_ra = Reactant.to_rarray(x)

st_ra_test = Lux.testmode(st_ra)

@test @jit(model(x_ra, ps_ra, st_ra_test))[1] ≈ model(x, ps, st_test)[1]
end
include("common.jl")

Reactant.set_default_backend(original_backend)
Loading
Loading