Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
_prefetch_embeddings,
_rewrite_model,
_start_data_dist,
_prepare_data_dist,
_start_embedding_lookup,
_to_device,
_wait_for_batch,
Expand Down Expand Up @@ -646,6 +647,10 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
logger.info("fill_pipeline: failed to load batch i+1")
return

def _data_processing_worker(self) -> None:
if len(self.batches) >= 2:
self.start_sparse_data_dist(self.batches[1], self.contexts[1], async_op=True)

def _wait_for_batch(self) -> None:
batch_id = self.contexts[0].index if len(self.contexts) > 0 else "?"
with record_function(f"## wait_for_batch {batch_id} ##"):
Expand Down Expand Up @@ -688,10 +693,6 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
# the input_dist of batches[0] has be invoked in previous iter. TODO: fact check
self._wait_for_batch()

if len(self.batches) >= 2:
# invoke splits all_to_all comms (first part of input_dist)
self.start_sparse_data_dist(self.batches[1], self.contexts[1])

if not self._enqueue_batch_after_forward:
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
self.enqueue_batch(dataloader_iter)
Expand All @@ -701,20 +702,26 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
self._state = PipelineState.CALL_FWD
losses, output = self._model_fwd(self.batches[0])

if self._enqueue_batch_after_forward:
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
self.enqueue_batch(dataloader_iter)

if len(self.batches) >= 2:
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
self.wait_sparse_data_dist(self.contexts[1])
async_op = True
if async_op == True:
_data_processing_future = self._data_processing_executor.submit(
self._data_processing_worker,
)
else:
self._data_processing_worker()

if self._model.training:
# backward
self._state = PipelineState.CALL_BWD
self._backward(losses)
if async_op == True:
_data_processing_future.result()
if len(self.batches) >= 2:
_fuse_input_dist_splits(self.contexts[1])

# batch i+2
self.enqueue_batch(dataloader_iter)


self.sync_embeddings(
self._model,
Expand Down Expand Up @@ -840,7 +847,7 @@ def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
return batch

def start_sparse_data_dist(
self, batch: Optional[In], context: TrainPipelineContext
self, batch: Optional[In], context: TrainPipelineContext, async_op: bool = False,
) -> None:
"""
Waits for batch to finish getting copied to GPU, then starts the input dist.
Expand All @@ -853,8 +860,13 @@ def start_sparse_data_dist(

# Temporarily set context for next iter to populate cache
with use_context_for_postprocs(self._pipelined_postprocs, context):
if async_op:
_prepare_data_dist(self._pipelined_modules, batch, context)
else:
_start_data_dist(self._pipelined_modules, batch, context)



def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None:
"""
Waits on the input dist splits requests to get the input dist tensors requests,
Expand Down
12 changes: 10 additions & 2 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def _wait_for_events(
), f"{type(batch)} must implement Multistreamable interface"
batch.record_stream(stream)


def _start_data_dist(
# We only asynchronously move the computation and blocking wait parts, while keeping the communication synchronous because Torch's communication is not thread-safe.
def _prepare_data_dist(
pipelined_modules: List[ShardedModule],
batch: Pipelineable,
context: TrainPipelineContext,
Expand Down Expand Up @@ -157,6 +157,14 @@ def _start_data_dist(
context.input_dist_splits_requests[forward.name] = module.input_dist(
module_ctx, *args, **kwargs
)

def _start_data_dist(
pipelined_modules: List[ShardedModule],
batch: Pipelineable,
context: TrainPipelineContext,
) -> None:

_prepare_data_dist(pipelined_modules, batch, context)
_fuse_input_dist_splits(context)


Expand Down