From cdc4bd78e14d8066fac498a3beac60fbec8aa1ed Mon Sep 17 00:00:00 2001 From: "Michael A. Smith" Date: Mon, 31 May 2021 15:47:36 -0400 Subject: [PATCH 1/6] AVRO-2921: Add Type Hints to Setup.py --- lang/py/setup.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/lang/py/setup.py b/lang/py/setup.py index d8d0fa3ff0e..5124319560b 100755 --- a/lang/py/setup.py +++ b/lang/py/setup.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# -*- coding: utf-8 -*- ## # Licensed to the Apache Software Foundation (ASF) under one @@ -19,6 +20,8 @@ import distutils.errors +import distutils.file_util +import distutils.log import glob import os import subprocess @@ -30,7 +33,7 @@ _VERSION_FILE_NAME = "VERSION.txt" -def _is_distribution(): +def _is_distribution() -> bool: """Tests whether setup.py is invoked from a distribution. Returns: @@ -42,13 +45,13 @@ def _is_distribution(): return os.path.exists(os.path.join(_HERE, "PKG-INFO")) -def _generate_package_data(): +def _generate_package_data() -> None: """Generate package data. This data will already exist in a distribution package, so this function only runs for local version control work tree. """ - distutils.log.info("Generating package data") + distutils.log.info("Generating package data", "") # Avro top-level source directory: root_dir = os.path.dirname(os.path.dirname(_HERE)) @@ -75,9 +78,7 @@ def _generate_package_data(): ) for src, dst in avsc_files: - src = os.path.join(*src) - dst = os.path.join(_AVRO_DIR, *dst) - distutils.file_util.copy_file(src, dst) + distutils.file_util.copy_file(os.path.join(*src), os.path.join(_AVRO_DIR, *dst)) class GenerateInteropDataCommand(setuptools.Command): @@ -88,14 +89,14 @@ class GenerateInteropDataCommand(setuptools.Command): ("output-path=", None, "path to output Avro data files"), ] - def initialize_options(self): + def initialize_options(self) -> None: self.schema_file = os.path.join(_AVRO_DIR, "interop.avsc") self.output_path = os.path.join(_AVRO_DIR, "test", "interop", "data") - def finalize_options(self): + def finalize_options(self) -> None: pass - def run(self): + def run(self) -> None: # Late import -- this can only be run when avro is on the pythonpath, # more or less after install. import avro.test.gen_interop_data @@ -105,7 +106,7 @@ def run(self): avro.test.gen_interop_data.generate(self.schema_file, os.path.join(self.output_path, "py.avro")) -def _get_version(): +def _get_version() -> str: curdir = os.getcwd() if os.path.isfile("avro/VERSION.txt"): version_file = "avro/VERSION.txt" @@ -119,7 +120,7 @@ def _get_version(): return verfile.read().rstrip().replace("-", "+") -def main(): +def main() -> None: if not _is_distribution(): _generate_package_data() From 93ade6a35a78cd04998369a42470f5d9c05db91a Mon Sep 17 00:00:00 2001 From: "Michael A. Smith" Date: Mon, 31 May 2021 16:05:59 -0400 Subject: [PATCH 2/6] AVRO-2921: Add Type Hints to avro.codecs --- lang/py/avro/codecs.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/lang/py/avro/codecs.py b/lang/py/avro/codecs.py index fa89d77e94e..c5b11ff169e 100644 --- a/lang/py/avro/codecs.py +++ b/lang/py/avro/codecs.py @@ -32,6 +32,9 @@ import struct import sys import zlib +from array import array +from mmap import mmap +from typing import List, Sequence, Tuple, Union import avro.errors import avro.io @@ -66,7 +69,7 @@ class Codec(abc.ABC): """Abstract base class for all Avro codec classes.""" @abc.abstractmethod - def compress(self, data): + def compress(self, data: bytes) -> Tuple[bytes, int]: """Compress the passed data. :param data: a byte string to be compressed @@ -77,7 +80,7 @@ def compress(self, data): """ @abc.abstractmethod - def decompress(self, readers_decoder): + def decompress(self, readers_decoder: avro.io.BinaryDecoder) -> avro.io.BinaryDecoder: """Read compressed data via the passed BinaryDecoder and decompress it. :param readers_decoder: a BinaryDecoder object currently being used for @@ -91,22 +94,22 @@ def decompress(self, readers_decoder): class NullCodec(Codec): - def compress(self, data): + def compress(self, data: bytes) -> Tuple[bytes, int]: return data, len(data) - def decompress(self, readers_decoder): + def decompress(self, readers_decoder: avro.io.BinaryDecoder) -> avro.io.BinaryDecoder: readers_decoder.skip_long() return readers_decoder class DeflateCodec(Codec): - def compress(self, data): + def compress(self, data: bytes) -> Tuple[bytes, int]: # The first two characters and last character are zlib # wrappers around deflate data. compressed_data = zlib.compress(data)[2:-1] return compressed_data, len(compressed_data) - def decompress(self, readers_decoder): + def decompress(self, readers_decoder: avro.io.BinaryDecoder) -> avro.io.BinaryDecoder: # Compressed data is stored as (length, data), which # corresponds to how the "bytes" type is encoded. data = readers_decoder.read_bytes() @@ -119,11 +122,11 @@ def decompress(self, readers_decoder): if has_bzip2: class BZip2Codec(Codec): - def compress(self, data): + def compress(self, data: bytes) -> Tuple[bytes, int]: compressed_data = bz2.compress(data) return compressed_data, len(compressed_data) - def decompress(self, readers_decoder): + def decompress(self, readers_decoder: avro.io.BinaryDecoder) -> avro.io.BinaryDecoder: length = readers_decoder.read_long() data = readers_decoder.read(length) uncompressed = bz2.decompress(data) @@ -133,13 +136,13 @@ def decompress(self, readers_decoder): if has_snappy: class SnappyCodec(Codec): - def compress(self, data): + def compress(self, data: bytes) -> Tuple[bytes, int]: compressed_data = snappy.compress(data) # A 4-byte, big-endian CRC32 checksum compressed_data += STRUCT_CRC32.pack(binascii.crc32(data) & 0xFFFFFFFF) return compressed_data, len(compressed_data) - def decompress(self, readers_decoder): + def decompress(self, readers_decoder: avro.io.BinaryDecoder) -> avro.io.BinaryDecoder: # Compressed data includes a 4-byte CRC32 checksum length = readers_decoder.read_long() data = readers_decoder.read(length - 4) @@ -148,20 +151,20 @@ def decompress(self, readers_decoder): self.check_crc32(uncompressed, checksum) return avro.io.BinaryDecoder(io.BytesIO(uncompressed)) - def check_crc32(self, bytes, checksum): + def check_crc32(self, bytes_: bytes, checksum: Union[array[int], bytes, bytearray, memoryview, mmap]) -> None: checksum = STRUCT_CRC32.unpack(checksum)[0] - if binascii.crc32(bytes) & 0xFFFFFFFF != checksum: + if binascii.crc32(bytes_) & 0xFFFFFFFF != checksum: raise avro.errors.AvroException("Checksum failure") if has_zstandard: class ZstandardCodec(Codec): - def compress(self, data): + def compress(self, data: bytes) -> Tuple[bytes, int]: compressed_data = zstd.ZstdCompressor().compress(data) return compressed_data, len(compressed_data) - def decompress(self, readers_decoder): + def decompress(self, readers_decoder: avro.io.BinaryDecoder) -> avro.io.BinaryDecoder: length = readers_decoder.read_long() data = readers_decoder.read(length) uncompressed = bytearray() @@ -175,7 +178,7 @@ def decompress(self, readers_decoder): return avro.io.BinaryDecoder(io.BytesIO(uncompressed)) -def get_codec(codec_name): +def get_codec(codec_name: str) -> Codec: codec_name = codec_name.lower() if codec_name == "null": return NullCodec() @@ -190,7 +193,7 @@ def get_codec(codec_name): raise avro.errors.UnsupportedCodec(f"Unsupported codec: {codec_name}. (Is it installed?)") -def supported_codec_names(): +def supported_codec_names() -> List[str]: codec_names = ["null", "deflate"] if has_bzip2: codec_names.append("bzip2") From 3d6e23e4e20383540df550b2055385ec48b3fff4 Mon Sep 17 00:00:00 2001 From: "Michael A. Smith" Date: Mon, 31 May 2021 16:13:06 -0400 Subject: [PATCH 3/6] AVRO-2921: Add Type Hints to avro.compatibility --- lang/py/avro/compatibility.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/lang/py/avro/compatibility.py b/lang/py/avro/compatibility.py index 1200773cfb3..d927bd32485 100644 --- a/lang/py/avro/compatibility.py +++ b/lang/py/avro/compatibility.py @@ -18,7 +18,7 @@ # limitations under the License. from copy import copy from enum import Enum -from typing import List, Optional, Set, cast +from typing import Any, List, MutableMapping, Optional, Set, cast from avro.errors import AvroRuntimeException from avro.schema import ( @@ -85,7 +85,7 @@ def __init__( incompatibilities: List[SchemaIncompatibilityType] = None, messages: Optional[Set[str]] = None, locations: Optional[Set[str]] = None, - ): + ) -> None: self.locations = locations or {"/"} self.messages = messages or set() self.compatibility = compatibility @@ -128,16 +128,15 @@ def __init__(self, reader: Schema, writer: Schema) -> None: def __hash__(self) -> int: return id(self.reader) ^ id(self.writer) - def __eq__(self, other) -> bool: - if not isinstance(other, ReaderWriter): - return False - return self.reader is other.reader and self.writer is other.writer + def __eq__(self, other: Any) -> bool: + return isinstance(other, ReaderWriter) and (self.reader is other.reader) and (self.writer is other.writer) class ReaderWriterCompatibilityChecker: ROOT_REFERENCE_TOKEN = "/" + memoize_map: MutableMapping[ReaderWriter, SchemaCompatibilityResult] - def __init__(self): + def __init__(self) -> None: self.memoize_map = {} def get_compatibility( @@ -374,9 +373,7 @@ def incompatible(incompat_type: SchemaIncompatibilityType, message: str, locatio def schema_name_equals(reader: NamedSchema, writer: NamedSchema) -> bool: - if reader.name == writer.name: - return True - return writer.fullname in reader.props.get("aliases", []) + return (reader.name == writer.name) or (writer.fullname in reader.props.get("aliases", [])) def lookup_writer_field(writer_schema: RecordSchema, reader_field: Field) -> Optional[Field]: From 63c342d99e69de3bc19d8ba33ea751dea7f69fa6 Mon Sep 17 00:00:00 2001 From: "Michael A. Smith" Date: Mon, 31 May 2021 17:48:40 -0400 Subject: [PATCH 4/6] AVRO-2921: Add Type Hints to avro.datafile --- lang/py/avro/datafile.py | 142 +++++++++++++++++++++++---------------- 1 file changed, 84 insertions(+), 58 deletions(-) diff --git a/lang/py/avro/datafile.py b/lang/py/avro/datafile.py index 8a660bc8019..818856c8a7a 100644 --- a/lang/py/avro/datafile.py +++ b/lang/py/avro/datafile.py @@ -19,11 +19,15 @@ """Read/Write Avro File Object Containers.""" +import abc import io import json import os import random import zlib +from contextlib import AbstractContextManager +from types import TracebackType +from typing import Any, MutableMapping, Optional, Type import avro.codecs import avro.errors @@ -64,60 +68,63 @@ # -class _DataFile: +class _DataFile(AbstractContextManager, abc.ABC): """Mixin for methods common to both reading and writing.""" block_count = 0 - _meta = None - _sync_marker = None + _meta: Optional[MutableMapping[str, bytes]] = None + _sync_marker: bytes - def __enter__(self): + @abc.abstractmethod + def close(self) -> None: + """Close the datafile""" + + def __enter__(self) -> _DataFile: return self - def __exit__(self, type, value, traceback): + def __exit__(self, type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: # Perform a close if there's no exception if type is None: self.close() - def get_meta(self, key): + def get_meta(self, key: str) -> Optional[bytes]: return self.meta.get(key) - def set_meta(self, key, val): + def set_meta(self, key: str, val: bytes) -> None: self.meta[key] = val @property - def sync_marker(self): + def sync_marker(self) -> bytes: return self._sync_marker @property - def meta(self): + def meta(self) -> MutableMapping[str, bytes]: """Read-only dictionary of metadata for this datafile.""" if self._meta is None: self._meta = {} return self._meta @property - def codec(self): + def codec(self) -> str: """Meta are stored as bytes, but codec is returned as a string.""" - try: - return self.get_meta(CODEC_KEY).decode() - except AttributeError: - return "null" + codec = self.get_meta(CODEC_KEY) + return codec.decode() if codec else "null" @codec.setter - def codec(self, value): + def codec(self, value: str) -> None: """Meta are stored as bytes, but codec is set as a string.""" if value not in VALID_CODECS: raise avro.errors.DataFileException(f"Unknown codec: {value!r}") self.set_meta(CODEC_KEY, value.encode()) @property - def schema(self): + def schema(self) -> Optional[str]: """Meta are stored as bytes, but schema is returned as a string.""" - return self.get_meta(SCHEMA_KEY).decode() + schema = self.get_meta(SCHEMA_KEY) + return schema.decode() if schema else None @schema.setter - def schema(self, value): + def schema(self, value: str) -> None: """Meta are stored as bytes, but schema is set as a string.""" self.set_meta(SCHEMA_KEY, value.encode()) @@ -125,7 +132,9 @@ def schema(self, value): class DataFileWriter(_DataFile): # TODO(hammer): make 'encoder' a metadata property - def __init__(self, writer, datum_writer, writers_schema=None, codec=NULL_CODEC): + def __init__( + self, writer: io.BytesIO, datum_writer: avro.io.DatumWriter, writers_schema: Optional[avro.schema.Schema] = None, codec: str = NULL_CODEC + ) -> None: """ If the schema is not present, presume we're appending. @@ -162,31 +171,33 @@ def __init__(self, writer, datum_writer, writers_schema=None, codec=NULL_CODEC): self._header_written = True # read-only properties - writer = property(lambda self: self._writer) - encoder = property(lambda self: self._encoder) - datum_writer = property(lambda self: self._datum_writer) - buffer_writer = property(lambda self: self._buffer_writer) - buffer_encoder = property(lambda self: self._buffer_encoder) + @property + def writer(self) -> io.BytesIO: + return self._writer - def _write_header(self): - header = {"magic": MAGIC, "meta": self.meta, "sync": self.sync_marker} - self.datum_writer.write_data(META_SCHEMA, header, self.encoder) - self._header_written = True + @property + def encoder(self) -> avro.io.BinaryEncoder: + return self._encoder @property - def codec(self): - """Meta are stored as bytes, but codec is returned as a string.""" - return self.get_meta(CODEC_KEY).decode() + def datum_writer(self) -> avro.io.DatumWriter: + return self._datum_writer - @codec.setter - def codec(self, value): - """Meta are stored as bytes, but codec is set as a string.""" - if value not in VALID_CODECS: - raise avro.errors.DataFileException(f"Unknown codec: {value!r}") - self.set_meta(CODEC_KEY, value.encode()) + @property + def buffer_writer(self) -> io.BytesIO: + return self._buffer_writer + + @property + def buffer_encoder(self) -> avro.io.BinaryEncoder: + return self._buffer_encoder + + def _write_header(self) -> None: + header = {"magic": MAGIC, "meta": self.meta, "sync": self.sync_marker} + self.datum_writer.write_data(META_SCHEMA, header, self.encoder) + self._header_written = True # TODO(hammer): make a schema for blocks and use datum_writer - def _write_block(self): + def _write_block(self) -> None: if not self._header_written: self._write_header() @@ -213,7 +224,7 @@ def _write_block(self): self.buffer_writer.seek(0) self.block_count = 0 - def append(self, datum): + def append(self, datum: Any) -> None: """Append a datum to the file.""" self.datum_writer.write(datum, self.buffer_encoder) self.block_count += 1 @@ -222,7 +233,7 @@ def append(self, datum): if self.buffer_writer.tell() >= SYNC_INTERVAL: self._write_block() - def sync(self): + def sync(self) -> int: """ Return the current position as a value that may be passed to DataFileReader.seek(long). Forces the end of the current block, @@ -231,12 +242,12 @@ def sync(self): self._write_block() return self.writer.tell() - def flush(self): + def flush(self) -> None: """Flush the current state of the file, including metadata.""" self._write_block() self.writer.flush() - def close(self): + def close(self) -> None: """Close the file.""" self.flush() self.writer.close() @@ -245,13 +256,14 @@ def close(self): class DataFileReader(_DataFile): """Read files written by DataFileWriter.""" + _datum_decoder: Optional[avro.io.BinaryDecoder] = None + # TODO(hammer): allow user to specify expected schema? # TODO(hammer): allow user to specify the encoder - def __init__(self, reader, datum_reader): + def __init__(self, reader: io.BytesIO, datum_reader: avro.io.DatumReader) -> None: self._reader = reader self._raw_decoder = avro.io.BinaryDecoder(reader) - self._datum_decoder = None # Maybe reset at every block. self._datum_reader = datum_reader # read the header: magic, meta, sync @@ -264,17 +276,31 @@ def __init__(self, reader, datum_reader): self.block_count = 0 self.datum_reader.writers_schema = avro.schema.parse(self.schema) - def __iter__(self): + def __iter__(self) -> DataFileReader: return self # read-only properties - reader = property(lambda self: self._reader) - raw_decoder = property(lambda self: self._raw_decoder) - datum_decoder = property(lambda self: self._datum_decoder) - datum_reader = property(lambda self: self._datum_reader) - file_length = property(lambda self: self._file_length) + @property + def reader(self) -> io.BytesIO: + return self._reader + + @property + def raw_decoder(self) -> avro.io.BinaryDecoder: + return self._raw_decoder + + @property + def datum_decoder(self) -> Optional[avro.io.BinaryDecoder]: + return self._datum_decoder + + @property + def datum_reader(self) -> avro.io.DatumReader: + return self._datum_reader + + @property + def file_length(self) -> int: + return self._file_length - def determine_file_length(self): + def determine_file_length(self) -> int: """ Get file length and leave file cursor where we found it. """ @@ -284,10 +310,10 @@ def determine_file_length(self): self.reader.seek(remember_pos) return file_length - def is_EOF(self): + def is_EOF(self) -> bool: return self.reader.tell() == self.file_length - def _read_header(self): + def _read_header(self) -> None: # seek to the beginning of the file to get magic block self.reader.seek(0, 0) @@ -304,12 +330,12 @@ def _read_header(self): # set sync marker self._sync_marker = header["sync"] - def _read_block_header(self): + def _read_block_header(self) -> None: self.block_count = self.raw_decoder.read_long() codec = avro.codecs.get_codec(self.codec) self._datum_decoder = codec.decompress(self.raw_decoder) - def _skip_sync(self): + def _skip_sync(self) -> bool: """ Read the length of the sync marker; if it matches the sync marker, return True. Otherwise, seek back to where we started and return False. @@ -320,7 +346,7 @@ def _skip_sync(self): return False return True - def __next__(self): + def __next__(self) -> Any: """Return the next datum in the file.""" while self.block_count == 0: if self.is_EOF() or (self._skip_sync() and self.is_EOF()): @@ -331,12 +357,12 @@ def __next__(self): self.block_count -= 1 return datum - def close(self): + def close(self) -> None: """Close this reader.""" self.reader.close() -def generate_sixteen_random_bytes(): +def generate_sixteen_random_bytes() -> bytes: try: return os.urandom(16) except NotImplementedError: From f51e859fe1c289a911589b977d3df0bc8e72aa9e Mon Sep 17 00:00:00 2001 From: "Michael A. Smith" Date: Mon, 31 May 2021 17:51:35 -0400 Subject: [PATCH 5/6] AVRO-2921: Add Type Hints to avro.errors --- lang/py/avro/errors.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lang/py/avro/errors.py b/lang/py/avro/errors.py index b613e559fa7..e4176e9b75d 100644 --- a/lang/py/avro/errors.py +++ b/lang/py/avro/errors.py @@ -18,9 +18,12 @@ # limitations under the License. import json +from typing import Any, Optional +from avro.schema import Schema -def _safe_pretty(schema): + +def _safe_pretty(schema: Any) -> Any: """Try to pretty-print a schema, but never raise an exception within another exception.""" try: return json.dumps(json.loads(str(schema)), indent=2) @@ -51,7 +54,7 @@ class IgnoredLogicalType(AvroWarning): class AvroTypeException(AvroException): """Raised when datum is not an example of schema.""" - def __init__(self, *args): + def __init__(self, *args: Any) -> None: try: expected_schema, datum = args[:2] except (IndexError, ValueError): @@ -62,7 +65,7 @@ def __init__(self, *args): class AvroOutOfScaleException(AvroTypeException): """Raised when attempting to write a decimal datum with an exponent too large for the decimal schema.""" - def __init__(self, *args): + def __init__(self, *args: Any) -> None: try: scale, datum, exponent = args[:3] except (IndexError, ValueError): @@ -71,7 +74,7 @@ def __init__(self, *args): class SchemaResolutionException(AvroException): - def __init__(self, fail_msg, writers_schema=None, readers_schema=None, *args): + def __init__(self, fail_msg: str, writers_schema: Optional[Schema] = None, readers_schema: Optional[Schema] = None, *args: Any) -> None: writers_message = f"\nWriter's Schema: {_safe_pretty(writers_schema)}" if writers_schema else "" readers_message = f"\nReader's Schema: {_safe_pretty(readers_schema)}" if readers_schema else "" super().__init__((fail_msg or "") + writers_message + readers_message, *args) From 5aca8e2dcd9a9bc6d550c6de8ccab73a790d314f Mon Sep 17 00:00:00 2001 From: "Michael A. Smith" Date: Mon, 31 May 2021 23:29:02 -0400 Subject: [PATCH 6/6] AVRO-2921: Add Type Hints to avro.io --- lang/py/avro/datafile.py | 2 + lang/py/avro/errors.py | 8 + lang/py/avro/io.py | 403 +++++++++++++++++++++------------------ lang/py/avro/schema.py | 7 +- 4 files changed, 231 insertions(+), 189 deletions(-) diff --git a/lang/py/avro/datafile.py b/lang/py/avro/datafile.py index 818856c8a7a..909483d5944 100644 --- a/lang/py/avro/datafile.py +++ b/lang/py/avro/datafile.py @@ -353,6 +353,8 @@ def __next__(self) -> Any: raise StopIteration self._read_block_header() + if self.datum_decoder is None: + raise avro.errors.UninitializedDataFileException datum = self.datum_reader.read(self.datum_decoder) self.block_count -= 1 return datum diff --git a/lang/py/avro/errors.py b/lang/py/avro/errors.py index e4176e9b75d..49e7f57e87e 100644 --- a/lang/py/avro/errors.py +++ b/lang/py/avro/errors.py @@ -106,3 +106,11 @@ class UsageError(RuntimeError, AvroException): class AvroRuntimeException(RuntimeError, AvroException): """Raised when compatibility parsing encounters an unknown type""" + + +class UninitializedDataFileException(AvroException): + """Raised when attempting to use a DataFile without a datum decoder.""" + + +class UninitializedDatumIOException(AvroException): + """Raised when attempting to use a DatumReader or DatumWriter without a schema.""" diff --git a/lang/py/avro/io.py b/lang/py/avro/io.py index 0978b3abe1e..75c4750000a 100644 --- a/lang/py/avro/io.py +++ b/lang/py/avro/io.py @@ -89,11 +89,14 @@ import decimal import json import struct +from typing import Any, Dict, List, Optional, Sequence, TypeVar, cast import avro.constants import avro.errors import avro.timezones +DatumType = Any # Best if we can refine this at some point. + # # Constants # @@ -115,7 +118,7 @@ ValidationNode = collections.namedtuple("ValidationNode", ["schema", "datum", "name"]) -def validate(expected_schema, datum, raise_on_error=False): +def validate(expected_schema: avro.schema.Schema, datum: DatumType, raise_on_error: bool = False) -> bool: """Return True if the provided datum is valid for the expected schema If raise_on_error is passed and True, then raise a validation error @@ -130,8 +133,7 @@ def validate(expected_schema, datum, raise_on_error=False): :returns: True if datum is valid for expected_schema, False if not. """ # use a FIFO queue to process schema nodes breadth first. - nodes = collections.deque() - nodes.append(ValidationNode(expected_schema, datum, getattr(expected_schema, "name", None))) + nodes = collections.deque([ValidationNode(expected_schema, datum, getattr(expected_schema, "name", None))]) while nodes: current_node = nodes.popleft() @@ -139,31 +141,26 @@ def validate(expected_schema, datum, raise_on_error=False): # _validate_node returns the node for iteration if it is valid. Or it returns None # if current_node.schema.type in {'array', 'map', 'record'}: validated_schema = current_node.schema.validate(current_node.datum) - if validated_schema: - valid_node = ValidationNode(validated_schema, current_node.datum, current_node.name) - else: - valid_node = None - # else: - # valid_node = _validate_node(current_node) - - if valid_node is not None: - # if there are children of this node to append, do so. - for child_node in _iterate_node(valid_node): - nodes.append(child_node) - else: + valid_node = ValidationNode(validated_schema, current_node.datum, current_node.name) if validated_schema else None + + if valid_node is None: # the current node was not valid. if raise_on_error: raise avro.errors.AvroTypeException(current_node.schema, current_node.datum) - else: - # preserve the prior validation behavior of returning false when there are problems. - return False + # preserve the prior validation behavior of returning false when there are problems. + return False + # if there are children of this node to append, do so. + for child_node in _iterate_node(valid_node): + nodes.append(child_node) return True -def _iterate_node(node): - for item in _ITERATORS.get(node.schema.type, _default_iterator)(node): - yield ValidationNode(*item) +from typing import Generator + + +def _iterate_node(node: ValidationNode) -> Generator[ValidationNode, None, None]: + return (ValidationNode(*item) for item in _ITERATORS.get(node.schema.type, _default_iterator)(node)) ############# @@ -171,7 +168,7 @@ def _iterate_node(node): ############# -def _default_iterator(_): +def _default_iterator(_: ValidationNode) -> Generator[ValidationNode, None, None]: """Immediately raise StopIteration. This exists to prevent problems with iteration over unsupported container types. @@ -179,26 +176,23 @@ def _default_iterator(_): yield from () -def _record_iterator(node): +def _record_iterator(node: ValidationNode) -> Generator[ValidationNode, None, None]: """Yield each child node of the provided record node.""" schema, datum, name = node - for field in schema.fields: - yield ValidationNode(field.type, datum.get(field.name), field.name) # type: ignore + return (ValidationNode(field.type, datum.get(field.name), field.name) for field in schema.fields) -def _array_iterator(node): +def _array_iterator(node: ValidationNode) -> Generator[ValidationNode, None, None]: """Yield each child node of the provided array node.""" schema, datum, name = node - for item in datum: # type: ignore - yield ValidationNode(schema.items, item, name) + return (ValidationNode(schema.items, item, name) for item in datum) -def _map_iterator(node): +def _map_iterator(node: ValidationNode) -> Generator[ValidationNode, None, None]: """Yield each child node of the provided map node.""" schema, datum, _ = node child_schema = schema.values - for child_name, child_datum in datum.items(): # type: ignore - yield ValidationNode(child_schema, child_datum, child_name) + return (ValidationNode(child_schema, child_datum, child_name) for child_name, child_datum in datum.items()) _ITERATORS = { @@ -212,46 +206,49 @@ def _map_iterator(node): # # Decoder/Encoder # +import io class BinaryDecoder: """Read leaf values.""" - def __init__(self, reader): + def __init__(self, reader: io.BytesIO) -> None: """ reader is a Python object on which we can call read, seek, and tell. """ self._reader = reader # read-only properties - reader = property(lambda self: self._reader) + @property + def reader(self) -> io.BytesIO: + return self._reader - def read(self, n): + def read(self, n: int) -> bytes: """ Read n bytes. """ return self.reader.read(n) - def read_null(self): + def read_null(self) -> None: """ null is written as zero bytes """ return None - def read_boolean(self): + def read_boolean(self) -> bool: """ a boolean is written as a single byte whose value is either 0 (false) or 1 (true). """ return ord(self.read(1)) == 1 - def read_int(self): + def read_int(self) -> int: """ int and long values are written using variable-length, zig-zag coding. """ return self.read_long() - def read_long(self): + def read_long(self) -> int: """ int and long values are written using variable-length, zig-zag coding. """ @@ -265,23 +262,23 @@ def read_long(self): datum = (n >> 1) ^ -(n & 1) return datum - def read_float(self): + def read_float(self) -> float: """ A float is written as 4 bytes. The float is converted into a 32-bit integer using a method equivalent to Java's floatToIntBits and then encoded in little-endian format. """ - return STRUCT_FLOAT.unpack(self.read(4))[0] + return cast(float, STRUCT_FLOAT.unpack(self.read(4))[0]) - def read_double(self): + def read_double(self) -> float: """ A double is written as 8 bytes. The double is converted into a 64-bit integer using a method equivalent to Java's doubleToLongBits and then encoded in little-endian format. """ - return STRUCT_DOUBLE.unpack(self.read(8))[0] + return cast(float, STRUCT_DOUBLE.unpack(self.read(8))[0]) - def read_decimal_from_bytes(self, precision, scale): + def read_decimal_from_bytes(self, precision: int, scale: int) -> decimal.Decimal: """ Decimal bytes are decoded as signed short, int or long depending on the size of bytes. @@ -289,7 +286,7 @@ def read_decimal_from_bytes(self, precision, scale): size = self.read_long() return self.read_decimal_from_fixed(precision, scale, size) - def read_decimal_from_fixed(self, precision, scale, size): + def read_decimal_from_fixed(self, precision: int, scale: int, size: int) -> decimal.Decimal: """ Decimal is encoded as fixed. Fixed instances are encoded using the number of bytes declared in the schema. @@ -318,20 +315,20 @@ def read_decimal_from_fixed(self, precision, scale, size): decimal.getcontext().prec = original_prec return scaled_datum - def read_bytes(self): + def read_bytes(self) -> bytes: """ Bytes are encoded as a long followed by that many bytes of data. """ return self.read(self.read_long()) - def read_utf8(self): + def read_utf8(self) -> str: """ A string is encoded as a long followed by that many bytes of UTF-8 encoded character data. """ return self.read_bytes().decode("utf-8") - def read_date_from_int(self): + def read_date_from_int(self) -> datetime.date: """ int is decoded as python date object. int stores the number of days from @@ -340,7 +337,7 @@ def read_date_from_int(self): days_since_epoch = self.read_int() return datetime.date(1970, 1, 1) + datetime.timedelta(days_since_epoch) - def _build_time_object(self, value, scale_to_micro): + def _build_time_object(self, value: int, scale_to_micro: int) -> datetime.time: value = value * scale_to_micro value, microseconds = divmod(value, 1000000) value, seconds = divmod(value, 60) @@ -349,7 +346,7 @@ def _build_time_object(self, value, scale_to_micro): return datetime.time(hour=hours, minute=minutes, second=seconds, microsecond=microseconds) - def read_time_millis_from_int(self): + def read_time_millis_from_int(self) -> datetime.time: """ int is decoded as python time object which represents the number of milliseconds after midnight, 00:00:00.000. @@ -357,7 +354,7 @@ def read_time_millis_from_int(self): milliseconds = self.read_int() return self._build_time_object(milliseconds, 1000) - def read_time_micros_from_long(self): + def read_time_micros_from_long(self) -> datetime.time: """ long is decoded as python time object which represents the number of microseconds after midnight, 00:00:00.000000. @@ -365,7 +362,7 @@ def read_time_micros_from_long(self): microseconds = self.read_long() return self._build_time_object(microseconds, 1) - def read_timestamp_millis_from_long(self): + def read_timestamp_millis_from_long(self) -> datetime.datetime: """ long is decoded as python datetime object which represents the number of milliseconds from the unix epoch, 1 January 1970. @@ -375,7 +372,7 @@ def read_timestamp_millis_from_long(self): unix_epoch_datetime = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) return unix_epoch_datetime + timedelta - def read_timestamp_micros_from_long(self): + def read_timestamp_micros_from_long(self) -> datetime.datetime: """ long is decoded as python datetime object which represents the number of microseconds from the unix epoch, 1 January 1970. @@ -385,72 +382,73 @@ def read_timestamp_micros_from_long(self): unix_epoch_datetime = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) return unix_epoch_datetime + timedelta - def skip_null(self): + def skip_null(self) -> None: pass - def skip_boolean(self): + def skip_boolean(self) -> None: self.skip(1) - def skip_int(self): + def skip_int(self) -> None: self.skip_long() - def skip_long(self): + def skip_long(self) -> None: b = ord(self.read(1)) while (b & 0x80) != 0: b = ord(self.read(1)) - def skip_float(self): + def skip_float(self) -> None: self.skip(4) - def skip_double(self): + def skip_double(self) -> None: self.skip(8) - def skip_bytes(self): + def skip_bytes(self) -> None: self.skip(self.read_long()) - def skip_utf8(self): + def skip_utf8(self) -> None: self.skip_bytes() - def skip(self, n): + def skip(self, n) -> None: self.reader.seek(self.reader.tell() + n) class BinaryEncoder: """Write leaf values.""" - def __init__(self, writer): + def __init__(self, writer: io.BytesIO) -> None: """ writer is a Python object on which we can call write. """ self._writer = writer # read-only properties - writer = property(lambda self: self._writer) + @property + def writer(self) -> io.BytesIO: + return self._writer - def write(self, datum): + def write(self, datum: DatumType) -> None: """Write an arbitrary datum.""" self.writer.write(datum) - def write_null(self, datum): + def write_null(self, datum: None) -> None: """ null is written as zero bytes """ - pass - def write_boolean(self, datum): + def write_boolean(self, datum: bool) -> None: """ a boolean is written as a single byte whose value is either 0 (false) or 1 (true). """ self.write(bytearray([bool(datum)])) - def write_int(self, datum): + def write_int(self, datum: int) -> None: """ int and long values are written using variable-length, zig-zag coding. """ self.write_long(datum) - def write_long(self, datum): + def write_long(self, datum: int) -> None: """ int and long values are written using variable-length, zig-zag coding. """ @@ -460,7 +458,7 @@ def write_long(self, datum): datum >>= 7 self.write(bytearray([datum])) - def write_float(self, datum): + def write_float(self, datum: float) -> None: """ A float is written as 4 bytes. The float is converted into a 32-bit integer using a method equivalent to @@ -468,7 +466,7 @@ def write_float(self, datum): """ self.write(STRUCT_FLOAT.pack(datum)) - def write_double(self, datum): + def write_double(self, datum: float) -> None: """ A double is written as 8 bytes. The double is converted into a 64-bit integer using a method equivalent to @@ -476,7 +474,7 @@ def write_double(self, datum): """ self.write(STRUCT_DOUBLE.pack(datum)) - def write_decimal_bytes(self, datum, scale): + def write_decimal_bytes(self, datum: decimal.Decimal, scale: int) -> None: """ Decimal in bytes are encoded as long. Since size of packed value in bytes for signed long is 8, 8 bytes are written. @@ -503,7 +501,7 @@ def write_decimal_bytes(self, datum, scale): bits_to_write = packed_bits >> (8 * index) self.write(bytearray([bits_to_write & 0xFF])) - def write_decimal_fixed(self, datum, scale, size): + def write_decimal_fixed(self, datum: decimal.Decimal, scale: int, size: int) -> None: """ Decimal in fixed are encoded as size of fixed bytes. """ @@ -544,22 +542,22 @@ def write_decimal_fixed(self, datum, scale, size): bits_to_write = unscaled_datum >> (8 * index) self.write(bytearray([bits_to_write & 0xFF])) - def write_bytes(self, datum): + def write_bytes(self, datum: bytes) -> None: """ Bytes are encoded as a long followed by that many bytes of data. """ self.write_long(len(datum)) self.write(struct.pack(f"{len(datum)}s", datum)) - def write_utf8(self, datum): + def write_utf8(self, datum: str) -> None: """ A string is encoded as a long followed by that many bytes of UTF-8 encoded character data. """ - datum = datum.encode("utf-8") - self.write_bytes(datum) + bytes_ = datum.encode("utf-8") + self.write_bytes(bytes_) - def write_date_int(self, datum): + def write_date_int(self, datum: datetime.date) -> None: """ Encode python date object as int. It stores the number of days from @@ -568,7 +566,7 @@ def write_date_int(self, datum): delta_date = datum - datetime.date(1970, 1, 1) self.write_int(delta_date.days) - def write_time_millis_int(self, datum): + def write_time_millis_int(self, datum: datetime.time) -> None: """ Encode python time object as int. It stores the number of milliseconds from midnight, 00:00:00.000 @@ -576,7 +574,7 @@ def write_time_millis_int(self, datum): milliseconds = datum.hour * 3600000 + datum.minute * 60000 + datum.second * 1000 + datum.microsecond // 1000 self.write_int(milliseconds) - def write_time_micros_long(self, datum): + def write_time_micros_long(self, datum: datetime.time) -> None: """ Encode python time object as long. It stores the number of microseconds from midnight, 00:00:00.000000 @@ -584,10 +582,10 @@ def write_time_micros_long(self, datum): microseconds = datum.hour * 3600000000 + datum.minute * 60000000 + datum.second * 1000000 + datum.microsecond self.write_long(microseconds) - def _timedelta_total_microseconds(self, timedelta): + def _timedelta_total_microseconds(self, timedelta: datetime.timedelta) -> int: return timedelta.microseconds + (timedelta.seconds + timedelta.days * 24 * 3600) * 10 ** 6 - def write_timestamp_millis_long(self, datum): + def write_timestamp_millis_long(self, datum: datetime.datetime) -> None: """ Encode python datetime object as long. It stores the number of milliseconds from midnight of unix epoch, 1 January 1970. @@ -597,7 +595,7 @@ def write_timestamp_millis_long(self, datum): milliseconds = self._timedelta_total_microseconds(timedelta) // 1000 self.write_long(milliseconds) - def write_timestamp_micros_long(self, datum): + def write_timestamp_micros_long(self, datum: datetime.datetime) -> None: """ Encode python datetime object as long. It stores the number of microseconds from midnight of unix epoch, 1 January 1970. @@ -614,7 +612,7 @@ def write_timestamp_micros_long(self, datum): class DatumReader: """Deserialize Avro-encoded data into a Python data structure.""" - def __init__(self, writers_schema=None, readers_schema=None): + def __init__(self, writers_schema: Optional[avro.schema.Schema] = None, readers_schema: Optional[avro.schema.Schema] = None) -> None: """ As defined in the Avro specification, we call the schema encoded in the data the "writer's schema", and the schema expected by the @@ -623,23 +621,30 @@ def __init__(self, writers_schema=None, readers_schema=None): self._writers_schema = writers_schema self._readers_schema = readers_schema - # read/write properties - def set_writers_schema(self, writers_schema): + @property + def writers_schema(self) -> Optional[avro.schema.Schema]: + return self._writers_schema + + @writers_schema.setter + def writers_schema(self, writers_schema: avro.schema.Schema) -> None: self._writers_schema = writers_schema - writers_schema = property(lambda self: self._writers_schema, set_writers_schema) + @property + def readers_schema(self) -> Optional[avro.schema.Schema]: + return self._readers_schema - def set_readers_schema(self, readers_schema): + @readers_schema.setter + def readers_schema(self, readers_schema: avro.schema.Schema) -> None: self._readers_schema = readers_schema - readers_schema = property(lambda self: self._readers_schema, set_readers_schema) - - def read(self, decoder): + def read(self, decoder: avro.io.BinaryDecoder) -> DatumType: + if self.writers_schema is None: + raise avro.errors.UninitializedDatumIOException if self.readers_schema is None: self.readers_schema = self.writers_schema return self.read_data(self.writers_schema, self.readers_schema, decoder) - def read_data(self, writers_schema, readers_schema, decoder): + def read_data(self, writers_schema: avro.schema.Schema, readers_schema: avro.schema.Schema, decoder: avro.io.BinaryDecoder) -> DatumType: # schema matching if not readers_schema.match(writers_schema): fail_msg = "Schemas do not match." @@ -649,11 +654,12 @@ def read_data(self, writers_schema, readers_schema, decoder): # function dispatch for reading data based on type of writer's schema if writers_schema.type in ["union", "error_union"]: + writers_schema = cast(avro.schema.UnionSchema, writers_schema) return self.read_union(writers_schema, readers_schema, decoder) if readers_schema.type in ["union", "error_union"]: # schema resolution: reader's schema is a union, writer's schema is not - for s in readers_schema.schemas: + for s in cast(avro.schema.UnionSchema, readers_schema).schemas: if s.match(writers_schema): return self.read_data(writers_schema, s, decoder) @@ -663,99 +669,100 @@ def read_data(self, writers_schema, readers_schema, decoder): if writers_schema.type == "null": return decoder.read_null() - elif writers_schema.type == "boolean": + if writers_schema.type == "boolean": return decoder.read_boolean() - elif writers_schema.type == "string": + if writers_schema.type == "string": return decoder.read_utf8() - elif writers_schema.type == "int": + if writers_schema.type == "int": if logical_type == avro.constants.DATE: return decoder.read_date_from_int() if logical_type == avro.constants.TIME_MILLIS: return decoder.read_time_millis_from_int() return decoder.read_int() - elif writers_schema.type == "long": + if writers_schema.type == "long": if logical_type == avro.constants.TIME_MICROS: return decoder.read_time_micros_from_long() - elif logical_type == avro.constants.TIMESTAMP_MILLIS: + if logical_type == avro.constants.TIMESTAMP_MILLIS: return decoder.read_timestamp_millis_from_long() - elif logical_type == avro.constants.TIMESTAMP_MICROS: + if logical_type == avro.constants.TIMESTAMP_MICROS: return decoder.read_timestamp_micros_from_long() - else: - return decoder.read_long() - elif writers_schema.type == "float": + return decoder.read_long() + if writers_schema.type == "float": return decoder.read_float() - elif writers_schema.type == "double": + if writers_schema.type == "double": return decoder.read_double() - elif writers_schema.type == "bytes": + if writers_schema.type == "bytes": if logical_type == "decimal": return decoder.read_decimal_from_bytes( writers_schema.get_prop("precision"), writers_schema.get_prop("scale"), ) - else: - return decoder.read_bytes() - elif writers_schema.type == "fixed": + return decoder.read_bytes() + if writers_schema.type == "fixed": + writers_schema = cast(avro.schema.FixedSchema, writers_schema) if logical_type == "decimal": return decoder.read_decimal_from_fixed( writers_schema.get_prop("precision"), writers_schema.get_prop("scale"), - writers_schema.size, + cast(avro.schema.FixedSchema, writers_schema).size, ) return self.read_fixed(writers_schema, readers_schema, decoder) - elif writers_schema.type == "enum": - return self.read_enum(writers_schema, readers_schema, decoder) - elif writers_schema.type == "array": - return self.read_array(writers_schema, readers_schema, decoder) - elif writers_schema.type == "map": - return self.read_map(writers_schema, readers_schema, decoder) - elif writers_schema.type in ["record", "error", "request"]: - return self.read_record(writers_schema, readers_schema, decoder) - else: - raise avro.errors.AvroException(f"Cannot read unknown schema type: {writers_schema.type}") - - def skip_data(self, writers_schema, decoder): + if writers_schema.type == "enum": + # Schema match ensures correct cast. + return self.read_enum(cast(avro.schema.EnumSchema, writers_schema), cast(avro.schema.EnumSchema, readers_schema), decoder) + if writers_schema.type == "array": + # Schema match ensures correct cast. + return self.read_array(cast(avro.schema.ArraySchema, writers_schema), cast(avro.schema.ArraySchema, readers_schema), decoder) + if writers_schema.type == "map": + # Schema match ensures correct cast. + return self.read_map(cast(avro.schema.MapSchema, writers_schema), cast(avro.schema.MapSchema, readers_schema), decoder) + if writers_schema.type in ["record", "error", "request"]: + # Schema match ensures correct cast. + return self.read_record(cast(avro.schema.RecordSchema, writers_schema), cast(avro.schema.RecordSchema, readers_schema), decoder) + raise avro.errors.AvroException(f"Cannot read unknown schema type: {writers_schema.type}") + + def skip_data(self, writers_schema: avro.schema.Schema, decoder: avro.io.BinaryDecoder) -> None: if writers_schema.type == "null": return decoder.skip_null() - elif writers_schema.type == "boolean": + if writers_schema.type == "boolean": return decoder.skip_boolean() - elif writers_schema.type == "string": + if writers_schema.type == "string": return decoder.skip_utf8() - elif writers_schema.type == "int": + if writers_schema.type == "int": return decoder.skip_int() - elif writers_schema.type == "long": + if writers_schema.type == "long": return decoder.skip_long() - elif writers_schema.type == "float": + if writers_schema.type == "float": return decoder.skip_float() - elif writers_schema.type == "double": + if writers_schema.type == "double": return decoder.skip_double() - elif writers_schema.type == "bytes": + if writers_schema.type == "bytes": return decoder.skip_bytes() - elif writers_schema.type == "fixed": - return self.skip_fixed(writers_schema, decoder) - elif writers_schema.type == "enum": - return self.skip_enum(writers_schema, decoder) - elif writers_schema.type == "array": - return self.skip_array(writers_schema, decoder) - elif writers_schema.type == "map": - return self.skip_map(writers_schema, decoder) - elif writers_schema.type in ["union", "error_union"]: - return self.skip_union(writers_schema, decoder) - elif writers_schema.type in ["record", "error", "request"]: - return self.skip_record(writers_schema, decoder) - else: - raise avro.errors.AvroException(f"Unknown schema type: {writers_schema.type}") + if writers_schema.type == "fixed": + return self.skip_fixed(cast(avro.schema.FixedSchema, writers_schema), decoder) + if writers_schema.type == "enum": + return self.skip_enum(cast(avro.schema.EnumSchema, writers_schema), decoder) + if writers_schema.type == "array": + return self.skip_array(cast(avro.schema.ArraySchema, writers_schema), decoder) + if writers_schema.type == "map": + return self.skip_map(cast(avro.schema.MapSchema, writers_schema), decoder) + if writers_schema.type in ["union", "error_union"]: + return self.skip_union(cast(avro.schema.UnionSchema, writers_schema), decoder) + if writers_schema.type in ["record", "error", "request"]: + return self.skip_record(cast(avro.schema.RecordSchema, writers_schema), decoder) + raise avro.errors.AvroException(f"Unknown schema type: {writers_schema.type}") - def read_fixed(self, writers_schema, readers_schema, decoder): + def read_fixed(self, writers_schema: avro.schema.FixedSchema, readers_schema: avro.schema.Schema, decoder: avro.io.BinaryDecoder) -> bytes: """ Fixed instances are encoded using the number of bytes declared in the schema. """ return decoder.read(writers_schema.size) - def skip_fixed(self, writers_schema, decoder): + def skip_fixed(self, writers_schema: avro.schema.FixedSchema, decoder: avro.io.BinaryDecoder) -> None: return decoder.skip(writers_schema.size) - def read_enum(self, writers_schema, readers_schema, decoder): + def read_enum(self, writers_schema: avro.schema.EnumSchema, readers_schema: avro.schema.EnumSchema, decoder: avro.io.BinaryDecoder) -> str: """ An enum is encoded by a int, representing the zero-based position of the symbol in the schema. @@ -764,7 +771,7 @@ def read_enum(self, writers_schema, readers_schema, decoder): index_of_symbol = decoder.read_int() if index_of_symbol >= len(writers_schema.symbols): raise avro.errors.SchemaResolutionException( - f"Can't access enum index {index_of_symbole} for enum with {len(writers_schema.symbols)} symbols", writers_schema, readers_schema + f"Can't access enum index {index_of_symbol} for enum with {len(writers_schema.symbols)} symbols", writers_schema, readers_schema ) read_symbol = writers_schema.symbols[index_of_symbol] @@ -774,10 +781,12 @@ def read_enum(self, writers_schema, readers_schema, decoder): return read_symbol - def skip_enum(self, writers_schema, decoder): + def skip_enum(self, writers_schema: avro.schema.EnumSchema, decoder: avro.io.BinaryDecoder) -> None: return decoder.skip_int() - def read_array(self, writers_schema, readers_schema, decoder): + def read_array( + self, writers_schema: avro.schema.ArraySchema, readers_schema: avro.schema.ArraySchema, decoder: avro.io.BinaryDecoder + ) -> Sequence[DatumType]: """ Arrays are encoded as a series of blocks. @@ -792,7 +801,7 @@ def read_array(self, writers_schema, readers_schema, decoder): The actual count in this case is the absolute value of the count written. """ - read_items = [] + read_items: List[DatumType] = [] block_count = decoder.read_long() while block_count != 0: if block_count < 0: @@ -803,7 +812,7 @@ def read_array(self, writers_schema, readers_schema, decoder): block_count = decoder.read_long() return read_items - def skip_array(self, writers_schema, decoder): + def skip_array(self, writers_schema: avro.schema.ArraySchema, decoder: avro.io.BinaryDecoder) -> None: block_count = decoder.read_long() while block_count != 0: if block_count < 0: @@ -814,7 +823,9 @@ def skip_array(self, writers_schema, decoder): self.skip_data(writers_schema.items, decoder) block_count = decoder.read_long() - def read_map(self, writers_schema, readers_schema, decoder): + def read_map( + self, writers_schema: avro.schema.MapSchema, readers_schema: avro.schema.MapSchema, decoder: avro.io.BinaryDecoder + ) -> Dict[str, DatumType]: """ Maps are encoded as a series of blocks. @@ -829,7 +840,7 @@ def read_map(self, writers_schema, readers_schema, decoder): The actual count in this case is the absolute value of the count written. """ - read_items = {} + read_items: Dict[str, DatumType] = {} block_count = decoder.read_long() while block_count != 0: if block_count < 0: @@ -841,7 +852,7 @@ def read_map(self, writers_schema, readers_schema, decoder): block_count = decoder.read_long() return read_items - def skip_map(self, writers_schema, decoder): + def skip_map(self, writers_schema: avro.schema.MapSchema, decoder: avro.io.BinaryDecoder) -> None: block_count = decoder.read_long() while block_count != 0: if block_count < 0: @@ -853,7 +864,7 @@ def skip_map(self, writers_schema, decoder): self.skip_data(writers_schema.values, decoder) block_count = decoder.read_long() - def read_union(self, writers_schema, readers_schema, decoder): + def read_union(self, writers_schema: avro.schema.UnionSchema, readers_schema: avro.schema.Schema, decoder: avro.io.BinaryDecoder) -> DatumType: """ A union is encoded by first writing an int value indicating the zero-based position within the union of the schema of its value. @@ -861,16 +872,15 @@ def read_union(self, writers_schema, readers_schema, decoder): """ # schema resolution index_of_schema = int(decoder.read_long()) - if index_of_schema >= len(writers_schema.schemas): + try: + selected_writers_schema = writers_schema.schemas[index_of_schema] + except IndexError: raise avro.errors.SchemaResolutionException( f"Can't access branch index {index_of_schema} for union with {len(writers_schema.schemas)} branches", writers_schema, readers_schema ) - selected_writers_schema = writers_schema.schemas[index_of_schema] + return self.read_data(selected_writers_schema, readers_schema, decoder) # read data - # read data - return self.read_data(selected_writers_schema, readers_schema, decoder) - - def skip_union(self, writers_schema, decoder): + def skip_union(self, writers_schema: avro.schema.UnionSchema, decoder: avro.io.BinaryDecoder) -> None: index_of_schema = int(decoder.read_long()) if index_of_schema >= len(writers_schema.schemas): raise avro.errors.SchemaResolutionException( @@ -878,7 +888,9 @@ def skip_union(self, writers_schema, decoder): ) return self.skip_data(writers_schema.schemas[index_of_schema], decoder) - def read_record(self, writers_schema, readers_schema, decoder): + def read_record( + self, writers_schema: avro.schema.RecordSchema, readers_schema: avro.schema.RecordSchema, decoder: avro.io.BinaryDecoder + ) -> Dict[str, DatumType]: """ A record is encoded by encoding the values of its fields in the order that they are declared. In other words, a record @@ -900,12 +912,11 @@ def read_record(self, writers_schema, readers_schema, decoder): """ # schema resolution readers_fields_dict = readers_schema.fields_dict - read_record = {} + read_record: Dict[str, DatumType] = {} for field in writers_schema.fields: readers_field = readers_fields_dict.get(field.name) if readers_field is not None: - field_val = self.read_data(field.type, readers_field.type, decoder) - read_record[field.name] = field_val + read_record[field.name] = self.read_data(field.type, readers_field.type, decoder) else: self.skip_data(field.type, decoder) @@ -920,41 +931,45 @@ def read_record(self, writers_schema, readers_schema, decoder): read_record[field.name] = field_val return read_record - def skip_record(self, writers_schema, decoder): + def skip_record(self, writers_schema: avro.schema.RecordSchema, decoder: avro.io.BinaryDecoder) -> None: for field in writers_schema.fields: self.skip_data(field.type, decoder) - def _read_default_value(self, field_schema, default_value): + def _read_default_value(self, field_schema: avro.schema.Schema, default_value: DatumType) -> DatumType: """ Basically a JSON Decoder? """ if field_schema.type == "null": return None - elif field_schema.type == "boolean": + if field_schema.type == "boolean": return bool(default_value) - elif field_schema.type == "int": + if field_schema.type == "int": return int(default_value) - elif field_schema.type == "long": + if field_schema.type == "long": return int(default_value) - elif field_schema.type in ["float", "double"]: + if field_schema.type in ["float", "double"]: return float(default_value) - elif field_schema.type in ["enum", "fixed", "string", "bytes"]: + if field_schema.type in ["enum", "fixed", "string", "bytes"]: return default_value - elif field_schema.type == "array": + if field_schema.type == "array": + field_schema = cast(avro.schema.ArraySchema, field_schema) read_array = [] for json_val in default_value: item_val = self._read_default_value(field_schema.items, json_val) read_array.append(item_val) return read_array - elif field_schema.type == "map": + if field_schema.type == "map": + field_schema = cast(avro.schema.MapSchema, field_schema) read_map = {} for key, json_val in default_value.items(): map_val = self._read_default_value(field_schema.values, json_val) read_map[key] = map_val return read_map - elif field_schema.type in ["union", "error_union"]: + if field_schema.type in ["union", "error_union"]: + field_schema = cast(avro.schema.UnionSchema, field_schema) return self._read_default_value(field_schema.schemas[0], default_value) - elif field_schema.type == "record": + if field_schema.type == "record": + field_schema = cast(avro.schema.RecordSchema, field_schema) read_record = {} for field in field_schema.fields: json_val = default_value.get(field.name) @@ -963,27 +978,33 @@ def _read_default_value(self, field_schema, default_value): field_val = self._read_default_value(field.type, json_val) read_record[field.name] = field_val return read_record - else: - raise avro.errors.AvroException(f"Unknown type: {field_schema.type}") + raise avro.errors.AvroException(f"Unknown type: {field_schema.type}") class DatumWriter: """DatumWriter for generic python objects.""" - def __init__(self, writers_schema=None): + _writers_schema: Optional[avro.schema.Schema] = None + + def __init__(self, writers_schema: Optional[avro.schema.Schema] = None) -> None: self._writers_schema = writers_schema # read/write properties - def set_writers_schema(self, writers_schema): - self._writers_schema = writers_schema + @property + def writers_schema(self) -> Optional[avro.schema.Schema]: + return self._writers_schema - writers_schema = property(lambda self: self._writers_schema, set_writers_schema) + @writers_schema.setter + def writers_schema(self, writers_schema: avro.schema.Schema) -> None: + self._writers_schema = writers_schema - def write(self, datum, encoder): + def write(self, datum: DatumType, encoder: avro.io.BinaryEncoder) -> None: + if self.writers_schema is None: + raise avro.errors.UninitializedDatumIOException validate(self.writers_schema, datum, raise_on_error=True) self.write_data(self.writers_schema, datum, encoder) - def write_data(self, writers_schema, datum, encoder): + def write_data(self, writers_schema: avro.schema.Schema, datum: Any, encoder: avro.io.BinaryEncoder) -> None: # function dispatch to write datum logical_type = getattr(writers_schema, "logical_type", None) if writers_schema.type == "null": @@ -1018,6 +1039,7 @@ def write_data(self, writers_schema, datum, encoder): else: encoder.write_bytes(datum) elif writers_schema.type == "fixed": + writers_schema = cast(avro.schema.FixedSchema, writers_schema) if logical_type == "decimal": encoder.write_decimal_fixed( datum, @@ -1027,26 +1049,31 @@ def write_data(self, writers_schema, datum, encoder): else: self.write_fixed(writers_schema, datum, encoder) elif writers_schema.type == "enum": + writers_schema = cast(avro.schema.EnumSchema, writers_schema) self.write_enum(writers_schema, datum, encoder) elif writers_schema.type == "array": + writers_schema = cast(avro.schema.ArraySchema, writers_schema) self.write_array(writers_schema, datum, encoder) elif writers_schema.type == "map": + writers_schema = cast(avro.schema.MapSchema, writers_schema) self.write_map(writers_schema, datum, encoder) elif writers_schema.type in ["union", "error_union"]: + writers_schema = cast(avro.schema.UnionSchema, writers_schema) self.write_union(writers_schema, datum, encoder) elif writers_schema.type in ["record", "error", "request"]: + writers_schema = cast(avro.schema.RecordSchema, writers_schema) self.write_record(writers_schema, datum, encoder) else: raise avro.errors.AvroException(f"Unknown type: {writers_schema.type}") - def write_fixed(self, writers_schema, datum, encoder): + def write_fixed(self, writers_schema: avro.schema.FixedSchema, datum: int, encoder: avro.io.BinaryEncoder) -> None: """ Fixed instances are encoded using the number of bytes declared in the schema. """ encoder.write(datum) - def write_enum(self, writers_schema, datum, encoder): + def write_enum(self, writers_schema: avro.schema.EnumSchema, datum: str, encoder: avro.io.BinaryEncoder) -> None: """ An enum is encoded by a int, representing the zero-based position of the symbol in the schema. @@ -1054,7 +1081,7 @@ def write_enum(self, writers_schema, datum, encoder): index_of_datum = writers_schema.symbols.index(datum) encoder.write_int(index_of_datum) - def write_array(self, writers_schema, datum, encoder): + def write_array(self, writers_schema: avro.schema.ArraySchema, datum: List[Any], encoder: avro.io.BinaryEncoder) -> None: """ Arrays are encoded as a series of blocks. @@ -1075,7 +1102,7 @@ def write_array(self, writers_schema, datum, encoder): self.write_data(writers_schema.items, item, encoder) encoder.write_long(0) - def write_map(self, writers_schema, datum, encoder): + def write_map(self, writers_schema: avro.schema.MapSchema, datum: Dict[str, Any], encoder: avro.io.BinaryEncoder) -> None: """ Maps are encoded as a series of blocks. @@ -1097,7 +1124,7 @@ def write_map(self, writers_schema, datum, encoder): self.write_data(writers_schema.values, val, encoder) encoder.write_long(0) - def write_union(self, writers_schema, datum, encoder): + def write_union(self, writers_schema: avro.schema.UnionSchema, datum: Any, encoder: avro.io.BinaryEncoder) -> None: """ A union is encoded by first writing an int value indicating the zero-based position within the union of the schema of its value. @@ -1115,7 +1142,7 @@ def write_union(self, writers_schema, datum, encoder): encoder.write_long(index_of_schema) self.write_data(writers_schema.schemas[index_of_schema], datum, encoder) - def write_record(self, writers_schema, datum, encoder): + def write_record(self, writers_schema: avro.schema.RecordSchema, datum: Dict[str, Any], encoder: avro.io.BinaryEncoder) -> None: """ A record is encoded by encoding the values of its fields in the order that they are declared. In other words, a record diff --git a/lang/py/avro/schema.py b/lang/py/avro/schema.py index 43c264d1b04..91703eb1d47 100644 --- a/lang/py/avro/schema.py +++ b/lang/py/avro/schema.py @@ -48,6 +48,7 @@ import sys import uuid import warnings +from typing import List, Sequence, cast import avro.constants import avro.errors @@ -762,7 +763,11 @@ def __init__( self.set_prop("doc", doc) # read-only properties - symbols = property(lambda self: self.get_prop("symbols")) + @property + def symbols(self) -> Sequence[str]: + symbols = self.get_prop("symbols") + return cast(List[str], symbols) + doc = property(lambda self: self.get_prop("doc")) def match(self, writer):