Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement combining validators #249

Merged
merged 11 commits into from
Apr 28, 2021
2 changes: 1 addition & 1 deletion hivemind/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from hivemind.utils import *
from hivemind.optim import *

__version__ = '0.9.7'
__version__ = '0.9.8'
23 changes: 20 additions & 3 deletions hivemind/dht/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import ctypes
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Sequence, Union, Callable, Awaitable, TypeVar
from functools import partial
from typing import Iterable, List, Optional, Sequence, Union, Callable, Awaitable, TypeVar

import hivemind
from hivemind.client import RemoteExpert
from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
from hivemind.dht.routing import DHTValue, DHTKey, Subkey
from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
from hivemind.utils.networking import Hostname, Endpoint, strip_port
from hivemind.utils import MPFuture, get_logger, switch_to_uvloop, ValueWithExpiration, await_cancelled, get_dht_time

Expand All @@ -49,12 +51,14 @@ class DHT(mp.Process):

def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
expiration: float = 300, **kwargs):
expiration: float = 300, record_validators: Iterable[RecordValidatorBase] = (),
**kwargs):
super().__init__()
assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
self.default_expiration = expiration
self._record_validator = CompositeValidator(record_validators)
self._port = mp.Value(ctypes.c_int32, 0) # initialized after dht starts
self._pipe, self.pipe = mp.Pipe(duplex=True)
self.ready = mp.Event()
Expand All @@ -70,7 +74,8 @@ def run(self) -> None:
async def _run():
node = await DHTNode.create(
initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
num_workers=self.max_workers or 1, **self.kwargs)
num_workers=self.max_workers or 1, record_validator=self._record_validator,
**self.kwargs)
if node.port is not None:
self._port.value = node.port
self.ready.set()
Expand Down Expand Up @@ -190,6 +195,18 @@ async def _run_coroutine(self, node: DHTNode, coro: Callable[[DHT, DHTNode], Awa
if not future.done():
future.set_exception(e)

def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
if not self.ready.is_set():
raise RuntimeError(
"Can't append new validators before the DHT process has started. "
"Consider adding them to the initial list via DHT.__init__(record_validators=...)")

self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))

