Skip to content

Commit

Permalink
Fix import of custom types (google#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxrdv authored Feb 20, 2025
1 parent e764366 commit 9df0972
Show file tree
Hide file tree
Showing 27 changed files with 44 additions and 65 deletions.
7 changes: 3 additions & 4 deletions docs/tutorials/quick_start/mnist_read_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,16 @@
"""
import argparse
from functools import partial
from typing import Any, Dict, Tuple
from typing import Any

from jax import Array
from jax.typing import ArrayLike
from flax import nnx
import jax.numpy as jnp
import numpy as np
import optax
from tqdm import tqdm

from sedpack.io import Dataset
from sedpack.io.types import ExampleT, TFModelT


def process_batch(d: Any) -> dict[str, Array]:
Expand Down Expand Up @@ -154,7 +152,8 @@ def main() -> None:
if step > train_steps:
break

# Run the optimization for one step and make a stateful update to the following:
# Run the optimization for one step and make a stateful update to the
# following:
# - The train state's model parameters
# - The optimizer state
# - The training loss and accuracy batch metrics
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/quick_start/mnist_read_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
python mnist_read_keras.py -d "~/Datasets/my_new_dataset/"
"""
import argparse
from typing import Any, Dict, Tuple
from typing import Any, Tuple

import numpy as np
import tensorflow as tf
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/quick_start/mnist_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def main() -> None:
random.shuffle(train_indices)
validation_split_position: int = int(len(x_train) * 0.1)
for index_position, index in enumerate(
tqdm(train_indices, desc='train and val')):
tqdm(train_indices, desc="train and val")):

