From d31a93372b771de3a8f37d95c3fd25e7ac875fee Mon Sep 17 00:00:00 2001 From: Jarry Shaw Date: Tue, 7 Feb 2023 22:55:54 -0800 Subject: [PATCH] revised schema implementation (preparing for Protocol integration) --- pcapkit/protocols/schema/misc/null.py | 6 +++++ pcapkit/protocols/schema/schema.py | 35 +++++++++++++++++++-------- pcapkit/utilities/warnings.py | 5 ++++ 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/pcapkit/protocols/schema/misc/null.py b/pcapkit/protocols/schema/misc/null.py index 943ddb271..76699dc64 100644 --- a/pcapkit/protocols/schema/misc/null.py +++ b/pcapkit/protocols/schema/misc/null.py @@ -8,3 +8,9 @@ class NoPayload(Schema): """Schema for empty payload.""" + + # NOTE: We add this method for both type annotation and to mark that this + # class accepts no arguments at runtime, since :class:`Schema` explicitly + # skipped those whose :attr:`__dict__` is empty :obj:`dict`. + def __init__(self) -> 'None': # pylint: disable=super-init-not-called + pass diff --git a/pcapkit/protocols/schema/schema.py b/pcapkit/protocols/schema/schema.py index 486c3b2ff..66c09c8d1 100644 --- a/pcapkit/protocols/schema/schema.py +++ b/pcapkit/protocols/schema/schema.py @@ -4,16 +4,20 @@ import collections.abc import io import itertools -from typing import TYPE_CHECKING, Generic, TypeVar +import math +from typing import TYPE_CHECKING, Generic, TypeVar, cast from pcapkit.corekit.fields.field import NoValue, _Field from pcapkit.corekit.fields.misc import ConditionalField, PayloadField from pcapkit.utilities.compat import Mapping from pcapkit.utilities.exceptions import NoDefaultValue, ProtocolUnbound +from pcapkit.utilities.warnings import UnknownFieldWarning, warn if TYPE_CHECKING: from typing import IO, Any, Iterable, Iterator, Optional + from typing_extensions import Self + from pcapkit.corekit.fields.field import NoValueType __all__ = ['Schema'] @@ -56,7 +60,7 @@ def __new__(cls, *args: 'VT', **kwargs: 'VT') -> 'Schema': # pylint: disable=un cls.__map_reverse__ = {} temp = ['__map__', '__map_reverse__', '__builtin__', - '__fields__', '__buffer__', 'pack', 'unpack'] + '__fields__', '__buffer__', '__updated__'] for obj in cls.mro(): temp.extend(dir(obj)) cls.__builtin__ = set(temp) @@ -130,7 +134,7 @@ def __new__(cls, *args: 'VT', **kwargs: 'VT') -> 'Schema': # pylint: disable=un # NOTE: We only create the attributes for the instance itself, # to avoid creating shared attributes for the class. - self.__buffer__ = {} + self.__buffer__ = {field.name: NoValue for field in self.__fields__} self.__updated__ = False return self @@ -158,6 +162,10 @@ def __update__(self, dict_: 'Optional[Mapping[str, VT] | Iterable[tuple[str, VT] data_iter = itertools.chain(dict_, kwargs.items()) for (key, value) in data_iter: + if key not in self.__buffer__: + warn(f'{key!r} is not a valid field name', UnknownFieldWarning) + continue + if key in self.__builtin__: new_key = f'_{__name__}{key}' @@ -282,7 +290,7 @@ def to_bytes(self) -> 'bytes': """Convert :class:`Schema` into :obj:`bytes`.""" return self.__bytes__() - def pack(self) -> 'None': + def pack(self) -> 'bytes': """Pack :class:`Schema` into :obj:`bytes`.""" buffer = self.__buffer__ packet = self.__dict__ @@ -318,31 +326,36 @@ def pack(self) -> 'None': buffer[field.name] = temp self.__updated__ = False + return self.__bytes__() @classmethod - def unpack(cls, data: 'bytes | IO[bytes]') -> 'None': + def unpack(cls, data: 'bytes | IO[bytes]', length: 'Optional[int]' = None) -> 'Self': # type: ignore[valid-type] """Unpack :obj:`bytes` into :class:`Schema`. Args: data: Packed data. + length: Length of data. """ self = cls.__new__(cls) if isinstance(data, bytes): - length = len(data) + length = len(data) if length is None else length data = io.BytesIO(data) else: - current = data.tell() - length = data.seek(0, io.SEEK_END) - current - data.seek(current) + length = cast('int', math.inf) if length is None else length packet = self.__dict__ buffer = self.__buffer__ for field in self.__fields__: if isinstance(field, PayloadField): - payload_length = field.test_length(packet, length) + if math.isinf(length): + current = data.tell() + default_length = data.seek(0, io.SEEK_END) - current + data.seek(current) + + payload_length = field.test_length(packet, default_length) payload = data.read(payload_length) buffer[field.name] = payload @@ -359,6 +372,8 @@ def unpack(cls, data: 'bytes | IO[bytes]') -> 'None': value = field(packet).unpack(byte, packet) setattr(self, field.name, value) + length -= field.length self.__updated__ = False + return self diff --git a/pcapkit/utilities/warnings.py b/pcapkit/utilities/warnings.py index 3a86e00e2..8f6d33b66 100644 --- a/pcapkit/utilities/warnings.py +++ b/pcapkit/utilities/warnings.py @@ -24,6 +24,7 @@ # RuntimeWarning 'FileWarning', 'LayerWarning', 'ProtocolWarning', 'AttributeWarning', 'DevModeWarning', 'VendorRequestWarning', 'VendorRuntimeWarning', + 'UnknownFieldWarning', # ResourceWarning 'DPKTWarning', 'ScapyWarning', 'PySharkWarning', 'EmojiWarning', 'VendorWarning', @@ -114,6 +115,10 @@ class VendorRuntimeWarning(BaseWarning, RuntimeWarning): """Vendor failed during runtime.""" +class UnknownFieldWarning(BaseWarning, RuntimeWarning): + """Unknown field.""" + + ############################################################################## # ResourceWarning session. ##############################################################################