Skip to content

Commit

Permalink
Add distributed support for StreamingDataset (#18850)
Browse files Browse the repository at this point in the history
Co-authored-by: Luca Antiga <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
5 people authored Oct 25, 2023
1 parent fd0bc59 commit 8748258
Show file tree
Hide file tree
Showing 10 changed files with 450 additions and 37 deletions.
8 changes: 6 additions & 2 deletions src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def __init__(
if has_index_file and (chunk_size is None and chunk_bytes is None):
chunk_size = 2

# Add the version to the cache_dir to avoid collisions.
if remote_dir and os.path.basename(remote_dir).startswith("version_"):
cache_dir = os.path.join(cache_dir, os.path.basename(remote_dir))

if cache_dir:
os.makedirs(cache_dir, exist_ok=True)

Expand Down Expand Up @@ -116,8 +120,8 @@ def _merge_no_wait(self, node_rank: Optional[int] = None) -> None:
def __len__(self) -> int:
return self._reader.get_length()

def get_chunk_interval(self) -> List[Tuple[int, int]]:
return self._reader.get_chunk_interval()
def get_chunk_intervals(self) -> List[Tuple[int, int]]:
return self._reader.get_chunk_intervals()

def _get_chunk_index_from_index(self, index: int) -> int:
return self._reader._get_chunk_index_from_index(index)
53 changes: 28 additions & 25 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lightning.data.streaming import Cache
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle, TruncatedShuffle


class StreamingDataset(IterableDataset):
Expand All @@ -31,7 +32,7 @@ def __init__(
version: Optional[Union[int, Literal["latest"]]] = "latest",
cache_dir: Optional[str] = None,
item_loader: Optional[BaseItemLoader] = None,
shuffle: bool = True,
shuffle: Union[bool, Literal["truncated", "full"]] = "truncated",
seed: int = 42,
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
Expand All @@ -53,13 +54,21 @@ def __init__(
if not self.cache.filled:
raise ValueError(f"The provided dataset `{name}` isn't filled up.")

self.shuffle = shuffle
self.distributed_env = _DistributedEnv.detect()
self.worker_env: Optional[_WorkerEnv] = None

chunk_intervals = self.cache.get_chunk_interval()
self.L = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
if isinstance(shuffle, bool):
_shuffle = TruncatedShuffle(self.cache, seed) if shuffle else NoShuffle(self.cache, seed)

if isinstance(shuffle, str):
if shuffle == "truncated":
_shuffle = TruncatedShuffle(self.cache, seed)
elif shuffle == "full":
_shuffle = FullShuffle(self.cache, seed)
else:
raise ValueError(f"The provided shuffle doesn't exist. Found {shuffle}")

self.shuffle: Shuffle = _shuffle
self.worker_env: Optional[_WorkerEnv] = None
self.worker_chunks: List[int] = []
self.worker_intervals: List[List[int]] = []
self.current_indexes: List[int] = []
Expand All @@ -68,26 +77,16 @@ def __init__(
self.has_triggered_download = False
self.min_items_per_replica: Optional[int] = None
self.seed = seed
self.num_iter = 0
self.current_epoch = 0
self.random_state = None

def __len__(self) -> int:
return self.L
return self.shuffle.get_len(self.distributed_env, self.current_epoch)

def __iter__(self) -> "StreamingDataset":
self.random_state = np.random.RandomState(seed=self.seed + self.num_iter) # type: ignore
chunk_intervals = self.cache.get_chunk_interval()
indexes = range(len(chunk_intervals))
shuffled_indexes = self.random_state.permutation(indexes) if self.shuffle else list(indexes)
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]

chunks_per_replica: List[List[int]] = [[] for _ in range(self.distributed_env.world_size)]
intervals_per_replica: List[List[List[int]]] = [[] for _ in range(self.distributed_env.world_size)]
for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
replica_index = index % self.distributed_env.world_size
chunks_per_replica[replica_index].append(chunk_index)
intervals_per_replica[replica_index].append(chunk_interval)

chunks_per_replica, intervals_per_replica = self.shuffle.get_chunks_and_intervals_per_process(
self.distributed_env, self.current_epoch
)
current_chunks = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
current_intervals = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]

Expand All @@ -105,7 +104,7 @@ def __iter__(self) -> "StreamingDataset":

self.current_indexes = []
self.chunk_index = 0
self.num_iter += 1
self.index = 0

return self

Expand All @@ -115,16 +114,20 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
return self.cache[index]

def __next__(self) -> Any:
# Prevent to create more batch on a given process
if self.index >= len(self):
self.current_epoch += 1
raise StopIteration

# Lazily re-populate the interval to reduce memory usage.
if len(self.current_indexes) == 0:
if self.chunk_index == len(self.worker_intervals):
self.current_epoch += 1
raise StopIteration

interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(0, interval[1] - interval[0])
if self.shuffle:
current_indexes = self.random_state.permutation(current_indexes)
self.current_indexes = current_indexes.tolist()
current_indexes = np.arange(interval[0], interval[1])
self.current_indexes = self.shuffle(current_indexes)
self.chunk_index += 1

# Get the first index
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/dataset_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def __init__(
self.name = name
self.src_dir = str(src_dir)
self.num_workers = num_workers or (1 if fast_dev_run else (os.cpu_count() or 1) * 4)
self.num_downloaders = num_downloaders or (1 if fast_dev_run else 2)
self.num_downloaders = num_downloaders or 1
if chunk_size is not None and chunk_bytes is not None:
raise ValueError("Either one of the `chunk_size` or the `chunk_bytes` need to be provided.")
self.chunk_size = chunk_size
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/data/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def generate_intervals(self) -> List[Tuple[int, int]]:

return self._intervals

def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, _: int) -> torch.Tensor:
def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> torch.Tensor:
while not os.path.exists(chunk_filepath):
sleep(0.0001)

Expand All @@ -137,5 +137,5 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
assert self._dtype

buffer: bytes = self._buffers[chunk_index]
offset = self._dtype.itemsize * index
offset = self._dtype.itemsize * ((index - begin) if index >= begin else index + 1)
return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def get_length(self) -> int:

return len(self.config)

def get_chunk_interval(self) -> List[Tuple[int, int]]:
def get_chunk_intervals(self) -> List[Tuple[int, int]]:
"""Get the index interval of each chunk."""
if self._config is None and self._try_load_config() is None:
raise Exception("The reader index isn't defined.")
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/data/streaming/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ def __iter_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
yield from self.__iter_indices_per_workers__(worker_indices_batches)

def __iter_from_chunks_non_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
chunk_intervals = self._cache.get_chunk_interval()
chunk_intervals = self._cache.get_chunk_intervals()
shuffled_indexes = np.random.permutation(range(len(chunk_intervals)))
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
yield from self.__iter_from_shuffled_chunks(shuffled_indexes.tolist(), shuffled_chunk_intervals)

def __iter_from_chunks_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
chunk_intervals = self._cache.get_chunk_interval()
chunk_intervals = self._cache.get_chunk_intervals()
shuffled_indexes = np.random.permutation(range(len(chunk_intervals)))
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]

Expand Down
192 changes: 192 additions & 0 deletions src/lightning/data/streaming/shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, List

import numpy as np

from lightning.data.datasets.env import _DistributedEnv
from lightning.data.streaming import Cache


class Shuffle(ABC):
"""Shuffle describe how to distribute chunked datasets across processes and workers."""

def __init__(self, cache: Cache, seed: int):
self.cache = cache
self.seed = seed
self.random_state = None

@abstractmethod
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
pass

@abstractmethod
def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
pass

@abstractmethod
def __call__(self, array: np.ndarray) -> List[int]:
pass


class NoShuffle(Shuffle):
"""NoShuffle doesn't shuffle the items and ensure all the processes receive the same number of items."""

@lru_cache(maxsize=10)
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
_, intervals_per_process = self.get_chunks_and_intervals_per_process(distributed_env, current_epoch)
min_items_per_process = min(
[sum([(interval[-1] - interval[0]) for interval in intervals]) for intervals in intervals_per_process]
)
return min_items_per_process

@lru_cache(maxsize=10)
def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
chunk_intervals = self.cache.get_chunk_intervals()
indexes = list(range(len(chunk_intervals)))
shuffled_chunk_intervals = np.asarray(chunk_intervals)[indexes]

chunks_per_process: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
intervals_per_process: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
for index, (chunk_index, chunk_interval) in enumerate(zip(indexes, shuffled_chunk_intervals)):
replica_index = index % distributed_env.world_size
chunks_per_process[replica_index].append(chunk_index)
intervals_per_process[replica_index].append(chunk_interval)

return chunks_per_process, intervals_per_process

def __call__(self, array: np.ndarray) -> List[int]:
return array.tolist()


class TruncatedShuffle(Shuffle):
"""TruncatedShuffle shuffles the chunks and associates them to the ranks.
As the number of items in a chunk varies, it is possible for a rank to end up with more or less items.
To ensure the same fixed dataset length for all ranks, we compute the minimum number of items across all ranks.
For the ranks with more items than the minimum, the remaining items are dropped.
Note: This is the fastest sampling strategy but at the cost of losing items.
"""

@lru_cache(maxsize=10)
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
_, intervals_per_process = self.get_chunks_and_intervals_per_process(distributed_env, current_epoch)
min_items_per_process = min(
[sum([(interval[-1] - interval[0]) for interval in intervals]) for intervals in intervals_per_process]
)
return min_items_per_process

@lru_cache(maxsize=10)
def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
chunk_intervals = self.cache.get_chunk_intervals()
indexes = range(len(chunk_intervals))
shuffled_indexes = self.random_state.permutation(indexes)
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]

chunks_per_process: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
intervals_per_process: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
replica_index = index % distributed_env.world_size
chunks_per_process[replica_index].append(chunk_index)
intervals_per_process[replica_index].append(chunk_interval)

return chunks_per_process, intervals_per_process

def __call__(self, array: np.ndarray) -> List[int]:
assert self.random_state
return self.random_state.permutation(array).tolist()


class FullShuffle(Shuffle):
"""FullShuffle shuffles the chunks and associates them to the ranks.
As the number of items in a chunk varies, it is possible for a rank to end up with more or less items.
To ensure the same fixed dataset length for all ranks while dropping as few items as possible,
we adopt the following strategy.
We compute the maximum number of items per rank (M) and iterate through the chunks and ranks
until we have associated at least M items per rank.
As a result, we lose at most (number of ranks) items. However, as some chunks are shared across ranks. This leads to
the same chunk to be downloaded multiple times.
"""

@lru_cache(maxsize=10)
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
_, intervals_per_process = self.get_chunks_and_intervals_per_process(distributed_env, current_epoch)
min_items_per_process = min([sum([(i[-1] - i[0]) for i in intervals]) for intervals in intervals_per_process])
return min_items_per_process

@lru_cache(maxsize=10)
def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
chunk_intervals = self.cache.get_chunk_intervals()
indexes = range(len(chunk_intervals))
shuffled_indexes = self.random_state.permutation(indexes)
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]

