Skip to content

feat: add support for Intel XPU backend in device selection #1553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions src/lerobot/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def auto_select_torch_device() -> torch.device:
elif torch.backends.mps.is_available():
logging.info("Metal backend detected, using mps.")
return torch.device("mps")
elif torch.xpu.is_available():
logging.info("Intel XPU backend detected, using xpu.")
return torch.device("xpu")
else:
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
return torch.device("cpu")
Expand All @@ -66,6 +69,9 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
case "mps":
assert torch.backends.mps.is_available()
device = torch.device("mps")
case "xpu":
assert torch.xpu.is_available()
device = torch.device("xpu")
case "cpu":
device = torch.device("cpu")
if log:
Expand All @@ -86,6 +92,18 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
device = device.type
if device == "mps" and dtype == torch.float64:
return torch.float32
if device == "xpu" and dtype == torch.float64:
if hasattr(torch.xpu, "get_device_capability"):
device_capability = torch.xpu.get_device_capability()
if not device_capability.get("has_fp64", False):
logging.warning(f"Device {device} does not support float64, using float32 instead.")
return torch.float32
else:
logging.warning(
f"Device {device} capability check failed. Assuming no support for float64, using float32 instead."
)
return torch.float32
return dtype
else:
return dtype

Expand All @@ -96,14 +114,16 @@ def is_torch_device_available(try_device: str) -> bool:
return torch.cuda.is_available()
elif try_device == "mps":
return torch.backends.mps.is_available()
elif try_device == "xpu":
return torch.xpu.is_available()
elif try_device == "cpu":
return True
else:
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.")


def is_amp_available(device: str):
if device in ["cuda", "cpu"]:
if device in ["cuda", "xpu", "cpu"]:
return True
elif device == "mps":
return False
Expand Down