# Assign to either train or test (aka validation).
split: SplitT = "test"
Expand Down
11 changes: 6 additions & 5 deletions docs/tutorials/sca/tiny_aes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
The SCAAML package is a requirement: `python3 -m pip install "scaaml>=3.0.3"`.
Example use:
python tiny_aes.py --dataset_path "~/datasets/tiny_aes_sedpack/" --original_files "~/datasets/tinyaes"
python tiny_aes.py --dataset_path "~/datasets/tiny_aes_sedpack/" \
--original_files "~/datasets/tinyaes"
"""
import argparse
from pathlib import Path
from typing import Any, get_args
from typing import Any

import keras
import numpy as np
Expand All @@ -37,7 +38,7 @@
DatasetStructure,
Attribute,
)
from sedpack.io.typing import SplitT
from sedpack.io.types import SplitT


def add_shard(shard_file: Path, dataset_filler: DatasetFillerContext,
Expand Down Expand Up @@ -121,7 +122,7 @@ def convert_to_sedpack(dataset_path: Path, original_files: Path) -> None:
year={2019},
editor={DEF CON}
}
""",
""", # pylint: disable=line-too-long
"original from":
"https://github.com/google/scaaml/tree/main/scaaml_intro",
},
Expand Down Expand Up @@ -271,7 +272,7 @@ def train(dataset_path: Path) -> None:
)

# Train the model.
history = model.fit(
_ = model.fit(
train_ds,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
Expand Down
1 change: 1 addition & 0 deletions project-words.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@ tfrecord
tfrecords
tinyaes
tobytes
uoffset
zstd
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ dev = [
[project.urls]
"Homepage" = "https://github.com/google/sedpack"
"Bug Tracker" = "https://github.com/google/sedpack"

[tool.ruff]
target-version = "py310"
1 change: 0 additions & 1 deletion src/sedpack/io/flatbuffer/shardfile/Attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# pylint: skip-file

import flatbuffers # type: ignore[import-untyped]
from flatbuffers.compat import import_numpy # type: ignore[import-untyped]

import numpy as np
import numpy.typing as npt
Expand Down
5 changes: 0 additions & 5 deletions tests/io/itertools/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
import random
from typing import Any, Union

import numpy as np
import numpy.typing as npt

from sedpack.io.itertools import *

Expand Down
3 changes: 1 addition & 2 deletions tests/io/shard/test_shard_write_and_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from pathlib import Path
import pytest
from typing import get_args

import numpy as np

Expand All @@ -24,7 +23,7 @@
from sedpack.io.flatbuffer import IterateShardFlatBuffer
from sedpack.io.npz import IterateShardNP
from sedpack.io.tfrec import IterateShardTFRec
from sedpack.io.shard.get_shard_writer import get_shard_writer, _SHARD_FILE_TYPE_TO_CLASS
from sedpack.io.shard.get_shard_writer import get_shard_writer


def shard_write_and_read(attributes: dict[str, np.ndarray], shard_file: Path,
Expand Down
3 changes: 1 addition & 2 deletions tests/io/shard/test_shard_write_and_read_process_and_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from pathlib import Path
from typing import get_args

import numpy as np

Expand All @@ -23,7 +22,7 @@
from sedpack.io.flatbuffer import IterateShardFlatBuffer
from sedpack.io.npz import IterateShardNP
from sedpack.io.tfrec import IterateShardTFRec
from sedpack.io.shard.get_shard_writer import get_shard_writer, _SHARD_FILE_TYPE_TO_CLASS
from sedpack.io.shard.get_shard_writer import get_shard_writer


def shard_write_and_read(attributes: dict[str, np.ndarray], shard_file: Path,
Expand Down
4 changes: 1 addition & 3 deletions tests/io/shard/test_shard_write_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from pathlib import Path
from typing import get_args
import pytest

import numpy as np
Expand All @@ -24,7 +22,7 @@

from sedpack.io.flatbuffer import IterateShardFlatBuffer
from sedpack.io.npz import IterateShardNP
from sedpack.io.shard.get_shard_writer import get_shard_writer, _SHARD_FILE_TYPE_TO_CLASS
from sedpack.io.shard.get_shard_writer import get_shard_writer

pytest_plugins = ("pytest_asyncio",)

Expand Down
3 changes: 1 addition & 2 deletions tests/io/shard/test_shard_writer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from typing import get_args

from sedpack.io.types import ShardFileTypeT
from sedpack.io.shard.shard_writer_base import ShardWriterBase
from sedpack.io.shard.get_shard_writer import get_shard_writer, _SHARD_FILE_TYPE_TO_CLASS
from sedpack.io.shard.get_shard_writer import _SHARD_FILE_TYPE_TO_CLASS


def test_all_file_types_supported():
Expand Down
5 changes: 2 additions & 3 deletions tests/io/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
# limitations under the License.

from pathlib import Path
from typing import Any, Union
from typing import Union

import numpy as np
import numpy.typing as npt

import sedpack
from sedpack.io import Dataset
from sedpack.io import Metadata, DatasetStructure, Attribute
from sedpack.io import Metadata
from sedpack.io.types import TRAIN_SPLIT


Expand Down
5 changes: 2 additions & 3 deletions tests/io/test_check_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
# limitations under the License.

from pathlib import Path
from typing import Any, get_args, Union
from typing import get_args, Union

import pytest
import numpy as np
import numpy.typing as npt

import sedpack
from sedpack.io import Dataset
from sedpack.io import Metadata, DatasetStructure, Attribute
from sedpack.io import Metadata
from sedpack.io.types import CompressionT, HashChecksumT, ShardFileTypeT, TRAIN_SPLIT


Expand Down
4 changes: 1 addition & 3 deletions tests/io/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@

import gzip
from pathlib import Path
from typing import Any, Union
from typing import Union

import pytest
import numpy as np
import numpy.typing as npt

from sedpack.io.compress import CompressedFile
from sedpack.io.types import CompressionT
Expand Down
6 changes: 2 additions & 4 deletions tests/io/test_continue_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@
import os
from pathlib import Path
import random
from typing import Any, Union
from typing import Union

import numpy as np
import numpy.typing as npt

import sedpack
from sedpack.io import Dataset
from sedpack.io import Metadata, DatasetStructure, Attribute
from sedpack.io.types import CompressionT, ShardFileTypeT
from sedpack.io import Metadata


def get_dataset(tmpdir: Union[str, Path]) -> Dataset:
Expand Down
4 changes: 2 additions & 2 deletions tests/io/test_custom_metadata_type_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

from pathlib import Path
from typing import Any, Union
from typing import Union

import numpy as np
import numpy.typing as npt

import sedpack
from sedpack.io import Dataset
from sedpack.io import Metadata, DatasetStructure, Attribute
from sedpack.io import Metadata
from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT


Expand Down
4 changes: 2 additions & 2 deletions tests/io/test_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

from pathlib import Path
from typing import Any, Union
from typing import Union

import numpy as np
import numpy.typing as npt

import sedpack
from sedpack.io import Dataset
from sedpack.io import Metadata, DatasetStructure, Attribute
from sedpack.io import Metadata
from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT


Expand Down
4 changes: 2 additions & 2 deletions tests/io/test_end2end_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

from pathlib import Path
import pytest
from typing import Any, Union
from typing import Union

import numpy as np
import numpy.typing as npt

import sedpack
from sedpack.io import Dataset
from sedpack.io import Metadata, DatasetStructure, Attribute
from sedpack.io import Metadata
from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT

pytest_plugins = ("pytest_asyncio",)
Expand Down
4 changes: 2 additions & 2 deletions tests/io/test_end2end_shuffled.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

from pathlib import Path
from typing import Any, Union
from typing import Union

import numpy as np
import numpy.typing as npt

import sedpack
from sedpack.io import Dataset
from sedpack.io import Metadata, DatasetStructure, Attribute
from sedpack.io import Metadata
from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT


Expand Down
4 changes: 2 additions & 2 deletions tests/io/test_end2end_wrong_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

from pathlib import Path
import pytest
from typing import Any, Union
from typing import Union

import numpy as np
import numpy.typing as npt

import sedpack
from sedpack.io import Dataset
from sedpack.io import Metadata, DatasetStructure, Attribute
from sedpack.io import Metadata
from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT


Expand Down
2 changes: 1 addition & 1 deletion tests/io/test_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import sedpack
from sedpack.io import Dataset
from sedpack.io import Attribute, Metadata
from sedpack.io import Metadata
from sedpack.io.errors import DatasetExistsError


Expand Down
5 changes: 0 additions & 5 deletions tests/io/test_hash_checksums.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gzip
from pathlib import Path
from typing import Any, Union

import numpy as np
import numpy.typing as npt

from sedpack.io.utils import hash_checksums

Expand Down
5 changes: 2 additions & 3 deletions tests/io/test_rust_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
# limitations under the License.

from pathlib import Path
from typing import Any, get_args, Union
from typing import get_args, Union

import numpy as np
import numpy.typing as npt
import pytest

import sedpack
from sedpack.io import Dataset
from sedpack.io import Metadata, DatasetStructure, Attribute
from sedpack.io.flatbuffer.iterate import IterateShardFlatBuffer
from sedpack.io import Metadata
from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT

from sedpack import _sedpack_rs # type: ignore[attr-defined]
Expand Down
5 changes: 2 additions & 3 deletions tests/io/test_shard_custom_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
# limitations under the License.

from pathlib import Path
from typing import Any, Union
from typing import Union

import numpy as np
import numpy.typing as npt

import sedpack
from sedpack.io import Dataset
from sedpack.io import Metadata, DatasetStructure, Attribute
from sedpack.io import Metadata
from sedpack.io.types import TRAIN_SPLIT


Expand Down
Loading

0 comments on commit 9df0972

Please sign in to comment.