diff --git a/src/sedpack/io/dataset_filler.py b/src/sedpack/io/dataset_filler.py index 1df0bc90..4e0afbde 100644 --- a/src/sedpack/io/dataset_filler.py +++ b/src/sedpack/io/dataset_filler.py @@ -18,7 +18,7 @@ import dataclasses from pathlib import Path from types import TracebackType -from typing import Any, Optional, Type, TYPE_CHECKING +from typing import Any, Type, TYPE_CHECKING import uuid from sedpack.io.file_info import FileInfo @@ -121,7 +121,7 @@ def _get_new_shard(self, split: SplitT) -> Shard: def write_example(self, values: ExampleT, split: SplitT, - custom_metadata: Optional[dict[str, Any]] = None) -> None: + custom_metadata: dict[str, Any] | None = None) -> None: """Write an example. Opens a new shard if necessary. Args: @@ -130,7 +130,7 @@ def write_example(self, split (SplitT): Which split to write this example into. - custom_metadata (Optional[dict[str, Any]]): Optional metadata saved + custom_metadata (dict[str, Any] | None): Optional metadata saved with in the shard info. The shards then can be filtered using these metadata. When a value is changed a new shard is open and the current shard is closed. TODO there is no check if a shard with the @@ -255,20 +255,20 @@ def __enter__(self) -> _DatasetFillerContext: assert not self._updated_infos return self._dataset_filler_context - def __exit__(self, exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + def __exit__(self, exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + exc_tb: TracebackType | None) -> None: """Make sure to close the last shard. Args: - exc_type (Optional[Type[BaseException]]): None if no exception, + exc_type (Type[BaseException] | None): None if no exception, otherwise the exception type. - exc_value (Optional[BaseException]): None if no exception, otherwise + exc_value (BaseException | None): None if no exception, otherwise the exception value. - exc_tb (Optional[TracebackType]): None if no exception, otherwise the + exc_tb (TracebackType | None): None if no exception, otherwise the traceback. """ # Close the shard only if there was an example written. diff --git a/src/sedpack/io/dataset_iteration.py b/src/sedpack/io/dataset_iteration.py index cc481c50..35f804cc 100644 --- a/src/sedpack/io/dataset_iteration.py +++ b/src/sedpack/io/dataset_iteration.py @@ -20,7 +20,6 @@ AsyncIterator, Callable, Iterable, - Optional, ) import os @@ -49,9 +48,9 @@ class DatasetIteration(DatasetBase): def shard_paths_dataset( self, split: SplitT, - shards: Optional[int] = None, + shards: int | None = None, custom_metadata_type_limit: int | None = None, - shard_filter: Optional[Callable[[ShardInfo], bool]] = None, + shard_filter: Callable[[ShardInfo], bool] | None = None, ) -> list[str]: """Return a list of shard filenames. @@ -59,7 +58,7 @@ def shard_paths_dataset( split (SplitT): Split, see SplitT. - shards (Optional[int]): If specified limits the dataset to the + shards (int | None): If specified limits the dataset to the first `shards` shards. custom_metadata_type_limit (int | None): Ignored when None. If @@ -68,7 +67,7 @@ def shard_paths_dataset( shards with the concrete `custom_metadata`. This is best effort for different `custom_metadata` (hashed as a tuple of sorted items). - shard_filter (Optional[Callable[[ShardInfo], bool]): If present + shard_filter (Callable[[ShardInfo], bool | None): If present this is a function taking the ShardInfo and returning True if the shard shall be used for traversal and False otherwise. @@ -120,10 +119,9 @@ def shard_paths_dataset( return shard_paths - def read_and_decode(self, tf_dataset: TFDatasetT, - cycle_length: Optional[int], - num_parallel_calls: Optional[int], - parallelism: Optional[int]) -> TFDatasetT: + def read_and_decode(self, tf_dataset: TFDatasetT, cycle_length: int | None, + num_parallel_calls: int | None, + parallelism: int | None) -> TFDatasetT: """Read the shard files and decode them. Args: @@ -131,18 +129,18 @@ def read_and_decode(self, tf_dataset: TFDatasetT, tf_dataset (tf.data.Dataset): Dataset containing shard paths as strings. - cycle_length (Optional[int]): How many files to read at once. + cycle_length (int | None): How many files to read at once. - num_parallel_calls (Optional[int]): Number of parallel reading calls. + num_parallel_calls (int | None): Number of parallel reading calls. - parallelism (Optional[int]): Decoding parallelism. + parallelism (int | None): Decoding parallelism. Returns: tf.data.Dataset containing decoded examples. """ # If the cycle_length is None it is determined automatically but we do # use determinism. See documentation # https://www.tensorflow.org/api_docs/python/tf/data/Dataset#interleave - deterministic: Optional[bool] = True + deterministic: bool | None = True if isinstance(cycle_length, int): # Use tf.data.Options.deterministic to decide `deterministic` if # cycle_length is <= 1. @@ -188,15 +186,15 @@ def as_tfdataset( # pylint: disable=too-many-arguments self, split: SplitT, *, - process_record: Optional[Callable[[ExampleT], T]] = None, - shards: Optional[int] = None, + process_record: Callable[[ExampleT], T] | None = None, + shards: int | None = None, custom_metadata_type_limit: int | None = None, - shard_filter: Optional[Callable[[ShardInfo], bool]] = None, + shard_filter: Callable[[ShardInfo], bool] | None = None, repeat: bool = True, batch_size: int = 32, prefetch: int = 2, - file_parallelism: Optional[int] = os.cpu_count(), - parallelism: Optional[int] = os.cpu_count(), + file_parallelism: int | None = os.cpu_count(), + parallelism: int | None = os.cpu_count(), shuffle: int = 1_000) -> TFDatasetT: """"Dataset as tfdataset @@ -204,10 +202,10 @@ def as_tfdataset( # pylint: disable=too-many-arguments split (SplitT): Split, see SplitT. - process_record (Optional[Callable[[ExampleT], T]]): Optional + process_record (Callable[[ExampleT], T] | None): Optional function that processes a single record. - shards (Optional[int]): If specified limits the dataset to the + shards (int | None): If specified limits the dataset to the first `shards` shards. custom_metadata_type_limit (int | None): Ignored when None. If @@ -216,7 +214,7 @@ def as_tfdataset( # pylint: disable=too-many-arguments shards with the concrete `custom_metadata`. This is best effort for different `custom_metadata` (hashed as a tuple of sorted items). - shard_filter (Optional[Callable[[ShardInfo], bool]): If present + shard_filter (Callable[[ShardInfo], bool | None): If present this is a function taking the ShardInfo and returning True if the shard shall be used for traversal and False otherwise. @@ -228,9 +226,9 @@ def as_tfdataset( # pylint: disable=too-many-arguments prefetch (int): Prefetch this many batches. - file_parallelism (Optional[int]): IO parallelism. + file_parallelism (int | None): IO parallelism. - parallelism (Optional[int]): Parallelism of trace decoding and + parallelism (int | None): Parallelism of trace decoding and processing (ignored if shuffle is zero). shuffle (int): How many examples should be shuffled across shards. @@ -322,9 +320,9 @@ async def as_numpy_iterator_async( self, *, split: SplitT, - process_record: Optional[Callable[[ExampleT], T]] = None, - shards: Optional[int] = None, - shard_filter: Optional[Callable[[ShardInfo], bool]] = None, + process_record: Callable[[ExampleT], T] | None = None, + shards: int | None = None, + shard_filter: Callable[[ShardInfo], bool] | None = None, repeat: bool = True, file_parallelism: int = os.cpu_count() or 4, shuffle: int = 1_000, @@ -336,13 +334,13 @@ async def as_numpy_iterator_async( split (SplitT): Split, see SplitT. - process_record (Optional[Callable[[ExampleT], T]]): Optional + process_record (Callable[[ExampleT], T] | None): Optional function that processes a single record. - shards (Optional[int]): If specified limits the dataset to the + shards (int | None): If specified limits the dataset to the first `shards` shards. - shard_filter (Optional[Callable[[ShardInfo], bool]): If present + shard_filter (Callable[[ShardInfo], bool | None): If present this is a function taking the ShardInfo and returning True if the shard shall be used for traversal and False otherwise. @@ -417,9 +415,9 @@ def _as_numpy_common( self, *, split: SplitT, - shards: Optional[int] = None, + shards: int | None = None, custom_metadata_type_limit: int | None = None, - shard_filter: Optional[Callable[[ShardInfo], bool]] = None, + shard_filter: Callable[[ShardInfo], bool] | None = None, repeat: bool = True, shuffle: int = 1_000, ) -> Iterable[str]: @@ -429,7 +427,7 @@ def _as_numpy_common( split (SplitT): Split, see SplitT. - shards (Optional[int]): If specified limits the dataset to the + shards (int | None): If specified limits the dataset to the first `shards` shards. custom_metadata_type_limit (int | None): Ignored when None. If @@ -438,7 +436,7 @@ def _as_numpy_common( shards with the concrete `custom_metadata`. This is best effort for different `custom_metadata` (hashed as a tuple of sorted items). - shard_filter (Optional[Callable[[ShardInfo], bool]): If present + shard_filter (Callable[[ShardInfo], bool | None): If present this is a function taking the ShardInfo and returning True if the shard shall be used for traversal and False otherwise. @@ -476,10 +474,10 @@ def as_numpy_iterator_concurrent( self, *, split: SplitT, - process_record: Optional[Callable[[ExampleT], T]] = None, - shards: Optional[int] = None, + process_record: Callable[[ExampleT], T] | None = None, + shards: int | None = None, custom_metadata_type_limit: int | None = None, - shard_filter: Optional[Callable[[ShardInfo], bool]] = None, + shard_filter: Callable[[ShardInfo], bool] | None = None, repeat: bool = True, file_parallelism: int = os.cpu_count() or 1, shuffle: int = 1_000, @@ -491,10 +489,10 @@ def as_numpy_iterator_concurrent( split (SplitT): Split, see SplitT. - process_record (Optional[Callable[[ExampleT], T]]): Optional + process_record (Callable[[ExampleT], T] | None): Optional function that processes a single record. - shards (Optional[int]): If specified limits the dataset to the + shards (int | None): If specified limits the dataset to the first `shards` shards. custom_metadata_type_limit (int | None): Ignored when None. If @@ -503,7 +501,7 @@ def as_numpy_iterator_concurrent( shards with the concrete `custom_metadata`. This is best effort for different `custom_metadata` (hashed as a tuple of sorted items). - shard_filter (Optional[Callable[[ShardInfo], bool]): If present + shard_filter (Callable[[ShardInfo], bool | None): If present this is a function taking the ShardInfo and returning True if the shard shall be used for traversal and False otherwise. @@ -593,10 +591,10 @@ def as_numpy_iterator( self, *, split: SplitT, - process_record: Optional[Callable[[ExampleT], T]] = None, - shards: Optional[int] = None, + process_record: Callable[[ExampleT], T] | None = None, + shards: int | None = None, custom_metadata_type_limit: int | None = None, - shard_filter: Optional[Callable[[ShardInfo], bool]] = None, + shard_filter: Callable[[ShardInfo], bool] | None = None, repeat: bool = True, shuffle: int = 1_000, ) -> Iterable[ExampleT] | Iterable[T]: @@ -607,10 +605,10 @@ def as_numpy_iterator( split (SplitT): Split, see SplitT. - process_record (Optional[Callable[[ExampleT], T]]): Optional + process_record (Callable[[ExampleT], T] | None): Optional function that processes a single record. - shards (Optional[int]): If specified limits the dataset to the + shards (int | None): If specified limits the dataset to the first `shards` shards. custom_metadata_type_limit (int | None): Ignored when None. If @@ -619,7 +617,7 @@ def as_numpy_iterator( shards with the concrete `custom_metadata`. This is best effort for different `custom_metadata` (hashed as a tuple of sorted items). - shard_filter (Optional[Callable[[ShardInfo], bool]): If present + shard_filter (Callable[[ShardInfo], bool | None): If present this is a function taking the ShardInfo and returning True if the shard shall be used for traversal and False otherwise. @@ -687,9 +685,9 @@ def as_numpy_iterator_rust( # pylint: disable=too-many-arguments self, *, split: SplitT, - process_record: Optional[Callable[[ExampleT], T]] = None, - shards: Optional[int] = None, - shard_filter: Optional[Callable[[ShardInfo], bool]] = None, + process_record: Callable[[ExampleT], T] | None = None, + shards: int | None = None, + shard_filter: Callable[[ShardInfo], bool] | None = None, repeat: bool = True, file_parallelism: int = os.cpu_count() or 1, shuffle: int = 1_000, @@ -701,13 +699,13 @@ def as_numpy_iterator_rust( # pylint: disable=too-many-arguments split (SplitT): Split, see SplitT. - process_record (Optional[Callable[[ExampleT], T]]): Optional + process_record (Callable[[ExampleT], T] | None): Optional function that processes a single record. - shards (Optional[int]): If specified limits the dataset to the + shards (int | None): If specified limits the dataset to the first `shards` shards. - shard_filter (Optional[Callable[[ShardInfo], bool]): If present + shard_filter (Callable[[ShardInfo], bool | None): If present this is a function taking the ShardInfo and returning True if the shard shall be used for traversal and False otherwise. diff --git a/src/sedpack/io/dataset_writing.py b/src/sedpack/io/dataset_writing.py index 1a4b9605..81a44b2f 100644 --- a/src/sedpack/io/dataset_writing.py +++ b/src/sedpack/io/dataset_writing.py @@ -19,11 +19,7 @@ from typing import ( Any, Callable, - Dict, - List, Mapping, - Optional, - Tuple, get_args, ) import uuid @@ -50,10 +46,10 @@ def write_multiprocessing( # pylint: disable=too-many-arguments self, *, feed_writer: Callable[..., Any], - custom_arguments: List[Any], - custom_kwarguments: Optional[List[Dict[str, Any]]] = None, + custom_arguments: list[Any], + custom_kwarguments: list[dict[str, Any]] | None = None, consistency_check: bool = True, - single_process: bool = False) -> List[Any]: + single_process: bool = False) -> list[Any]: """Multiprocessing write support. Spawn `len(custom_arguments)` processes to write examples in parallel. Note that all computation is run on the CPU (using `tf.device("CPU")`) in order to prevent each @@ -68,11 +64,11 @@ def write_multiprocessing( # pylint: disable=too-many-arguments `custom_kwarguments` are provided (defaults to no keyword arguments). - custom_arguments (List[Any]): A list of arguments which are passed + custom_arguments (list[Any]): A list of arguments which are passed to `feed_writer` instances. A pool of `len(custom_kwarguments)` processes is being created to do this. - custom_kwarguments (Optional[List[Dict[str, Any]]]): A list of + custom_kwarguments (list[dict[str, Any]] | None): A list of keyword arguments which are passed to `feed_writer` instances. Defaults to no keyword arguments. Needs to have the same length as `custom_arguments`. @@ -227,10 +223,10 @@ def check(self, show_progressbar: bool = True) -> None: # We want to get results back. def _wrapper_func( - feed_writer_dataset_filler_args_kwargs: Tuple[Callable[..., + feed_writer_dataset_filler_args_kwargs: tuple[Callable[..., Any], DatasetFiller, Any, Mapping[str, Any]] -) -> Tuple[DatasetFiller, Any]: +) -> tuple[DatasetFiller, Any]: """Helper function for write_multiprocessing. Needs to be pickleable. """ # Prevent each process from hoarding the whole GPU memory. diff --git a/src/sedpack/io/merge_shard_infos.py b/src/sedpack/io/merge_shard_infos.py index fedae560..98038946 100644 --- a/src/sedpack/io/merge_shard_infos.py +++ b/src/sedpack/io/merge_shard_infos.py @@ -39,7 +39,7 @@ def merge_shard_infos(updates: list[ShardListInfo], dataset_root: Path, thus one should set `common=1`. It is not guaranteed to update `shards_list.json` files all the way to the split when `common>1`. - hashes (tuple[HashChecksumT, ...]): List of hash checksum algorithms. + hashes (tuple[HashChecksumT, ...]): A tuple of hash checksum algorithms. """ assert updates, "Nothing to update." diff --git a/src/sedpack/io/shard/shard.py b/src/sedpack/io/shard/shard.py index 025f61ab..407dd691 100644 --- a/src/sedpack/io/shard/shard.py +++ b/src/sedpack/io/shard/shard.py @@ -15,7 +15,6 @@ """ from pathlib import Path -from typing import Optional import sedpack from sedpack.io.metadata import DatasetStructure @@ -47,7 +46,7 @@ def __init__(self, shard_info: ShardInfo, self.dataset_structure: DatasetStructure = dataset_structure self._dataset_path: Path = dataset_root_path - self._shard_writer: Optional[ShardWriterBase] = get_shard_writer( + self._shard_writer: ShardWriterBase | None = get_shard_writer( dataset_structure=dataset_structure, shard_file=self._get_full_path(), ) diff --git a/src/sedpack/io/shard/shard_writer_tfrec.py b/src/sedpack/io/shard/shard_writer_tfrec.py index 22a991b2..bc6e5557 100644 --- a/src/sedpack/io/shard/shard_writer_tfrec.py +++ b/src/sedpack/io/shard/shard_writer_tfrec.py @@ -18,7 +18,7 @@ """ from pathlib import Path -from typing import Any, Optional +from typing import Any import tensorflow as tf @@ -52,7 +52,7 @@ def __init__(self, dataset_structure: DatasetStructure, # Open the tf.io.TFRecordWriter only with the first `write` call. Make # it None immediately during a call to `close`. - self._tf_shard_writer: Optional[Any] = None + self._tf_shard_writer: Any | None = None def _write(self, values: ExampleT) -> None: """Write an example on disk. Writing may be buffered. diff --git a/src/sedpack/io/tfrec/tfdata.py b/src/sedpack/io/tfrec/tfdata.py index e1cc713e..2f6b6e0b 100644 --- a/src/sedpack/io/tfrec/tfdata.py +++ b/src/sedpack/io/tfrec/tfdata.py @@ -17,7 +17,7 @@ https://www.tensorflow.org/tutorials/load_data/tfrecord """ -from typing import Any, Callable, cast, Dict, List, Tuple +from typing import Any, Callable, cast import numpy as np import tensorflow as tf @@ -57,7 +57,7 @@ def int64_feature(value: Any) -> Any: def get_from_tfrecord( - saved_data_description: List[Attribute]) -> Callable[[Any], Any]: + saved_data_description: list[Attribute]) -> Callable[[Any], Any]: """Construct the from_tfrecord function. """ @@ -81,7 +81,7 @@ def get_from_tfrecord( "float64": tf.float64, }[attribute.dtype] - shape: Tuple[int, ...] = attribute.shape + shape: tuple[int, ...] = attribute.shape if attribute.dtype == "float16": # We parse from bytes so no shape shape = () @@ -103,16 +103,16 @@ def from_tfrecord(tf_record: Any) -> Any: return from_tfrecord -def to_tfrecord(saved_data_description: List[Attribute], - values: Dict[str, Any]) -> bytes: +def to_tfrecord(saved_data_description: list[Attribute], + values: dict[str, Any]) -> bytes: """Convert example data into a tfrecord example Args: - saved_data_description (List[Attribute]): Descriptions of all saved + saved_data_description (list[Attribute]): Descriptions of all saved data. - values (Dict): The name and value to be saved (corresponding to + values (dict): The name and value to be saved (corresponding to saved_data_description). Returns: TF.train.Example diff --git a/tests/io/test_as_tfdataset.py b/tests/io/test_as_tfdataset.py index a2be2f83..24f49a2a 100644 --- a/tests/io/test_as_tfdataset.py +++ b/tests/io/test_as_tfdataset.py @@ -16,7 +16,7 @@ import itertools from pathlib import Path -from typing import Callable, Optional, Union +from typing import Callable, Union import pytest import numpy as np @@ -39,7 +39,7 @@ def end2end( dtype: npt.DTypeLike, shard_file_type: ShardFileTypeT, compression: CompressionT, - process_record: Optional[Callable[[ExampleT], T]], + process_record: Callable[[ExampleT], T] | None, ) -> None: array_of_values = np.random.random((1024, 138)) array_of_values = array_of_values.astype(dtype) @@ -150,7 +150,7 @@ def test_end2end_as_tfdataset( shard_file_type: str, compression: str, dtype: str, - process_record: Optional[Callable[[ExampleT], T]], + process_record: Callable[[ExampleT], T] | None, tmp_path: Union[str, Path], ) -> None: end2end(