Skip to content

Give an option to either provide dataset or dataset_size in distributed sampler #1479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion test/stateful_dataloader/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_initialization_StatefulDistributedSampler(self):
seed=42,
drop_last=False,
)
self.assertEqual(sampler.dataset, self.dataset)
self.assertEqual(sampler.dataset_size, len(self.dataset))
self.assertEqual(sampler.num_replicas, 10)
self.assertEqual(sampler.rank, 0)
self.assertFalse(sampler.shuffle)
Expand Down Expand Up @@ -232,6 +232,76 @@ def test_seed_replicability(self):
self.assertEqual(results1, results2, "Data should be replicable with same seed")
self.assertNotEqual(results1, results3, "Data should not be replicable with different seed")

def test_StatefulDistributedSampler_initialization_with_dataset_size(self):
sampler = StatefulDistributedSampler(dataset_size=100, num_replicas=2, rank=0, shuffle=False)
self.assertEqual(sampler.dataset_size, 100)
indices = list(iter(sampler))
expected_indices = list(range(0, 100, 2))
self.assertEqual(indices, expected_indices)

def test_StatefulDistributedSampler_mismatched_dataset_and_dataset_size(self):
dataset = MockDataset(100)
with self.assertRaises(ValueError):
StatefulDistributedSampler(dataset=dataset, dataset_size=50)

def test_StatefulDistributedSampler_no_dataset_or_dataset_size(self):
with self.assertRaises(ValueError):
StatefulDistributedSampler()

def test_StatefulDistributedSampler_drop_last_with_dataset_size(self):
dataset_size = 100
num_replicas = 3
sampler = StatefulDistributedSampler(
dataset_size=dataset_size,
num_replicas=num_replicas,
rank=0,
drop_last=True,
shuffle=False,
)
self.assertEqual(sampler.num_samples, 33)
indices = list(iter(sampler))
self.assertEqual(len(indices), 33)
expected_indices = list(range(0, 99, num_replicas))
self.assertEqual(indices, expected_indices)

def test_StatefulDistributedSampler_dataloader_state_dict_with_dataset_size(self):
dataset_size = 100
sampler = StatefulDistributedSampler(dataset_size=dataset_size, num_replicas=1, rank=0, shuffle=False)
dataset = MockDataset(dataset_size)
dataloader = StatefulDataLoader(dataset, batch_size=10, sampler=sampler)
iter_count = 5
for i, _ in enumerate(dataloader):
if i == iter_count - 1:
break
state_dict = dataloader.state_dict()
new_sampler = StatefulDistributedSampler(dataset_size=dataset_size, num_replicas=1, rank=0, shuffle=False)
new_dataloader = StatefulDataLoader(MockDataset(dataset_size), batch_size=10, sampler=new_sampler)
new_dataloader.load_state_dict(state_dict)
resumed_data = []
for data in new_dataloader:
resumed_data.append(data.tolist())
expected_data = []
full_dataloader = StatefulDataLoader(MockDataset(dataset_size), batch_size=10, sampler=sampler)
for data in full_dataloader:
expected_data.append(data.tolist())
self.assertEqual(resumed_data, expected_data[iter_count:])

def test_StatefulDistributedSampler_dataset_size_zero(self):
sampler = StatefulDistributedSampler(dataset_size=0, num_replicas=1, rank=0)
self.assertEqual(len(sampler), 0)
indices = list(iter(sampler))
self.assertEqual(len(indices), 0)

def test_StatefulDistributedSampler_shuffle_with_dataset_size(self):
dataset_size = 100
sampler = StatefulDistributedSampler(dataset_size=dataset_size, num_replicas=1, rank=0, shuffle=True, seed=42)
indices = list(iter(sampler))
self.assertEqual(len(indices), dataset_size)
self.assertEqual(sorted(indices), list(range(dataset_size)))
sampler.set_epoch(1)
indices_epoch_1 = list(iter(sampler))
self.assertNotEqual(indices, indices_epoch_1)


if __name__ == "__main__":
run_tests()
99 changes: 95 additions & 4 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
# LICENSE file in the root directory of this source tree.

import itertools
import math
from typing import Any, Dict, Iterator, List, Optional, Sized

import torch.distributed as dist

import torch.utils.data.sampler
from torch.utils.data import Dataset
from torch.utils.data.dataloader import _InfiniteConstantSampler
Expand Down Expand Up @@ -179,19 +182,66 @@ def __iter__(self):
)


class StatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler):
class StatefulDistributedSampler(Sampler[int]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should continue subclassing DistributedSampler for StatefulDistributedSampler - it is easy to udnerstand that by just the naming and we might trigger many type checking issues in downstream code which uses StatefulDistributedSampler and expects a variant of DistributedSampler.

Since DistributedSampler is a common utility in PyTorch, StatefulDistributedSampler should be expected to be an extension of it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to fork it instead of subclassing because I do not want to upstream these changes in torch.utils.data.distributed.DistributedSampler as it might break other users' code.
Nevertheless, it is redundant to have Dataset as an arg when we just need the length of it.

_YIELDED = "yielded"

def __init__(
self,
dataset: Dataset,
dataset: Optional[Dataset] = None,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
dataset_size: Optional[int] = None,
) -> None:
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)

# Validate inputs
if dataset is None and dataset_size is None:
raise ValueError("Either dataset or dataset_size must be provided.")

if dataset_size is not None:
if dataset is not None and (hasattr(dataset, "__len__") and dataset_size != len(dataset)):
raise ValueError(
f"dataset_size must match the length of the dataset. {dataset_size=} and {len(dataset)=}"
)
self.dataset_size = dataset_size
else:
if dataset is not None and hasattr(dataset, "__len__"):
self.dataset_size = len(dataset)
else:
raise ValueError("Either a dataset with the __len__ method or dataset_size must be provided.")

if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")

self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and self.dataset_size % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(self.dataset_size - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(self.dataset_size / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed

self.yielded = 0
self.next_yielded = None

Expand All @@ -200,11 +250,52 @@ def __iter__(self):
if self.next_yielded is not None:
self.yielded = self.next_yielded
self.next_yielded = None
it = super().__iter__()
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(self.dataset_size, generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(self.dataset_size)) # type: ignore[arg-type]

if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size

# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples

it = iter(indices)

for idx in itertools.islice(it, self.yielded, None):
self.yielded += 1
yield idx

def __len__(self) -> int:
return self.num_samples

def set_epoch(self, epoch: int) -> None:
r"""
Set the epoch for this sampler.

When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.

Args:
epoch (int): Epoch number.
"""
self.epoch = epoch

def state_dict(self) -> Dict[str, Any]:
return {self._YIELDED: self.yielded}

Expand Down
Loading