Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/18394_batch_size_finder_max_val_bat…
Browse files Browse the repository at this point in the history
…ches
  • Loading branch information
BoringDonut authored Oct 25, 2023
2 parents a86878b + 8748258 commit 75c218b
Show file tree
Hide file tree
Showing 14 changed files with 483 additions and 39 deletions.
2 changes: 1 addition & 1 deletion _notebooks
17 changes: 17 additions & 0 deletions src/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Root package info."""
import logging
import sys

# explicitly don't set root logger's propagation and leave this to subpackages to manage
_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -28,3 +29,19 @@
"seed_everything",
"Fabric",
]


def _cli_entry_point() -> None:
from lightning_utilities.core.imports import ModuleAvailableCache, RequirementCache

if not (
ModuleAvailableCache("lightning.app")
if RequirementCache("lightning-utilities<0.10.0")
else RequirementCache(module="lightning.app") # type: ignore[call-arg]
):
print("The `lightning` command requires additional dependencies: `pip install lightning[app]`")
sys.exit(1)

from lightning.app.cli.lightning_cli import main

main()
2 changes: 1 addition & 1 deletion src/lightning/__setup__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _setup_args() -> Dict[str, Any]:
"python_requires": ">=3.8", # todo: take the lowes based on all packages
"entry_points": {
"console_scripts": [
"lightning = lightning.app.cli.lightning_cli:main",
"lightning = lightning:_cli_entry_point",
],
},
"setup_requires": [],
Expand Down
8 changes: 6 additions & 2 deletions src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def __init__(
if has_index_file and (chunk_size is None and chunk_bytes is None):
chunk_size = 2

# Add the version to the cache_dir to avoid collisions.
if remote_dir and os.path.basename(remote_dir).startswith("version_"):
cache_dir = os.path.join(cache_dir, os.path.basename(remote_dir))

if cache_dir:
os.makedirs(cache_dir, exist_ok=True)

Expand Down Expand Up @@ -116,8 +120,8 @@ def _merge_no_wait(self, node_rank: Optional[int] = None) -> None:
def __len__(self) -> int:
return self._reader.get_length()

def get_chunk_interval(self) -> List[Tuple[int, int]]:
return self._reader.get_chunk_interval()
def get_chunk_intervals(self) -> List[Tuple[int, int]]:
return self._reader.get_chunk_intervals()

def _get_chunk_index_from_index(self, index: int) -> int:
return self._reader._get_chunk_index_from_index(index)
53 changes: 28 additions & 25 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lightning.data.streaming import Cache
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle, TruncatedShuffle


class StreamingDataset(IterableDataset):
Expand All @@ -31,7 +32,7 @@ def __init__(
version: Optional[Union[int, Literal["latest"]]] = "latest",
cache_dir: Optional[str] = None,
item_loader: Optional[BaseItemLoader] = None,
shuffle: bool = True,
shuffle: Union[bool, Literal["truncated", "full"]] = "truncated",
seed: int = 42,
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
Expand All @@ -53,13 +54,21 @@ def __init__(
if not self.cache.filled:
raise ValueError(f"The provided dataset `{name}` isn't filled up.")

self.shuffle = shuffle
self.distributed_env = _DistributedEnv.detect()
self.worker_env: Optional[_WorkerEnv] = None

chunk_intervals = self.cache.get_chunk_interval()
self.L = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
if isinstance(shuffle, bool):
_shuffle = TruncatedShuffle(self.cache, seed) if shuffle else NoShuffle(self.cache, seed)

if isinstance(shuffle, str):
if shuffle == "truncated":
_shuffle = TruncatedShuffle(self.cache, seed)
elif shuffle == "full":
_shuffle = FullShuffle(self.cache, seed)
else:
raise ValueError(f"The provided shuffle doesn't exist. Found {shuffle}")

self.shuffle: Shuffle = _shuffle
self.worker_env: Optional[_WorkerEnv] = None
self.worker_chunks: List[int] = []
self.worker_intervals: List[List[int]] = []
self.current_indexes: List[int] = []
Expand All @@ -68,26 +77,16 @@ def __init__(
self.has_triggered_download = False
self.min_items_per_replica: Optional[int] = None
self.seed = seed
self.num_iter = 0
self.current_epoch = 0
self.random_state = None

def __len__(self) -> int:
return self.L
return self.shuffle.get_len(self.distributed_env, self.current_epoch)

def __iter__(self) -> "StreamingDataset":
self.random_state = np.random.RandomState(seed=self.seed + self.num_iter) # type: ignore
chunk_intervals = self.cache.get_chunk_interval()
indexes = range(len(chunk_intervals))
shuffled_indexes = self.random_state.permutation(indexes) if self.shuffle else list(indexes)
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]

chunks_per_replica: List[List[int]] = [[] for _ in range(self.distributed_env.world_size)]
intervals_per_replica: List[List[List[int]]] = [[] for _ in range(self.distributed_env.world_size)]
for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
replica_index = index % self.distributed_env.world_size
chunks_per_replica[replica_index].append(chunk_index)
intervals_per_replica[replica_index].append(chunk_interval)

chunks_per_replica, intervals_per_replica = self.shuffle.get_chunks_and_intervals_per_process(
self.distributed_env, self.current_epoch
)
current_chunks = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
current_intervals = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]

Expand All @@ -105,7 +104,7 @@ def __iter__(self) -> "StreamingDataset":

self.current_indexes = []
self.chunk_index = 0
self.num_iter += 1
self.index = 0

return self

Expand All @@ -115,16 +114,20 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
return self.cache[index]

def __next__(self) -> Any:
# Prevent to create more batch on a given process
if self.index >= len(self):
self.current_epoch += 1
raise StopIteration

# Lazily re-populate the interval to reduce memory usage.
if len(self.current_indexes) == 0:
if self.chunk_index == len(self.worker_intervals):
self.current_epoch += 1
raise StopIteration

interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(0, interval[1] - interval[0])
if self.shuffle:
current_indexes = self.random_state.permutation(current_indexes)
self.current_indexes = current_indexes.tolist()
current_indexes = np.arange(interval[0], interval[1])
self.current_indexes = self.shuffle(current_indexes)
self.chunk_index += 1

# Get the first index
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/dataset_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def __init__(
self.name = name
self.src_dir = str(src_dir)
self.num_workers = num_workers or (1 if fast_dev_run else (os.cpu_count() or 1) * 4)
self.num_downloaders = num_downloaders or (1 if fast_dev_run else 2)
self.num_downloaders = num_downloaders or 1
if chunk_size is not None and chunk_bytes is not None:
raise ValueError("Either one of the `chunk_size` or the `chunk_bytes` need to be provided.")
self.chunk_size = chunk_size
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/data/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def generate_intervals(self) -> List[Tuple[int, int]]:

return self._intervals

def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, _: int) -> torch.Tensor:
def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> torch.Tensor:
while not os.path.exists(chunk_filepath):
sleep(0.0001)

Expand All @@ -137,5 +137,5 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
assert self._dtype

buffer: bytes = self._buffers[chunk_index]
offset = self._dtype.itemsize * index
offset = self._dtype.itemsize * ((index - begin) if index >= begin else index + 1)
return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def get_length(self) -> int:

return len(self.config)

def get_chunk_interval(self) -> List[Tuple[int, int]]:
def get_chunk_intervals(self) -> List[Tuple[int, int]]:
"""Get the index interval of each chunk."""
if self._config is None and self._try_load_config() is None:
raise Exception("The reader index isn't defined.")
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/data/streaming/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ def __iter_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
yield from self.__iter_indices_per_workers__(worker_indices_batches)

def __iter_from_chunks_non_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
chunk_intervals = self._cache.get_chunk_interval()
chunk_intervals = self._cache.get_chunk_intervals()
shuffled_indexes = np.random.permutation(range(len(chunk_intervals)))
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
yield from self.__iter_from_shuffled_chunks(shuffled_indexes.tolist(), shuffled_chunk_intervals)

def __iter_from_chunks_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
chunk_intervals = self._cache.get_chunk_interval()
chunk_intervals = self._cache.get_chunk_intervals()
shuffled_indexes = np.random.permutation(range(len(chunk_intervals)))
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]

Expand Down
Loading

0 comments on commit 75c218b

Please sign in to comment.