Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 3 additions & 28 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,8 @@
from torch.distributed._tensor.device_mesh import init_device_mesh


def find_repo_root(max_depth=10):
dir_path = os.path.dirname(os.path.realpath(__file__))
for i in range(max_depth):
files = os.listdir(dir_path)
if "MODULE.bazel" in files:
return dir_path
else:
dir_path = os.path.dirname(dir_path)

raise RuntimeError("Could not find repo root")


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
# this is kept at the application level, when mpirun is used to run the application
def initialize_distributed_env(rank=0, world_size=1, port=29500):
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
Expand All @@ -50,9 +29,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["TRTLLM_PLUGINS_PATH"] = (
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)
Expand All @@ -66,13 +42,12 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger
return device_mesh, world_size, rank


def cleanup_distributed_env():
Expand Down
17 changes: 13 additions & 4 deletions examples/distributed_inference/tensor_parallel_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,30 @@
import time

import torch
import torch_tensorrt
from rotary_embedding import RotaryAttention, parallel_rotary_block
import torch.distributed as dist
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
initialize_distributed_env,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_rotary_embedding"
if not dist.is_initialized():
initialize_distributed_env()

import torch_tensorrt
from torch_tensorrt.dynamo.distributed.utils import (
get_tensor_parallel_device_mesh,
initialize_distributed_logger,
)

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_distributed_logger(_rank, "tensor_parallel_rotary_embedding")

from rotary_embedding import RotaryAttention, parallel_rotary_block

"""
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py
Command to run with 2 GPUs: mpirun -n 2 --allow-run-as-root python tensor_parallel_rotary_embedding.py
"""

BATCH = 2
Expand Down
15 changes: 11 additions & 4 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,29 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
initialize_distributed_env,
)

if not dist.is_initialized():
initialize_distributed_env()
import torch_tensorrt
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
from torch_tensorrt.dynamo.distributed.utils import (
get_tensor_parallel_device_mesh,
initialize_distributed_logger,
)

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_distributed_logger(_rank, "tensor_parallel_simple_example")


"""
This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""
Expand Down
76 changes: 45 additions & 31 deletions py/torch_tensorrt/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,55 @@ def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path:
)


def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None:
# this will not be encountered in case of platforms not supporting torch distributed/nccl/TRT-LLM
from torch.distributed import barrier, get_rank, is_initialized

if not is_initialized():
# Single process case, just unzip
is_master = True
else:
is_master = get_rank() == 0 # only rank 0 does the unzip

if is_master:
try:
import zipfile
except ImportError as e:
raise ImportError(
"zipfile module is required but not found. Please install zipfile"
)
try:
with zipfile.ZipFile(wheel_path) as zip_ref:
zip_ref.extractall(extract_dir)
logger.debug(f"Extracted wheel to {extract_dir}")

except FileNotFoundError as e:
# This should capture the errors in the download failure above
logger.error(f"Wheel file not found at {wheel_path}: {e}")
raise RuntimeError(
f"Failed to find downloaded wheel file at {wheel_path}"
) from e
except zipfile.BadZipFile as e:
logger.error(f"Invalid or corrupted wheel file: {e}")
raise RuntimeError(
"Downloaded wheel file is corrupted or not a valid zip archive"
) from e
except Exception as e:
logger.error(f"Unexpected error while extracting wheel: {e}")
raise RuntimeError(
"Unexpected error during extraction of TensorRT-LLM wheel"
) from e

# Make sure others wait until unzip is done
if is_initialized():
barrier()


def download_and_get_plugin_lib_path() -> Optional[str]:
"""
Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.
Args:
platform (str): Platform identifier (e.g., 'linux_x86_64')
Returns:
Optional[str]: Path to shared library or None if operation fails.
"""
Expand Down Expand Up @@ -194,32 +236,7 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
except OSError as e:
logger.error(f"Local file write error: {e}")

try:
import zipfile
except ImportError as e:
raise ImportError(
"zipfile module is required but not found. Please install zipfile"
)
try:
with zipfile.ZipFile(wheel_path) as zip_ref:
zip_ref.extractall(extract_dir)
logger.debug(f"Extracted wheel to {extract_dir}")
except FileNotFoundError as e:
# This should capture the errors in the download failure above
logger.error(f"Wheel file not found at {wheel_path}: {e}")
raise RuntimeError(
f"Failed to find downloaded wheel file at {wheel_path}"
) from e
except zipfile.BadZipFile as e:
logger.error(f"Invalid or corrupted wheel file: {e}")
raise RuntimeError(
"Downloaded wheel file is corrupted or not a valid zip archive"
) from e
except Exception as e:
logger.error(f"Unexpected error while extracting wheel: {e}")
raise RuntimeError(
"Unexpected error during extraction of TensorRT-LLM wheel"
) from e
extract_wheel_file(wheel_path, extract_dir)

try:
wheel_path.unlink(missing_ok=True)
Expand All @@ -238,10 +255,8 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
"""
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
Args:
plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.
Returns:
bool: True if successful, False otherwise.
"""
Expand Down Expand Up @@ -293,7 +308,6 @@ def load_tensorrt_llm_for_nccl() -> bool:
Attempts to load the TensorRT-LLM plugin and initialize it.
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
Returns:
bool: True if the plugin was successfully loaded and initialized, False otherwise.
"""
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

