Skip to content

Commit

Permalink
Fix typehint problems
Browse files Browse the repository at this point in the history
  • Loading branch information
kralka committed Jan 7, 2025
1 parent c649810 commit 03c0767
Show file tree
Hide file tree
Showing 16 changed files with 222 additions and 151 deletions.
6 changes: 3 additions & 3 deletions src/sedpack/io/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import gzip
import lzma

import lz4.frame # type: ignore
import lz4.frame
import zstandard as zstd

from sedpack.io.types import CompressionT
Expand Down Expand Up @@ -76,7 +76,7 @@ def compress(self, data: bytes) -> bytes:
case "LZMA":
return lzma.compress(data)
case "LZ4":
return lz4.frame.compress(data)
return lz4.frame.compress(data) # type: ignore[no-any-return]
case "ZSTD":
return zstd.compress(data)
case _:
Expand All @@ -102,7 +102,7 @@ def decompress(self, data: bytes) -> bytes:
case "LZMA":
return lzma.decompress(data)
case "LZ4":
return lz4.frame.decompress(data)
return lz4.frame.decompress(data) # type: ignore[no-any-return]
case "ZSTD":
return zstd.decompress(data)
case _:
Expand Down
53 changes: 33 additions & 20 deletions src/sedpack/io/dataset_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import asyncstdlib
import numpy as np
import tensorflow as tf
import tensorflow as tf # type: ignore[import-untyped]

Check failure on line 31 in src/sedpack/io/dataset_iteration.py

View workflow job for this annotation

GitHub Actions / mypy

Unused "type: ignore" comment [unused-ignore]

