Skip to content

mpi4jax API version mismatch #32

Open
@coreyjadams

Description

@coreyjadams

With release 0.3.0, I am unable to get mpi4jax to run. I am using this branch from an Intel-forked mpi4jax: https://github.com/jczaja/mpi4jax/tree/jczaja/xpu-support. This is running on Argonne's Sunspot cluster with Intel Max 1550 gpus.

I have installed intel_extension_for_open_xla with version 0.3.0 via pip. I have oneapi 2024.1 and agam 803.29. Here is what I see when I import jax, then import mpi4jax:

>>> import jax
jax.local_devices()
>>> jax.local_devices()
INFO: Intel Extension for OpenXLA version: 0.3.0, commit: 9a484818
Platform 'xpu' is experimental and not all JAX functionality may be correctly supported!
[xpu(id=0), xpu(id=1), xpu(id=2), xpu(id=3), xpu(id=4), xpu(id=5), xpu(id=6), xpu(id=7), xpu(id=8), xpu(id=9), xpu(id=10), xpu(id=11)]
>>> import mpi4jax
Registering b'mpi_allgather' and function <capsule object "xla._CUSTOM_CALL_TARGET" at 0x1458e2532ca0>
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/mpi4jax-0+untagged.386.gcb25ca5.dirty-py3.10-linux-x86_64.egg/mpi4jax/__init__.py", line 9, in <module>
    from ._src import (  # noqa: E402
  File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/mpi4jax-0+untagged.386.gcb25ca5.dirty-py3.10-linux-x86_64.egg/mpi4jax/_src/__init__.py", line 11, in <module>
    from . import xla_bridge  # noqa: E402
  File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/mpi4jax-0+untagged.386.gcb25ca5.dirty-py3.10-linux-x86_64.egg/mpi4jax/_src/xla_bridge/__init__.py", line 42, in <module>
    xla_client.register_custom_call_target(name, fn, platform="SYCL")
  File "/soft/datascience/jax/0.4.24/miniconda3/lib/python3.10/site-packages/jaxlib/xla_client.py", line 588, in register_custom_call_target
    _custom_callback_handler[xla_platform_name](name, fn, xla_platform_name)
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: API version 1986225522 not supported for PJRT GPU plugin. Supported versions are 0 and 1.
>>> 

Do I need to target a specific api version in mpi4jax to make this work? Or, do I need to build JAX from source?

Thanks!
Corey

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions