Skip to content

Commit

Permalink
Lazy data files resolution and offline cache reload (#6493)
Browse files Browse the repository at this point in the history
* lazy data files resolution

* fix tests

* minor

* don't use expand_info=False yet

* fix

* retrieve cached datasets that were pushed to hub

* minor

* style

* tests

* fix win test

* fix tests

* fix tests again

* remove unused code

* allow load from cache in streaming mode

* remove comment

* more tests

* fix tests

* fix more tests

* fix tests

* fix tests

* fix cache on config change

* simpler

* fix tests

* make both PRs compatible

* style

* fix tests

* fix tests

* fix tests

* fix test

* update hash when loading from parquet export too

* fix modify files

* fix base_path

* just use the commit sha as hash

* use commit sha in parquet export dataset cache directories too

* use version from parquet export dataset info

* fix cache reload when config name and version are not the default ones

* fix tests
  • Loading branch information
lhoestq authored Dec 21, 2023
1 parent cf71653 commit ef3b5dd
Show file tree
Hide file tree
Showing 12 changed files with 632 additions and 182 deletions.
25 changes: 18 additions & 7 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@
ReadInstruction,
)
from .arrow_writer import ArrowWriter, BeamWriter, ParquetWriter, SchemaInferenceError
from .data_files import DataFilesDict, sanitize_patterns
from .data_files import DataFilesDict, DataFilesPatternsDict, sanitize_patterns
from .dataset_dict import DatasetDict, IterableDatasetDict
from .download.download_config import DownloadConfig
from .download.download_manager import DownloadManager, DownloadMode
from .download.mock_download_manager import MockDownloadManager
from .download.streaming_download_manager import StreamingDownloadManager, xopen
from .download.streaming_download_manager import StreamingDownloadManager, xjoin, xopen
from .exceptions import DatasetGenerationCastError, DatasetGenerationError, FileFormatError, ManualDownloadError
from .features import Features
from .filesystems import (
Expand Down Expand Up @@ -115,7 +115,7 @@ class BuilderConfig:
name: str = "default"
version: Optional[Union[utils.Version, str]] = utils.Version("0.0.0")
data_dir: Optional[str] = None
data_files: Optional[DataFilesDict] = None
data_files: Optional[Union[DataFilesDict, DataFilesPatternsDict]] = None
description: Optional[str] = None

def __post_init__(self):
Expand All @@ -126,7 +126,7 @@ def __post_init__(self):
f"Bad characters from black list '{INVALID_WINDOWS_CHARACTERS_IN_PATH}' found in '{self.name}'. "
f"They could create issues when creating a directory for this config on Windows filesystem."
)
if self.data_files is not None and not isinstance(self.data_files, DataFilesDict):
if self.data_files is not None and not isinstance(self.data_files, (DataFilesDict, DataFilesPatternsDict)):
raise ValueError(f"Expected a DataFilesDict in data_files but got {self.data_files}")