41 changes: 41 additions & 0 deletions py/torch_tensorrt/dynamo/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
import os

import torch
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh

logger = logging.getLogger(__name__)


def check_tensor_parallel_device_number(world_size: int) -> None:
if world_size % 2 != 0:
raise ValueError(
f"TP examples require even number of GPUs, but got {world_size} gpus"
)


def get_tensor_parallel_device_mesh(
rank: int = 0, world_size: int = 1
) -> tuple[DeviceMesh, int, int]:
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank


def initialize_distributed_logger(rank: int, logger_file_name: str) -> logging.Logger:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def run(self):
"torch_tensorrt.dynamo.conversion.impl.unary",
"torch_tensorrt.dynamo.conversion.plugins",
"torch_tensorrt.dynamo.debug",
"torch_tensorrt.dynamo.distributed",
"torch_tensorrt.dynamo.lowering",
"torch_tensorrt.dynamo.lowering.passes",
"torch_tensorrt.dynamo.partitioning",
Expand Down
36 changes: 11 additions & 25 deletions tests/py/dynamo/distributed/distributed_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import random

import numpy as np
import tensorrt as trt
Expand All @@ -8,24 +9,21 @@
from torch.distributed._tensor.device_mesh import init_device_mesh


def set_environment_variables_pytest():
# the below two functions are used to set the environment variables for the pytest single and multi process
# this is for the github CI where we use pytest
def set_environment_variables_pytest_single_process():
port = 29500 + random.randint(1, 1000)
os.environ["WORLD_SIZE"] = str(1)
os.environ["RANK"] = str(0)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(29500)


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger
os.environ["MASTER_PORT"] = str(port)


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
def set_environment_variables_pytest_multi_process(
rank: int = 0, world_size: int = 1
) -> None:
port = 29500 + random.randint(1, 1000)
# these variables are set by mpirun -n 2
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
Expand All @@ -36,7 +34,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["TRTLLM_PLUGINS_PATH"] = "./tmp/lib/libnvinfer_plugin_tensorrt_llm.so"

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)
Expand All @@ -46,14 +43,3 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950

# set a manual seed for reproducibility
torch.manual_seed(1111)

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger
26 changes: 20 additions & 6 deletions tests/py/dynamo/distributed/test_nccl_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,27 @@
import torch.distributed as dist
import torch.nn as nn
from conversion.harness import DispatchTestCase
from distributed_utils import set_environment_variables_pytest

# The distributed env initialization has to be before import of torchTRT, since it uses barrier for installation
from distributed_utils import (
set_environment_variables_pytest_multi_process,
set_environment_variables_pytest_single_process,
)
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt._utils import is_platform_supported_for_trtllm

if "OMPI_COMM_WORLD_SIZE" in os.environ:
set_environment_variables_pytest_multi_process()
else:
set_environment_variables_pytest_single_process()

if not dist.is_initialized():
dist.init_process_group(
backend="nccl",
init_method="env://",
)


class DistributedGatherModel(nn.Module):
def __init__(self, input_dim, world_size, group_name):
Expand Down Expand Up @@ -48,11 +64,9 @@ class TestNcclOpsConverter(DispatchTestCase):
)
@classmethod
def setUpClass(cls):
set_environment_variables_pytest()
cls.world_size = 1
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
cls.group = dist.new_group(ranks=[0])
cls.world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1))
cls.rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0))
cls.group = dist.new_group(ranks=list(range(cls.world_size)))
cls.group_name = cls.group.group_name

@classmethod
Expand Down
Loading