async def _add_validators(
self, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
node.protocol.record_validator.extend(record_validators)

def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
"""
Get this machine's visible address by requesting other peers or using pre-specified network addresses.
Expand Down
27 changes: 27 additions & 0 deletions hivemind/dht/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,30 @@ def strip_value(self, record: DHTRecord) -> bytes:

def _serialize_record(self, record: DHTRecord) -> bytes:
return MSGPackSerializer.dumps(dataclasses.astuple(record))

@property
def priority(self) -> int:
# On validation, this validator must be executed before validators
# that deserialize the record
return 10

def merge_with(self, other: RecordValidatorBase) -> bool:
if not isinstance(other, RSASignatureValidator):
return False

# Ignore another RSASignatureValidator instance (it doesn't make sense to have several
# instances of this class) and report successful merge
return True

def __getstate__(self):
state = self.__dict__.copy()
# Serializes the private key to make the class instances picklable
state['_private_key'] = self._private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=serialization.NoEncryption())
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
99 changes: 74 additions & 25 deletions hivemind/dht/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import binascii
import re
from typing import Type
from typing import Any, Dict, Type

import pydantic

Expand All @@ -19,26 +19,33 @@ class SchemaValidator(RecordValidatorBase):
This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
"""

def __init__(self, schema: pydantic.BaseModel):
def __init__(self, schema: pydantic.BaseModel, *, allow_extra_keys: bool=True):
"""
:param schema: The Pydantic model (a subclass of pydantic.BaseModel).

You must always use strict types for the number fields
(e.g. ``StrictInt`` instead of ``int``,
``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
See the validate() docstring for details.

:param allow_extra_keys: Whether to allow keys that are not defined in the schema.

If a SchemaValidator is merged with another SchemaValidator, this option applies to
keys that are not defined in each of the schemas.
"""

self._alias_to_name = {}

for field in schema.__fields__.values():
field.alias = self._key_id_to_str(DHTID.generate(source=field.name.encode()).to_bytes())
self._alias_to_name[field.alias] = field.name

# Because validate() interface provides one key at a time
field.required = False
schema.Config.extra = pydantic.Extra.forbid

schema.Config.extra = pydantic.Extra.allow
self._schema = schema
self._schemas = [schema]
self._allow_extra_keys = allow_extra_keys

def validate(self, record: DHTRecord) -> bool:
"""
Expand All @@ -62,34 +69,58 @@ def validate(self, record: DHTRecord) -> bool:
.. [3] https://pydantic-docs.helpmanual.io/usage/types/#strict-types
"""

key_alias = self._key_id_to_str(record.key)
try:
record = self._deserialize_record(record)
except ValueError as e:
logger.warning(e)
return False
[key_alias] = list(record.keys())

n_outside_schema = 0
validation_errors = []
for schema in self._schemas:
try:
parsed_record = schema.parse_obj(record)
except pydantic.ValidationError as e:
if self._is_failed_due_to_extra_field(e):
n_outside_schema += 1
else:
validation_errors.append(e)
continue

parsed_value = parsed_record.dict(by_alias=True)[key_alias]
if parsed_value != record[key_alias]:
validation_errors.append(ValueError(
f"Value {record[key_alias]} needed type conversions to match "
f"the schema: {parsed_value}. Type conversions are not allowed"))
else:
return True

readable_record = {self._alias_to_name.get(key_alias, key_alias): record[key_alias]}

if n_outside_schema == len(self._schemas):
if not self._allow_extra_keys:
logger.warning(f"Record {readable_record} contains a field that "
f"is not defined in each of the schemas")
return self._allow_extra_keys

logger.warning(
f"Record {readable_record} doesn't match any of the schemas: {validation_errors}")
return False

@staticmethod
def _deserialize_record(record: DHTRecord) -> Dict[str, Any]:
key_alias = SchemaValidator._key_id_to_str(record.key)
deserialized_value = DHTProtocol.serializer.loads(record.value)
if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
deserialized_record = {key_alias: {deserialized_subkey: deserialized_value}}
return {key_alias: {deserialized_subkey: deserialized_value}}
else:
if isinstance(deserialized_value, dict):
logger.warning(
raise ValueError(
f'Record {record} contains an improperly serialized dictionary (you must use '
f'a DictionaryDHTValue of serialized values instead of a `dict` subclass)')
return False
deserialized_record = {key_alias: deserialized_value}

try:
parsed_record = self._schema.parse_obj(deserialized_record)
except pydantic.ValidationError as e:
readable_record = {self._alias_to_name.get(key_alias, key_alias):
deserialized_record[key_alias]}
logger.warning(f"Record {readable_record} doesn't match the schema: {e}")
return False

parsed_value = parsed_record.dict(by_alias=True)[key_alias]
if parsed_value != deserialized_record[key_alias]:
logger.warning(
f"Value {deserialized_record[key_alias]} needed type conversions to match "
f" the schema: {parsed_value}. Type conversions are not allowed")
return False
return True
return {key_alias: deserialized_value}

@staticmethod
def _key_id_to_str(key_id: bytes) -> str:
Expand All @@ -100,6 +131,24 @@ def _key_id_to_str(key_id: bytes) -> str:

return binascii.hexlify(key_id).decode()

@staticmethod
def _is_failed_due_to_extra_field(exc: pydantic.ValidationError):
inner_errors = exc.errors()
return (
len(inner_errors) == 1 and
inner_errors[0]['type'] == 'value_error.extra' and
len(inner_errors[0]['loc']) == 1 # Require the extra field to be on the top level
)

def merge_with(self, other: RecordValidatorBase) -> bool:
if not isinstance(other, SchemaValidator):
return False

self._alias_to_name.update(other._alias_to_name)
self._schemas.extend(other._schemas)
self._allow_extra_keys = self._allow_extra_keys or other._allow_extra_keys
return True


def conbytes(*, regex: bytes=None, **kwargs) -> Type[pydantic.BaseModel]:
"""
Expand Down
68 changes: 68 additions & 0 deletions hivemind/dht/validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
from abc import ABC, abstractmethod
from typing import Iterable


@dataclasses.dataclass(init=True, repr=True, frozen=True)
Expand Down Expand Up @@ -52,3 +53,70 @@ def strip_value(self, record: DHTRecord) -> bytes:
"""

return record.value

@property
def priority(self) -> int:
"""
Defines the order of applying this validator with respect to other validators.

The validators are applied:
- In order of increasing priority for signing a record
- In order of decreasing priority for validating and stripping a record
"""

return 0

def merge_with(self, other: 'RecordValidatorBase') -> bool:
"""
By default, all validators are applied sequentially (i.e. we require all validate() calls
to return True for a record to be validated successfully).

However, you may want to define another policy for combining your validator classes
(e.g. for schema validators, we want to require only one validate() call to return True
because each validator bears a part of the schema).

This can be achieved with overriding merge_with(). It should:

- Return True if it has successfully merged the `other` validator to `self`,
so that `self` became a validator that combines the old `self` and `other` using
the necessary policy. In this case, `other` should remain unchanged.

- Return False if the merging has not happened. In this case, both `self` and `other`
should remain unchanged. The DHT will try merging `other` to another validator or
add it as a separate validator (to be applied sequentially).
"""

return False


class CompositeValidator(RecordValidatorBase):
def __init__(self, validators: Iterable[RecordValidatorBase]=()):
self._validators = []
self.extend(validators)

def extend(self, validators: Iterable[RecordValidatorBase]) -> None:
for new_validator in validators:
for existing_validator in self._validators:
if existing_validator.merge_with(new_validator):
break
else:
self._validators.append(new_validator)
self._validators.sort(key=lambda item: item.priority)

def validate(self, record: DHTRecord) -> bool:
for i, validator in enumerate(reversed(self._validators)):
if not validator.validate(record):
return False
if i < len(self._validators) - 1:
record = dataclasses.replace(record, value=validator.strip_value(record))
return True

def sign_value(self, record: DHTRecord) -> bytes:
for validator in self._validators:
record = dataclasses.replace(record, value=validator.sign_value(record))
return record.value

def strip_value(self, record: DHTRecord) -> bytes:
for validator in reversed(self._validators):
record = dataclasses.replace(record, value=validator.strip_value(record))
return record.value
Loading