Skip to content

Commit

Permalink
modernize typehints (google#52)
Browse files Browse the repository at this point in the history
Closes google#49 

Used python script to automate changes:

```python
import re
from pathlib import Path

project_dir = Path("path")

optional_pattern = re.compile(r"Optional\[(.+)\]")

def replace_optional(match):
    '''Replace "Optional[...]" with "... | None"'''
    content = match.group(1)
    bracket_count = 0
    end_index = -1

    for i, char in enumerate(content):
        if char == '[':
            bracket_count += 1
        elif char == ']':
            bracket_count -= 1
            if bracket_count == 0:
                end_index = i
                break

    if end_index != -1:
        return f"{content[:end_index + 1]} | None{content[end_index + 1:]}"
    else:
        return f"{content} | None"  

patterns = [
    (re.compile(r"\bList\b"), "list"),
    (re.compile(r"\bDict\b"), "dict"),
    (re.compile(r"\bTuple\b"), "tuple"),
]


for filepath in project_dir.rglob("*.py"):
    with open(filepath, "r", encoding="utf-8") as file:
        content = file.read()
    
    content = optional_pattern.sub(replace_optional, content)

    for pattern, replacement in patterns:
        content = pattern.sub(replacement, content)

    with open(filepath, "w", encoding="utf-8") as file:
        file.write(content)
```

---------

Co-authored-by: kralka <[email protected]>
  • Loading branch information
Jemeljanov and kralka authored Nov 3, 2024
1 parent 2555cf6 commit 65fc059
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 86 deletions.
18 changes: 9 additions & 9 deletions src/sedpack/io/dataset_filler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import dataclasses
from pathlib import Path
from types import TracebackType
from typing import Any, Optional, Type, TYPE_CHECKING
from typing import Any, Type, TYPE_CHECKING
import uuid

from sedpack.io.file_info import FileInfo
Expand Down Expand Up @@ -121,7 +121,7 @@ def _get_new_shard(self, split: SplitT) -> Shard:
def write_example(self,
values: ExampleT,
split: SplitT,
custom_metadata: Optional[dict[str, Any]] = None) -> None:
custom_metadata: dict[str, Any] | None = None) -> None:
"""Write an example. Opens a new shard if necessary.
Args:
Expand All @@ -130,7 +130,7 @@ def write_example(self,
split (SplitT): Which split to write this example into.
custom_metadata (Optional[dict[str, Any]]): Optional metadata saved
custom_metadata (dict[str, Any] | None): Optional metadata saved
with in the shard info. The shards then can be filtered using these
metadata. When a value is changed a new shard is open and the
current shard is closed. TODO there is no check if a shard with the
Expand Down Expand Up @@ -255,20 +255,20 @@ def __enter__(self) -> _DatasetFillerContext:
assert not self._updated_infos
return self._dataset_filler_context

def __exit__(self, exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> None:
def __exit__(self, exc_type: Type[BaseException] | None,
exc_value: BaseException | None,
exc_tb: TracebackType | None) -> None:
"""Make sure to close the last shard.
Args:
exc_type (Optional[Type[BaseException]]): None if no exception,
exc_type (Type[BaseException] | None): None if no exception,
otherwise the exception type.
exc_value (Optional[BaseException]): None if no exception, otherwise
exc_value (BaseException | None): None if no exception, otherwise
the exception value.
exc_tb (Optional[TracebackType]): None if no exception, otherwise the
exc_tb (TracebackType | None): None if no exception, otherwise the
traceback.
"""
# Close the shard only if there was an example written.
Expand Down
100 changes: 49 additions & 51 deletions src/sedpack/io/dataset_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
AsyncIterator,
Callable,
Iterable,
Optional,
)
import os

Expand Down Expand Up @@ -49,17 +48,17 @@ class DatasetIteration(DatasetBase):
def shard_paths_dataset(
self,
split: SplitT,
shards: Optional[int] = None,
shards: int | None = None,
custom_metadata_type_limit: int | None = None,
shard_filter: Optional[Callable[[ShardInfo], bool]] = None,
shard_filter: Callable[[ShardInfo], bool] | None = None,
) -> list[str]:
"""Return a list of shard filenames.
Args:
split (SplitT): Split, see SplitT.
shards (Optional[int]): If specified limits the dataset to the
shards (int | None): If specified limits the dataset to the
first `shards` shards.
custom_metadata_type_limit (int | None): Ignored when None. If
Expand All @@ -68,7 +67,7 @@ def shard_paths_dataset(
shards with the concrete `custom_metadata`. This is best effort for
different `custom_metadata` (hashed as a tuple of sorted items).
shard_filter (Optional[Callable[[ShardInfo], bool]): If present
shard_filter (Callable[[ShardInfo], bool | None): If present
this is a function taking the ShardInfo and returning True if the
shard shall be used for traversal and False otherwise.
Expand Down Expand Up @@ -120,29 +119,28 @@ def shard_paths_dataset(

return shard_paths

def read_and_decode(self, tf_dataset: TFDatasetT,
cycle_length: Optional[int],
num_parallel_calls: Optional[int],
parallelism: Optional[int]) -> TFDatasetT:
def read_and_decode(self, tf_dataset: TFDatasetT, cycle_length: int | None,
num_parallel_calls: int | None,
parallelism: int | None) -> TFDatasetT:
"""Read the shard files and decode them.
Args:
tf_dataset (tf.data.Dataset): Dataset containing shard paths as
strings.
cycle_length (Optional[int]): How many files to read at once.
cycle_length (int | None): How many files to read at once.
num_parallel_calls (Optional[int]): Number of parallel reading calls.
num_parallel_calls (int | None): Number of parallel reading calls.
parallelism (Optional[int]): Decoding parallelism.
parallelism (int | None): Decoding parallelism.
Returns: tf.data.Dataset containing decoded examples.
"""
# If the cycle_length is None it is determined automatically but we do
# use determinism. See documentation
# https://www.tensorflow.org/api_docs/python/tf/data/Dataset#interleave
deterministic: Optional[bool] = True
deterministic: bool | None = True
if isinstance(cycle_length, int):
# Use tf.data.Options.deterministic to decide `deterministic` if
# cycle_length is <= 1.
Expand Down Expand Up @@ -188,26 +186,26 @@ def as_tfdataset( # pylint: disable=too-many-arguments
self,
split: SplitT,
*,
process_record: Optional[Callable[[ExampleT], T]] = None,
shards: Optional[int] = None,
process_record: Callable[[ExampleT], T] | None = None,
shards: int | None = None,
custom_metadata_type_limit: int | None = None,
shard_filter: Optional[Callable[[ShardInfo], bool]] = None,
shard_filter: Callable[[ShardInfo], bool] | None = None,
repeat: bool = True,
batch_size: int = 32,
prefetch: int = 2,
file_parallelism: Optional[int] = os.cpu_count(),
parallelism: Optional[int] = os.cpu_count(),
file_parallelism: int | None = os.cpu_count(),
parallelism: int | None = os.cpu_count(),
shuffle: int = 1_000) -> TFDatasetT:
""""Dataset as tfdataset
Args:
split (SplitT): Split, see SplitT.
process_record (Optional[Callable[[ExampleT], T]]): Optional
process_record (Callable[[ExampleT], T] | None): Optional
function that processes a single record.
shards (Optional[int]): If specified limits the dataset to the
shards (int | None): If specified limits the dataset to the
first `shards` shards.
custom_metadata_type_limit (int | None): Ignored when None. If
Expand All @@ -216,7 +214,7 @@ def as_tfdataset( # pylint: disable=too-many-arguments
shards with the concrete `custom_metadata`. This is best effort for
different `custom_metadata` (hashed as a tuple of sorted items).
shard_filter (Optional[Callable[[ShardInfo], bool]): If present
shard_filter (Callable[[ShardInfo], bool | None): If present
this is a function taking the ShardInfo and returning True if the
shard shall be used for traversal and False otherwise.
Expand All @@ -228,9 +226,9 @@ def as_tfdataset( # pylint: disable=too-many-arguments
prefetch (int): Prefetch this many batches.
file_parallelism (Optional[int]): IO parallelism.
file_parallelism (int | None): IO parallelism.
parallelism (Optional[int]): Parallelism of trace decoding and
parallelism (int | None): Parallelism of trace decoding and
processing (ignored if shuffle is zero).
shuffle (int): How many examples should be shuffled across shards.
Expand Down Expand Up @@ -322,9 +320,9 @@ async def as_numpy_iterator_async(
self,
*,
split: SplitT,
process_record: Optional[Callable[[ExampleT], T]] = None,
shards: Optional[int] = None,
shard_filter: Optional[Callable[[ShardInfo], bool]] = None,
process_record: Callable[[ExampleT], T] | None = None,
shards: int | None = None,
shard_filter: Callable[[ShardInfo], bool] | None = None,
repeat: bool = True,
file_parallelism: int = os.cpu_count() or 4,
shuffle: int = 1_000,
Expand All @@ -336,13 +334,13 @@ async def as_numpy_iterator_async(
split (SplitT): Split, see SplitT.
process_record (Optional[Callable[[ExampleT], T]]): Optional
process_record (Callable[[ExampleT], T] | None): Optional
function that processes a single record.
shards (Optional[int]): If specified limits the dataset to the
shards (int | None): If specified limits the dataset to the
first `shards` shards.
shard_filter (Optional[Callable[[ShardInfo], bool]): If present
shard_filter (Callable[[ShardInfo], bool | None): If present
this is a function taking the ShardInfo and returning True if the
shard shall be used for traversal and False otherwise.
Expand Down Expand Up @@ -417,9 +415,9 @@ def _as_numpy_common(
self,
*,
split: SplitT,
shards: Optional[int] = None,
shards: int | None = None,
custom_metadata_type_limit: int | None = None,
shard_filter: Optional[Callable[[ShardInfo], bool]] = None,
shard_filter: Callable[[ShardInfo], bool] | None = None,
repeat: bool = True,
shuffle: int = 1_000,
) -> Iterable[str]:
Expand All @@ -429,7 +427,7 @@ def _as_numpy_common(
split (SplitT): Split, see SplitT.
shards (Optional[int]): If specified limits the dataset to the
shards (int | None): If specified limits the dataset to the
first `shards` shards.
custom_metadata_type_limit (int | None): Ignored when None. If
Expand All @@ -438,7 +436,7 @@ def _as_numpy_common(
shards with the concrete `custom_metadata`. This is best effort for
different `custom_metadata` (hashed as a tuple of sorted items).
shard_filter (Optional[Callable[[ShardInfo], bool]): If present
shard_filter (Callable[[ShardInfo], bool | None): If present
this is a function taking the ShardInfo and returning True if the
shard shall be used for traversal and False otherwise.
Expand Down Expand Up @@ -476,10 +474,10 @@ def as_numpy_iterator_concurrent(
self,
*,
split: SplitT,
process_record: Optional[Callable[[ExampleT], T]] = None,
shards: Optional[int] = None,
process_record: Callable[[ExampleT], T] | None = None,
shards: int | None = None,
custom_metadata_type_limit: int | None = None,
shard_filter: Optional[Callable[[ShardInfo], bool]] = None,
shard_filter: Callable[[ShardInfo], bool] | None = None,
repeat: bool = True,
file_parallelism: int = os.cpu_count() or 1,
shuffle: int = 1_000,
Expand All @@ -491,10 +489,10 @@ def as_numpy_iterator_concurrent(
split (SplitT): Split, see SplitT.
process_record (Optional[Callable[[ExampleT], T]]): Optional
process_record (Callable[[ExampleT], T] | None): Optional
function that processes a single record.
shards (Optional[int]): If specified limits the dataset to the
shards (int | None): If specified limits the dataset to the
first `shards` shards.
custom_metadata_type_limit (int | None): Ignored when None. If
Expand All @@ -503,7 +501,7 @@ def as_numpy_iterator_concurrent(
shards with the concrete `custom_metadata`. This is best effort for
different `custom_metadata` (hashed as a tuple of sorted items).
shard_filter (Optional[Callable[[ShardInfo], bool]): If present
shard_filter (Callable[[ShardInfo], bool | None): If present
this is a function taking the ShardInfo and returning True if the
shard shall be used for traversal and False otherwise.
Expand Down Expand Up @@ -593,10 +591,10 @@ def as_numpy_iterator(
self,
*,
split: SplitT,
process_record: Optional[Callable[[ExampleT], T]] = None,
shards: Optional[int] = None,
process_record: Callable[[ExampleT], T] | None = None,
shards: int | None = None,
custom_metadata_type_limit: int | None = None,
shard_filter: Optional[Callable[[ShardInfo], bool]] = None,
shard_filter: Callable[[ShardInfo], bool] | None = None,
repeat: bool = True,
shuffle: int = 1_000,
) -> Iterable[ExampleT] | Iterable[T]:
Expand All @@ -607,10 +605,10 @@ def as_numpy_iterator(
split (SplitT): Split, see SplitT.
process_record (Optional[Callable[[ExampleT], T]]): Optional
process_record (Callable[[ExampleT], T] | None): Optional
function that processes a single record.
shards (Optional[int]): If specified limits the dataset to the
shards (int | None): If specified limits the dataset to the
first `shards` shards.
custom_metadata_type_limit (int | None): Ignored when None. If
Expand All @@ -619,7 +617,7 @@ def as_numpy_iterator(
shards with the concrete `custom_metadata`. This is best effort for
different `custom_metadata` (hashed as a tuple of sorted items).
shard_filter (Optional[Callable[[ShardInfo], bool]): If present
shard_filter (Callable[[ShardInfo], bool | None): If present
this is a function taking the ShardInfo and returning True if the
shard shall be used for traversal and False otherwise.
Expand Down Expand Up @@ -687,9 +685,9 @@ def as_numpy_iterator_rust( # pylint: disable=too-many-arguments
self,
*,
split: SplitT,
process_record: Optional[Callable[[ExampleT], T]] = None,
shards: Optional[int] = None,
shard_filter: Optional[Callable[[ShardInfo], bool]] = None,
process_record: Callable[[ExampleT], T] | None = None,
shards: int | None = None,
shard_filter: Callable[[ShardInfo], bool] | None = None,
repeat: bool = True,
file_parallelism: int = os.cpu_count() or 1,
shuffle: int = 1_000,
Expand All @@ -701,13 +699,13 @@ def as_numpy_iterator_rust( # pylint: disable=too-many-arguments
split (SplitT): Split, see SplitT.
process_record (Optional[Callable[[ExampleT], T]]): Optional
process_record (Callable[[ExampleT], T] | None): Optional
function that processes a single record.
shards (Optional[int]): If specified limits the dataset to the
shards (int | None): If specified limits the dataset to the
first `shards` shards.
shard_filter (Optional[Callable[[ShardInfo], bool]): If present
shard_filter (Callable[[ShardInfo], bool | None): If present
this is a function taking the ShardInfo and returning True if the
shard shall be used for traversal and False otherwise.
Expand Down
18 changes: 7 additions & 11 deletions src/sedpack/io/dataset_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
get_args,
)
import uuid
Expand All @@ -50,10 +46,10 @@ def write_multiprocessing( # pylint: disable=too-many-arguments
self,
*,
feed_writer: Callable[..., Any],
custom_arguments: List[Any],
custom_kwarguments: Optional[List[Dict[str, Any]]] = None,
custom_arguments: list[Any],
custom_kwarguments: list[dict[str, Any]] | None = None,
consistency_check: bool = True,
single_process: bool = False) -> List[Any]:
single_process: bool = False) -> list[Any]:
"""Multiprocessing write support. Spawn `len(custom_arguments)`
processes to write examples in parallel. Note that all computation is
run on the CPU (using `tf.device("CPU")`) in order to prevent each
Expand All @@ -68,11 +64,11 @@ def write_multiprocessing( # pylint: disable=too-many-arguments
`custom_kwarguments` are provided (defaults to no keyword
arguments).
custom_arguments (List[Any]): A list of arguments which are passed
custom_arguments (list[Any]): A list of arguments which are passed
to `feed_writer` instances. A pool of `len(custom_kwarguments)`
processes is being created to do this.
custom_kwarguments (Optional[List[Dict[str, Any]]]): A list of
custom_kwarguments (list[dict[str, Any]] | None): A list of
keyword arguments which are passed to `feed_writer` instances.
Defaults to no keyword arguments. Needs to have the same length as
`custom_arguments`.
Expand Down Expand Up @@ -227,10 +223,10 @@ def check(self, show_progressbar: bool = True) -> None:

# We want to get results back.
def _wrapper_func(
feed_writer_dataset_filler_args_kwargs: Tuple[Callable[...,
feed_writer_dataset_filler_args_kwargs: tuple[Callable[...,
Any], DatasetFiller,
Any, Mapping[str, Any]]
) -> Tuple[DatasetFiller, Any]:
) -> tuple[DatasetFiller, Any]:
"""Helper function for write_multiprocessing. Needs to be pickleable.
"""
# Prevent each process from hoarding the whole GPU memory.
Expand Down
Loading

0 comments on commit 65fc059

Please sign in to comment.