Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions torchrec/distributed/train_pipeline/runtime_forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def detach_embeddings(
class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
"""
This pipeline is used in PrefetchTrainPipelineSparseDist
OR in TrainPipelineCustomizedOrderSparseDist, when prefetch is enabled but pipeline_embedding_lookup_fwd is disabled
"""

def __init__(
Expand Down Expand Up @@ -267,6 +268,67 @@ def __call__(self, *input, **kwargs) -> Awaitable:
return self._module.compute_and_output_dist(ctx, data)


class PrefetchEmbeddingPipelinedForward(PrefetchPipelinedForward):
"""
This pipeline is used in TrainPipelineCustomizedOrderSparseDist when
prefetch is enabled and pipelined_sprase_lookup_fwd is enabled
compute_and_output_dist for batch N is called at the end of step N - 1
"""

def __init__(
self,
name: str,
args: CallArgs,
module: ShardedModule,
context: PrefetchTrainPipelineContext,
prefetch_stream: Optional[torch.Stream] = None,
) -> None:
super().__init__(
name=name,
args=args,
module=module,
context=context,
prefetch_stream=prefetch_stream,
)
self._compute_and_output_dist_awaitable: Optional[
Awaitable[Multistreamable]
] = None

def compute_and_output_dist(self) -> None:
assert (
self._name in self._context.module_input_post_prefetch
), "Invalid PrefetchEmbeddingPipelinedForward usage, please do not directly call model.forward()"
data = self._context.module_input_post_prefetch.pop(self._name)
ctx = self._context.module_contexts_post_prefetch.pop(self._name)

# Make sure that both result of input_dist and context
# are properly transferred to the current stream.
if self._stream is not None:
torch.get_device_module(self._device).current_stream().wait_stream(
self._stream
)
cur_stream = torch.get_device_module(self._device).current_stream()

assert isinstance(
data, (torch.Tensor, Multistreamable)
), f"{type(data)} must implement Multistreamable interface"
data.record_stream(cur_stream)

ctx.record_stream(cur_stream)

self._compute_and_output_dist_awaitable = self._module.compute_and_output_dist(
ctx, data
)

# pyre-ignore [2, 24]
def __call__(self, *input, **kwargs) -> Awaitable:
if not self._compute_and_output_dist_awaitable:
raise Exception(
"compute_and_output_dist must be called before __call__",
)
return self._compute_and_output_dist_awaitable


class KJTAllToAllForward:
def __init__(
self, pg: dist.ProcessGroup, splits: List[int], stagger: int = 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest
from unittest.mock import MagicMock

from torchrec.distributed.train_pipeline.pipeline_context import (
PrefetchTrainPipelineContext,
)
from torchrec.distributed.train_pipeline.runtime_forwards import (
PrefetchEmbeddingPipelinedForward,
PrefetchPipelinedForward,
)
from torchrec.distributed.train_pipeline.types import CallArgs


class TestPrefetchEmbeddingPipelinedForward(unittest.TestCase):
"""Test PrefetchEmbeddingPipelinedForward key functionality"""

def setUp(self) -> None:
"""Set up test fixtures."""
self.mock_module = MagicMock()
self.prefetch_context = PrefetchTrainPipelineContext()
self.mock_args = CallArgs(args=[], kwargs={})

def test_prefetch_returns_true(self) -> None:
"""Test that prefetch() returns True."""
forward = PrefetchEmbeddingPipelinedForward(
name="test_prefetch",
args=self.mock_args,
module=self.mock_module,
context=self.prefetch_context,
)

# Test that prefetch returns True
self.assertIsInstance(forward, PrefetchPipelinedForward)

def test_call_fails_without_compute_and_output_dist(self) -> None:
"""Test that __call__ fails if compute_and_output_dist is not called first."""
forward = PrefetchEmbeddingPipelinedForward(
name="test_call_error",
args=self.mock_args,
module=self.mock_module,
context=self.prefetch_context,
)

# Should raise exception when called without compute_and_output_dist
with self.assertRaises(Exception) as context:
forward()

self.assertIn(
"compute_and_output_dist must be called before __call__",
str(context.exception),
)

def test_call_succeeds_after_compute_and_output_dist(self) -> None:
"""Test that __call__ succeeds when compute_and_output_dist is called first."""
forward = PrefetchEmbeddingPipelinedForward(
name="test_call_success",
args=self.mock_args,
module=self.mock_module,
context=self.prefetch_context,
)

# Set up mock data in context
test_data = MagicMock()
test_ctx = MagicMock()
self.prefetch_context.module_input_post_prefetch = {
"test_call_success": test_data
}
self.prefetch_context.module_contexts_post_prefetch = {
"test_call_success": test_ctx
}

# Mock the module's compute_and_output_dist method
mock_awaitable = MagicMock()
self.mock_module.compute_and_output_dist.return_value = mock_awaitable

# Call compute_and_output_dist first
forward.compute_and_output_dist()

# Now __call__ should succeed and return the awaitable
result = forward()
self.assertEqual(result, mock_awaitable)


if __name__ == "__main__":
unittest.main()
Loading