diff --git a/docs/src/code/algo/asha.rst b/docs/src/code/algo/asha.rst index 638f5e1f6..948544981 100644 --- a/docs/src/code/algo/asha.rst +++ b/docs/src/code/algo/asha.rst @@ -3,6 +3,6 @@ Asynchronous Successive Halving Algorithm Can't build documentation because of import order. Sphinx is loading ``orion.algo.asha`` before ``orion.algo`` and therefore -there is a cycle between the definition of ``OptimizationAlgorithm`` and +there is a cycle between the definition of ``BaseAlgorithm`` and ``ASHA`` as the meta-class ``Factory`` is trying to import ``ASHA``. `PR #135 `_ should get rid of this problem. diff --git a/docs/src/code/core/io/database.rst b/docs/src/code/core/io/database.rst index 11d2059f3..099a7c985 100644 --- a/docs/src/code/core/io/database.rst +++ b/docs/src/code/core/io/database.rst @@ -12,5 +12,3 @@ Databases .. automodule:: orion.core.io.database :members: :show-inheritance: - - diff --git a/docs/src/user/parallel.rst b/docs/src/user/parallel.rst index d3f5d3948..08e04314f 100644 --- a/docs/src/user/parallel.rst +++ b/docs/src/user/parallel.rst @@ -42,7 +42,7 @@ Executor backends It is also possible to execute multiple workers using the argument ``--n-workers`` in commandline or ``experiment.workon(n_workers)`` using the python API. The workers will work together using the same mechanisms explained above, but an -:class:`orion.executor.base.Executor` backend will be used in addition +:class:`orion.executor.base.BaseExecutor` backend will be used in addition to spawn the workers and maintain them alive. The default backend is :ref:`executor-joblib`. You can configure it diff --git a/docs/src/user/storage.rst b/docs/src/user/storage.rst index 00157b46d..dbbbef9d9 100644 --- a/docs/src/user/storage.rst +++ b/docs/src/user/storage.rst @@ -477,18 +477,22 @@ Here's an example on how you could remove an experiment -------------- .. automethod:: orion.core.io.database.Database.read + :noindex: :hidden:`write` --------------- .. automethod:: orion.core.io.database.Database.write + :noindex: :hidden:`remove` ---------------- .. automethod:: orion.core.io.database.Database.remove + :noindex: :hidden:`read_and_write` ------------------------ .. automethod:: orion.core.io.database.Database.read_and_write + :noindex: diff --git a/setup.py b/setup.py index 07f67bcaf..941f272b9 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ "console_scripts": [ "orion = orion.core.cli:main", ], - "OptimizationAlgorithm": [ + "BaseAlgorithm": [ "random = orion.algo.random:Random", "gridsearch = orion.algo.gridsearch:GridSearch", "asha = orion.algo.asha:ASHA", @@ -54,11 +54,11 @@ "tpe = orion.algo.tpe:TPE", "EvolutionES = orion.algo.evolution_es:EvolutionES", ], - "Storage": [ + "BaseStorageProtocol": [ "track = orion.storage.track:Track", "legacy = orion.storage.legacy:Legacy", ], - "Executor": [ + "BaseExecutor": [ "singleexecutor = orion.executor.single_backend:SingleExecutor", "joblib = orion.executor.joblib_backend:Joblib", "dask = orion.executor.dask_backend:Dask", diff --git a/src/orion/algo/base.py b/src/orion/algo/base.py index 8b5cd5fd0..df1515ea8 100644 --- a/src/orion/algo/base.py +++ b/src/orion/algo/base.py @@ -4,7 +4,14 @@ ===================== Formulation of a general search algorithm with respect to some objective. -Algorithm implementations must inherit from `orion.algo.base.OptimizationAlgorithm`. +Algorithm implementations must inherit from `orion.algo.base.BaseAlgorithm`. + +Algorithms can be created using `algo_factory.create()`. + +Examples +-------- +>>> algo_factory.create('random', space, seed=1) +>>> algo_factory.create('some_fancy_algo', space, **some_fancy_algo_config) """ import copy @@ -13,7 +20,7 @@ from abc import ABCMeta, abstractmethod from orion.algo.space import Fidelity -from orion.core.utils import Factory +from orion.core.utils import GenericFactory log = logging.getLogger(__name__) @@ -24,7 +31,7 @@ def infer_trial_id(point): # pylint: disable=too-many-public-methods -class BaseAlgorithm(object, metaclass=ABCMeta): +class BaseAlgorithm: """Base class describing what an algorithm can do. Parameters @@ -109,22 +116,12 @@ def __init__(self, space, **kwargs): self._param_names = list(kwargs.keys()) # Instantiate tunable parameters of an algorithm for varname, param in kwargs.items(): - # Check if tunable element is another algorithm - if isinstance(param, dict) and len(param) == 1: - subalgo_type = list(param)[0] - subalgo_kwargs = param[subalgo_type] - if isinstance(subalgo_kwargs, dict): - param = OptimizationAlgorithm(subalgo_type, space, **subalgo_kwargs) - elif ( - isinstance(param, str) and param.lower() in OptimizationAlgorithm.types - ): - # pylint: disable=too-many-function-args - param = OptimizationAlgorithm(param, space) - elif varname == "seed": - self.seed_rng(param) - setattr(self, varname, param) + # TODO: move this inside an initialization function. + if hasattr(self, "seed"): + self.seed_rng(self.seed) + def seed_rng(self, seed): """Seed the state of the random number generator. @@ -394,10 +391,7 @@ def configuration(self): for attrname in self._param_names: if attrname.startswith("_"): # Do not log _space or others in conf continue - attr = getattr(self, attrname) - if isinstance(attr, BaseAlgorithm): - attr = attr.configuration - dict_form[attrname] = attr + dict_form[attrname] = getattr(self, attrname) return {self.__class__.__name__.lower(): dict_form} @@ -407,16 +401,9 @@ def space(self): return self._space @space.setter - def space(self, space_): - """Propagate changes in defined space to possibly nested algorithms.""" - self._space = space_ - for attr in self.__dict__.values(): - if isinstance(attr, BaseAlgorithm): - attr.space = space_ - + def space(self, space): + """Set space.""" + self._space = space -# pylint: disable=too-few-public-methods,abstract-method -class OptimizationAlgorithm(BaseAlgorithm, metaclass=Factory): - """Class used to inject dependency on an algorithm implementation.""" - pass +algo_factory = GenericFactory(BaseAlgorithm) diff --git a/src/orion/benchmark/__init__.py b/src/orion/benchmark/__init__.py index aceb185f1..b7634472e 100644 --- a/src/orion/benchmark/__init__.py +++ b/src/orion/benchmark/__init__.py @@ -11,7 +11,7 @@ import orion.core from orion.client import create_experiment -from orion.executor.base import Executor +from orion.executor.base import executor_factory class Benchmark: @@ -51,7 +51,7 @@ class Benchmark: storage: dict, optional Configuration of the storage backend. - executor: `orion.executor.base.Executor`, optional + executor: `orion.executor.base.BaseExecutor`, optional Executor to run the benchmark experiments """ @@ -62,7 +62,7 @@ def __init__(self, name, algorithms, targets, storage=None, executor=None): self.targets = targets self.metadata = {} self.storage_config = storage - self.executor = executor or Executor( + self.executor = executor or executor_factory.create( orion.core.config.worker.executor, n_workers=orion.core.config.worker.n_workers, **orion.core.config.worker.executor_configuration, diff --git a/src/orion/benchmark/assessment/__init__.py b/src/orion/benchmark/assessment/__init__.py index 5815fac7b..7bef8c2ec 100644 --- a/src/orion/benchmark/assessment/__init__.py +++ b/src/orion/benchmark/assessment/__init__.py @@ -5,6 +5,11 @@ from .averagerank import AverageRank from .averageresult import AverageResult -from .base import BaseAssess +from .base import BenchmarkAssessment, bench_assessment_factory -__all__ = ["BaseAssess", "AverageRank", "AverageResult"] +__all__ = [ + "bench_assessment_factory", + "BenchmarkAssessment", + "AverageRank", + "AverageResult", +] diff --git a/src/orion/benchmark/assessment/averagerank.py b/src/orion/benchmark/assessment/averagerank.py index 9b4c54c81..d858af536 100644 --- a/src/orion/benchmark/assessment/averagerank.py +++ b/src/orion/benchmark/assessment/averagerank.py @@ -7,11 +7,11 @@ from collections import defaultdict -from orion.benchmark.assessment.base import BaseAssess +from orion.benchmark.assessment.base import BenchmarkAssessment from orion.plotting.base import rankings -class AverageRank(BaseAssess): +class AverageRank(BenchmarkAssessment): """ Evaluate the average performance (objective value) between different search algorithms from the rank perspective at different time steps (trial number). diff --git a/src/orion/benchmark/assessment/averageresult.py b/src/orion/benchmark/assessment/averageresult.py index c11fdba8d..56275b33e 100644 --- a/src/orion/benchmark/assessment/averageresult.py +++ b/src/orion/benchmark/assessment/averageresult.py @@ -6,11 +6,11 @@ """ from collections import defaultdict -from orion.benchmark.assessment.base import BaseAssess +from orion.benchmark.assessment.base import BenchmarkAssessment from orion.plotting.base import regrets -class AverageResult(BaseAssess): +class AverageResult(BenchmarkAssessment): """ Evaluate the average performance (objective value) for each search algorithm at different time steps (trial number). diff --git a/src/orion/benchmark/assessment/base.py b/src/orion/benchmark/assessment/base.py index 45856e471..e1b521dc9 100644 --- a/src/orion/benchmark/assessment/base.py +++ b/src/orion/benchmark/assessment/base.py @@ -7,10 +7,10 @@ from abc import ABC, abstractmethod -from orion.core.utils import Factory +from orion.core.utils import GenericFactory -class BaseAssess(ABC): +class BenchmarkAssessment(ABC): """Base class describing what an assessment can do. Parameters @@ -52,8 +52,4 @@ def configuration(self): return {self.__class__.__qualname__: self._param_names} -# pylint: disable=too-few-public-methods,abstract-method -class BenchmarkAssessment(BaseAssess, metaclass=Factory): - """Class used to inject dependency on an assessment implementation.""" - - pass +bench_assessment_factory = GenericFactory(BenchmarkAssessment) diff --git a/src/orion/benchmark/benchmark_client.py b/src/orion/benchmark/benchmark_client.py index 444e6008e..50d1646da 100644 --- a/src/orion/benchmark/benchmark_client.py +++ b/src/orion/benchmark/benchmark_client.py @@ -8,8 +8,8 @@ import logging from orion.benchmark import Benchmark, Study -from orion.benchmark.assessment.base import BenchmarkAssessment -from orion.benchmark.task.base import BenchmarkTask +from orion.benchmark.assessment.base import bench_assessment_factory +from orion.benchmark.task.base import bench_task_factory from orion.core.io.database import DuplicateKeyError from orion.core.utils.exceptions import NoConfigurationError from orion.storage.base import get_storage, setup_storage @@ -38,7 +38,7 @@ def get_or_create_benchmark( Task objects storage: dict, optional Configuration of the storage backend. - executor: `orion.executor.base.Executor`, optional + executor: `orion.executor.base.BaseExecutor`, optional Executor to run the benchmark experiments debug: bool, optional If using in debug mode, the storage config is overrided with legacy:EphemeralDB. @@ -94,11 +94,11 @@ def get_or_create_benchmark( def _get_task(name, **kwargs): - return BenchmarkTask(of_type=name, **kwargs) + return bench_task_factory.create(of_type=name, **kwargs) def _get_assessment(name, **kwargs): - return BenchmarkAssessment(of_type=name, **kwargs) + return bench_assessment_factory.create(of_type=name, **kwargs) def _resolve_db_config(db_config): diff --git a/src/orion/benchmark/task/__init__.py b/src/orion/benchmark/task/__init__.py index 4c2b7f05c..4426b2e31 100644 --- a/src/orion/benchmark/task/__init__.py +++ b/src/orion/benchmark/task/__init__.py @@ -3,10 +3,17 @@ =========================== """ -from .base import BaseTask +from .base import BenchmarkTask, bench_task_factory from .branin import Branin from .carromtable import CarromTable from .eggholder import EggHolder from .rosenbrock import RosenBrock -__all__ = ["BaseTask", "RosenBrock", "Branin", "CarromTable", "EggHolder"] +__all__ = [ + "BenchmarkTask", + "RosenBrock", + "Branin", + "CarromTable", + "EggHolder", + "bench_task_factory", +] diff --git a/src/orion/benchmark/task/base.py b/src/orion/benchmark/task/base.py index 7ae4f432a..167bd45fe 100644 --- a/src/orion/benchmark/task/base.py +++ b/src/orion/benchmark/task/base.py @@ -7,10 +7,10 @@ from abc import ABC, abstractmethod -from orion.core.utils import Factory +from orion.core.utils import GenericFactory -class BaseTask(ABC): +class BenchmarkTask(ABC): """Base class describing what a task can do. A task will define the objective function and search space of it. @@ -59,8 +59,4 @@ def configuration(self): return {self.__class__.__qualname__: self._param_names} -# pylint: disable=too-few-public-methods,abstract-method -class BenchmarkTask(BaseTask, metaclass=Factory): - """Class used to inject dependency on an task implementation.""" - - pass +bench_task_factory = GenericFactory(BenchmarkTask) diff --git a/src/orion/benchmark/task/branin.py b/src/orion/benchmark/task/branin.py index a57e40a50..93953b274 100644 --- a/src/orion/benchmark/task/branin.py +++ b/src/orion/benchmark/task/branin.py @@ -9,10 +9,10 @@ import numpy -from orion.benchmark.task.base import BaseTask +from orion.benchmark.task.base import BenchmarkTask -class Branin(BaseTask): +class Branin(BenchmarkTask): """`Branin function `_ as benchmark task """ diff --git a/src/orion/benchmark/task/carromtable.py b/src/orion/benchmark/task/carromtable.py index 26280514e..b26c3e94c 100644 --- a/src/orion/benchmark/task/carromtable.py +++ b/src/orion/benchmark/task/carromtable.py @@ -6,10 +6,10 @@ """ import numpy -from orion.benchmark.task.base import BaseTask +from orion.benchmark.task.base import BenchmarkTask -class CarromTable(BaseTask): +class CarromTable(BenchmarkTask): """`CarromTable function `_ as benchmark task""" diff --git a/src/orion/benchmark/task/eggholder.py b/src/orion/benchmark/task/eggholder.py index f3544e534..d80753f3c 100644 --- a/src/orion/benchmark/task/eggholder.py +++ b/src/orion/benchmark/task/eggholder.py @@ -6,10 +6,10 @@ """ import numpy -from orion.benchmark.task.base import BaseTask +from orion.benchmark.task.base import BenchmarkTask -class EggHolder(BaseTask): +class EggHolder(BenchmarkTask): """`EggHolder function `_ as benchmark task""" diff --git a/src/orion/benchmark/task/rosenbrock.py b/src/orion/benchmark/task/rosenbrock.py index f4edb0e39..47d4b6f1e 100644 --- a/src/orion/benchmark/task/rosenbrock.py +++ b/src/orion/benchmark/task/rosenbrock.py @@ -6,10 +6,10 @@ """ import numpy -from orion.benchmark.task.base import BaseTask +from orion.benchmark.task.base import BenchmarkTask -class RosenBrock(BaseTask): +class RosenBrock(BenchmarkTask): """`RosenBrock function `_ as benchmark task""" diff --git a/src/orion/client/__init__.py b/src/orion/client/__init__.py index a2a7ebd1a..f1e49cf64 100644 --- a/src/orion/client/__init__.py +++ b/src/orion/client/__init__.py @@ -177,7 +177,7 @@ def build_experiment( config_change_type: str, optional How to resolve config change automatically. Must be one of 'noeffect', 'unsure' or 'break'. Defaults to 'break'. - executor: `orion.executor.base.Executor`, optional + executor: `orion.executor.base.BaseExecutor`, optional Executor to run the experiment Raises diff --git a/src/orion/client/experiment.py b/src/orion/client/experiment.py index 8e2180d11..05ab33bb1 100644 --- a/src/orion/client/experiment.py +++ b/src/orion/client/experiment.py @@ -25,7 +25,7 @@ from orion.core.utils.flatten import flatten, unflatten from orion.core.worker.trial import Trial, TrialCM from orion.core.worker.trial_pacemaker import TrialPacemaker -from orion.executor.base import Executor +from orion.executor.base import executor_factory from orion.plotting.base import PlotAccessor from orion.storage.base import FailedUpdate @@ -81,7 +81,7 @@ def __init__(self, experiment, producer, executor=None, heartbeat=None): if heartbeat is None: heartbeat = orion.core.config.worker.heartbeat self.heartbeat = heartbeat - self.executor = executor or Executor( + self.executor = executor or executor_factory.create( orion.core.config.worker.executor, n_workers=orion.core.config.worker.n_workers, **orion.core.config.worker.executor_configuration, @@ -628,15 +628,15 @@ def tmp_executor(self, executor, **config): Parameters ---------- - executor: str or :class:`orion.executor.base.Executor` + executor: str or :class:`orion.executor.base.BaseExecutor` The executor to use. If it is a ``str``, the provided ``config`` will be used - to create the executor with ``Executor(executor, **config)``. + to create the executor with ``executor_factory.create(executor, **config)``. **config: Configuration to use if ``executor`` is a ``str``. """ if isinstance(executor, str): - executor = Executor(executor, **config) + executor = executor_factory.create(executor, **config) old_executor = self.executor self.executor = executor with executor: diff --git a/src/orion/core/cli/checks/creation.py b/src/orion/core/cli/checks/creation.py index ca09f4e70..b4cde140d 100644 --- a/src/orion/core/cli/checks/creation.py +++ b/src/orion/core/cli/checks/creation.py @@ -8,7 +8,7 @@ """ -from orion.core.io.database import Database +from orion.core.io.database import database_factory from orion.core.utils.exceptions import CheckError @@ -37,7 +37,7 @@ def check_database_creation(self): db_type = database.pop("type") try: - db = Database(of_type=db_type, **database) + db = database_factory.create(db_type, **database) except ValueError as ex: raise CheckError(str(ex)) diff --git a/src/orion/core/cli/db/setup.py b/src/orion/core/cli/db/setup.py index c6dbeccbb..02cd30592 100644 --- a/src/orion/core/cli/db/setup.py +++ b/src/orion/core/cli/db/setup.py @@ -14,7 +14,7 @@ import yaml import orion.core -from orion.core.io.database import Database +from orion.core.io.database import database_factory from orion.core.utils.terminal import ask_question log = logging.getLogger(__name__) @@ -57,12 +57,12 @@ def main(*args): # Get database type. _type = ask_question( "Enter the database", - choice=sorted(Database.types.keys()), + choice=sorted(database_factory.get_classes().keys()), default="mongodb", ignore_case=True, ).lower() # Get database arguments. - db_class = Database.types[_type] + db_class = database_factory.get_classes()[_type] db_args = db_class.get_defaults() arg_vals = {} for arg_name, default_value in sorted(db_args.items()): diff --git a/src/orion/core/evc/adapters.py b/src/orion/core/evc/adapters.py index 86542f718..e34fdc8d7 100644 --- a/src/orion/core/evc/adapters.py +++ b/src/orion/core/evc/adapters.py @@ -36,7 +36,7 @@ from orion.algo.space import Dimension from orion.core.io.space_builder import DimensionBuilder -from orion.core.utils import Factory +from orion.core.utils import GenericFactory from orion.core.worker.trial import Trial log = logging.getLogger(__name__) @@ -45,6 +45,31 @@ class BaseAdapter(object, metaclass=ABCMeta): """Base class describing what an adapter can do.""" + @classmethod + def build(cls, adapter_dicts): + """Builder method for a list of adapters. + + Parameters + ---------- + adapter_dicts: list of `dict` + List of adapter representation in dictionary form as expected to be saved in a database. + + Returns + ------- + `orion.core.evc.adapters.CompositeAdapter` + An adapter which may contain many adapters + + """ + adapters = [] + for adapter_dict in adapter_dicts: + if isinstance(adapter_dict, (list, tuple)): + adapter = BaseAdapter.build(adapter_dict) + else: + adapter = adapter_factory.create(**adapter_dict) + adapters.append(adapter) + + return CompositeAdapter(*adapters) + @abstractmethod def forward(self, trials): """Adapt trials of the parent experiment such that they are compatible to the child @@ -906,34 +931,4 @@ def to_dict(self): return ret -# pylint: disable=too-few-public-methods,abstract-method -class Adapter(BaseAdapter, metaclass=Factory): - """Class used to inject dependency on an adapter implementation. - - .. seealso:: `orion.core.utils.Factory` metaclass and `BaseAlgorithm` interface. - """ - - @classmethod - def build(cls, adapter_dicts): - """Builder method for a list of adapters. - - Parameters - ---------- - adapter_dicts: list of `dict` - List of adapter representation in dictionary form as expected to be saved in a database. - - Returns - ------- - `orion.core.evc.adapters.CompositeAdapter` - An adapter which may contain many adapters - - """ - adapters = [] - for adapter_dict in adapter_dicts: - if isinstance(adapter_dict, (list, tuple)): - adapter = Adapter.build(adapter_dict) - else: - adapter = cls(**adapter_dict) - adapters.append(adapter) - - return CompositeAdapter(*adapters) +adapter_factory = GenericFactory(BaseAdapter) diff --git a/src/orion/core/io/convert.py b/src/orion/core/io/convert.py index 4a195aa42..08b9ea091 100644 --- a/src/orion/core/io/convert.py +++ b/src/orion/core/io/convert.py @@ -6,7 +6,7 @@ Defines and instantiates a converter for configuration file types. Given a file path infer which configuration file parser/emitter it corresponds to. -Define `Converter` classes with a common interface for many popular configuration +Define `BaseConverter` classes with a common interface for many popular configuration file types. Currently supported: @@ -23,7 +23,7 @@ from abc import ABC, abstractmethod from collections import deque -from orion.core.utils import Factory, nesteddict +from orion.core.utils import GenericFactory, nesteddict def infer_converter_from_file_type(config_path, regex=None, default_keyword=""): @@ -31,7 +31,7 @@ def infer_converter_from_file_type(config_path, regex=None, default_keyword=""): converter. """ _, ext_type = os.path.splitext(os.path.abspath(config_path)) - for klass in Converter.types.values(): + for klass in config_converter_factory.get_classes().values(): if ext_type in klass.file_extensions: return klass() @@ -145,7 +145,7 @@ class GenericConverter(BaseConverter): generic text parser, semantics are going to be tied to their consequent usage. A template document is going to be created on `parse` and filled with values on `read`. This template document consists the state of this - `Converter` object. + `BaseConverter` object. Dimension should be defined for instance as: ``meaningful_name~uniform(0, 4)`` @@ -228,7 +228,7 @@ def parse(self, filepath): ) self.template = substituted - # Wrap it in style of what the rest of `Converter`s return + # Wrap it in style of what the rest of `BaseConverter`s return ret_nested = nesteddict() for namespace, expression in ret.items(): keys = namespace.split("/") @@ -271,19 +271,10 @@ def generate(self, filepath, data): name = namespace[0] unnested_data[self.has_leading.get(name, "") + name] = stuff - print(self.template) - print(unnested_data) document = self.template.format(**unnested_data) with open(filepath, "w") as f: f.write(document) -# pylint: disable=too-few-public-methods,abstract-method -class Converter(BaseConverter, metaclass=Factory): - """Class used to inject dependency on a configuration file parser/generator. - - .. seealso:: :class:`orion.core.utils.Factory` metaclass and `BaseConverter` interface. - """ - - pass +config_converter_factory = GenericFactory(BaseConverter) diff --git a/src/orion/core/io/database/__init__.py b/src/orion/core/io/database/__init__.py index f99c6d7d3..b64e16d9a 100644 --- a/src/orion/core/io/database/__init__.py +++ b/src/orion/core/io/database/__init__.py @@ -3,20 +3,21 @@ Wrappers for database frameworks ================================ -Contains :class:`AbstractDB`, an interface for databases. -Currently, implemented wrappers: +Contains :class:`Database`, an interface for databases. - - :class:`orion.core.io.database.mongodb.MongoDB` +Database objects can be created using ``database_factory.create()``. +See :py:class:`orion.core.utils.GenericFactory` for more information on the factory. """ +# :obj:`database_factory`. import logging from abc import abstractmethod, abstractproperty -from orion.core.utils.singleton import AbstractSingletonType, SingletonFactory +from orion.core.utils.singleton import GenericSingletonFactory # pylint: disable=too-many-public-methods -class AbstractDB(object, metaclass=AbstractSingletonType): +class Database(object): """Base class for database framework wrappers. Attributes @@ -42,7 +43,7 @@ class AbstractDB(object, metaclass=AbstractSingletonType): def __init__( self, host=None, name=None, port=None, username=None, password=None, **kwargs ): - """Init method, see attributes of :class:`AbstractDB`.""" + """Init method, see attributes of :class:`Database`.""" defaults = self.get_defaults() host = defaults.get("host", None) if host is None or host == "" else host name = defaults.get("name", None) if name is None or name == "" else name @@ -65,7 +66,7 @@ def is_connected(self): @abstractmethod def initiate_connection(self): - """Connect to database, unless `AbstractDB` `is_connected`. + """Connect to database, unless `Database` `is_connected`. Raises ------ @@ -77,7 +78,7 @@ def initiate_connection(self): @abstractmethod def close_connection(self): - """Disconnect from database, if `AbstractDB` `is_connected`.""" + """Disconnect from database, if `Database` `is_connected`.""" pass @abstractmethod @@ -91,8 +92,8 @@ def ensure_index(self, collection_name, keys, unique=False): keys: str or list of tuples Can be a string representing a key to index, or a list of tuples with the structure `[(key_name, sort_order)]`. `key_name` must be a - string and sort_order can be either ``AbstractDB.ASCENDING`` or - ``AbstractDB.DESCENDING``. + string and sort_order can be either ``Database.ASCENDING`` or + ``Database.DESCENDING``. unique: bool, optional Ensure each document have a different key value. If not, operations like `write()` and `read_and_write()` will raise @@ -170,7 +171,7 @@ def write(self, collection_name, data, query=None): ------ DuplicateKeyError If the operation is creating duplicate keys in two different documents. Only occurs if - the keys have unique indexes. See :meth:`AbstractDB.ensure_index` for more information + the keys have unique indexes. See :meth:`Database.ensure_index` for more information about indexes. """ @@ -225,7 +226,7 @@ def read_and_write(self, collection_name, query, data, selection=None): ------ DuplicateKeyError If the operation is creating duplicate keys in two different documents. Only occurs if - the keys have unique indexes. See :meth:`AbstractDB.ensure_index` for more information + the keys have unique indexes. See :meth:`Database.ensure_index` for more information about indexes. """ @@ -297,7 +298,7 @@ class ReadOnlyDB(object): ) def __init__(self, database): - """Init method, see attributes of :class:`AbstractDB`.""" + """Init method, see attributes of :class:`Database`.""" self._database = database def __getattr__(self, attr): @@ -338,11 +339,7 @@ class OutdatedDatabaseError(DatabaseError): pass -# pylint: disable=too-few-public-methods,abstract-method -class Database(AbstractDB, metaclass=SingletonFactory): - """Class used to inject dependency on a database framework.""" - - pass +database_factory = GenericSingletonFactory(Database) # set per-module log level diff --git a/src/orion/core/io/database/ephemeraldb.py b/src/orion/core/io/database/ephemeraldb.py index f8d20bae2..347dc7209 100644 --- a/src/orion/core/io/database/ephemeraldb.py +++ b/src/orion/core/io/database/ephemeraldb.py @@ -3,13 +3,13 @@ Non permanent database ====================== -Implement non permanent version of :class:`orion.core.io.database.AbstractDB` +Implement non permanent version of :class:`orion.core.io.database.Database` """ import copy from collections import defaultdict -from orion.core.io.database import AbstractDB, DatabaseError, DuplicateKeyError +from orion.core.io.database import Database, DatabaseError, DuplicateKeyError from orion.core.utils.flatten import flatten, unflatten @@ -25,13 +25,13 @@ def _convert_keys_to_name(keys): # pylint: disable=too-many-public-methods -class EphemeralDB(AbstractDB): +class EphemeralDB(Database): """Non permanent database This database is meant for debugging purposes. It only lives through one execution and all information saved during it is lost when the process is terminated. - .. seealso:: :class:`orion.core.io.database.AbstractDB` for more on attributes. + .. seealso:: :class:`orion.core.io.database.Database` for more on attributes. """ @@ -66,7 +66,7 @@ def drop_index(self, collection_name, name): def write(self, collection_name, data, query=None): """Write new information to a collection. Perform insert or update. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.write` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.write` for argument documentation. """ dbcollection = self._db[collection_name] @@ -85,7 +85,7 @@ def write(self, collection_name, data, query=None): def read(self, collection_name, query=None, selection=None): """Read a collection and return a value according to the query. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.read` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.read` for argument documentation. """ dbcollection = self._db[collection_name] @@ -99,7 +99,7 @@ def read_and_write(self, collection_name, query, data, selection=None): Returns the updated document, or None if nothing found. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.read_and_write` for + .. seealso:: :meth:`orion.core.io.database.Database.read_and_write` for argument documentation. """ @@ -114,7 +114,7 @@ def read_and_write(self, collection_name, query, data, selection=None): def count(self, collection_name, query=None): """Count the number of documents in a collection which match the `query`. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.count` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.count` for argument documentation. """ dbcollection = self._db[collection_name] @@ -123,7 +123,7 @@ def count(self, collection_name, query=None): def remove(self, collection_name, query): """Delete from a collection document[s] which match the `query`. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.remove` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.remove` for argument documentation. """ dbcollection = self._db[collection_name] @@ -134,7 +134,7 @@ def remove(self, collection_name, query): def get_defaults(cls): """Get database arguments needed to create a database instance. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.get_defaults` + .. seealso:: :meth:`orion.core.io.database.Database.get_defaults` for argument documentation. """ @@ -201,7 +201,7 @@ def _register_keys(self, document): def find(self, query=None, selection=None): """Find documents in the collection and return a value according to the query. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.read` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.read` for argument documentation. """ found_documents = [] @@ -296,7 +296,7 @@ def _upsert(self, query, update): def count(self, query=None): """Count the number of documents in a collection which match the `query`. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.count` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.count` for argument documentation. """ return len(self.find(query)) @@ -304,7 +304,7 @@ def count(self, query=None): def delete_many(self, query=None): """Delete from a collection document[s] which match the `query`. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.remove` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.remove` for argument documentation. """ deleted = 0 diff --git a/src/orion/core/io/database/mongodb.py b/src/orion/core/io/database/mongodb.py index 223ad27f9..0cf3d2abd 100644 --- a/src/orion/core/io/database/mongodb.py +++ b/src/orion/core/io/database/mongodb.py @@ -8,7 +8,7 @@ import pymongo from orion.core.io.database import ( - AbstractDB, + Database, DatabaseError, DatabaseTimeout, DuplicateKeyError, @@ -76,7 +76,7 @@ def _decorator(self, *args, **kwargs): # pylint: disable=too-many-public-methods -class MongoDB(AbstractDB): +class MongoDB(Database): """Wrap MongoDB with three primary methods `read`, `write`, `remove`. Attributes @@ -88,7 +88,7 @@ class MongoDB(AbstractDB): Information on MongoDB `connection string `_. - .. seealso:: :class:`orion.core.io.database.AbstractDB` for more on attributes. + .. seealso:: :class:`orion.core.io.database.Database` for more on attributes. """ @@ -101,7 +101,7 @@ def __init__( password=None, serverSelectionTimeoutMS=5000, ): - """Init method, see attributes of :class:`AbstractDB`.""" + """Init method, see attributes of :class:`Database`.""" if host == "": host = "localhost" self.uri = None @@ -186,7 +186,7 @@ def close_connection(self): def ensure_index(self, collection_name, keys, unique=False): """Create given indexes if they do not already exist in database. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.ensure_index` for argument + .. seealso:: :meth:`orion.core.io.database.Database.ensure_index` for argument documentation. """ @@ -225,7 +225,7 @@ def _convert_index_keys(self, keys): return converted_keys def _convert_sort_order(self, sort_order): - """Convert generic `AbstractDB` sort orders to MongoDB ones.""" + """Convert generic `Database` sort orders to MongoDB ones.""" if sort_order is self.ASCENDING: return pymongo.ASCENDING elif sort_order is self.DESCENDING: @@ -237,7 +237,7 @@ def _convert_sort_order(self, sort_order): def write(self, collection_name, data, query=None): """Write new information to a collection. Perform insert or update. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.write` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.write` for argument documentation. """ dbcollection = self._db[collection_name] @@ -261,7 +261,7 @@ def write(self, collection_name, data, query=None): def read(self, collection_name, query=None, selection=None): """Read a collection and return a value according to the query. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.read` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.read` for argument documentation. """ dbcollection = self._db[collection_name] @@ -277,7 +277,7 @@ def read_and_write(self, collection_name, query, data, selection=None): Returns the updated document, or None if nothing found. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.read_and_write` for + .. seealso:: :meth:`orion.core.io.database.Database.read_and_write` for argument documentation. """ @@ -297,7 +297,7 @@ def read_and_write(self, collection_name, query, data, selection=None): def count(self, collection_name, query=None): """Count the number of documents in a collection which match the `query`. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.count` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.count` for argument documentation. """ dbcollection = self._db[collection_name] @@ -311,7 +311,7 @@ def count(self, collection_name, query=None): def remove(self, collection_name, query): """Delete from a collection document[s] which match the `query`. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.remove` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.remove` for argument documentation. """ dbcollection = self._db[collection_name] @@ -350,7 +350,7 @@ def _sanitize_attrs(self): def get_defaults(cls): """Get database arguments needed to create a database instance. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.get_defaults` + .. seealso:: :meth:`orion.core.io.database.Database.get_defaults` for argument documentation. """ diff --git a/src/orion/core/io/database/pickleddb.py b/src/orion/core/io/database/pickleddb.py index 4a3835616..0b4593e77 100644 --- a/src/orion/core/io/database/pickleddb.py +++ b/src/orion/core/io/database/pickleddb.py @@ -17,7 +17,7 @@ from filelock import FileLock, SoftFileLock, Timeout import orion.core -from orion.core.io.database import AbstractDB, DatabaseTimeout +from orion.core.io.database import Database, DatabaseTimeout from orion.core.io.database.ephemeraldb import EphemeralDB log = logging.getLogger(__name__) @@ -78,7 +78,7 @@ def find_unpickable_field(doc): # pylint: disable=too-many-public-methods -class PickledDB(AbstractDB): +class PickledDB(Database): """Pickled EphemeralDB to support permanancy and concurrency This is a very simple and inefficient implementation of a permanent database on disk for OrĂ­on. @@ -143,7 +143,7 @@ def drop_index(self, collection_name, name): def write(self, collection_name, data, query=None): """Write new information to a collection. Perform insert or update. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.write` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.write` for argument documentation. """ with self.locked_database() as database: @@ -152,7 +152,7 @@ def write(self, collection_name, data, query=None): def read(self, collection_name, query=None, selection=None): """Read a collection and return a value according to the query. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.read` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.read` for argument documentation. """ with self.locked_database(write=False) as database: @@ -163,7 +163,7 @@ def read_and_write(self, collection_name, query, data, selection=None): Returns the updated document, or None if nothing found. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.read_and_write` for + .. seealso:: :meth:`orion.core.io.database.Database.read_and_write` for argument documentation. """ @@ -175,7 +175,7 @@ def read_and_write(self, collection_name, query, data, selection=None): def count(self, collection_name, query=None): """Count the number of documents in a collection which match the `query`. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.count` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.count` for argument documentation. """ with self.locked_database(write=False) as database: @@ -184,7 +184,7 @@ def count(self, collection_name, query=None): def remove(self, collection_name, query): """Delete from a collection document[s] which match the `query`. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.remove` for argument documentation. + .. seealso:: :meth:`orion.core.io.database.Database.remove` for argument documentation. """ with self.locked_database() as database: @@ -247,7 +247,7 @@ def locked_database(self, write=True): def get_defaults(cls): """Get database arguments needed to create a database instance. - .. seealso:: :meth:`orion.core.io.database.AbstractDB.get_defaults` + .. seealso:: :meth:`orion.core.io.database.Database.get_defaults` for argument documentation. """ diff --git a/src/orion/core/io/experiment_builder.py b/src/orion/core/io/experiment_builder.py index e4ffcb9db..95e20ab5f 100644 --- a/src/orion/core/io/experiment_builder.py +++ b/src/orion/core/io/experiment_builder.py @@ -82,8 +82,9 @@ import orion.core import orion.core.utils.backward as backward +from orion.algo.base import algo_factory from orion.algo.space import Space -from orion.core.evc.adapters import Adapter +from orion.core.evc.adapters import BaseAdapter from orion.core.evc.conflicts import ExperimentNameConflict, detect_conflicts from orion.core.io import resolve_config from orion.core.io.database import DuplicateKeyError @@ -97,8 +98,8 @@ RaceCondition, ) from orion.core.worker.experiment import Experiment -from orion.core.worker.primary_algo import PrimaryAlgo -from orion.core.worker.strategy import Strategy +from orion.core.worker.primary_algo import SpaceTransformAlgoWrapper +from orion.core.worker.strategy import strategy_factory from orion.storage.base import get_storage, setup_storage log = logging.getLogger(__name__) @@ -465,7 +466,7 @@ def _instantiate_adapters(config): List of adapter configurations to build a CompositeAdapter for the EVC. """ - return Adapter.build(config) + return BaseAdapter.build(config) def _instantiate_space(config): @@ -495,14 +496,18 @@ def _instantiate_algo(space, max_trials, config=None, ignore_unavailable=False): (orion.core.config.experiment.algorithms). ignore_unavailable: bool, optional If True and algorithm is not available (plugin not installed), return the configuration. - Otherwise, raise Factory error from PrimaryAlgo + Otherwise, raise Factory error. """ if not config: config = orion.core.config.experiment.algorithms try: - algo = PrimaryAlgo(space, config) + backported_config = backward.port_algo_config(config) + algo_constructor = algo_factory.get_class(backported_config.pop("of_type")) + algo = SpaceTransformAlgoWrapper( + algo_constructor, space=space, **backported_config + ) algo.algorithm.max_trials = max_trials except NotImplementedError as e: if not ignore_unavailable: @@ -524,7 +529,7 @@ def _instantiate_strategy(config=None, ignore_unavailable=False): (orion.core.config.producer.strategy). ignore_unavailable: bool, optional If True and algorithm is not available (plugin not installed), return the configuration. - Otherwise, raise Factory error from PrimaryAlgo + Otherwise, raise Factory error. """ @@ -539,7 +544,7 @@ def _instantiate_strategy(config=None, ignore_unavailable=False): strategy_type, config = next(iter(config.items())) try: - strategy = Strategy(of_type=strategy_type, **config) + strategy = strategy_factory.create(strategy_type, **config) except NotImplementedError as e: if not ignore_unavailable: raise e diff --git a/src/orion/core/utils/__init__.py b/src/orion/core/utils/__init__.py index b4697f168..886a86511 100644 --- a/src/orion/core/utils/__init__.py +++ b/src/orion/core/utils/__init__.py @@ -64,18 +64,8 @@ def get_all_types(parent_cls, cls_name): def _import_modules(cls): cls.modules = [] - base = import_module(cls.__base__.__module__) - try: - py_files = glob(os.path.abspath(os.path.join(base.__path__[0], "[A-Za-z]*.py"))) - py_mods = map( - lambda x: "." + os.path.split(os.path.splitext(x)[0])[1], py_files - ) - for py_mod in py_mods: - cls.modules.append(import_module(py_mod, package=cls.__base__.__module__)) - except AttributeError: - # This means that base class and implementations reside in a module - # itself and not a subpackage. - pass + # TODO: remove? + # base = import_module(cls.__base__.__module__) # Get types advertised through entry points! for entry_point in pkg_resources.iter_entry_points(cls.__name__): @@ -96,19 +86,77 @@ def _set_typenames(cls): log.debug("Implementations found: %s", sorted(cls.types.keys())) -class Factory(ABCMeta): - """Instantiate appropriate wrapper for the infrastructure based on input - argument, ``of_type``. +class GenericFactory: + """Factory to create instances of classes inheriting a given ``base`` class. + + The factory can instantiate children of the base class at any level of inheritance. + The children class must have different names (capitalization insensitive). To instantiate + objects with the factory, use ``factory.create('name_of_the_children_class')`` passing the name + of the children class to instantiate. + + To support classes even when they are not imported, register them in the ``entry_points`` + of the package's ``setup.py``. The factory will import all registered classes in the + entry_points before looking for available children to create new objects. - Attributes + Parameters ---------- - types : dict of subclasses of ``cls.__base__`` - Updated to contain all possible implementations currently. Check out code. + base: class + Base class of all children that the factory can instantiate. """ + def __init__(self, base): + self.base = base + + def create(self, of_type, *args, **kwargs): + """Create an object, instance of ``self.base`` + + Parameters + ---------- + of_type: str + Name of class, subclass of ``self.base``. Capitalization insensitive + + args: * + Positional arguments to construct the givin class. + + kwargs: ** + Keyword arguments to construct the givin class. + """ + + constructor = self.get_class(of_type) + return constructor(*args, **kwargs) + + def get_class(self, of_type): + """Get the class object (not instantiated) + + Parameters + ---------- + of_type: str + Name of class, subclass of ``self.base``. Capitalization insensitive + """ + of_type = of_type.lower() + constructors = self.get_classes() + + if of_type not in constructors: + error = "Could not find implementation of {0}, type = '{1}'".format( + self.base.__name__, of_type + ) + error += "\nCurrently, there is an implementation for types:\n" + error += str(sorted(constructors.keys())) + raise NotImplementedError(error) + + return constructors[of_type] + + def get_classes(self): + """Get children classes of ``self.base``""" + _import_modules(self.base) + return get_all_types(self.base, self.base.__name__) + + +class Factory(ABCMeta): + """Deprecated, will be removed in v0.3.0. See GenericFactory instead""" + def __init__(cls, names, bases, dictionary): - """Search in directory for attribute names subclassing `bases[0]`""" super(Factory, cls).__init__(names, bases, dictionary) cls.types = {} try: @@ -118,24 +166,7 @@ def __init__(cls, names, bases, dictionary): _set_typenames(cls) def __call__(cls, of_type, *args, **kwargs): - """Create an object, instance of ``cls.__base__``, on first call. - - :param of_type: Name of class, subclass of ``cls.__base__``, wrapper - of a database framework that will be instantiated on the first call. - :param args: positional arguments to initialize ``cls.__base__``'s instance (if any) - :param kwargs: keyword arguments to initialize ``cls.__base__``'s instance (if any) - - .. seealso:: - `Factory.types` keys for values of argument `of_type`. - - .. seealso:: - Attributes of ``cls.__base__`` and ``cls.__base__.__init__`` for - values of `args` and `kwargs`. - - .. note:: New object is saved as `Factory`'s internal state. - - :return: The object which was created on the first call. - """ + """Create an object, instance of ``cls.__base__``, on first call.""" _import_modules(cls) _set_typenames(cls) diff --git a/src/orion/core/utils/backward.py b/src/orion/core/utils/backward.py index 999e198ed..54c8d4a36 100644 --- a/src/orion/core/utils/backward.py +++ b/src/orion/core/utils/backward.py @@ -147,3 +147,27 @@ def get_algo_requirements(algorithm): shape_requirement=algorithm.requires_shape, dist_requirement=algorithm.requires_dist, ) + + +def port_algo_config(config): + """Convert algorithm configuration to be compliant with factory interface + + Examples + -------- + >>> port_algo_config('algo_name') + {'of_type': 'algo_name'} + >>> port_algo_config({'algo_name': {'some': 'args'}}) + {'of_type': 'algo_name', 'some': 'args'} + >>> port_algo_config({'of_type': 'algo_name', 'some': 'args'}) + {'of_type': 'algo_name', 'some': 'args'} + + """ + config = copy.deepcopy(config) + if isinstance(config, dict) and len(config) == 1: + algo_name, algo_config = next(iter(config.items())) + config = algo_config + config["of_type"] = algo_name + elif isinstance(config, str): + config = {"of_type": config} + + return config diff --git a/src/orion/core/utils/singleton.py b/src/orion/core/utils/singleton.py index 794a4b58f..a53aa4451 100644 --- a/src/orion/core/utils/singleton.py +++ b/src/orion/core/utils/singleton.py @@ -6,7 +6,7 @@ """ from abc import ABCMeta -from orion.core.utils import Factory +from orion.core.utils import Factory, GenericFactory class SingletonAlreadyInstantiatedError(ValueError): @@ -74,19 +74,75 @@ def update_singletons(values=None): values = {} # Avoiding circular import problems when importing this module. - from orion.core.io.database import Database - from orion.core.io.database.ephemeraldb import EphemeralDB - from orion.core.io.database.mongodb import MongoDB - from orion.core.io.database.pickleddb import PickledDB - from orion.storage.base import Storage - from orion.storage.legacy import Legacy - from orion.storage.track import Track + from orion.core.io.database import database_factory + from orion.storage.base import storage_factory - singletons = (Storage, Legacy, Database, MongoDB, PickledDB, EphemeralDB, Track) + singletons = (storage_factory, database_factory) updated_singletons = {} for singleton in singletons: - updated_singletons[singleton] = singleton.instance - singleton.instance = values.get(singleton, None) + name = singleton.base.__name__.lower() + updated_singletons[name] = singleton.instance + singleton.instance = values.get(name, None) return updated_singletons + + +class GenericSingletonFactory(GenericFactory): + """Factory to create singleton instances of classes inheriting a given ``base`` class. + + .. seealso:: + + :py:class:`orion.core.utils.GenericFactory` + + """ + + def __init__(self, base): + super(GenericSingletonFactory, self).__init__(base=base) + self.instance = None + + def create(self, of_type=None, *args, **kwargs): + """Create an object, instance of ``self.base`` + + If the instance is already created, ``self.create`` can only be called without arguments + and will return the singleton. + + Cannot be called without arguments if the singleton was not already created. + + Parameters + ---------- + of_type: str, optional + Name of class, subclass of ``self.base``. Capitalization insensitive. + + args: * + Positional arguments to construct the givin class. + + kwargs: ** + Keyword arguments to construct the givin class. + + Raises + ------ + `SingletonNotInstantiatedError` + - If ``self.create()`` was never called and is called without arguments for the first + time. + - If ``self.create()`` was never called and the current call raises an error. + `SingletonAlreadyInstantiatedError` + If ``self.create()`` was already called with arguments (the singleton exist) and + is called again with arguments. + + """ + + if self.instance is None and of_type is None: + raise SingletonNotInstantiatedError(self.base.__name__) + elif self.instance is None: + try: + self.instance = super(GenericSingletonFactory, self).create( + of_type, *args, **kwargs + ) + except TypeError as exception: + raise SingletonNotInstantiatedError(self.base.__name__) from exception + + elif of_type or args or kwargs: + raise SingletonAlreadyInstantiatedError(self.base.__name__) + + return self.instance diff --git a/src/orion/core/worker/experiment.py b/src/orion/core/worker/experiment.py index 1dff1852b..62d90b0e8 100644 --- a/src/orion/core/worker/experiment.py +++ b/src/orion/core/worker/experiment.py @@ -57,7 +57,7 @@ class Experiment: it will overwrite the previous one. space: Space Object representing the optimization space. - algorithms : `PrimaryAlgo` object. + algorithms : `BaseAlgorithm` object or a wrapper. Complete specification of the optimization and dynamical procedures taking place in this `Experiment`. diff --git a/src/orion/core/worker/primary_algo.py b/src/orion/core/worker/primary_algo.py index 91e00c9c1..126314131 100644 --- a/src/orion/core/worker/primary_algo.py +++ b/src/orion/core/worker/primary_algo.py @@ -7,12 +7,11 @@ """ import orion.core.utils.backward as backward -from orion.algo.base import BaseAlgorithm from orion.core.worker.transformer import build_required_space # pylint: disable=too-many-public-methods -class PrimaryAlgo(BaseAlgorithm): +class SpaceTransformAlgoWrapper: """Perform checks on points and transformations. Wrap the primary algorithm. 1. Checks requirements on the parameter space from algorithms and create the @@ -20,24 +19,22 @@ class PrimaryAlgo(BaseAlgorithm): of the primary algorithm. 2. Checks whether incoming and outcoming points are compliant with a space. - """ - - def __init__(self, space, algorithm_config): - """ - Initialize the primary algorithm. + Parameters + ---------- + algo_constructor: Child class of `BaseAlgorithm` + Class constructor to build the algorithm object. + space : `orion.algo.space.Space` + The original definition of a problem's parameters space. + algorithm_config : dict + Configuration for the algorithm. - Parameters - ---------- - space : `orion.algo.space.Space` - The original definition of a problem's parameters space. - algorithm_config : dict - Configuration for the algorithm. + """ - """ - self.algorithm = None - super(PrimaryAlgo, self).__init__(space, algorithm=algorithm_config) - requirements = backward.get_algo_requirements(self.algorithm) - self.transformed_space = build_required_space(self.space, **requirements) + def __init__(self, algo_constructor, space, **algorithm_config): + self._space = space + requirements = backward.get_algo_requirements(algo_constructor) + self.transformed_space = build_required_space(space, **requirements) + self.algorithm = algo_constructor(space, **algorithm_config) self.algorithm.space = self.transformed_space def seed_rng(self, seed): @@ -175,3 +172,17 @@ def space(self): .. note:: Redefining property here without setter, denies base class' setter. """ return self._space + + def get_id(self, point, ignore_fidelity=False): + """Compute a unique hash for a point based on params""" + return self.algorithm.get_id( + self.transformed_space.transform(point), ignore_fidelity=ignore_fidelity + ) + + @property + def fidelity_index(self): + """Compute the index of the point where fidelity is. + + Returns None if there is no fidelity dimension. + """ + return self.algorithm.fidelity_index diff --git a/src/orion/core/worker/producer.py b/src/orion/core/worker/producer.py index b03c9940a..5a4f38cf3 100644 --- a/src/orion/core/worker/producer.py +++ b/src/orion/core/worker/producer.py @@ -50,7 +50,7 @@ def __init__(self, experiment, max_idle_time=None): self.max_idle_time = max_idle_time self.strategy = experiment.producer["strategy"] self.naive_algorithm = None - # TODO: Move trials_history into PrimaryAlgo during the refactoring of Algorithm with + # TODO: Move trials_history into BaseAlgorithm during the refactoring of Algorithm with # Strategist and Scheduler. self.trials_history = TrialsHistory() self.params_hashes = set() diff --git a/src/orion/core/worker/strategy.py b/src/orion/core/worker/strategy.py index 7a9df2690..d74dc94fd 100644 --- a/src/orion/core/worker/strategy.py +++ b/src/orion/core/worker/strategy.py @@ -3,13 +3,14 @@ Parallel Strategies =================== -Register objectives for incomplete trials +Register objectives for incomplete trials. + +Parallel strategy objects can be created using `strategy_factory.create('strategy_name')`. """ import logging -from abc import ABCMeta, abstractmethod -from orion.core.utils import Factory +from orion.core.utils import GenericFactory from orion.core.worker.trial import Trial log = logging.getLogger(__name__) @@ -48,13 +49,12 @@ def get_objective(trial): return objective -class BaseParallelStrategy(object, metaclass=ABCMeta): +class ParallelStrategy(object): """Strategy to give intermediate results for incomplete trials""" def __init__(self, *args, **kwargs): pass - @abstractmethod def observe(self, points, results): """Observe completed trials @@ -73,7 +73,7 @@ def observe(self, points, results): # NOTE: In future points and results will be converted to trials for coherence with # `Strategy.lie()` as well as for coherence with `Algorithm.observe` which will also be # converted to expect trials instead of lists and dictionaries. - pass + raise NotImplementedError() # pylint: disable=no-self-use def lie(self, trial): @@ -109,15 +109,15 @@ def configuration(self): return self.__class__.__name__ -class NoParallelStrategy(BaseParallelStrategy): +class NoParallelStrategy(ParallelStrategy): """No parallel strategy""" def observe(self, points, results): - """See BaseParallelStrategy.observe""" + """See ParallelStrategy.observe""" pass def lie(self, trial): - """See BaseParallelStrategy.lie""" + """See ParallelStrategy.lie""" result = super(NoParallelStrategy, self).lie(trial) if result: return result @@ -125,7 +125,7 @@ def lie(self, trial): return None -class MaxParallelStrategy(BaseParallelStrategy): +class MaxParallelStrategy(ParallelStrategy): """Parallel strategy that uses the max of completed objectives""" def __init__(self, default_result=float("inf")): @@ -140,8 +140,7 @@ def configuration(self): return {self.__class__.__name__: {"default_result": self.default_result}} def observe(self, points, results): - """See BaseParallelStrategy.observe""" - super(MaxParallelStrategy, self).observe(points, results) + """See ParallelStrategy.observe""" results = [ result["objective"] for result in results if result["objective"] is not None ] @@ -149,7 +148,7 @@ def observe(self, points, results): self.max_result = max(results) def lie(self, trial): - """See BaseParallelStrategy.lie""" + """See ParallelStrategy.lie""" result = super(MaxParallelStrategy, self).lie(trial) if result: return result @@ -157,7 +156,7 @@ def lie(self, trial): return Trial.Result(name="lie", type="lie", value=self.max_result) -class MeanParallelStrategy(BaseParallelStrategy): +class MeanParallelStrategy(ParallelStrategy): """Parallel strategy that uses the mean of completed objectives""" def __init__(self, default_result=float("inf")): @@ -172,8 +171,7 @@ def configuration(self): return {self.__class__.__name__: {"default_result": self.default_result}} def observe(self, points, results): - """See BaseParallelStrategy.observe""" - super(MeanParallelStrategy, self).observe(points, results) + """See ParallelStrategy.observe""" objective_values = [ result["objective"] for result in results if result["objective"] is not None ] @@ -183,7 +181,7 @@ def observe(self, points, results): ) def lie(self, trial): - """See BaseParallelStrategy.lie""" + """See ParallelStrategy.lie""" result = super(MeanParallelStrategy, self).lie(trial) if result: return result @@ -191,7 +189,7 @@ def lie(self, trial): return Trial.Result(name="lie", type="lie", value=self.mean_result) -class StubParallelStrategy(BaseParallelStrategy): +class StubParallelStrategy(ParallelStrategy): """Parallel strategy that returns static objective value for incompleted trials.""" def __init__(self, stub_value=None): @@ -205,11 +203,11 @@ def configuration(self): return {self.__class__.__name__: {"stub_value": self.stub_value}} def observe(self, points, results): - """See BaseParallelStrategy.observe""" + """See ParallelStrategy.observe""" pass def lie(self, trial): - """See BaseParallelStrategy.lie""" + """See ParallelStrategy.lie""" result = super(StubParallelStrategy, self).lie(trial) if result: return result @@ -217,11 +215,4 @@ def lie(self, trial): return Trial.Result(name="lie", type="lie", value=self.stub_value) -# pylint: disable=too-few-public-methods,abstract-method -class Strategy(BaseParallelStrategy, metaclass=Factory): - """Class used to build a parallel strategy given name and params - - .. seealso:: `orion.core.utils.Factory` metaclass and `BaseParallelStrategy` interface. - """ - - pass +strategy_factory = GenericFactory(ParallelStrategy) diff --git a/src/orion/executor/base.py b/src/orion/executor/base.py index 581ee455d..6be25cc59 100644 --- a/src/orion/executor/base.py +++ b/src/orion/executor/base.py @@ -7,7 +7,7 @@ """ -from orion.core.utils import Factory +from orion.core.utils import GenericFactory class BaseExecutor: @@ -62,6 +62,4 @@ def __exit__(self, exc_type, exc_value, traceback): pass -# pylint: disable=too-few-public-methods,abstract-method -class Executor(BaseExecutor, metaclass=Factory): - """Factory class to build Executors""" +executor_factory = GenericFactory(BaseExecutor) diff --git a/src/orion/storage/base.py b/src/orion/storage/base.py index 32a870855..09d12c4d4 100644 --- a/src/orion/storage/base.py +++ b/src/orion/storage/base.py @@ -5,13 +5,27 @@ Implement a generic protocol to allow Orion to communicate using different storage backend. +Storage protocol is a generic way of allowing Orion to interface with different storage. +MongoDB, track, cometML, MLFLow, etc... + +Examples +-------- +>>> storage_factory.create('track', uri='file://orion_test.json') +>>> storage_factory.create('legacy', experiment=...) + +Notes +----- +When retrieving an already initialized Storage object you should use `get_storage`. +`storage_factory.create()` should only be used for initialization purposes as `get_storage` +raises more granular error messages. + """ import copy import logging import orion.core from orion.core.io import resolve_config -from orion.core.utils.singleton import AbstractSingletonType, SingletonFactory +from orion.core.utils.singleton import GenericSingletonFactory log = logging.getLogger(__name__) @@ -62,7 +76,7 @@ class MissingArguments(Exception): pass -class BaseStorageProtocol(metaclass=AbstractSingletonType): +class BaseStorageProtocol: """Implement a generic protocol to allow Orion to communicate using different storage backend @@ -380,25 +394,7 @@ def update_heartbeat(self, trial): raise NotImplementedError() -# pylint: disable=too-few-public-methods,abstract-method -class Storage(BaseStorageProtocol, metaclass=SingletonFactory): - """Storage protocol is a generic way of allowing Orion to interface with different storage. - MongoDB, track, cometML, MLFLow, etc... - - Examples - -------- - >>> Storage('track', uri='file://orion_test.json') - >>> Storage('legacy', experiment=...) - - Notes - ----- - When retrieving an already initialized Storage object you should use `get_storage`. - `Storage()` should only be used for initialization purposes as `get_storage` - raises more granular error messages. - - """ - - pass +storage_factory = GenericSingletonFactory(BaseStorageProtocol) def get_storage(): @@ -418,7 +414,7 @@ def get_storage(): with the appropriate arguments for the chosen backend """ - return Storage() + return storage_factory.create() def setup_storage(storage=None, debug=False): @@ -457,9 +453,9 @@ def setup_storage(storage=None, debug=False): log.debug("Creating %s storage client with args: %s", storage_type, storage) try: - Storage(of_type=storage_type, **storage) + storage_factory.create(of_type=storage_type, **storage) except ValueError: - if Storage().__class__.__name__.lower() != storage_type.lower(): + if storage_factory.create().__class__.__name__.lower() != storage_type.lower(): raise diff --git a/src/orion/storage/legacy.py b/src/orion/storage/legacy.py index 773cfe800..42c068275 100644 --- a/src/orion/storage/legacy.py +++ b/src/orion/storage/legacy.py @@ -12,7 +12,7 @@ import orion.core import orion.core.utils.backward as backward -from orion.core.io.database import Database, OutdatedDatabaseError +from orion.core.io.database import Database, OutdatedDatabaseError, database_factory from orion.core.utils.exceptions import MissingResultFile from orion.core.worker.trial import Trial, validate_status from orion.storage.base import ( @@ -42,7 +42,7 @@ def get_database(): with the appropriate arguments for the chosen backend """ - return Database() + return database_factory.create() def setup_database(config=None): @@ -63,11 +63,7 @@ def setup_database(config=None): dbtype = db_opts.pop("type") log.debug("Creating %s database client with args: %s", dbtype, db_opts) - try: - Database(of_type=dbtype, **db_opts) - except ValueError: - if Database().__class__.__name__.lower() != dbtype.lower(): - raise + return database_factory.create(dbtype, **db_opts) class Legacy(BaseStorageProtocol): @@ -88,7 +84,7 @@ def __init__(self, database=None, setup=True): if database is not None: setup_database(database) - self._db = Database() + self._db = database_factory.create() if setup: self._setup_db() diff --git a/src/orion/testing/algo.py b/src/orion/testing/algo.py index 5c1023e19..f49329701 100644 --- a/src/orion/testing/algo.py +++ b/src/orion/testing/algo.py @@ -17,7 +17,7 @@ from orion.algo.tpe import TPE from orion.benchmark.task.branin import Branin from orion.core.io.space_builder import SpaceBuilder -from orion.core.worker.primary_algo import PrimaryAlgo +from orion.core.worker.primary_algo import SpaceTransformAlgoWrapper from orion.testing.space import build_space algorithms = { @@ -108,9 +108,9 @@ class BaseAlgoTests: This test-suite covers all typical cases for HPO algorithms. To use it for a new algorithm, the class inheriting from this one must redefine the attributes ``algo_name`` with the name of the algorithm used to create it with the algorithm factory - ``orion.core.worker.primary_algo.PrimaryAlgo`` and ``config`` with a base configuration for the - algorithm that contains all its arguments. The base space can be redefine if needed - with the attribute ``space``. + ``orion.core.worker.primary_algo.SpaceTransformAlgoWrapper`` and ``config`` with a base + configuration for the algorithm that contains all its arguments. The base space can be redefine + if needed with the attribute ``space``. Many algorithms have different phases that should be tested. For instance TPE have a first phase of random search and a second of Bayesian Optimization. @@ -169,9 +169,10 @@ def create_algo(self, config=None, space=None, **kwargs): """ config = copy.deepcopy(config or self.config) config.update(kwargs) - algo = PrimaryAlgo( + algo = SpaceTransformAlgoWrapper( + orion.algo.base.algo_factory.get_class(self.algo_name), space or self.create_space(), - {self.algo_name: config}, + **config, ) algo.algorithm.max_trials = self.max_trials return algo @@ -210,7 +211,7 @@ def observe_points(self, points, algo, objective=0): ---------- points: list of points Trials formatted as tuples of values - algo: ``orion.algo.base.OptimizationAlgorithm`` + algo: ``orion.algo.base.BaseAlgorithm`` The algorithm used to observe points. objective: int, optional The base objective for the trials. All objectives @@ -235,7 +236,7 @@ def force_observe(self, num, algo): ---------- num: int Number of trials to suggest and observe. - algo: ``orion.algo.base.OptimizationAlgorithm`` + algo: ``orion.algo.base.BaseAlgorithm`` The algorithm that must suggest and observe. Raises @@ -277,7 +278,7 @@ def spy_phase(self, mocker, num, algo, attribute): Mocker from ``pytest_mock``. Should be given by fixtures of the tests. num: int Number of trials to suggest and observe - algo: ``orion.algo.base.OptimizationAlgorithm`` + algo: ``orion.algo.base.BaseAlgorithm`` The algorithm to test attribute: str The algorithm attribute or method to mock. The path is respective to the @@ -300,7 +301,7 @@ def assert_callbacks(self, spy, num, algo): Object mocked by ``BaseAlgoTests.spy_phase``. num: int number of points of the phase. - algo: ``orion.algo.base.OptimizationAlgorithm`` + algo: ``orion.algo.base.BaseAlgorithm`` The algorithm being tested. """ pass @@ -317,7 +318,7 @@ def assert_dim_type_supported(self, mocker, num, attr, test_space): Mocker from ``pytest_mock``. Should be given by fixtures of the tests. num: int Number of trials to suggest and observe - algo: ``orion.algo.base.OptimizationAlgorithm`` + algo: ``orion.algo.base.BaseAlgorithm`` The algorithm to test attribute: str The algorithm attribute or method to mock. The path is respective to the @@ -350,28 +351,32 @@ def test_get_id(self): algo = self.create_algo(space=space) - assert algo.get_id(["is here", 1]) == algo.get_id(["is here", 1]) - assert algo.get_id(["is here", 1]) != algo.get_id(["is here", 2]) - assert algo.get_id(["matters", 1]) != algo.get_id(["is here", 1]) + assert algo.get_id([1, 1, 1]) == algo.get_id([1, 1, 1]) + assert algo.get_id([1, 1, 1]) != algo.get_id([1, 2, 2]) + assert algo.get_id([1, 1, 1]) != algo.get_id([2, 1, 1]) - assert algo.get_id(["is here", 1], ignore_fidelity=False) == algo.get_id( - ["is here", 1], ignore_fidelity=False + assert algo.get_id([1, 1, 1], ignore_fidelity=False) == algo.get_id( + [1, 1, 1], ignore_fidelity=False ) - assert algo.get_id(["is here", 1], ignore_fidelity=False) != algo.get_id( - ["is here", 2], ignore_fidelity=False + # Fidelity changes id + assert algo.get_id([1, 1, 1], ignore_fidelity=False) != algo.get_id( + [2, 1, 1], ignore_fidelity=False ) - assert algo.get_id(["matters", 1], ignore_fidelity=False) != algo.get_id( - ["is here", 1], ignore_fidelity=False + # Non-fidelity changes id + assert algo.get_id([1, 1, 1], ignore_fidelity=False) != algo.get_id( + [1, 1, 2], ignore_fidelity=False ) - assert algo.get_id(["is here", 1], ignore_fidelity=True) == algo.get_id( - ["is here", 1], ignore_fidelity=True + assert algo.get_id([1, 1, 1], ignore_fidelity=True) == algo.get_id( + [1, 1, 1], ignore_fidelity=True ) - assert algo.get_id(["is here", 1], ignore_fidelity=True) != algo.get_id( - ["is here", 2], ignore_fidelity=True + # Fidelity does not change id + assert algo.get_id([1, 1, 1], ignore_fidelity=True) == algo.get_id( + [2, 1, 1], ignore_fidelity=True ) - assert algo.get_id(["whatever", 1], ignore_fidelity=True) == algo.get_id( - ["is here", 1], ignore_fidelity=True + # Non-fidelity still changes id + assert algo.get_id([1, 1, 1], ignore_fidelity=True) != algo.get_id( + [1, 1, 2], ignore_fidelity=True ) @phase diff --git a/src/orion/testing/state.py b/src/orion/testing/state.py index 38d861645..fc8e80c6d 100644 --- a/src/orion/testing/state.py +++ b/src/orion/testing/state.py @@ -19,7 +19,7 @@ update_singletons, ) from orion.core.worker.trial import Trial -from orion.storage.base import Storage, get_storage +from orion.storage.base import get_storage, storage_factory # pylint: disable=no-self-use,protected-access @@ -197,7 +197,7 @@ def storage(self, config=None): try: config["of_type"] = config.pop("type") - db = Storage(**config) + db = storage_factory.create(**config) self.storage_config = config except SingletonAlreadyInstantiatedError: db = get_storage() diff --git a/tests/conftest.py b/tests/conftest.py index a7ef281d9..92bd9fc11 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,14 +13,14 @@ import orion.core import orion.core.utils.backward as backward -from orion.algo.base import BaseAlgorithm, OptimizationAlgorithm +from orion.algo.base import BaseAlgorithm from orion.core.io import resolve_config -from orion.core.io.database import Database +from orion.core.io.database import database_factory from orion.core.io.database.mongodb import MongoDB from orion.core.io.database.pickleddb import PickledDB from orion.core.utils.singleton import update_singletons from orion.core.worker.trial import Trial -from orion.storage.base import Storage, get_storage, setup_storage +from orion.storage.base import get_storage, setup_storage, storage_factory from orion.storage.legacy import Legacy from orion.testing import OrionState, mocked_datetime @@ -304,11 +304,8 @@ def clean_db(database, exp_config): @pytest.fixture() def null_db_instances(): """Nullify singleton instance so that we can assure independent instantiation tests.""" - Storage.instance = None - Legacy.instance = None - Database.instance = None - MongoDB.instance = None - PickledDB.instance = None + storage_factory.instance = None + database_factory.instance = None @pytest.fixture(scope="function") diff --git a/tests/functional/backward_compatibility/test_versions.py b/tests/functional/backward_compatibility/test_versions.py index 9473d5cc2..9ea1adf08 100644 --- a/tests/functional/backward_compatibility/test_versions.py +++ b/tests/functional/backward_compatibility/test_versions.py @@ -10,11 +10,8 @@ import orion.core.io.experiment_builder as experiment_builder from orion.client import create_experiment -from orion.core.io.database import Database -from orion.core.io.database.mongodb import MongoDB -from orion.core.io.database.pickleddb import PickledDB -from orion.storage.base import Storage, get_storage -from orion.storage.legacy import Legacy +from orion.core.io.database import database_factory +from orion.storage.base import get_storage, storage_factory DIRNAME = os.path.dirname(os.path.abspath(__file__)) @@ -226,11 +223,8 @@ def fill_db(request): def null_db_instances(): """Nullify singleton instance so that we can assure independent instantiation tests.""" - Storage.instance = None - Legacy.instance = None - Database.instance = None - MongoDB.instance = None - PickledDB.instance = None + storage_factory.instance = None + database_factory.instance = None def build_storage(): diff --git a/tests/functional/benchmark/test_benchmark_flow.py b/tests/functional/benchmark/test_benchmark_flow.py index 79aa2a817..e0f2c8336 100644 --- a/tests/functional/benchmark/test_benchmark_flow.py +++ b/tests/functional/benchmark/test_benchmark_flow.py @@ -7,7 +7,13 @@ from orion.benchmark.assessment import AverageRank, AverageResult from orion.benchmark.benchmark_client import get_or_create_benchmark -from orion.benchmark.task import BaseTask, Branin, CarromTable, EggHolder, RosenBrock +from orion.benchmark.task import ( + BenchmarkTask, + Branin, + CarromTable, + EggHolder, + RosenBrock, +) algorithms = [ {"algorithm": {"random": {"seed": 1}}}, @@ -15,7 +21,7 @@ ] -class BirdLike(BaseTask): +class BirdLike(BenchmarkTask): """User defined benchmark task""" def __init__(self, max_trials=20): diff --git a/tests/functional/commands/conftest.py b/tests/functional/commands/conftest.py index 03b09b1e7..b94a452b5 100644 --- a/tests/functional/commands/conftest.py +++ b/tests/functional/commands/conftest.py @@ -12,7 +12,7 @@ import orion.core.cli import orion.core.io.experiment_builder as experiment_builder import orion.core.utils.backward as backward -from orion.core.io.database import Database +from orion.core.io.database import database_factory from orion.core.worker.trial import Trial from orion.storage.base import get_storage @@ -117,7 +117,9 @@ def with_experiment_missing_conf_file(monkeypatch, one_experiment): conf_file = "idontexist.yaml" exp.metadata["user_config"] = conf_file exp.metadata["user_args"] += ["--config", conf_file] - Database().write("experiments", exp.configuration, query={"_id": exp.id}) + database_factory.create().write( + "experiments", exp.configuration, query={"_id": exp.id} + ) return exp @@ -144,7 +146,7 @@ def single_without_success(one_experiment): x["value"] = x_value trial = Trial(experiment=exp.id, params=[x], status=status) x_value += 1 - Database().write("trials", trial.to_dict()) + database_factory.create().write("trials", trial.to_dict()) @pytest.fixture @@ -155,7 +157,7 @@ def single_with_trials(single_without_success): x = {"name": "/x", "type": "real", "value": 100} results = {"name": "obj", "type": "objective", "value": 0} trial = Trial(experiment=exp.id, params=[x], status="completed", results=[results]) - Database().write("trials", trial.to_dict()) + database_factory.create().write("trials", trial.to_dict()) return exp.configuration @@ -208,8 +210,8 @@ def family_with_trials(two_experiments): x["value"] = x_value + 0.5 # To avoid duplicates trial2 = Trial(experiment=exp2.id, params=[x, y], status=status) x_value += 1 - Database().write("trials", trial.to_dict()) - Database().write("trials", trial2.to_dict()) + database_factory.create().write("trials", trial.to_dict()) + database_factory.create().write("trials", trial2.to_dict()) @pytest.fixture @@ -217,8 +219,8 @@ def unrelated_with_trials(family_with_trials, single_with_trials): """Create two unrelated experiments with all types of trials.""" exp = experiment_builder.build(name="test_double_exp_child") - Database().remove("trials", {"experiment": exp.id}) - Database().remove("experiments", {"_id": exp.id}) + database_factory.create().remove("trials", {"experiment": exp.id}) + database_factory.create().remove("experiments", {"_id": exp.id}) @pytest.fixture @@ -266,7 +268,7 @@ def three_family_with_trials(three_experiments_family, family_with_trials): z["value"] = x_value * 100 trial = Trial(experiment=exp.id, params=[x, z], status=status) x_value += 1 - Database().write("trials", trial.to_dict()) + database_factory.create().write("trials", trial.to_dict()) @pytest.fixture @@ -310,7 +312,7 @@ def three_family_branch_with_trials( z["value"] = x_value * 100 trial = Trial(experiment=exp.id, params=[x, y, z], status=status) x_value += 1 - Database().write("trials", trial.to_dict()) + database_factory.create().write("trials", trial.to_dict()) @pytest.fixture @@ -430,7 +432,7 @@ def three_experiments_same_name_with_trials(two_experiments_same_name, storage): trial = Trial(experiment=exp.id, params=[x], status=status) trial2 = Trial(experiment=exp2.id, params=[x, y], status=status) trial3 = Trial(experiment=exp3.id, params=[x, y, z], status=status) - Database().write("trials", trial.to_dict()) - Database().write("trials", trial2.to_dict()) - Database().write("trials", trial3.to_dict()) + database_factory.create().write("trials", trial.to_dict()) + database_factory.create().write("trials", trial2.to_dict()) + database_factory.create().write("trials", trial3.to_dict()) x_value += 1 diff --git a/tests/functional/commands/test_setup_command.py b/tests/functional/commands/test_setup_command.py index 4a93157e8..772164978 100644 --- a/tests/functional/commands/test_setup_command.py +++ b/tests/functional/commands/test_setup_command.py @@ -8,7 +8,7 @@ import orion.core import orion.core.cli -from orion.core.io.database import Database +from orion.core.io.database import database_factory class _mock_input: @@ -120,7 +120,8 @@ def test_invalid_database(monkeypatch, tmp_path, capsys): for invalid_db_name in invalid_db_names: assert ( "Unexpected value: {}. Must be one of: {}\n".format( - invalid_db_name, ", ".join(sorted(Database.types.keys())) + invalid_db_name, + ", ".join(sorted(database_factory.get_classes().keys())), ) in captured_output ) diff --git a/tests/functional/configuration/conftest.py b/tests/functional/configuration/conftest.py index ebf0b35f0..147d3693a 100644 --- a/tests/functional/configuration/conftest.py +++ b/tests/functional/configuration/conftest.py @@ -1,6 +1,6 @@ """Common fixtures and utils for configuration tests.""" -from orion.algo.base import BaseAlgorithm, OptimizationAlgorithm -from orion.core.worker.strategy import BaseParallelStrategy, Strategy +from orion.algo.base import BaseAlgorithm +from orion.core.worker.strategy import ParallelStrategy, strategy_factory def __init__(self, *args, **params): @@ -19,19 +19,21 @@ def configuration(self): return {self.__class__.__name__.lower(): self.params} +# Keep pointers to classes so that they are not garbage collected. +algo_classes = [] for char in "ABCDE": algo_class = type(f"A{char}", (BaseAlgorithm,), {"suggest": stub, "observe": stub}) - # Hack it into being discoverable - OptimizationAlgorithm.types[algo_class.__name__.lower()] = algo_class + algo_classes.append(algo_class) +# Keep pointers to classes so that they are not garbage collected. +strategy_classes = [] for char in "ABCDE": strategy_class = type( - f"S{char}", (BaseParallelStrategy,), {"observe": stub, "__init__": __init__} + f"S{char}", (ParallelStrategy,), {"observe": stub, "__init__": __init__} ) strategy_class.configuration = property(configuration) - # Hack it into being discoverable - Strategy.types[strategy_class.__name__.lower()] = strategy_class + strategy_classes.append(strategy_class) diff --git a/tests/functional/gradient_descent_algo/setup.py b/tests/functional/gradient_descent_algo/setup.py index 118e08622..dd69c2b8c 100644 --- a/tests/functional/gradient_descent_algo/setup.py +++ b/tests/functional/gradient_descent_algo/setup.py @@ -15,7 +15,7 @@ package_dir={"": "src"}, include_package_data=True, entry_points={ - "OptimizationAlgorithm": [ + "BaseAlgorithm": [ "gradient_descent = orion.algo.gradient_descent:Gradient_Descent" ], }, diff --git a/tests/unittests/algo/test_asha.py b/tests/unittests/algo/test_asha.py index 202f088eb..e7f0e42e8 100644 --- a/tests/unittests/algo/test_asha.py +++ b/tests/unittests/algo/test_asha.py @@ -493,7 +493,7 @@ class TestGenericASHA(BaseAlgoTests): "num_brackets": 2, "repetitions": 3, } - space = {"x": "uniform(0, 1)", "f": "fidelity(1, 10, base=2)"} + space = {"x": "uniform(0, 1)", "y": "uniform(0, 1)", "f": "fidelity(1, 10, base=2)"} def test_suggest_n(self, mocker, num, attr): algo = self.create_algo() diff --git a/tests/unittests/algo/test_base.py b/tests/unittests/algo/test_base.py index 4e9a8ee0b..3a3993696 100644 --- a/tests/unittests/algo/test_base.py +++ b/tests/unittests/algo/test_base.py @@ -2,30 +2,25 @@ # -*- coding: utf-8 -*- """Example usage and tests for :mod:`orion.algo.base`.""" -from orion.algo.base import BaseAlgorithm from orion.algo.space import Integer, Real, Space def test_init(dumbalgo): """Check if initialization works for nested algos.""" nested_algo = {"DumbAlgo": dict(value=6, scoring=5)} - algo = dumbalgo(8, value=1, subone=nested_algo) + algo = dumbalgo(8, value=1) assert algo.space == 8 assert algo.value == 1 assert algo.scoring == 0 assert algo.judgement is None assert algo.suspend is False assert algo.done is False - assert isinstance(algo.subone, BaseAlgorithm) - assert algo.subone.space == 8 - assert algo.subone.value == 6 - assert algo.subone.scoring == 5 def test_configuration(dumbalgo): """Check configuration getter works for nested algos.""" nested_algo = {"DumbAlgo": dict(value=6, scoring=5)} - algo = dumbalgo(8, value=1, subone=nested_algo) + algo = dumbalgo(8, value=1) config = algo.configuration assert config == { "dumbalgo": { @@ -35,45 +30,14 @@ def test_configuration(dumbalgo): "judgement": None, "suspend": False, "done": False, - "subone": { - "dumbalgo": { - "seed": None, - "value": 6, - "scoring": 5, - "judgement": None, - "suspend": False, - "done": False, - } - }, } } -def test_space_setter(dumbalgo): - """Check whether space setter works for nested algos.""" - nested_algo = { - "DumbAlgo": dict( - value=9, - ) - } - nested_algo2 = { - "DumbAlgo": dict( - judgement=10, - ) - } - algo = dumbalgo(8, value=1, naedw=nested_algo, naekei=nested_algo2) - algo.space = "etsh" - assert algo.space == "etsh" - assert algo.naedw.space == "etsh" - assert algo.naedw.value == 9 - assert algo.naekei.space == "etsh" - assert algo.naekei.judgement == 10 - - def test_state_dict(dumbalgo): """Check whether trials_info is in the state dict""" nested_algo = {"DumbAlgo": dict(value=6, scoring=5)} - algo = dumbalgo(8, value=1, subone=nested_algo) + algo = dumbalgo(8, value=1) algo.suggest(1) assert not algo.state_dict["_trials_info"] algo.observe([(1, 2)], [{"objective": 3}]) diff --git a/tests/unittests/algo/test_evolution_es.py b/tests/unittests/algo/test_evolution_es.py index 9af1a7b4e..17d77af31 100644 --- a/tests/unittests/algo/test_evolution_es.py +++ b/tests/unittests/algo/test_evolution_es.py @@ -369,7 +369,7 @@ class TestGenericEvolutionES(BaseAlgoTests): "max_retries": 1000, "mutate": None, } - space = {"x": "uniform(0, 1)", "f": "fidelity(1, 10, base=2)"} + space = {"x": "uniform(0, 1)", "y": "uniform(0, 1)", "f": "fidelity(1, 10, base=2)"} @pytest.mark.skip(reason="See https://github.com/Epistimio/orion/issues/598") def test_is_done_cardinality(self): diff --git a/tests/unittests/algo/test_hyperband.py b/tests/unittests/algo/test_hyperband.py index 1e8194402..5e4a56fb5 100644 --- a/tests/unittests/algo/test_hyperband.py +++ b/tests/unittests/algo/test_hyperband.py @@ -754,7 +754,7 @@ class TestGenericHyperband(BaseAlgoTests): "seed": 123456, "repetitions": 3, } - space = {"x": "uniform(0, 1)", "f": "fidelity(1, 10, base=2)"} + space = {"x": "uniform(0, 1)", "y": "uniform(0, 1)", "f": "fidelity(1, 10, base=2)"} @phase def test_suggest_lots(self, mocker, num, attr): diff --git a/tests/unittests/benchmark/test_benchmark_client.py b/tests/unittests/benchmark/test_benchmark_client.py index f3c1c0345..006d6594f 100644 --- a/tests/unittests/benchmark/test_benchmark_client.py +++ b/tests/unittests/benchmark/test_benchmark_client.py @@ -88,9 +88,8 @@ def test_create_benchmark_bad_storage(self, benchmark_config_py): } get_or_create_benchmark(**benchmark_config_py) - assert ( - "Could not find implementation of AbstractDB, type = 'idontexist'" - in str(exc.value) + assert "Could not find implementation of Database, type = 'idontexist'" in str( + exc.value ) def test_create_experiment_debug_mode(self, tmp_path, benchmark_config_py): @@ -240,7 +239,9 @@ def test_create_with_not_loaded_targets(self, benchmark_config): with OrionState(benchmarks=cfg_invalid_assess): with pytest.raises(NotImplementedError) as exc: get_or_create_benchmark(benchmark_config["name"]) - assert "Could not find implementation of BaseAssess" in str(exc.value) + assert "Could not find implementation of BenchmarkAssessment" in str( + exc.value + ) cfg_invalid_task = copy.deepcopy(benchmark_config) cfg_invalid_task["targets"][0]["task"]["idontexist"] = {"max_trials": 2} @@ -248,7 +249,7 @@ def test_create_with_not_loaded_targets(self, benchmark_config): with OrionState(benchmarks=cfg_invalid_task): with pytest.raises(NotImplementedError) as exc: get_or_create_benchmark(benchmark_config["name"]) - assert "Could not find implementation of BaseTask" in str(exc.value) + assert "Could not find implementation of BenchmarkTask" in str(exc.value) def test_create_with_not_exist_targets_parameters(self, benchmark_config): """Test creation with not existing assessment parameters""" diff --git a/tests/unittests/client/test_client.py b/tests/unittests/client/test_client.py index e940cfda6..352f78852 100644 --- a/tests/unittests/client/test_client.py +++ b/tests/unittests/client/test_client.py @@ -192,9 +192,8 @@ def test_create_experiment_bad_storage(self): storage={"type": "legacy", "database": {"type": "idontexist"}}, ) - assert ( - "Could not find implementation of AbstractDB, type = 'idontexist'" - in str(exc.value) + assert "Could not find implementation of Database, type = 'idontexist'" in str( + exc.value ) def test_create_experiment_new_default(self): diff --git a/tests/unittests/client/test_experiment_client.py b/tests/unittests/client/test_experiment_client.py index f1e690c53..f961cf0cf 100644 --- a/tests/unittests/client/test_experiment_client.py +++ b/tests/unittests/client/test_experiment_client.py @@ -17,7 +17,6 @@ SampleTimeout, ) from orion.core.worker.trial import Trial -from orion.executor.base import Executor from orion.executor.joblib_backend import Joblib from orion.storage.base import get_storage from orion.testing import create_experiment, mock_space_iterate diff --git a/tests/unittests/core/database/test_mongodb.py b/tests/unittests/core/database/test_mongodb.py index 30f15dffc..e2748d2b9 100644 --- a/tests/unittests/core/database/test_mongodb.py +++ b/tests/unittests/core/database/test_mongodb.py @@ -19,7 +19,12 @@ test_collection, ) -from orion.core.io.database import Database, DatabaseError, DuplicateKeyError +from orion.core.io.database import ( + Database, + DatabaseError, + DuplicateKeyError, + database_factory, +) from orion.core.io.database.mongodb import AUTH_FAILED_MESSAGES, MongoDB @@ -232,8 +237,9 @@ def test_overwrite_partial_uri(self, monkeypatch): def test_singleton(self): """Test that MongoDB class is a singleton.""" - orion_db = MongoDB( - "mongodb://localhost", + orion_db = database_factory.create( + of_type="mongodb", + host="mongodb://localhost", port=27017, name="orion_test", username="user", @@ -242,7 +248,7 @@ def test_singleton(self): # reinit connection does not change anything orion_db.initiate_connection() orion_db.close_connection() - assert MongoDB() is orion_db + assert database_factory.create() is orion_db def test_change_server_timeout(self): """Test that the server timeout is correctly changed.""" diff --git a/tests/unittests/core/evc/test_adapters.py b/tests/unittests/core/evc/test_adapters.py index 80dc55599..dccba07e1 100644 --- a/tests/unittests/core/evc/test_adapters.py +++ b/tests/unittests/core/evc/test_adapters.py @@ -7,8 +7,8 @@ from orion.algo.space import Real from orion.core.evc.adapters import ( - Adapter, AlgorithmChange, + BaseAdapter, CodeChange, CompositeAdapter, DimensionAddition, @@ -310,8 +310,8 @@ def test_composite_adapter_init_with_bad_adapters(self): def test_adapter_creation(dummy_param): - """Test initialization using :meth:`orion.core.evc.adapters.Adapter.build`""" - adapter = Adapter.build( + """Test initialization using :meth:`orion.core.evc.adapters.BaseAdapter.build`""" + adapter = BaseAdapter.build( [{"of_type": "DimensionAddition", "param": dummy_param.to_dict()}] ) @@ -854,7 +854,9 @@ def test_dimension_addition_configuration(dummy_param): assert configuration["of_type"] == "dimensionaddition" assert configuration["param"] == dummy_param.to_dict() - assert Adapter.build([configuration]).adapters[0].configuration[0] == configuration + assert ( + BaseAdapter.build([configuration]).adapters[0].configuration[0] == configuration + ) def test_dimension_deletion_configuration(dummy_param): @@ -866,7 +868,9 @@ def test_dimension_deletion_configuration(dummy_param): assert configuration["of_type"] == "dimensiondeletion" assert configuration["param"] == dummy_param.to_dict() - assert Adapter.build([configuration]).adapters[0].configuration[0] == configuration + assert ( + BaseAdapter.build([configuration]).adapters[0].configuration[0] == configuration + ) def test_dimension_prior_change_configuration(small_prior, large_prior): @@ -883,7 +887,9 @@ def test_dimension_prior_change_configuration(small_prior, large_prior): assert configuration["old_prior"] == small_prior assert configuration["new_prior"] == large_prior - assert Adapter.build([configuration]).adapters[0].configuration[0] == configuration + assert ( + BaseAdapter.build([configuration]).adapters[0].configuration[0] == configuration + ) def test_dimension_renaming_configuration(): @@ -898,7 +904,9 @@ def test_dimension_renaming_configuration(): assert configuration["old_name"] == old_name assert configuration["new_name"] == new_name - assert Adapter.build([configuration]).adapters[0].configuration[0] == configuration + assert ( + BaseAdapter.build([configuration]).adapters[0].configuration[0] == configuration + ) def test_algorithm_change_configuration(): @@ -909,7 +917,9 @@ def test_algorithm_change_configuration(): assert configuration["of_type"] == "algorithmchange" - assert Adapter.build([configuration]).adapters[0].configuration[0] == configuration + assert ( + BaseAdapter.build([configuration]).adapters[0].configuration[0] == configuration + ) def test_orion_version_change_configuration(): @@ -920,7 +930,9 @@ def test_orion_version_change_configuration(): assert configuration["of_type"] == "orionversionchange" - assert Adapter.build([configuration]).adapters[0].configuration[0] == configuration + assert ( + BaseAdapter.build([configuration]).adapters[0].configuration[0] == configuration + ) def test_code_change_configuration(): @@ -932,7 +944,9 @@ def test_code_change_configuration(): assert configuration["of_type"] == "codechange" assert configuration["change_type"] == CodeChange.UNSURE - assert Adapter.build([configuration]).adapters[0].configuration[0] == configuration + assert ( + BaseAdapter.build([configuration]).adapters[0].configuration[0] == configuration + ) def test_composite_configuration(dummy_param): @@ -950,5 +964,11 @@ def test_composite_configuration(dummy_param): assert configuration[0] == dimension_addition_adapter.configuration[0] assert configuration[1] == dimension_deletion_adapter.configuration[0] - assert Adapter.build(configuration).adapters[0].configuration[0] == configuration[0] - assert Adapter.build(configuration).adapters[1].configuration[0] == configuration[1] + assert ( + BaseAdapter.build(configuration).adapters[0].configuration[0] + == configuration[0] + ) + assert ( + BaseAdapter.build(configuration).adapters[1].configuration[0] + == configuration[1] + ) diff --git a/tests/unittests/core/evc/test_experiment_tree.py b/tests/unittests/core/evc/test_experiment_tree.py index cd5d10ba5..2d1dabd53 100644 --- a/tests/unittests/core/evc/test_experiment_tree.py +++ b/tests/unittests/core/evc/test_experiment_tree.py @@ -5,7 +5,7 @@ import pytest from orion.client import build_experiment, get_experiment -from orion.core.evc.adapters import Adapter, CodeChange +from orion.core.evc.adapters import CodeChange from orion.core.evc.experiment import ExperimentNode from orion.testing.evc import ( build_child_experiment, diff --git a/tests/unittests/core/io/database_test.py b/tests/unittests/core/io/database_test.py index 8459b1398..b37104782 100644 --- a/tests/unittests/core/io/database_test.py +++ b/tests/unittests/core/io/database_test.py @@ -4,7 +4,7 @@ import pytest -from orion.core.io.database import Database, ReadOnlyDB +from orion.core.io.database import ReadOnlyDB, database_factory from orion.core.io.database.pickleddb import PickledDB from orion.core.utils.singleton import ( SingletonAlreadyInstantiatedError, @@ -17,7 +17,7 @@ class TestDatabaseFactory(object): """Test the creation of a determinate `Database` type, by a complete spefication of a database by-itself (this on which every `Database` acts on as part - of its being, attributes of an `AbstractDB`) and for-itself (what essentially + of its being, attributes of an `Database`) and for-itself (what essentially differentiates one concrete `Database` from one other). """ @@ -31,25 +31,24 @@ def test_empty_first_call(self): Type indeterminate <-> type abstracted from its property <-> No type """ with pytest.raises(SingletonNotInstantiatedError): - Database() + database_factory.create() def test_notfound_type_first_call(self): """Raise when supplying not implemented wrapper name.""" with pytest.raises(NotImplementedError) as exc_info: - Database("notfound") + database_factory.create("notfound") - assert "AbstractDB" in str(exc_info.value) + assert "Database" in str(exc_info.value) def test_instantiation_and_singleton(self): """Test create just one object, that object persists between calls.""" - database = Database(of_type="PickledDB", name="orion_test") + database = database_factory.create(of_type="PickledDB", name="orion_test") assert isinstance(database, PickledDB) - assert database is PickledDB() - assert database is Database() + assert database is database_factory.create() with pytest.raises(SingletonAlreadyInstantiatedError): - Database("fire", [], {"it_matters": "it's singleton"}) + database_factory.create("fire", [], {"it_matters": "it's singleton"}) @pytest.mark.usefixtures("null_db_instances") diff --git a/tests/unittests/core/io/test_experiment_builder.py b/tests/unittests/core/io/test_experiment_builder.py index 5747ae79f..f0ad8293d 100644 --- a/tests/unittests/core/io/test_experiment_builder.py +++ b/tests/unittests/core/io/test_experiment_builder.py @@ -9,7 +9,6 @@ import orion.core.io.experiment_builder as experiment_builder import orion.core.utils.backward as backward -from orion.algo.base import BaseAlgorithm from orion.algo.space import Space from orion.core.evc.adapters import BaseAdapter from orion.core.io.database.ephemeraldb import EphemeralDB @@ -21,6 +20,7 @@ UnsupportedOperation, ) from orion.core.utils.singleton import update_singletons +from orion.core.worker.primary_algo import SpaceTransformAlgoWrapper from orion.storage.base import get_storage from orion.storage.legacy import Legacy from orion.testing import OrionState @@ -773,7 +773,7 @@ def test_instantiation_after_init(self, new_config): with OrionState(experiments=[new_config], trials=[]): exp = experiment_builder.build(**new_config) - assert isinstance(exp.algorithms, BaseAlgorithm) + assert isinstance(exp.algorithms, SpaceTransformAlgoWrapper) assert isinstance(exp.space, Space) assert isinstance(exp.refers["adapter"], BaseAdapter) @@ -1150,7 +1150,7 @@ def test_load_unavailable_strategy(strategy_unavailable_config, capsys): with pytest.raises(NotImplementedError) as exc: experiment_builder.build("supernaekei") - exc.match("Could not find implementation of BaseParallelStrategy") + exc.match("Could not find implementation of ParallelStrategy") class TestInitExperimentReadWrite(object): diff --git a/tests/unittests/core/test_primary_algo.py b/tests/unittests/core/test_primary_algo.py index 7f1c98fce..586b01da8 100644 --- a/tests/unittests/core/test_primary_algo.py +++ b/tests/unittests/core/test_primary_algo.py @@ -4,25 +4,23 @@ import pytest -from orion.core.worker.primary_algo import PrimaryAlgo +from orion.algo.base import algo_factory +from orion.core.worker.primary_algo import SpaceTransformAlgoWrapper @pytest.fixture() def palgo(dumbalgo, space, fixed_suggestion): - """Set up a PrimaryAlgo with dumb configuration.""" + """Set up a SpaceTransformAlgoWrapper with dumb configuration.""" algo_config = { - "DumbAlgo": { - "value": fixed_suggestion, - "subone": {"DumbAlgo": dict(value=6, scoring=5)}, - } + "value": fixed_suggestion, } - palgo = PrimaryAlgo(space, algo_config) + palgo = SpaceTransformAlgoWrapper(dumbalgo, space, **algo_config) return palgo -class TestPrimaryAlgoWraps(object): - """Test if PrimaryAlgo is actually wrapping the configured algorithm. +class TestSpaceTransformAlgoWrapperWraps(object): + """Test if SpaceTransformAlgoWrapper is actually wrapping the configured algorithm. Does not test for transformations. """ @@ -38,16 +36,6 @@ def test_init_and_configuration(self, dumbalgo, palgo, fixed_suggestion): "judgement": None, "suspend": False, "done": False, - "subone": { - "dumbalgo": { - "seed": None, - "value": 6, - "scoring": 5, - "judgement": None, - "suspend": False, - "done": False, - } - }, } } @@ -104,9 +92,3 @@ def test_judge(self, palgo, fixed_suggestion): assert palgo.algorithm._measurements == 8 with pytest.raises(AssertionError): palgo.judge((5,), 8) - - -class TestPrimaryAlgoTransforms(object): - """Check if PrimaryAlgo appropriately transforms spaces and samples.""" - - pass diff --git a/tests/unittests/core/test_strategy.py b/tests/unittests/core/test_strategy.py index 609815fd3..0d8229487 100644 --- a/tests/unittests/core/test_strategy.py +++ b/tests/unittests/core/test_strategy.py @@ -9,8 +9,8 @@ MaxParallelStrategy, MeanParallelStrategy, NoParallelStrategy, - Strategy, StubParallelStrategy, + strategy_factory, ) from orion.core.worker.trial import Trial @@ -52,8 +52,8 @@ def corrupted_trial(): def test_handle_corrupted_trials(caplog, strategy, corrupted_trial): """Verify that corrupted trials are handled properly""" with caplog.at_level(logging.WARNING, logger="orion.core.worker.strategy"): - Strategy(strategy).observe([corrupted_trial], [{"objective": 1}]) - lie = Strategy(strategy).lie(corrupted_trial) + strategy_factory.create(strategy).observe([corrupted_trial], [{"objective": 1}]) + lie = strategy_factory.create(strategy).lie(corrupted_trial) match = "Trial `{}` has an objective but status is not completed".format( corrupted_trial.id @@ -68,8 +68,10 @@ def test_handle_corrupted_trials(caplog, strategy, corrupted_trial): def test_handle_uncompleted_trials(caplog, strategy, incomplete_trial): """Verify that no warning is logged if trial is valid""" with caplog.at_level(logging.WARNING, logger="orion.core.worker.strategy"): - Strategy(strategy).observe([incomplete_trial], [{"objective": None}]) - Strategy(strategy).lie(incomplete_trial) + strategy_factory.create(strategy).observe( + [incomplete_trial], [{"objective": None}] + ) + strategy_factory.create(strategy).lie(incomplete_trial) assert "Trial `{}` has an objective but status is not completed" not in caplog.text @@ -79,12 +81,12 @@ class TestStrategyFactory: def test_create_noparallel(self): """Test creating a NoParallelStrategy class""" - strategy = Strategy("NoParallelStrategy") + strategy = strategy_factory.create("NoParallelStrategy") assert isinstance(strategy, NoParallelStrategy) def test_create_meanparallel(self): """Test creating a MeanParallelStrategy class""" - strategy = Strategy("MeanParallelStrategy") + strategy = strategy_factory.create("MeanParallelStrategy") assert isinstance(strategy, MeanParallelStrategy) diff --git a/tests/unittests/core/utils/test_backward.py b/tests/unittests/core/utils/test_backward.py index 7c3bdd094..6d71d99c5 100644 --- a/tests/unittests/core/utils/test_backward.py +++ b/tests/unittests/core/utils/test_backward.py @@ -65,3 +65,24 @@ def test_no_changes(requirement): "shape_requirement": "flattened", "dist_requirement": "linear", } + + +def test_port_algo_config_str(): + """Function should convert string to compliant dict""" + assert backward.port_algo_config("algo_name") == {"of_type": "algo_name"} + + +def test_port_algo_config_dict_legacy(): + """Function should convert dict to be compliant""" + assert backward.port_algo_config({"algo_name": {"some": "args"}}) == { + "of_type": "algo_name", + "some": "args", + } + + +def test_port_algo_config_dict_compliant(): + """Function should leave compliant dict as-is""" + assert backward.port_algo_config({"of_type": "algo_name", "some": "args"}) == { + "of_type": "algo_name", + "some": "args", + } diff --git a/tests/unittests/core/utils/test_utils.py b/tests/unittests/core/utils/test_utils.py index 348f827b0..736fd66ad 100644 --- a/tests/unittests/core/utils/test_utils.py +++ b/tests/unittests/core/utils/test_utils.py @@ -4,11 +4,12 @@ import pytest -from orion.core.utils import Factory, float_to_digits_list +from orion.core.utils import Factory, GenericFactory, float_to_digits_list -def test_factory_subclasses_detection(): +def test_deprecated_factory_subclasses_detection(): """Verify that meta-class Factory finds all subclasses""" + # TODO: Remove in v0.3.0 class Base(object): pass @@ -57,6 +58,50 @@ class Random(Base): assert type(MyFactory(of_type="random")) is Random +def test_new_factory_subclasses_detection(): + """Verify that Factory finds all subclasses""" + + class Base(object): + pass + + class A(Base): + pass + + class B(Base): + pass + + class AA(A): + pass + + class AB(A): + pass + + class AAA(AA): + pass + + class AA_AB(AA, AB): + pass + + factory = GenericFactory(Base) + + assert type(factory.create(of_type="A")) is A + assert type(factory.create(of_type="B")) is B + assert type(factory.create(of_type="AA")) is AA + assert type(factory.create(of_type="AAA")) is AAA + assert type(factory.create(of_type="AA_AB")) is AA_AB + + with pytest.raises(NotImplementedError) as exc_info: + factory.create(of_type="random") + assert "Could not find implementation of Base, type = 'random'" in str( + exc_info.value + ) + + class Random(Base): + pass + + assert type(factory.create(of_type="random")) is Random + + @pytest.mark.parametrize( "number,digits_list", [ diff --git a/tests/unittests/core/worker/test_experiment.py b/tests/unittests/core/worker/test_experiment.py index 01d3dd0f6..8a3459626 100644 --- a/tests/unittests/core/worker/test_experiment.py +++ b/tests/unittests/core/worker/test_experiment.py @@ -19,7 +19,7 @@ from orion.core.io.space_builder import SpaceBuilder from orion.core.utils.exceptions import UnsupportedOperation from orion.core.worker.experiment import Experiment -from orion.core.worker.primary_algo import PrimaryAlgo +from orion.core.worker.primary_algo import SpaceTransformAlgoWrapper from orion.core.worker.trial import Trial from orion.storage.base import get_storage from orion.testing import OrionState @@ -172,9 +172,9 @@ def space(): @pytest.fixture() -def algorithm(space): +def algorithm(dumbalgo, space): """Build a dumb algo object""" - return PrimaryAlgo(space, "dumbalgo") + return SpaceTransformAlgoWrapper(dumbalgo, space=space) class TestReserveTrial(object): @@ -574,6 +574,8 @@ def test_experiment_pickleable(): assert len(exp_trials) > 0 + from orion.storage.base import storage_factory + exp_bytes = pickle.dumps(exp) new_exp = pickle.loads(exp_bytes) diff --git a/tests/unittests/core/worker/test_producer.py b/tests/unittests/core/worker/test_producer.py index 2fd5a8122..6f0335f70 100644 --- a/tests/unittests/core/worker/test_producer.py +++ b/tests/unittests/core/worker/test_producer.py @@ -18,13 +18,13 @@ class DumbParallelStrategy: """Mock object for parallel strategy""" def observe(self, points, results): - """See BaseParallelStrategy.observe""" + """See ParallelStrategy.observe""" self._observed_points = points self._observed_results = results self._value = None def lie(self, trial): - """See BaseParallelStrategy.lie""" + """See ParallelStrategy.lie""" if self._value: value = self._value else: diff --git a/tests/unittests/storage/test_legacy.py b/tests/unittests/storage/test_legacy.py index dec110546..1a5faf2a7 100644 --- a/tests/unittests/storage/test_legacy.py +++ b/tests/unittests/storage/test_legacy.py @@ -9,7 +9,7 @@ import pytest -from orion.core.io.database import Database +from orion.core.io.database import database_factory from orion.core.io.database.pickleddb import PickledDB from orion.core.utils.exceptions import MissingResultFile from orion.core.utils.singleton import ( @@ -75,7 +75,7 @@ def test_setup_database_default(monkeypatch): """Test that database is setup using default config""" update_singletons() setup_database() - database = Database() + database = database_factory.create() assert isinstance(database, PickledDB) @@ -92,7 +92,7 @@ def test_setup_database_custom(): """Test setup with local configuration""" update_singletons() setup_database({"type": "pickleddb", "host": "test.pkl"}) - database = Database() + database = database_factory.create() assert isinstance(database, PickledDB) assert database.host == os.path.abspath("test.pkl") @@ -101,7 +101,7 @@ def test_setup_database_bad_override(): """Test setup with different type than existing singleton""" update_singletons() setup_database({"type": "pickleddb", "host": "test.pkl"}) - database = Database() + database = database_factory.create() assert isinstance(database, PickledDB) with pytest.raises(SingletonAlreadyInstantiatedError) as exc: setup_database({"type": "mongodb"}) @@ -109,12 +109,11 @@ def test_setup_database_bad_override(): assert exc.match("A singleton instance of \(type: Database\)") -@pytest.mark.xfail(reason="Fix this when introducing #135 in v0.2.0") def test_setup_database_bad_config_override(): """Test setup with different config than existing singleton""" update_singletons() setup_database({"type": "pickleddb", "host": "test.pkl"}) - database = Database() + database = database_factory.create() assert isinstance(database, PickledDB) with pytest.raises(SingletonAlreadyInstantiatedError): setup_database({"type": "pickleddb", "host": "other.pkl"}) diff --git a/tests/unittests/storage/test_storage.py b/tests/unittests/storage/test_storage.py index a44a11b03..294a0d75b 100644 --- a/tests/unittests/storage/test_storage.py +++ b/tests/unittests/storage/test_storage.py @@ -23,9 +23,9 @@ from orion.storage.base import ( FailedUpdate, MissingArguments, - Storage, get_storage, setup_storage, + storage_factory, ) from orion.storage.legacy import Legacy from orion.storage.track import HAS_TRACK, REASON @@ -136,7 +136,7 @@ def test_setup_storage_default(): """Test that storage is setup using default config""" update_singletons() setup_storage() - storage = Storage() + storage = storage_factory.create() assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) @@ -156,7 +156,7 @@ def test_setup_storage_custom(): setup_storage( {"type": "legacy", "database": {"type": "pickleddb", "host": "test.pkl"}} ) - storage = Storage() + storage = storage_factory.create() assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) assert storage._db.host == os.path.abspath("test.pkl") @@ -166,7 +166,7 @@ def test_setup_storage_custom_type_missing(): """Test setup with local configuration with type missing""" update_singletons() setup_storage({"database": {"type": "pickleddb", "host": "test.pkl"}}) - storage = Storage() + storage = storage_factory.create() assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) assert storage._db.host == os.path.abspath("test.pkl") @@ -177,7 +177,7 @@ def test_setup_storage_custom_legacy_emtpy(): """Test setup with local configuration with legacy but no config""" update_singletons() setup_storage({"type": "legacy"}) - storage = Storage() + storage = storage_factory.create() assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) assert storage._db.host == orion.core.config.storage.database.host @@ -189,13 +189,13 @@ def test_setup_storage_bad_override(): setup_storage( {"type": "legacy", "database": {"type": "pickleddb", "host": "test.pkl"}} ) - storage = Storage() + storage = storage_factory.create() assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) with pytest.raises(SingletonAlreadyInstantiatedError) as exc: setup_storage({"type": "track"}) - assert exc.match("A singleton instance of \(type: Storage\)") + assert exc.match("A singleton instance of \(type: BaseStorageProtocol\)") @pytest.mark.xfail(reason="Fix this when introducing #135 in v0.2.0") @@ -203,7 +203,7 @@ def test_setup_storage_bad_config_override(): """Test setup with different config than existing singleton""" update_singletons() setup_storage({"database": {"type": "pickleddb", "host": "test.pkl"}}) - storage = Storage() + storage = storage_factory.create() assert isinstance(storage, Legacy) assert isinstance(storage._db, PickledDB) with pytest.raises(SingletonAlreadyInstantiatedError): @@ -225,7 +225,9 @@ def test_get_storage_uninitiated(): with pytest.raises(SingletonNotInstantiatedError) as exc: get_storage() - assert exc.match("No singleton instance of \(type: Storage\) was created") + assert exc.match( + "No singleton instance of \(type: BaseStorageProtocol\) was created" + ) def test_get_storage():