def __eq__(self, o):
Expand Down Expand Up @@ -200,6 +200,11 @@ def create_config_id(
else:
return self.name

def _resolve_data_files(self, base_path: str, download_config: DownloadConfig) -> None:
if isinstance(self.data_files, DataFilesPatternsDict):
base_path = xjoin(base_path, self.data_dir) if self.data_dir else base_path
self.data_files = self.data_files.resolve(base_path, download_config)


class DatasetBuilder:
"""Abstract base class for all datasets.
Expand Down Expand Up @@ -504,7 +509,7 @@ def _create_builder_config(
builder_config = None

# try default config
if config_name is None and self.BUILDER_CONFIGS and not config_kwargs:
if config_name is None and self.BUILDER_CONFIGS:
if self.DEFAULT_CONFIG_NAME is not None:
builder_config = self.builder_configs.get(self.DEFAULT_CONFIG_NAME)
logger.info(f"No config specified, defaulting to: {self.dataset_name}/{builder_config.name}")
Expand Down Expand Up @@ -542,7 +547,7 @@ def _create_builder_config(

# otherwise use the config_kwargs to overwrite the attributes
else:
builder_config = copy.deepcopy(builder_config)
builder_config = copy.deepcopy(builder_config) if config_kwargs else builder_config
for key, value in config_kwargs.items():
if value is not None:
if not hasattr(builder_config, key):
Expand All @@ -552,6 +557,12 @@ def _create_builder_config(
if not builder_config.name:
raise ValueError(f"BuilderConfig must have a name, got {builder_config.name}")

# resolve data files if needed
builder_config._resolve_data_files(
base_path=self.base_path,
download_config=DownloadConfig(token=self.token, storage_options=self.storage_options),
)

# compute the config id that is going to be used for caching
config_id = builder_config.create_config_id(
config_kwargs,
Expand All @@ -577,7 +588,7 @@ def _create_builder_config(
@classproperty
@classmethod
@memoize()
def builder_configs(cls):
def builder_configs(cls) -> Dict[str, BuilderConfig]:
"""Dictionary of pre-defined configurations for this builder class."""
configs = {config.name: config for config in cls.BUILDER_CONFIGS}
if len(configs) != len(cls.BUILDER_CONFIGS):
Expand Down
91 changes: 91 additions & 0 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,3 +702,94 @@ def filter_extensions(self, extensions: List[str]) -> "DataFilesDict":
for key, data_files_list in self.items():
out[key] = data_files_list.filter_extensions(extensions)
return out


class DataFilesPatternsList(List[str]):
"""
List of data files patterns (absolute local paths or URLs).
For each pattern there should also be a list of allowed extensions
to keep, or a None ot keep all the files for the pattern.
"""

def __init__(
self,
patterns: List[str],
allowed_extensions: List[Optional[List[str]]],
):
super().__init__(patterns)
self.allowed_extensions = allowed_extensions

def __add__(self, other):
return DataFilesList([*self, *other], self.allowed_extensions + other.allowed_extensions)

@classmethod
def from_patterns(
cls, patterns: List[str], allowed_extensions: Optional[List[str]] = None
) -> "DataFilesPatternsDict":
return cls(patterns, [allowed_extensions] * len(patterns))

def resolve(
self,
base_path: str,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesList":
base_path = base_path if base_path is not None else Path().resolve().as_posix()
data_files = []
for pattern, allowed_extensions in zip(self, self.allowed_extensions):
try:
data_files.extend(
resolve_pattern(
pattern,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
)
except FileNotFoundError:
if not has_magic(pattern):
raise
origin_metadata = _get_origin_metadata(data_files, download_config=download_config)
return DataFilesList(data_files, origin_metadata)

def filter_extensions(self, extensions: List[str]) -> "DataFilesList":
return DataFilesPatternsList(
self, [allowed_extensions + extensions for allowed_extensions in self.allowed_extensions]
)


class DataFilesPatternsDict(Dict[str, DataFilesPatternsList]):
"""
Dict of split_name -> list of data files patterns (absolute local paths or URLs).
"""

@classmethod
def from_patterns(
cls, patterns: Dict[str, List[str]], allowed_extensions: Optional[List[str]] = None
) -> "DataFilesPatternsDict":
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
DataFilesPatternsList.from_patterns(
patterns_for_key,
allowed_extensions=allowed_extensions,
)
if not isinstance(patterns_for_key, DataFilesPatternsList)
else patterns_for_key
)
return out

def resolve(
self,
base_path: str,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesDict":
out = DataFilesDict()
for key, data_files_patterns_list in self.items():
out[key] = data_files_patterns_list.resolve(base_path, download_config)
return out

def filter_extensions(self, extensions: List[str]) -> "DataFilesPatternsDict":
out = type(self)()
for key, data_files_patterns_list in self.items():
out[key] = data_files_patterns_list.filter_extensions(extensions)
return out
Loading

0 comments on commit ef3b5dd

Please sign in to comment.