Open
Description
With release 0.3.0, I am unable to get mpi4jax to run. I am using this branch from an Intel-forked mpi4jax: https://github.com/jczaja/mpi4jax/tree/jczaja/xpu-support. This is running on Argonne's Sunspot cluster with Intel Max 1550 gpus.
I have installed intel_extension_for_open_xla with version 0.3.0 via pip. I have oneapi 2024.1 and agam 803.29. Here is what I see when I import jax, then import mpi4jax:
>>> import jax
jax.local_devices()
>>> jax.local_devices()
INFO: Intel Extension for OpenXLA version: 0.3.0, commit: 9a484818
Platform 'xpu' is experimental and not all JAX functionality may be correctly supported!
[xpu(id=0), xpu(id=1), xpu(id=2), xpu(id=3), xpu(id=4), xpu(id=5), xpu(id=6), xpu(id=7), xpu(id=8), xpu(id=9), xpu(id=10), xpu(id=11)]
>>> import mpi4jax
Registering b'mpi_allgather' and function <capsule object "xla._CUSTOM_CALL_TARGET" at 0x1458e2532ca0>
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/mpi4jax-0+untagged.386.gcb25ca5.dirty-py3.10-linux-x86_64.egg/mpi4jax/__init__.py", line 9, in <module>
from ._src import ( # noqa: E402
File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/mpi4jax-0+untagged.386.gcb25ca5.dirty-py3.10-linux-x86_64.egg/mpi4jax/_src/__init__.py", line 11, in <module>
from . import xla_bridge # noqa: E402
File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/mpi4jax-0+untagged.386.gcb25ca5.dirty-py3.10-linux-x86_64.egg/mpi4jax/_src/xla_bridge/__init__.py", line 42, in <module>
xla_client.register_custom_call_target(name, fn, platform="SYCL")
File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/jaxlib/xla_client.py", line 588, in register_custom_call_target
_custom_callback_handler[xla_platform_name](name, fn, xla_platform_name)
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: API version 1986225522 not supported for PJRT GPU plugin. Supported versions are 0 and 1.
>>>
Do I need to target a specific api version in mpi4jax to make this work? Or, do I need to build JAX from source?
Thanks!
Corey