diff --git a/batconf/manager.py b/batconf/manager.py index cfc9352..1f41ee6 100644 --- a/batconf/manager.py +++ b/batconf/manager.py @@ -4,6 +4,7 @@ from .source import SourceList +@typing.runtime_checkable class ConfigProtocol(typing.Protocol): __dataclass_fields__: dict __module__: str diff --git a/batconf/sources/dataclass.py b/batconf/sources/dataclass.py index a229e84..29c7040 100644 --- a/batconf/sources/dataclass.py +++ b/batconf/sources/dataclass.py @@ -1,24 +1,25 @@ from dataclasses import ( - dataclass, fields, - is_dataclass, _MISSING_TYPE, ) +from typing import Any + from ..source import SourceInterface +from ..manager import ConfigProtocol class DataclassConfig(SourceInterface): - def __init__(self, ConfigClass: dataclass): + def __init__(self, ConfigClass: ConfigProtocol): '''Extract default values from the Config dataclass Properties without defaults are set to None ''' self._root = ConfigClass.__module__ - self._data = dict() + self._data: dict = dict() for field in fields(ConfigClass): - if is_dataclass(field.type): + if isinstance(field.type, ConfigProtocol): self._data[field.name] = DataclassConfig(field.type) elif type(field.default) is _MISSING_TYPE: self._data[field.name] = None @@ -35,7 +36,7 @@ def get(self, key: str, module: str = None): else: path = key.split('.') - conf = self._data + conf: Any = self._data for k in path: if not (conf := conf.get(k)): return conf