Skip to content

Commit

Permalink
Add allow_extra_keys param to SchemaValidator, refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Apr 28, 2021
1 parent b7b5b4a commit 62d7c71
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 39 deletions.
75 changes: 54 additions & 21 deletions hivemind/dht/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import binascii
import re
from typing import Type
from typing import Any, Dict, Type

import pydantic

Expand All @@ -19,14 +19,19 @@ class SchemaValidator(RecordValidatorBase):
This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
"""

def __init__(self, schema: pydantic.BaseModel):
def __init__(self, schema: pydantic.BaseModel, *, allow_extra_keys: bool=True):
"""
:param schema: The Pydantic model (a subclass of pydantic.BaseModel).
You must always use strict types for the number fields
(e.g. ``StrictInt`` instead of ``int``,
``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
See the validate() docstring for details.
:param allow_extra_keys: Whether to allow keys that are not defined in the schema.
If a SchemaValidator is merged with another SchemaValidator, this option applies to
keys that are not defined in each of the schemas.
"""

self._alias_to_name = {}
Expand All @@ -40,6 +45,7 @@ def __init__(self, schema: pydantic.BaseModel):
schema.Config.extra = pydantic.Extra.forbid

self._schemas = [schema]
self._allow_extra_keys = allow_extra_keys

def validate(self, record: DHTRecord) -> bool:
"""
Expand All @@ -63,42 +69,59 @@ def validate(self, record: DHTRecord) -> bool:
.. [3] https://pydantic-docs.helpmanual.io/usage/types/#strict-types
"""

key_alias = self._key_id_to_str(record.key)
deserialized_value = DHTProtocol.serializer.loads(record.value)
if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
deserialized_record = {key_alias: {deserialized_subkey: deserialized_value}}
else:
if isinstance(deserialized_value, dict):
logger.warning(
f'Record {record} contains an improperly serialized dictionary (you must use '
f'a DictionaryDHTValue of serialized values instead of a `dict` subclass)')
return False
deserialized_record = {key_alias: deserialized_value}
try:
record = self._deserialize_record(record)
except ValueError as e:
logger.warning(e)
return False
[key_alias] = list(record.keys())

parsed_record = None
n_outside_schema = 0
validation_errors = []
for schema in self._schemas:
try:
parsed_record = schema.parse_obj(deserialized_record)
parsed_record = schema.parse_obj(record)
except pydantic.ValidationError as e:
validation_errors.append(e)
if self._is_failed_due_to_extra_field(e):
n_outside_schema += 1
else:
validation_errors.append(e)
continue

parsed_value = parsed_record.dict(by_alias=True)[key_alias]
if parsed_value != deserialized_record[key_alias]:
if parsed_value != record[key_alias]:
validation_errors.append(ValueError(
f"Value {deserialized_record[key_alias]} needed type conversions to match "
f"Value {record[key_alias]} needed type conversions to match "
f"the schema: {parsed_value}. Type conversions are not allowed"))
else:
return True

readable_record = {self._alias_to_name.get(key_alias, key_alias):
deserialized_record[key_alias]}
readable_record = {self._alias_to_name.get(key_alias, key_alias): record[key_alias]}

if n_outside_schema == len(self._schemas):
if not self._allow_extra_keys:
logger.warning(f"Record {readable_record} contains a field that "
f"is not defined in each of the schemas")
return self._allow_extra_keys

logger.warning(
f"Record {readable_record} doesn't match any of the schemas: {validation_errors}")
return False

@staticmethod
def _deserialize_record(record: DHTRecord) -> Dict[str, Any]:
key_alias = SchemaValidator._key_id_to_str(record.key)
deserialized_value = DHTProtocol.serializer.loads(record.value)
if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
return {key_alias: {deserialized_subkey: deserialized_value}}
else:
if isinstance(deserialized_value, dict):
raise ValueError(
f'Record {record} contains an improperly serialized dictionary (you must use '
f'a DictionaryDHTValue of serialized values instead of a `dict` subclass)')
return {key_alias: deserialized_value}

@staticmethod
def _key_id_to_str(key_id: bytes) -> str:
"""
Expand All @@ -108,12 +131,22 @@ def _key_id_to_str(key_id: bytes) -> str:

return binascii.hexlify(key_id).decode()

@staticmethod
def _is_failed_due_to_extra_field(exc: pydantic.ValidationError):
inner_errors = exc.errors()
return (
len(inner_errors) == 1 and
inner_errors[0]['type'] == 'value_error.extra' and
len(inner_errors[0]['loc']) == 1 # Require the extra field to be on the top level
)

def merge_with(self, other: RecordValidatorBase) -> bool:
if not isinstance(other, SchemaValidator):
return False

self._alias_to_name.update(other._alias_to_name)
self._schemas.extend(other._schemas)
self._allow_extra_keys = self._allow_extra_keys or other._allow_extra_keys
return True


Expand Down
53 changes: 37 additions & 16 deletions tests/test_dht_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,6 @@ class Schema(BaseModel):
return alice, bob


@pytest.mark.forked
@pytest.mark.asyncio
async def test_keys_outside_schema(dht_nodes_with_schema):
alice, bob = dht_nodes_with_schema

assert not await bob.store(b'unknown_key', b'foo_bar', get_dht_time() + 10)

for peer in [alice, bob]:
assert (await peer.get(b'unknown_key', latest=True)) is None


@pytest.mark.forked
@pytest.mark.asyncio
async def test_expecting_regular_value(dht_nodes_with_schema):
Expand Down Expand Up @@ -105,6 +94,34 @@ async def test_expecting_public_keys(dht_nodes_with_schema):
dictionary[b'uid[owner:public-key]'].value == b'foo_bar')