from sedpack.io.dataset_base import DatasetBase
from sedpack.io.flatbuffer import IterateShardFlatBuffer
Expand Down Expand Up @@ -169,8 +169,7 @@ def read_and_decode(self, tf_dataset: TFDatasetT, cycle_length: int | None,
tf_dataset = tf_dataset.interleave(
lambda x: tf.data.TFRecordDataset(
x,
compression_type=self.dataset_structure.
compression, # type: ignore
compression_type=self.dataset_structure.compression,

Check failure on line 172 in src/sedpack/io/dataset_iteration.py

View workflow job for this annotation

GitHub Actions / mypy

Argument "compression_type" to "TFRecordDataset" has incompatible type "Literal['', 'BZ2', 'GZIP', 'LZMA', 'LZ4', 'ZIP', 'ZLIB', 'ZSTD']"; expected "Literal['ZLIB', 'GZIP', 'AUTO', '', 0, 1, 2] | None" [arg-type]
),
cycle_length=cycle_length,
block_length=1,
Expand Down Expand Up @@ -272,7 +271,7 @@ def as_tfdataset( # pylint: disable=too-many-arguments
)
if process_record:
tf_dataset = tf_dataset.map(
process_record, # type: ignore[arg-type]
process_record,

Check failure on line 274 in src/sedpack/io/dataset_iteration.py

View workflow job for this annotation

GitHub Actions / mypy

Argument 1 to "map" of "Dataset" has incompatible type "Callable[[dict[str, str | int | ndarray[Any, dtype[generic]] | bytes]], T]"; expected "Callable[..., Iterable[dict[str, str | int | ndarray[Any, dtype[generic]] | bytes]]]" [arg-type]
num_parallel_calls=parallelism,
)
if shuffle:
Expand Down Expand Up @@ -303,7 +302,7 @@ def as_tfdataset( # pylint: disable=too-many-arguments
# Process each record if requested
if process_record:
tf_dataset = tf_dataset.map(
process_record, # type: ignore[arg-type]
process_record,
num_parallel_calls=parallelism,
)

Expand Down Expand Up @@ -393,24 +392,32 @@ async def as_numpy_iterator_async(
f"implemented.")

# Automatically shuffle.
# TODO(issue #85) Async iterator typing.
example_iterator: AsyncIterator[ExampleT]
if shuffle:
example_iterator = round_robin_async(
asyncstdlib.map(
shard_iterator.iterate_shard_async, # type: ignore
shard_iterator.
iterate_shard_async, # type: ignore[arg-type]
shard_paths_iterator,
),
buffer_size=file_parallelism,
)
) # type: ignore[assignment]
else:
example_iterator = asyncstdlib.chain.from_iterable(
asyncstdlib.map(
shard_iterator.iterate_shard_async, # type: ignore
shard_iterator.
iterate_shard_async, # type: ignore[arg-type]
shard_paths_iterator,
))

# Process each record if requested.
example_iterator_processed: AsyncIterator[ExampleT] | AsyncIterator[T]

Check failure on line 415 in src/sedpack/io/dataset_iteration.py

View workflow job for this annotation

GitHub Actions / pylint (3.10)

Unused variable 'example_iterator_processed' (unused-variable)

Check failure on line 415 in src/sedpack/io/dataset_iteration.py

View workflow job for this annotation

GitHub Actions / pylint (3.11)

Unused variable 'example_iterator_processed' (unused-variable)

Check failure on line 415 in src/sedpack/io/dataset_iteration.py

View workflow job for this annotation

GitHub Actions / pylint (3.12)

Unused variable 'example_iterator_processed' (unused-variable)
if process_record:
example_iterator = asyncstdlib.map(process_record, example_iterator)
example_iterator_processed = asyncstdlib.map(
process_record, example_iterator)
else:
example_iterator_processed = example_iterator

async for example in example_iterator:
yield example
Expand Down Expand Up @@ -465,13 +472,14 @@ def as_numpy_common(
if repeat:
shard_paths_iterator = itertools.cycle(shard_paths)
else:
shard_paths_iterator = shard_paths # type: ignore
shard_paths_iterator = shard_paths # type: ignore[assignment]

# Randomize only if > 0 -- no shuffle in test/validation
if shuffle:
shard_paths_iterator = shuffle_buffer(
shard_paths_iterator, # type: ignore
buffer_size=len(shard_paths))
shard_paths_iterator, # type: ignore[assignment]
buffer_size=len(shard_paths),
)
return shard_paths_iterator

def as_numpy_iterator_concurrent(
Expand Down Expand Up @@ -567,7 +575,8 @@ def as_numpy_iterator_concurrent(
with LazyPool(file_parallelism) as pool:
yield from round_robin(
pool.imap_unordered(
shard_iterator.process_and_list, # type: ignore
shard_iterator.
process_and_list, # type: ignore[arg-type]
shard_paths_iterator,
),
# round_robin keeps the whole shard files in memory.
Expand Down Expand Up @@ -669,19 +678,23 @@ def as_numpy_iterator(

example_iterator = itertools.chain.from_iterable(
map(
shard_iterator.iterate_shard, # type: ignore
shard_paths_iterator)) # type: ignore
shard_iterator.iterate_shard, # type: ignore[arg-type]
shard_paths_iterator,
))

# Process each record if requested
if process_record:
example_iterator = map(process_record,
example_iterator) # type: ignore
example_iterator = map(
process_record,
example_iterator, # type: ignore[assignment]
)

# Randomize only if > 0 -- no shuffle in test/validation
if shuffle:
example_iterator = shuffle_buffer(
example_iterator, # type: ignore
buffer_size=shuffle)
example_iterator, # type: ignore[assignment]
buffer_size=shuffle,
)

yield from example_iterator

Expand Down Expand Up @@ -832,7 +845,7 @@ def __init__(self,
shuffle (int): Size of the shuffle buffer.
"""
self._rust_iter: _sedpack_rs.RustIter | None = None
self._rust_iter: _sedpack_rs.RustIter | None = None # type: ignore[no-any-unimported]

Check failure on line 848 in src/sedpack/io/dataset_iteration.py

View workflow job for this annotation

GitHub Actions / pylint (3.10)

Line too long (94/80) (line-too-long)

Check failure on line 848 in src/sedpack/io/dataset_iteration.py

View workflow job for this annotation

GitHub Actions / pylint (3.11)

Line too long (94/80) (line-too-long)

Check failure on line 848 in src/sedpack/io/dataset_iteration.py

View workflow job for this annotation

GitHub Actions / pylint (3.12)

Line too long (94/80) (line-too-long)

self._dataset: DatasetIteration = dataset
self._split: SplitT = split
Expand Down
4 changes: 2 additions & 2 deletions src/sedpack/io/dataset_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import uuid

from tqdm.auto import tqdm
import tensorflow as tf
import tensorflow as tf # type: ignore[import-untyped]

Check failure on line 28 in src/sedpack/io/dataset_writing.py

View workflow job for this annotation

GitHub Actions / mypy

Unused "type: ignore" comment [unused-ignore]

import sedpack
from sedpack.io.dataset_base import DatasetBase
Expand Down Expand Up @@ -177,7 +177,7 @@ def write_config(self, updated_infos: list[ShardListInfo]) -> FileInfo:
# Type narrowing with get_args does not seem to work with mypy.
if split not in get_args(SplitT):
raise ValueError(f"Not a known split value: {split}")
splits_to_update[split].append(info) # type: ignore
splits_to_update[split].append(info) # type: ignore[index]

# Merge recursively.
for split, updates in splits_to_update.items():
Expand Down
34 changes: 25 additions & 9 deletions src/sedpack/io/flatbuffer/iterate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
information how it is saved.
"""

import logging
from pathlib import Path
from typing import Iterable
from typing import AsyncIterator, Iterable

import aiofiles
import numpy as np
import numpy.typing as npt

from sedpack.io.compress import CompressedFile
from sedpack.io.metadata import Attribute
Expand All @@ -30,6 +32,7 @@
from sedpack.io.utils import func_or_identity

# Autogenerated from src/sedpack/io/flatbuffer/shard.fbs
import sedpack.io.flatbuffer.shardfile.Attribute as fbapi_Attribute
import sedpack.io.flatbuffer.shardfile.Example as fbapi_Example
import sedpack.io.flatbuffer.shardfile.Shard as fbapi_Shard

Expand All @@ -40,10 +43,16 @@ class IterateShardFlatBuffer(IterateShardBase[T]):
"""

def _iterate_content(self, content: bytes) -> Iterable[ExampleT]:
shard = fbapi_Shard.Shard.GetRootAs(content, 0)
shard: fbapi_Shard.Shard = fbapi_Shard.Shard.GetRootAs(content, 0)

for example_id in range(shard.ExamplesLength()):
example: fbapi_Example.Example = shard.Examples(example_id)
maybe_example: fbapi_Example.Example | None = shard.Examples(
example_id)
if maybe_example is None:
logger = logging.getLogger("sedpack.io.Dataset")
logger.error("Unable to get an example, corrupted shard?")
continue
example: fbapi_Example.Example = maybe_example

example_dictionary: ExampleT = {}

Expand All @@ -59,8 +68,14 @@ def _iterate_content(self, content: bytes) -> Iterable[ExampleT]:
# - Speed, since we first need to check the type for every
# attribute.
# Bytearray representation. Little endian, just loaded.
np_bytes = example.Attributes(
attribute_id).AttributeBytesAsNumpy()
maybe_attribute_data: fbapi_Attribute.Attribute | None = example.Attributes(

Check failure on line 71 in src/sedpack/io/flatbuffer/iterate.py

View workflow job for this annotation

GitHub Actions / pylint (3.10)

Line too long (92/80) (line-too-long)

Check failure on line 71 in src/sedpack/io/flatbuffer/iterate.py

View workflow job for this annotation

GitHub Actions / pylint (3.11)

Line too long (92/80) (line-too-long)

Check failure on line 71 in src/sedpack/io/flatbuffer/iterate.py

View workflow job for this annotation

GitHub Actions / pylint (3.12)

Line too long (92/80) (line-too-long)
attribute_id)
if maybe_attribute_data is None:
logger = logging.getLogger("sedpack.io.Dataset")
logger.error("Unable to get an attribute, corrupted shard?")
break
attribute_data: fbapi_Attribute.Attribute = maybe_attribute_data
np_bytes = attribute_data.AttributeBytesAsNumpy()

np_array = IterateShardFlatBuffer.decode_array(
np_bytes=np_bytes,
Expand All @@ -76,9 +91,9 @@ def _iterate_content(self, content: bytes) -> Iterable[ExampleT]:
yield example_dictionary

@staticmethod
def decode_array(np_bytes: np.ndarray,
def decode_array(np_bytes: npt.NDArray[np.uint8],
attribute: Attribute,
batch_size: int = 0) -> np.ndarray:
batch_size: int = 0) -> npt.NDArray[np.generic]:
"""Decode an array. See `sedpack.io.shard.shard_writer_flatbuffer
.ShardWriterFlatBuffer.save_numpy_vector_as_bytearray`
for format description. The code tries to avoid unnecessary copies.
Expand Down Expand Up @@ -119,12 +134,13 @@ def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]:
"""
# Read then decompress (nice for benchmarking).
with open(file_path, "rb") as f:
content = f.read()
content: bytes = f.read()
content = CompressedFile(
self.dataset_structure.compression).decompress(content)
yield from self._iterate_content(content=content)

async def iterate_shard_async(self, file_path: Path):
async def iterate_shard_async(self,

Check failure on line 142 in src/sedpack/io/flatbuffer/iterate.py

View workflow job for this annotation

GitHub Actions / pylint (3.10)

Method 'iterate_shard_async' was expected to be 'non-async', found it instead as 'async' (invalid-overridden-method)

Check failure on line 142 in src/sedpack/io/flatbuffer/iterate.py

View workflow job for this annotation

GitHub Actions / pylint (3.11)

Method 'iterate_shard_async' was expected to be 'non-async', found it instead as 'async' (invalid-overridden-method)

Check failure on line 142 in src/sedpack/io/flatbuffer/iterate.py

View workflow job for this annotation

GitHub Actions / pylint (3.12)

Method 'iterate_shard_async' was expected to be 'non-async', found it instead as 'async' (invalid-overridden-method)
file_path: Path) -> AsyncIterator[ExampleT]:
"""Asynchronously iterate a shard.
"""
async with aiofiles.open(file_path, "rb") as f:
Expand Down
Loading

0 comments on commit 03c0767

Please sign in to comment.