From 874825857ffc09923407ada36814e11adb66c352 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 25 Oct 2023 03:18:20 +0100 Subject: [PATCH] Add distributed support for StreamingDataset (#18850) Co-authored-by: Luca Antiga Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning/data/streaming/cache.py | 8 +- src/lightning/data/streaming/dataset.py | 53 ++--- .../data/streaming/dataset_optimizer.py | 2 +- src/lightning/data/streaming/item_loader.py | 4 +- src/lightning/data/streaming/reader.py | 2 +- src/lightning/data/streaming/sampler.py | 4 +- src/lightning/data/streaming/shuffle.py | 192 ++++++++++++++++ tests/tests_data/streaming/test_cache.py | 8 +- tests/tests_data/streaming/test_dataset.py | 212 ++++++++++++++++++ tests/tests_data/streaming/test_sampler.py | 2 +- 10 files changed, 450 insertions(+), 37 deletions(-) create mode 100644 src/lightning/data/streaming/shuffle.py create mode 100644 tests/tests_data/streaming/test_dataset.py diff --git a/src/lightning/data/streaming/cache.py b/src/lightning/data/streaming/cache.py index 4c43b1a18740a..8945305bfcd15 100644 --- a/src/lightning/data/streaming/cache.py +++ b/src/lightning/data/streaming/cache.py @@ -71,6 +71,10 @@ def __init__( if has_index_file and (chunk_size is None and chunk_bytes is None): chunk_size = 2 + # Add the version to the cache_dir to avoid collisions. + if remote_dir and os.path.basename(remote_dir).startswith("version_"): + cache_dir = os.path.join(cache_dir, os.path.basename(remote_dir)) + if cache_dir: os.makedirs(cache_dir, exist_ok=True) @@ -116,8 +120,8 @@ def _merge_no_wait(self, node_rank: Optional[int] = None) -> None: def __len__(self) -> int: return self._reader.get_length() - def get_chunk_interval(self) -> List[Tuple[int, int]]: - return self._reader.get_chunk_interval() + def get_chunk_intervals(self) -> List[Tuple[int, int]]: + return self._reader.get_chunk_intervals() def _get_chunk_index_from_index(self, index: int) -> int: return self._reader._get_chunk_index_from_index(index) diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index bb30be4f99075..7aafc33cb54da 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -20,6 +20,7 @@ from lightning.data.streaming import Cache from lightning.data.streaming.item_loader import BaseItemLoader from lightning.data.streaming.sampler import ChunkedIndex +from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle, TruncatedShuffle class StreamingDataset(IterableDataset): @@ -31,7 +32,7 @@ def __init__( version: Optional[Union[int, Literal["latest"]]] = "latest", cache_dir: Optional[str] = None, item_loader: Optional[BaseItemLoader] = None, - shuffle: bool = True, + shuffle: Union[bool, Literal["truncated", "full"]] = "truncated", seed: int = 42, ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. @@ -53,13 +54,21 @@ def __init__( if not self.cache.filled: raise ValueError(f"The provided dataset `{name}` isn't filled up.") - self.shuffle = shuffle self.distributed_env = _DistributedEnv.detect() - self.worker_env: Optional[_WorkerEnv] = None - chunk_intervals = self.cache.get_chunk_interval() - self.L = sum([(interval[-1] - interval[0]) for interval in chunk_intervals]) + if isinstance(shuffle, bool): + _shuffle = TruncatedShuffle(self.cache, seed) if shuffle else NoShuffle(self.cache, seed) + + if isinstance(shuffle, str): + if shuffle == "truncated": + _shuffle = TruncatedShuffle(self.cache, seed) + elif shuffle == "full": + _shuffle = FullShuffle(self.cache, seed) + else: + raise ValueError(f"The provided shuffle doesn't exist. Found {shuffle}") + self.shuffle: Shuffle = _shuffle + self.worker_env: Optional[_WorkerEnv] = None self.worker_chunks: List[int] = [] self.worker_intervals: List[List[int]] = [] self.current_indexes: List[int] = [] @@ -68,26 +77,16 @@ def __init__( self.has_triggered_download = False self.min_items_per_replica: Optional[int] = None self.seed = seed - self.num_iter = 0 + self.current_epoch = 0 self.random_state = None def __len__(self) -> int: - return self.L + return self.shuffle.get_len(self.distributed_env, self.current_epoch) def __iter__(self) -> "StreamingDataset": - self.random_state = np.random.RandomState(seed=self.seed + self.num_iter) # type: ignore - chunk_intervals = self.cache.get_chunk_interval() - indexes = range(len(chunk_intervals)) - shuffled_indexes = self.random_state.permutation(indexes) if self.shuffle else list(indexes) - shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] - - chunks_per_replica: List[List[int]] = [[] for _ in range(self.distributed_env.world_size)] - intervals_per_replica: List[List[List[int]]] = [[] for _ in range(self.distributed_env.world_size)] - for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)): - replica_index = index % self.distributed_env.world_size - chunks_per_replica[replica_index].append(chunk_index) - intervals_per_replica[replica_index].append(chunk_interval) - + chunks_per_replica, intervals_per_replica = self.shuffle.get_chunks_and_intervals_per_process( + self.distributed_env, self.current_epoch + ) current_chunks = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size] current_intervals = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size] @@ -105,7 +104,7 @@ def __iter__(self) -> "StreamingDataset": self.current_indexes = [] self.chunk_index = 0 - self.num_iter += 1 + self.index = 0 return self @@ -115,16 +114,20 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: return self.cache[index] def __next__(self) -> Any: + # Prevent to create more batch on a given process + if self.index >= len(self): + self.current_epoch += 1 + raise StopIteration + # Lazily re-populate the interval to reduce memory usage. if len(self.current_indexes) == 0: if self.chunk_index == len(self.worker_intervals): + self.current_epoch += 1 raise StopIteration interval = self.worker_intervals[self.chunk_index] - current_indexes = np.arange(0, interval[1] - interval[0]) - if self.shuffle: - current_indexes = self.random_state.permutation(current_indexes) - self.current_indexes = current_indexes.tolist() + current_indexes = np.arange(interval[0], interval[1]) + self.current_indexes = self.shuffle(current_indexes) self.chunk_index += 1 # Get the first index diff --git a/src/lightning/data/streaming/dataset_optimizer.py b/src/lightning/data/streaming/dataset_optimizer.py index 083e941eba7c5..eb10bfef780b3 100644 --- a/src/lightning/data/streaming/dataset_optimizer.py +++ b/src/lightning/data/streaming/dataset_optimizer.py @@ -485,7 +485,7 @@ def __init__( self.name = name self.src_dir = str(src_dir) self.num_workers = num_workers or (1 if fast_dev_run else (os.cpu_count() or 1) * 4) - self.num_downloaders = num_downloaders or (1 if fast_dev_run else 2) + self.num_downloaders = num_downloaders or 1 if chunk_size is not None and chunk_bytes is not None: raise ValueError("Either one of the `chunk_size` or the `chunk_bytes` need to be provided.") self.chunk_size = chunk_size diff --git a/src/lightning/data/streaming/item_loader.py b/src/lightning/data/streaming/item_loader.py index d22db33a6c3a3..da35a4c6e9283 100644 --- a/src/lightning/data/streaming/item_loader.py +++ b/src/lightning/data/streaming/item_loader.py @@ -122,7 +122,7 @@ def generate_intervals(self) -> List[Tuple[int, int]]: return self._intervals - def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, _: int) -> torch.Tensor: + def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> torch.Tensor: while not os.path.exists(chunk_filepath): sleep(0.0001) @@ -137,5 +137,5 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str assert self._dtype buffer: bytes = self._buffers[chunk_index] - offset = self._dtype.itemsize * index + offset = self._dtype.itemsize * ((index - begin) if index >= begin else index + 1) return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) diff --git a/src/lightning/data/streaming/reader.py b/src/lightning/data/streaming/reader.py index 4bb05cf3cf659..b95f88d931d07 100644 --- a/src/lightning/data/streaming/reader.py +++ b/src/lightning/data/streaming/reader.py @@ -166,7 +166,7 @@ def get_length(self) -> int: return len(self.config) - def get_chunk_interval(self) -> List[Tuple[int, int]]: + def get_chunk_intervals(self) -> List[Tuple[int, int]]: """Get the index interval of each chunk.""" if self._config is None and self._try_load_config() is None: raise Exception("The reader index isn't defined.") diff --git a/src/lightning/data/streaming/sampler.py b/src/lightning/data/streaming/sampler.py index fe88a2cf2c316..cec70466fcada 100644 --- a/src/lightning/data/streaming/sampler.py +++ b/src/lightning/data/streaming/sampler.py @@ -146,13 +146,13 @@ def __iter_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]: yield from self.__iter_indices_per_workers__(worker_indices_batches) def __iter_from_chunks_non_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]: - chunk_intervals = self._cache.get_chunk_interval() + chunk_intervals = self._cache.get_chunk_intervals() shuffled_indexes = np.random.permutation(range(len(chunk_intervals))) shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] yield from self.__iter_from_shuffled_chunks(shuffled_indexes.tolist(), shuffled_chunk_intervals) def __iter_from_chunks_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]: - chunk_intervals = self._cache.get_chunk_interval() + chunk_intervals = self._cache.get_chunk_intervals() shuffled_indexes = np.random.permutation(range(len(chunk_intervals))) shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] diff --git a/src/lightning/data/streaming/shuffle.py b/src/lightning/data/streaming/shuffle.py new file mode 100644 index 0000000000000..3225bc4197b71 --- /dev/null +++ b/src/lightning/data/streaming/shuffle.py @@ -0,0 +1,192 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Any, List + +import numpy as np + +from lightning.data.datasets.env import _DistributedEnv +from lightning.data.streaming import Cache + + +class Shuffle(ABC): + """Shuffle describe how to distribute chunked datasets across processes and workers.""" + + def __init__(self, cache: Cache, seed: int): + self.cache = cache + self.seed = seed + self.random_state = None + + @abstractmethod + def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int: + pass + + @abstractmethod + def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any: + pass + + @abstractmethod + def __call__(self, array: np.ndarray) -> List[int]: + pass + + +class NoShuffle(Shuffle): + """NoShuffle doesn't shuffle the items and ensure all the processes receive the same number of items.""" + + @lru_cache(maxsize=10) + def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int: + _, intervals_per_process = self.get_chunks_and_intervals_per_process(distributed_env, current_epoch) + min_items_per_process = min( + [sum([(interval[-1] - interval[0]) for interval in intervals]) for intervals in intervals_per_process] + ) + return min_items_per_process + + @lru_cache(maxsize=10) + def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any: + self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore + chunk_intervals = self.cache.get_chunk_intervals() + indexes = list(range(len(chunk_intervals))) + shuffled_chunk_intervals = np.asarray(chunk_intervals)[indexes] + + chunks_per_process: List[List[int]] = [[] for _ in range(distributed_env.world_size)] + intervals_per_process: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)] + for index, (chunk_index, chunk_interval) in enumerate(zip(indexes, shuffled_chunk_intervals)): + replica_index = index % distributed_env.world_size + chunks_per_process[replica_index].append(chunk_index) + intervals_per_process[replica_index].append(chunk_interval) + + return chunks_per_process, intervals_per_process + + def __call__(self, array: np.ndarray) -> List[int]: + return array.tolist() + + +class TruncatedShuffle(Shuffle): + """TruncatedShuffle shuffles the chunks and associates them to the ranks. + + As the number of items in a chunk varies, it is possible for a rank to end up with more or less items. + + To ensure the same fixed dataset length for all ranks, we compute the minimum number of items across all ranks. + + For the ranks with more items than the minimum, the remaining items are dropped. + + Note: This is the fastest sampling strategy but at the cost of losing items. + + """ + + @lru_cache(maxsize=10) + def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int: + _, intervals_per_process = self.get_chunks_and_intervals_per_process(distributed_env, current_epoch) + min_items_per_process = min( + [sum([(interval[-1] - interval[0]) for interval in intervals]) for intervals in intervals_per_process] + ) + return min_items_per_process + + @lru_cache(maxsize=10) + def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any: + self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore + chunk_intervals = self.cache.get_chunk_intervals() + indexes = range(len(chunk_intervals)) + shuffled_indexes = self.random_state.permutation(indexes) + shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] + + chunks_per_process: List[List[int]] = [[] for _ in range(distributed_env.world_size)] + intervals_per_process: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)] + for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)): + replica_index = index % distributed_env.world_size + chunks_per_process[replica_index].append(chunk_index) + intervals_per_process[replica_index].append(chunk_interval) + + return chunks_per_process, intervals_per_process + + def __call__(self, array: np.ndarray) -> List[int]: + assert self.random_state + return self.random_state.permutation(array).tolist() + + +class FullShuffle(Shuffle): + """FullShuffle shuffles the chunks and associates them to the ranks. + + As the number of items in a chunk varies, it is possible for a rank to end up with more or less items. + + To ensure the same fixed dataset length for all ranks while dropping as few items as possible, + + we adopt the following strategy. + + We compute the maximum number of items per rank (M) and iterate through the chunks and ranks + + until we have associated at least M items per rank. + + As a result, we lose at most (number of ranks) items. However, as some chunks are shared across ranks. This leads to + the same chunk to be downloaded multiple times. + + """ + + @lru_cache(maxsize=10) + def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int: + _, intervals_per_process = self.get_chunks_and_intervals_per_process(distributed_env, current_epoch) + min_items_per_process = min([sum([(i[-1] - i[0]) for i in intervals]) for intervals in intervals_per_process]) + return min_items_per_process + + @lru_cache(maxsize=10) + def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any: + self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore + chunk_intervals = self.cache.get_chunk_intervals() + indexes = range(len(chunk_intervals)) + shuffled_indexes = self.random_state.permutation(indexes) + shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] + + num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals]) + num_items_per_process: List[int] = [ + num_items // distributed_env.world_size for _ in range(distributed_env.world_size) + ] + chunks_per_process: List[List[int]] = [[] for _ in range(distributed_env.world_size)] + intervals_per_process: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)] + for chunk_index, chunk_interval in zip(shuffled_indexes, shuffled_chunk_intervals): + process_index = 0 + + while True: + if process_index == len(num_items_per_process): + break + + items_left_to_assign = num_items_per_process[process_index] + + if items_left_to_assign == 0: + process_index += 1 + continue + + items_in_chunk = chunk_interval[-1] - chunk_interval[0] + + if items_in_chunk == 0: + break + + if items_in_chunk > items_left_to_assign: + chunks_per_process[process_index].append(chunk_index) + begin, end = chunk_interval + intervals_per_process[process_index].append([begin, begin + items_left_to_assign]) + chunk_interval = (begin + items_left_to_assign + 1, end) + num_items_per_process[process_index] = 0 + process_index += 1 + else: + chunks_per_process[process_index].append(chunk_index) + intervals_per_process[process_index].append(chunk_interval) + num_items_per_process[process_index] -= items_in_chunk + break + + return chunks_per_process, intervals_per_process + + def __call__(self, array: np.ndarray) -> List[int]: + assert self.random_state + return self.random_state.permutation(array).tolist() diff --git a/tests/tests_data/streaming/test_cache.py b/tests/tests_data/streaming/test_cache.py index 5c9bb1c6fee8a..9933f0f7211a6 100644 --- a/tests/tests_data/streaming/test_cache.py +++ b/tests/tests_data/streaming/test_cache.py @@ -227,11 +227,13 @@ def test_cache_with_name(tmpdir, monkeypatch): os.makedirs(os.path.join(tmpdir, "remote_dir"), exist_ok=True) monkeypatch.setattr(cache_module, "_try_create_cache_dir", lambda name: os.path.join(tmpdir, name)) - monkeypatch.setattr(cache_module, "_find_remote_dir", lambda name, _: (os.path.join(tmpdir, "remote_dir"), True)) + monkeypatch.setattr( + cache_module, "_find_remote_dir", lambda name, _: (os.path.join(tmpdir, "remote_dir", "version_0"), True) + ) cache = Cache(name="something") assert cache._writer._chunk_size == 2 - assert cache._writer._cache_dir == os.path.join(tmpdir, "something") - assert cache._reader._remote_dir == os.path.join(tmpdir, "remote_dir") + assert cache._writer._cache_dir == os.path.join(tmpdir, "something", "version_0") + assert cache._reader._remote_dir == os.path.join(tmpdir, "remote_dir", "version_0") def test_streaming_dataset(tmpdir, monkeypatch): diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py new file mode 100644 index 0000000000000..5303b723c3969 --- /dev/null +++ b/tests/tests_data/streaming/test_dataset.py @@ -0,0 +1,212 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch +from lightning import seed_everything +from lightning.data.datasets.env import _DistributedEnv +from lightning.data.streaming import Cache +from lightning.data.streaming import cache as cache_module +from lightning.data.streaming.dataloader import StreamingDataLoader +from lightning.data.streaming.dataset import StreamingDataset +from lightning.data.streaming.item_loader import TokensLoader +from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, TruncatedShuffle +from lightning.pytorch.demos.boring_classes import RandomDataset +from torch.utils.data import DataLoader + + +def test_streaming_dataset(tmpdir, monkeypatch): + seed_everything(42) + + os.makedirs(os.path.join(tmpdir, "remote_dir"), exist_ok=True) + monkeypatch.setattr(cache_module, "_try_create_cache_dir", lambda name: tmpdir) + + with pytest.raises(ValueError, match="The provided dataset `choco` isn't filled up."): + dataset = StreamingDataset(name="choco", cache_dir=tmpdir) + + dataset = RandomDataset(128, 64) + dataloader = StreamingDataLoader(dataset, cache_dir=tmpdir, chunk_bytes=2 << 12) + for batch in dataloader: + assert isinstance(batch, torch.Tensor) + + dataset = StreamingDataset(name="choco", cache_dir=tmpdir, item_loader=TokensLoader(block_size=10)) + + assert len(dataset) == 816 + dataset_iter = iter(dataset) + assert len(dataset_iter) == 816 + + dataloader = DataLoader(dataset, num_workers=2, batch_size=2) + assert len(dataloader) == 408 + + +def test_streaming_dataset_distributed_min_shuffle(tmpdir): + seed_everything(42) + + cache = Cache(tmpdir, chunk_size=10) + for i in range(101): + cache[i] = i + + cache.done() + cache.merge() + + dataset = StreamingDataset(name="choco", cache_dir=tmpdir, shuffle=True) + + assert isinstance(dataset.shuffle, TruncatedShuffle) + + for i in range(101): + assert dataset[i] == i + + dataset.distributed_env = _DistributedEnv(2, 0) + assert len(dataset) == 41 + dataset_iter = iter(dataset) + assert len(dataset_iter) == 41 + process_1_1 = list(dataset_iter) + assert process_1_1[:10] == [50, 56, 59, 51, 58, 55, 52, 53, 54, 57] + assert len(process_1_1) == 41 + dataset_iter = iter(dataset) + assert len(dataset_iter) == 50 + process_1_2 = list(dataset_iter) + assert process_1_2[:10] == [100, 68, 66, 64, 61, 65, 69, 62, 63, 60] + assert len(process_1_2) == 50 + + dataset = StreamingDataset(name="choco", cache_dir=tmpdir) + dataset.distributed_env = _DistributedEnv(2, 1) + assert len(dataset) == 41 + dataset_iter = iter(dataset) + process_2_1 = list(dataset_iter) + assert process_2_1[:10] == [0, 6, 9, 1, 8, 5, 2, 3, 4, 7] + assert len(process_2_1) == 41 + dataset_iter = iter(dataset) + assert len(dataset_iter) == 50 + process_2_2 = list(dataset_iter) + assert process_2_2[:10] == [78, 76, 74, 71, 75, 79, 72, 73, 70, 77] + assert len(process_2_2) == 50 + + assert len([i for i in process_1_1 if i in process_2_1]) == 0 + assert len([i for i in process_1_2 if i in process_2_2]) == 0 + + +def test_streaming_dataset_distributed_no_shuffle(tmpdir): + seed_everything(42) + + cache = Cache(tmpdir, chunk_size=10) + for i in range(101): + cache[i] = i + + cache.done() + cache.merge() + + dataset = StreamingDataset(name="choco", cache_dir=tmpdir, shuffle=False) + + assert isinstance(dataset.shuffle, NoShuffle) + + for i in range(101): + assert dataset[i] == i + + dataset.distributed_env = _DistributedEnv(2, 0) + assert len(dataset) == 50 + dataset_iter = iter(dataset) + assert len(dataset_iter) == 50 + process_1_1 = list(dataset_iter) + assert len(process_1_1) == 50 + assert process_1_1[:10] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + dataset_iter = iter(dataset) + assert len(dataset_iter) == 50 + process_1_2 = list(dataset_iter) + assert process_1_2[:10] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + assert len(process_1_2) == 50 + + dataset = StreamingDataset(name="choco", cache_dir=tmpdir, shuffle=False) + dataset.distributed_env = _DistributedEnv(2, 1) + assert len(dataset) == 50 + dataset_iter = iter(dataset) + process_2_1 = list(dataset_iter) + assert process_2_1[:10] == [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + assert len(process_2_1) == 50 + dataset_iter = iter(dataset) + assert len(dataset_iter) == 50 + process_2_2 = list(dataset_iter) + assert process_2_2[:10] == [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + + assert len(process_2_2) == 50 + + _, intervals_per_process = dataset.shuffle.get_chunks_and_intervals_per_process( + dataset.distributed_env, dataset.current_epoch + ) + + assert process_1_1 == process_1_2 + + found_list = [] + for i in process_1_1: + found = False + for interval in intervals_per_process[0]: + if interval[0] <= i <= interval[1]: + found = True + break + found_list.append(found) + + assert all(found_list) is True + + found_list = [] + for i in process_2_1: + found = False + for interval in intervals_per_process[1]: + if interval[0] <= i <= interval[1]: + found = True + break + found_list.append(found) + + assert all(found_list) is True + + assert len([i for i in process_1_1 if i in process_2_1]) == 0 + assert len([i for i in process_1_2 if i in process_2_2]) == 0 + + +def test_streaming_dataset_distributed_full_shuffle(tmpdir): + seed_everything(42) + + cache = Cache(tmpdir, chunk_size=10) + for i in range(1097): + cache[i] = i + + cache.done() + cache.merge() + + dataset = StreamingDataset(name="choco", cache_dir=tmpdir, shuffle="full") + + assert isinstance(dataset.shuffle, FullShuffle) + + for i in range(1097): + assert dataset[i] == i + + dataset.distributed_env = _DistributedEnv(2, 0) + assert len(dataset) == 548 + dataset_iter = iter(dataset) + assert len(dataset_iter) == 548 + process_1_1 = list(dataset_iter) + assert process_1_1[:10] == [785, 788, 782, 783, 789, 787, 786, 781, 784, 780] + assert len(process_1_1) == 548 + + dataset_2 = StreamingDataset(name="choco", cache_dir=tmpdir, shuffle="full") + assert isinstance(dataset_2.shuffle, FullShuffle) + dataset_2.distributed_env = _DistributedEnv(2, 1) + assert len(dataset_2) == 548 + dataset_2_iter = iter(dataset_2) + assert len(dataset_2_iter) == 548 + process_2_1 = list(dataset_2_iter) + assert process_2_1[:10] == [939, 255, 258, 252, 253, 259, 257, 256, 251, 254] + assert len(process_2_1) == 548 + + assert len([i for i in process_1_1 if i in process_2_1]) == 0 diff --git a/tests/tests_data/streaming/test_sampler.py b/tests/tests_data/streaming/test_sampler.py index ee2f0e3968337..d379b3591896e 100644 --- a/tests/tests_data/streaming/test_sampler.py +++ b/tests/tests_data/streaming/test_sampler.py @@ -65,7 +65,7 @@ def test_cache_batch_sampler(params): chunks_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)] cache.filled = True - cache.get_chunk_interval.return_value = chunks_interval + cache.get_chunk_intervals.return_value = chunks_interval seed_everything(42)