@pytest.mark.forked
@pytest.mark.asyncio
async def test_keys_outside_schema(dht_nodes_with_schema):
class Schema(BaseModel):
some_field: StrictInt

class MergedSchema(BaseModel):
another_field: StrictInt

for allow_extra_keys in [False, True]:
validator = SchemaValidator(Schema, allow_extra_keys=allow_extra_keys)
assert validator.merge_with(SchemaValidator(MergedSchema, allow_extra_keys=False))

alice = await DHTNode.create(record_validator=validator)
bob = await DHTNode.create(
record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])

store_ok = await bob.store(b'unknown_key', b'foo_bar', get_dht_time() + 10)
assert store_ok == allow_extra_keys

for peer in [alice, bob]:
result = await peer.get(b'unknown_key', latest=True)
if allow_extra_keys:
assert result.value == b'foo_bar'
else:
assert result is None


@pytest.mark.forked
@pytest.mark.asyncio
async def test_merging_schema_validators(dht_nodes_with_schema):
Expand All @@ -126,18 +143,22 @@ class ThirdSchema(BaseModel):
another_field: StrictInt # Allow it to be a StrictInt as well

for schema in [SecondSchema, ThirdSchema]:
new_validator = SchemaValidator(schema)
new_validator = SchemaValidator(schema, allow_extra_keys=False)
for peer in [alice, bob]:
assert peer.protocol.record_validator.merge_with(new_validator)

assert await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10)
assert await bob.store(b'some_field', 777, get_dht_time() + 10)
assert await bob.store(b'another_field', 'string_value', get_dht_time() + 10)
assert not await bob.store(b'some_field', 'string_value', get_dht_time() + 10)
assert await bob.store(b'another_field', 42, get_dht_time() + 10)
assert not await bob.store(b'unknown_key', 777, get_dht_time() + 10)
assert await bob.store(b'another_field', 'string_value', get_dht_time() + 10)

# Unkown keys are allowed since the first schema is created with allow_extra_keys=True
assert await bob.store(b'unknown_key', 999, get_dht_time() + 10)

for peer in [alice, bob]:
assert (await peer.get(b'experiment_name', latest=True)).value == b'foo_bar'
assert (await peer.get(b'some_field', latest=True)).value == 777
assert (await peer.get(b'another_field', latest=True)).value == 42
assert (await peer.get(b'unknown_key', latest=True)) is None
assert (await peer.get(b'another_field', latest=True)).value == 'string_value'

assert (await peer.get(b'unknown_key', latest=True)).value == 999
4 changes: 2 additions & 2 deletions tests/test_dht_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class SchemaB(BaseModel):
def validators_for_app():
# Each application may add its own validator set
return {
'A': [RSASignatureValidator(), SchemaValidator(SchemaA)],
'B': [SchemaValidator(SchemaB), RSASignatureValidator()],
'A': [RSASignatureValidator(), SchemaValidator(SchemaA, allow_extra_keys=False)],
'B': [SchemaValidator(SchemaB, allow_extra_keys=False), RSASignatureValidator()],
}


Expand Down

0 comments on commit 62d7c71

Please sign in to comment.