num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
num_items_per_process: List[int] = [
num_items // distributed_env.world_size for _ in range(distributed_env.world_size)
]
chunks_per_process: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
intervals_per_process: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
for chunk_index, chunk_interval in zip(shuffled_indexes, shuffled_chunk_intervals):
process_index = 0

while True:
if process_index == len(num_items_per_process):
break

items_left_to_assign = num_items_per_process[process_index]

if items_left_to_assign == 0:
process_index += 1
continue

items_in_chunk = chunk_interval[-1] - chunk_interval[0]

if items_in_chunk == 0:
break

if items_in_chunk > items_left_to_assign:
chunks_per_process[process_index].append(chunk_index)
begin, end = chunk_interval
intervals_per_process[process_index].append([begin, begin + items_left_to_assign])
chunk_interval = (begin + items_left_to_assign + 1, end)
num_items_per_process[process_index] = 0
process_index += 1
else:
chunks_per_process[process_index].append(chunk_index)
intervals_per_process[process_index].append(chunk_interval)
num_items_per_process[process_index] -= items_in_chunk
break

return chunks_per_process, intervals_per_process

def __call__(self, array: np.ndarray) -> List[int]:
assert self.random_state
return self.random_state.permutation(array).tolist()
8 changes: 5 additions & 3 deletions tests/tests_data/streaming/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,13 @@ def test_cache_with_name(tmpdir, monkeypatch):
os.makedirs(os.path.join(tmpdir, "remote_dir"), exist_ok=True)
monkeypatch.setattr(cache_module, "_try_create_cache_dir", lambda name: os.path.join(tmpdir, name))

monkeypatch.setattr(cache_module, "_find_remote_dir", lambda name, _: (os.path.join(tmpdir, "remote_dir"), True))
monkeypatch.setattr(
cache_module, "_find_remote_dir", lambda name, _: (os.path.join(tmpdir, "remote_dir", "version_0"), True)
)
cache = Cache(name="something")
assert cache._writer._chunk_size == 2
assert cache._writer._cache_dir == os.path.join(tmpdir, "something")
assert cache._reader._remote_dir == os.path.join(tmpdir, "remote_dir")
assert cache._writer._cache_dir == os.path.join(tmpdir, "something", "version_0")
assert cache._reader._remote_dir == os.path.join(tmpdir, "remote_dir", "version_0")


def test_streaming_dataset(tmpdir, monkeypatch):
Expand Down
Loading

0 comments on commit 8748258

Please sign in to comment.