Skip to content

Xinchi/fuser #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/wan_t2v_dist.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"parallel_attn_type": "ulysses",
"parallel_attn_type": "pipefusion",
"parallel_vae": true
}
9 changes: 9 additions & 0 deletions lightx2v/dist/wrappers/distrifusion/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class DistruFusionKVCacheManager:
def __init__(self):
pass

def update_cache(self):
pass

def get_cache(self):
pass
91 changes: 91 additions & 0 deletions lightx2v/dist/wrappers/distrifusion/comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch
import torch.distributed as dist
from loguru import logger


class DistriFusionCommManager:
def __init__(self):
self.verbose = True

self.device = "cuda"

self.rank = dist.get_rank()
self.world_size = dist.get_world_size()

self.torch_dtype = None
self.numel = 0
self.numel_dict = {}

self.buffer_list = None

self.starts = []
self.ends = []
self.shapes = []

self.idx_queue = []

self.handles = None

def register_tensor(
self, shape: tuple[int, ...] or list[int], torch_dtype: torch.dtype, layer_type: str = None
) -> int:
if self.torch_dtype is None:
self.torch_dtype = torch_dtype
else:
assert self.torch_dtype == torch_dtype
self.starts.append(self.numel)
numel = 1
for dim in shape:
numel *= dim
self.numel += numel
if layer_type is not None:
if layer_type not in self.numel_dict:
self.numel_dict[layer_type] = 0
self.numel_dict[layer_type] += numel

self.ends.append(self.numel)
self.shapes.append(shape)
return len(self.starts) - 1

def create_buffer(self):
if self.rank == 0 and self.verbose:
logger.info(
f"Create buffer with {self.numel / 1e6:.3f}M parameters for {len(self.starts)} tensors on each device."
)
for layer_type, numel in self.numel_dict.items():
logger.info(f" {layer_type}: {numel / 1e6:.3f}M parameters")

self.buffer_list = [
torch.empty(self.numel, dtype=self.torch_dtype, device=self.device) for _ in range(self.world_size)
]
self.handels = [None for _ in range(len(self.starts))]

def get_buffer_list(self, idx: int) -> list[torch.Tensor]:
buffer_list = [t[self.starts[idx] : self.ends[idx]].view(self.shapes[idx]) for t in self.buffer_list]
return buffer_list

def communicate(self):
start = self.starts[self.idx_queue[0]]
end = self.ends[self.idx_queue[-1]]
tensor = self.buffer_list[self.rank][start:end]
buffer_list = [t[start:end] for t in self.buffer_list]
handle = dist.all_gather(buffer_list, tensor, async_op=True)
for i in self.idx_queue:
self.handles[i] = handle
self.idx_queue = []

def enqueue(self, idx: int, tensor: torch.Tensor):
if idx == 0 and len(self.idx_queue) > 0:
self.communicate()
assert len(self.idx_queue) == 0 or self.idx_queue[-1] == idx - 1
self.idx_queue.append(idx)
self.buffer_list[self.rank][self.starts[idx] : self.ends[idx]].copy_(tensor.flatten())

def clear(self):
if len(self.idx_queue) > 0:
self.communicate()
if self.handles is not None:
for i in range(len(self.handles)):
if self.handles[i] is not None:
self.handles[i].wait()
self.handles[i] = None
8 changes: 8 additions & 0 deletions lightx2v/dist/wrappers/distrifusion/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)


class DistriFusionWanTransformerInferWrapper:
def __init__(self, transformer_infer: WanTransformerInfer, config):
pass
30 changes: 30 additions & 0 deletions lightx2v/dist/wrappers/distrifusion/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.dist.wrappers.distrifusion.infer import DistriFusionWanTransformerInferWrapper
from lightx2v.dist.wrappers.distrifusion.weights import DistriFusionWanTransformerWeightsWrapper


class DistriFusionWanModelWrapper:
def __init__(self, model: WanModel, config):
self.model = model
self.config = config

self._wrap_transformer(self.model, self.config)

def __getattr__(self, name: str):
if name in self.__dict__:
return getattr(self, name)
else:
return getattr(self.model, name)

def __delattr__(self, name: str):
if name in self.__dict__:
del self.__dict__[name]
else:
del self.model.__dict__[name]

def _wrap_transformer(self, model, config):
model.transformer_weights = DistriFusionWanTransformerWeightsWrapper(model.transformer_weights, config)
model.transformer_infer = DistriFusionWanTransformerInferWrapper(model.transformer_infer, config)

def infer(self, inputs, is_warmup=True):
pass
17 changes: 17 additions & 0 deletions lightx2v/dist/wrappers/distrifusion/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from lightx2v.models.runners.wan.wan_runner import WanRunner


class DistriFusionWanRunnerWrapper:
def __init__(self, runner: WanRunner, config):
self.runner = runner
self.config = config

def __getattr__(self, name: str):
if name in self.__dict__:
return getattr(self, name)
else:
return getattr(self.runner, name)

def _wrap(self, runner, config):
runner.model = DistriFusionWanModelWrapper(runner.model, config)
self.runner.run = self.run
9 changes: 9 additions & 0 deletions lightx2v/dist/wrappers/distrifusion/weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch.distributed as dist
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)


class DistriFusionWanTransformerWeightsWrapper:
def __init__(self, transformer_weights: WanTransformerWeights, config):
pass
Loading
Loading