diff --git a/RELEASE.md b/RELEASE.md index 2840a9af17..5aafe60b46 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -14,6 +14,7 @@ * The default `kedro` environment names can now be set in `settings.py` with the help of the `CONFIG_LOADER_ARGS` variable. The relevant keys to be supplied are `base_env` and `default_run_env`. These values are set to `base` and `local` respectively as a default. * Added `kedro.config.abstract_config.AbstractConfigLoader` as an abstract base class for all `ConfigLoader` implementations. `ConfigLoader` and `TemplatedConfigLoader` now inherit directly from this base class. * Streamlined the `ConfigLoader.get` and `TemplatedConfigLoader.get` API and delegated the actual `get` method functional implementation to the `kedro.config.common` module. +* The `hook_manager` is no longer a global singleton. The `hook_manager` lifecycle is now managed by the `KedroSession`, a new `hook_manager` will be created everytime a `session` is instantiated. * Added the following new datasets: | Type | Description | Location | @@ -66,6 +67,8 @@ * Changed the behaviour of `kedro build-reqs` to compile requirements from `requirements.txt` instead of `requirements.in` and save them to `requirements.lock` instead of `requirements.txt`. * Removed `ProjectHooks.register_catalog` `hook_spec` in favour of loading `DATA_CATALOG_CLASS` directly from `settings.py`. The default option for `DATA_CATALOG_CLASS` is now set to `kedro.io.DataCatalog`. * Removed `RegistrationSpecs` and all registration hooks that belonged to it. Going forward users can register custom library components through `settings.py`. +* Added the `PluginManager` `hook_manager` argument to `KedroContext` and the `Runner.run()` method, which will be provided by the `KedroSession`. +* Removed the public method `get_hook_manager()` and replaced its functionality by `_create_hook_manager()`. ## Thanks for supporting contributions diff --git a/docs/conf.py b/docs/conf.py index e4bf7138e5..e5bab16106 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -133,6 +133,7 @@ "integer -- return number of occurrences of value", "integer -- return first index of value.", "kedro.extras.datasets.pandas.json_dataset.JSONDataSet", + "pluggy._manager.PluginManager", ), "py:data": ( "typing.Any", diff --git a/features/load_context.feature b/features/load_context.feature index 826ed52a6d..60e334df96 100644 --- a/features/load_context.feature +++ b/features/load_context.feature @@ -15,13 +15,20 @@ Feature: Custom Kedro project And I execute the kedro command "run" Then I should get a successful exit code - Scenario: Hooks from installed plugins are automatically registered + Scenario: Hooks from installed plugins are automatically registered and work with the default runner Given I have installed the test plugin When I execute the kedro command "run" Then I should get a successful exit code And I should get a message including "Registered hooks from 1 installed plugin(s): test-plugin-0.1" And I should get a message including "Reached after_catalog_created hook" + Scenario: Hooks from installed plugins are automatically registered and work with the parallel runner + Given I have installed the test plugin + When I execute the kedro command "run --runner=ParallelRunner" + Then I should get a successful exit code + And I should get a message including "Registered hooks from 1 installed plugin(s): test-plugin-0.1" + And I should get a message including "Reached after_catalog_created hook" + Scenario: Disable automatically registered plugin hooks Given I have installed the test plugin And I have disabled hooks for "test-plugin" plugin via config diff --git a/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/tests/test_run.py b/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/tests/test_run.py index cec7571094..059aeac8db 100644 --- a/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/tests/test_run.py +++ b/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/tests/test_run.py @@ -13,6 +13,7 @@ import pytest from kedro.config import ConfigLoader from kedro.framework.context import KedroContext +from kedro.framework.hooks import _create_hook_manager @pytest.fixture @@ -26,6 +27,7 @@ def project_context(config_loader): package_name="{{ cookiecutter.python_package }}", project_path=Path.cwd(), config_loader=config_loader, + hook_manager=_create_hook_manager(), ) diff --git a/kedro/extras/extensions/ipython.py b/kedro/extras/extensions/ipython.py index 7141aab8a4..62b543ac33 100644 --- a/kedro/extras/extensions/ipython.py +++ b/kedro/extras/extensions/ipython.py @@ -23,15 +23,6 @@ def _remove_cached_modules(package_name): del sys.modules[module] # pragma: no cover -def _clear_hook_manager(): - from kedro.framework.hooks import get_hook_manager - - hook_manager = get_hook_manager() - name_plugin_pairs = hook_manager.list_name_plugin() - for name, plugin in name_plugin_pairs: - hook_manager.unregister(name=name, plugin=plugin) # pragma: no cover - - def _find_kedro_project(current_dir): # pragma: no cover from kedro.framework.startup import _is_project @@ -53,8 +44,6 @@ def reload_kedro(path, env: str = None, extra_params: Dict[str, Any] = None): from kedro.framework.session.session import _activate_session from kedro.framework.startup import bootstrap_project - _clear_hook_manager() - path = path or project_path metadata = bootstrap_project(path) diff --git a/kedro/framework/context/context.py b/kedro/framework/context/context.py index 5ee9f214d8..f8f92e053d 100644 --- a/kedro/framework/context/context.py +++ b/kedro/framework/context/context.py @@ -6,8 +6,9 @@ from urllib.parse import urlparse from warnings import warn +from pluggy import PluginManager + from kedro.config import ConfigLoader, MissingConfigException -from kedro.framework.hooks import get_hook_manager from kedro.framework.project import settings from kedro.io import DataCatalog from kedro.pipeline.pipeline import _transcode_split @@ -168,6 +169,7 @@ def __init__( package_name: str, project_path: Union[Path, str], config_loader: ConfigLoader, + hook_manager: PluginManager, env: str = None, extra_params: Dict[str, Any] = None, ): # pylint: disable=too-many-arguments @@ -183,6 +185,7 @@ def __init__( package_name: Package name for the Kedro project the context is created for. project_path: Project path to define the context for. + hook_manager: The ``PluginManager`` to activate hooks, supplied by the session. env: Optional argument for configuration default environment to be used for running the pipeline. If not specified, it defaults to "local". extra_params: Optional dictionary containing extra project parameters. @@ -194,6 +197,7 @@ def __init__( self._config_loader = config_loader self._env = env self._extra_params = deepcopy(extra_params) + self._hook_manager = hook_manager @property # type: ignore def env(self) -> Optional[str]: @@ -279,8 +283,7 @@ def _get_catalog( catalog.add_feed_dict(feed_dict) if catalog.layers: _validate_layers_for_transcoding(catalog) - hook_manager = get_hook_manager() - hook_manager.hook.after_catalog_created( # pylint: disable=no-member + self._hook_manager.hook.after_catalog_created( catalog=catalog, conf_catalog=conf_catalog, conf_creds=conf_creds, diff --git a/kedro/framework/hooks/__init__.py b/kedro/framework/hooks/__init__.py index 35885aecfa..8ce7a9b695 100644 --- a/kedro/framework/hooks/__init__.py +++ b/kedro/framework/hooks/__init__.py @@ -1,5 +1,5 @@ """``kedro.framework.hooks`` provides primitives to use hooks to extend KedroContext's behaviour""" -from .manager import get_hook_manager +from .manager import _create_hook_manager from .markers import hook_impl -__all__ = ["get_hook_manager", "hook_impl"] +__all__ = ["_create_hook_manager", "hook_impl"] diff --git a/kedro/framework/hooks/manager.py b/kedro/framework/hooks/manager.py index 72084fba53..68544dec60 100644 --- a/kedro/framework/hooks/manager.py +++ b/kedro/framework/hooks/manager.py @@ -1,7 +1,6 @@ """This module provides an utility function to retrieve the global hook_manager singleton in a Kedro's execution process. """ -# pylint: disable=global-statement,invalid-name import logging from typing import Any, Iterable @@ -10,8 +9,6 @@ from .markers import HOOK_NAMESPACE from .specs import DataCatalogSpecs, DatasetSpecs, NodeSpecs, PipelineSpecs -_hook_manager = None - _PLUGIN_HOOKS = "kedro.hooks" # entry-point to load hooks from for installed plugins logger = logging.getLogger(__name__) @@ -27,14 +24,6 @@ def _create_hook_manager() -> PluginManager: return manager -def get_hook_manager(): - """Create or return the global _hook_manager singleton instance.""" - global _hook_manager - if _hook_manager is None: - _hook_manager = _create_hook_manager() - return _hook_manager - - def _register_hooks(hook_manager: PluginManager, hooks: Iterable[Any]) -> None: """Register all hooks as specified in ``hooks`` with the global ``hook_manager``. diff --git a/kedro/framework/project/__init__.py b/kedro/framework/project/__init__.py index c084fd73c3..353e22b90f 100644 --- a/kedro/framework/project/__init__.py +++ b/kedro/framework/project/__init__.py @@ -10,8 +10,6 @@ from dynaconf import LazySettings from dynaconf.validator import ValidationError, Validator -from kedro.framework.hooks import get_hook_manager -from kedro.framework.hooks.manager import _register_hooks, _register_hooks_setuptools from kedro.pipeline import Pipeline @@ -181,11 +179,6 @@ def configure_project(package_name: str): settings_module = f"{package_name}.settings" settings.configure(settings_module) - # set up all hooks so we can discover all pipelines - hook_manager = get_hook_manager() - _register_hooks(hook_manager, settings.HOOKS) - _register_hooks_setuptools(hook_manager, settings.DISABLE_HOOKS_FOR_PLUGINS) - pipelines_module = f"{package_name}.pipeline_registry" pipelines.configure(pipelines_module) diff --git a/kedro/framework/session/session.py b/kedro/framework/session/session.py index 5c66e41885..51c738e117 100644 --- a/kedro/framework/session/session.py +++ b/kedro/framework/session/session.py @@ -15,7 +15,8 @@ from kedro.config import ConfigLoader from kedro.framework.context import KedroContext from kedro.framework.context.context import _convert_paths_to_absolute_posix -from kedro.framework.hooks import get_hook_manager +from kedro.framework.hooks import _create_hook_manager +from kedro.framework.hooks.manager import _register_hooks, _register_hooks_setuptools from kedro.framework.project import ( configure_logging, configure_project, @@ -107,6 +108,11 @@ def __init__( self._package_name = package_name self._store = self._init_store() + hook_manager = _create_hook_manager() + _register_hooks(hook_manager, settings.HOOKS) + _register_hooks_setuptools(hook_manager, settings.DISABLE_HOOKS_FOR_PLUGINS) + self._hook_manager = hook_manager + @classmethod def create( # pylint: disable=too-many-arguments cls, @@ -244,6 +250,7 @@ def load_context(self) -> KedroContext: config_loader=config_loader, env=env, extra_params=extra_params, + hook_manager=self._hook_manager, ) return context @@ -374,13 +381,13 @@ def run( # pylint: disable=too-many-arguments,too-many-locals # Run the runner runner = runner or SequentialRunner() - hook_manager = get_hook_manager() + hook_manager = self._hook_manager hook_manager.hook.before_pipeline_run( # pylint: disable=no-member run_params=record_data, pipeline=filtered_pipeline, catalog=catalog ) try: - run_result = runner.run(filtered_pipeline, catalog, run_id) + run_result = runner.run(filtered_pipeline, catalog, hook_manager, run_id) except Exception as error: hook_manager.hook.on_pipeline_error( error=error, diff --git a/kedro/runner/parallel_runner.py b/kedro/runner/parallel_runner.py index a68a3a22fd..8eaeb2bde6 100644 --- a/kedro/runner/parallel_runner.py +++ b/kedro/runner/parallel_runner.py @@ -14,6 +14,14 @@ from pickle import PicklingError from typing import Any, Dict, Iterable, Set +from pluggy import PluginManager + +from kedro.framework.hooks.manager import ( + _create_hook_manager, + _register_hooks, + _register_hooks_setuptools, +) +from kedro.framework.project import settings from kedro.io import DataCatalog, DataSetError, MemoryDataSet from kedro.pipeline import Pipeline from kedro.pipeline.node import Node @@ -89,11 +97,8 @@ def _run_node_synchronization( # pylint: disable=too-many-arguments conf_logging: Dict[str, Any] = None, ) -> Node: """Run a single `Node` with inputs from and outputs to the `catalog`. - `KedroSession` instance is activated in every subprocess because of Windows - (and latest OSX with Python 3.8) limitation. - Windows has no "fork", so every subprocess is a brand new process - created via "spawn", hence the need to a) setup the logging, b) register - the hooks, and c) activate `KedroSession` in every subprocess. + A `PluginManager` `hook_manager` instance is created in every subprocess because + the `PluginManager` can't be serialised. Args: node: The ``Node`` to run. @@ -112,7 +117,11 @@ def _run_node_synchronization( # pylint: disable=too-many-arguments conf_logging = conf_logging or {} _bootstrap_subprocess(package_name, conf_logging) - return run_node(node, catalog, is_async, run_id) + hook_manager = _create_hook_manager() + _register_hooks(hook_manager, settings.HOOKS) + _register_hooks_setuptools(hook_manager, settings.DISABLE_HOOKS_FOR_PLUGINS) + + return run_node(node, catalog, hook_manager, is_async, run_id) class ParallelRunner(AbstractRunner): @@ -252,7 +261,11 @@ def _get_required_workers_count(self, pipeline: Pipeline): return min(required_processes, self._max_workers) def _run( # pylint: disable=too-many-locals,useless-suppression - self, pipeline: Pipeline, catalog: DataCatalog, run_id: str = None + self, + pipeline: Pipeline, + catalog: DataCatalog, + hook_manager: PluginManager, + run_id: str = None, ) -> None: """The abstract interface for running pipelines. diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index a3fdd0d731..a3c0dcc32a 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -13,7 +13,8 @@ ) from typing import Any, Dict, Iterable -from kedro.framework.hooks import get_hook_manager +from pluggy import PluginManager + from kedro.io import AbstractDataSet, DataCatalog from kedro.pipeline import Pipeline from kedro.pipeline.node import Node @@ -39,7 +40,11 @@ def _logger(self): return logging.getLogger(self.__module__) def run( - self, pipeline: Pipeline, catalog: DataCatalog, run_id: str = None + self, + pipeline: Pipeline, + catalog: DataCatalog, + hook_manager: PluginManager, + run_id: str = None, ) -> Dict[str, Any]: """Run the ``Pipeline`` using the ``DataSet``s provided by ``catalog`` and save results back to the same objects. @@ -47,6 +52,7 @@ def run( Args: pipeline: The ``Pipeline`` to run. catalog: The ``DataCatalog`` from which to fetch data. + hook_manager: The ``PluginManager`` to activate hooks. run_id: The id of the run. Raises: @@ -76,14 +82,14 @@ def run( self._logger.info( "Asynchronous mode is enabled for loading and saving data" ) - self._run(pipeline, catalog, run_id) + self._run(pipeline, catalog, hook_manager, run_id) self._logger.info("Pipeline execution completed successfully.") return {ds_name: catalog.load(ds_name) for ds_name in free_outputs} def run_only_missing( - self, pipeline: Pipeline, catalog: DataCatalog + self, pipeline: Pipeline, catalog: DataCatalog, hook_manager: PluginManager ) -> Dict[str, Any]: """Run only the missing outputs from the ``Pipeline`` using the ``DataSet``s provided by ``catalog`` and save results back to the same @@ -92,6 +98,7 @@ def run_only_missing( Args: pipeline: The ``Pipeline`` to run. catalog: The ``DataCatalog`` from which to fetch data. + hook_manager: The ``PluginManager`` to activate hooks. Raises: ValueError: Raised when ``Pipeline`` inputs cannot be satisfied. @@ -115,11 +122,15 @@ def run_only_missing( input_from_memory = to_rerun.inputs() & memory_sets to_rerun += output_to_memory.to_outputs(*input_from_memory) - return self.run(to_rerun, catalog) + return self.run(to_rerun, catalog, hook_manager) @abstractmethod # pragma: no cover def _run( - self, pipeline: Pipeline, catalog: DataCatalog, run_id: str = None + self, + pipeline: Pipeline, + catalog: DataCatalog, + hook_manager: PluginManager, + run_id: str = None, ) -> None: """The abstract interface for running pipelines, assuming that the inputs have already been checked and normalized by run(). @@ -127,6 +138,7 @@ def _run( Args: pipeline: The ``Pipeline`` to run. catalog: The ``DataCatalog`` from which to fetch data. + hook_manager: The ``PluginManager`` to activate hooks. run_id: The id of the run. """ @@ -170,13 +182,18 @@ def _suggest_resume_scenario( def run_node( - node: Node, catalog: DataCatalog, is_async: bool = False, run_id: str = None + node: Node, + catalog: DataCatalog, + hook_manager: PluginManager, + is_async: bool = False, + run_id: str = None, ) -> Node: """Run a single `Node` with inputs from and outputs to the `catalog`. Args: node: The ``Node`` to run. catalog: A ``DataCatalog`` containing the node's inputs and outputs. + hook_manager: The ``PluginManager`` to activate hooks. is_async: If True, the node inputs and outputs are loaded and saved asynchronously with threads. Defaults to False. run_id: The id of the pipeline run @@ -186,9 +203,9 @@ def run_node( """ if is_async: - node = _run_node_async(node, catalog, run_id) + node = _run_node_async(node, catalog, hook_manager, run_id) else: - node = _run_node_sequential(node, catalog, run_id) + node = _run_node_sequential(node, catalog, hook_manager, run_id) for name in node.confirms: catalog.confirm(name) @@ -200,11 +217,12 @@ def _collect_inputs_from_hook( catalog: DataCatalog, inputs: Dict[str, Any], is_async: bool, + hook_manager: PluginManager, run_id: str = None, ) -> Dict[str, Any]: + # pylint: disable=too-many-arguments inputs = inputs.copy() # shallow copy to prevent in-place modification by the hook - hook_manager = get_hook_manager() - hook_response = hook_manager.hook.before_node_run( # pylint: disable=no-member + hook_response = hook_manager.hook.before_node_run( node=node, catalog=catalog, inputs=inputs, @@ -231,13 +249,14 @@ def _call_node_run( catalog: DataCatalog, inputs: Dict[str, Any], is_async: bool, + hook_manager: PluginManager, run_id: str = None, ) -> Dict[str, Any]: - hook_manager = get_hook_manager() + # pylint: disable=too-many-arguments try: outputs = node.run(inputs) except Exception as exc: - hook_manager.hook.on_node_error( # pylint: disable=no-member + hook_manager.hook.on_node_error( error=exc, node=node, catalog=catalog, @@ -246,7 +265,7 @@ def _call_node_run( run_id=run_id, ) raise exc - hook_manager.hook.after_node_run( # pylint: disable=no-member + hook_manager.hook.after_node_run( node=node, catalog=catalog, inputs=inputs, @@ -257,55 +276,49 @@ def _call_node_run( return outputs -def _run_node_sequential(node: Node, catalog: DataCatalog, run_id: str = None) -> Node: +def _run_node_sequential( + node: Node, catalog: DataCatalog, hook_manager: PluginManager, run_id: str = None +) -> Node: inputs = {} - hook_manager = get_hook_manager() for name in node.inputs: - hook_manager.hook.before_dataset_loaded( # pylint: disable=no-member - dataset_name=name - ) + hook_manager.hook.before_dataset_loaded(dataset_name=name) inputs[name] = catalog.load(name) - hook_manager.hook.after_dataset_loaded( # pylint: disable=no-member - dataset_name=name, data=inputs[name] - ) + hook_manager.hook.after_dataset_loaded(dataset_name=name, data=inputs[name]) is_async = False additional_inputs = _collect_inputs_from_hook( - node, catalog, inputs, is_async, run_id=run_id + node, catalog, inputs, is_async, hook_manager, run_id=run_id ) inputs.update(additional_inputs) - outputs = _call_node_run(node, catalog, inputs, is_async, run_id=run_id) + outputs = _call_node_run( + node, catalog, inputs, is_async, hook_manager, run_id=run_id + ) for name, data in outputs.items(): - hook_manager.hook.before_dataset_saved( # pylint: disable=no-member - dataset_name=name, data=data - ) + hook_manager.hook.before_dataset_saved(dataset_name=name, data=data) catalog.save(name, data) - hook_manager.hook.after_dataset_saved( # pylint: disable=no-member - dataset_name=name, data=data - ) + hook_manager.hook.after_dataset_saved(dataset_name=name, data=data) return node -def _run_node_async(node: Node, catalog: DataCatalog, run_id: str = None) -> Node: +def _run_node_async( + node: Node, catalog: DataCatalog, hook_manager: PluginManager, run_id: str = None +) -> Node: def _synchronous_dataset_load(dataset_name: str): """Minimal wrapper to ensure Hooks are run synchronously within an asynchronous dataset load.""" - hook_manager.hook.before_dataset_loaded( # pylint: disable=no-member - dataset_name=dataset_name - ) + hook_manager.hook.before_dataset_loaded(dataset_name=dataset_name) return_ds = catalog.load(dataset_name) - hook_manager.hook.after_dataset_loaded( # pylint: disable=no-member + hook_manager.hook.after_dataset_loaded( dataset_name=dataset_name, data=return_ds ) return return_ds with ThreadPoolExecutor() as pool: inputs: Dict[str, Future] = {} - hook_manager = get_hook_manager() for name in node.inputs: inputs[name] = pool.submit(_synchronous_dataset_load, name) @@ -314,25 +327,25 @@ def _synchronous_dataset_load(dataset_name: str): inputs = {key: value.result() for key, value in inputs.items()} is_async = True additional_inputs = _collect_inputs_from_hook( - node, catalog, inputs, is_async, run_id=run_id + node, catalog, inputs, is_async, hook_manager, run_id=run_id ) inputs.update(additional_inputs) - outputs = _call_node_run(node, catalog, inputs, is_async, run_id=run_id) + outputs = _call_node_run( + node, catalog, inputs, is_async, hook_manager, run_id=run_id + ) save_futures = set() for name, data in outputs.items(): - hook_manager.hook.before_dataset_saved( # pylint: disable=no-member - dataset_name=name, data=data - ) + hook_manager.hook.before_dataset_saved(dataset_name=name, data=data) save_futures.add(pool.submit(catalog.save, name, data)) for future in as_completed(save_futures): exception = future.exception() if exception: raise exception - hook_manager.hook.after_dataset_saved( # pylint: disable=no-member + hook_manager.hook.after_dataset_saved( dataset_name=name, data=data # pylint: disable=undefined-loop-variable ) return node diff --git a/kedro/runner/sequential_runner.py b/kedro/runner/sequential_runner.py index 20a578f372..be95debcad 100644 --- a/kedro/runner/sequential_runner.py +++ b/kedro/runner/sequential_runner.py @@ -6,6 +6,8 @@ from collections import Counter from itertools import chain +from pluggy import PluginManager + from kedro.io import AbstractDataSet, DataCatalog, MemoryDataSet from kedro.pipeline import Pipeline from kedro.runner.runner import AbstractRunner, run_node @@ -41,7 +43,11 @@ def create_default_data_set(self, ds_name: str) -> AbstractDataSet: return MemoryDataSet() def _run( - self, pipeline: Pipeline, catalog: DataCatalog, run_id: str = None + self, + pipeline: Pipeline, + catalog: DataCatalog, + hook_manager: PluginManager, + run_id: str = None, ) -> None: """The method implementing sequential pipeline running. @@ -60,7 +66,7 @@ def _run( for exec_index, node in enumerate(nodes): try: - run_node(node, catalog, self._is_async, run_id) + run_node(node, catalog, hook_manager, self._is_async, run_id) done_nodes.add(node) except Exception: self._suggest_resume_scenario(pipeline, done_nodes) diff --git a/kedro/runner/thread_runner.py b/kedro/runner/thread_runner.py index 7f9f902928..8610d32dff 100644 --- a/kedro/runner/thread_runner.py +++ b/kedro/runner/thread_runner.py @@ -8,6 +8,8 @@ from itertools import chain from typing import Set +from pluggy import PluginManager + from kedro.io import AbstractDataSet, DataCatalog, MemoryDataSet from kedro.pipeline import Pipeline from kedro.pipeline.node import Node @@ -79,7 +81,11 @@ def _get_required_workers_count(self, pipeline: Pipeline): ) def _run( # pylint: disable=too-many-locals,useless-suppression - self, pipeline: Pipeline, catalog: DataCatalog, run_id: str = None + self, + pipeline: Pipeline, + catalog: DataCatalog, + hook_manager: PluginManager, + run_id: str = None, ) -> None: """The abstract interface for running pipelines. @@ -107,7 +113,14 @@ def _run( # pylint: disable=too-many-locals,useless-suppression todo_nodes -= ready for node in ready: futures.add( - pool.submit(run_node, node, catalog, self._is_async, run_id) + pool.submit( + run_node, + node, + catalog, + hook_manager, + self._is_async, + run_id, + ) ) if not futures: assert not todo_nodes, (todo_nodes, done_nodes, ready, done) diff --git a/requirements.txt b/requirements.txt index ce043a05af..fe2a7ca9fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ gitpython~=3.0 jmespath>=0.9.5, <1.0 jupyter_client>=5.1, <7.0 pip-tools~=6.4 -pluggy~=0.13.0 +pluggy~=1.0.0 python-json-logger~=2.0 PyYAML>=4.2, <6.0 rope~=0.21.0 # subject to LGPLv3 license diff --git a/tests/extras/datasets/spark/test_deltatable_dataset.py b/tests/extras/datasets/spark/test_deltatable_dataset.py index 73fcc9c537..26566d858c 100644 --- a/tests/extras/datasets/spark/test_deltatable_dataset.py +++ b/tests/extras/datasets/spark/test_deltatable_dataset.py @@ -5,6 +5,7 @@ from pyspark.sql.utils import AnalysisException from kedro.extras.datasets.spark import DeltaTableDataSet, SparkDataSet +from kedro.framework.hooks import _create_hook_manager from kedro.io import DataCatalog, DataSetError from kedro.pipeline import Pipeline, node from kedro.runner import ParallelRunner @@ -86,4 +87,6 @@ def no_output(x): r"multiprocessing: \['delta_in'\]" ) with pytest.raises(AttributeError, match=pattern): - ParallelRunner(is_async=is_async).run(pipeline, catalog) + ParallelRunner(is_async=is_async).run( + pipeline, catalog, _create_hook_manager() + ) diff --git a/tests/extras/datasets/spark/test_spark_dataset.py b/tests/extras/datasets/spark/test_spark_dataset.py index 286408ceb3..faf2b7c078 100644 --- a/tests/extras/datasets/spark/test_spark_dataset.py +++ b/tests/extras/datasets/spark/test_spark_dataset.py @@ -18,6 +18,7 @@ _dbfs_glob, _get_dbutils, ) +from kedro.framework.hooks import _create_hook_manager from kedro.io import DataCatalog, DataSetError, Version from kedro.io.core import generate_timestamp from kedro.pipeline import Pipeline, node @@ -275,7 +276,9 @@ def test_parallel_runner(self, is_async, spark_in): r"multiprocessing: \['spark_in'\]" ) with pytest.raises(AttributeError, match=pattern): - ParallelRunner(is_async=is_async).run(pipeline, catalog) + ParallelRunner(is_async=is_async).run( + pipeline, catalog, _create_hook_manager() + ) def test_s3_glob_refresh(self): spark_dataset = SparkDataSet(filepath="s3a://bucket/data") @@ -808,7 +811,9 @@ class TestDataFlowSequentialRunner: def test_spark_load_save(self, is_async, data_catalog): """SparkDataSet(load) -> node -> Spark (save).""" pipeline = Pipeline([node(identity, "spark_in", "spark_out")]) - SequentialRunner(is_async=is_async).run(pipeline, data_catalog) + SequentialRunner(is_async=is_async).run( + pipeline, data_catalog, _create_hook_manager() + ) save_path = Path(data_catalog._data_sets["spark_out"]._filepath.as_posix()) files = list(save_path.glob("*.parquet")) @@ -819,7 +824,9 @@ def test_spark_pickle(self, is_async, data_catalog): pipeline = Pipeline([node(identity, "spark_in", "pickle_ds")]) pattern = ".* was not serialized due to.*" with pytest.raises(DataSetError, match=pattern): - SequentialRunner(is_async=is_async).run(pipeline, data_catalog) + SequentialRunner(is_async=is_async).run( + pipeline, data_catalog, _create_hook_manager() + ) def test_spark_memory_spark(self, is_async, data_catalog): """SparkDataSet(load) -> node -> MemoryDataSet (save and then load) -> @@ -830,7 +837,9 @@ def test_spark_memory_spark(self, is_async, data_catalog): node(identity, "memory_ds", "spark_out"), ] ) - SequentialRunner(is_async=is_async).run(pipeline, data_catalog) + SequentialRunner(is_async=is_async).run( + pipeline, data_catalog, _create_hook_manager() + ) save_path = Path(data_catalog._data_sets["spark_out"]._filepath.as_posix()) files = list(save_path.glob("*.parquet")) diff --git a/tests/framework/cli/pipeline/test_pipeline_package.py b/tests/framework/cli/pipeline/test_pipeline_package.py index 50649fe3ca..1174d32877 100644 --- a/tests/framework/cli/pipeline/test_pipeline_package.py +++ b/tests/framework/cli/pipeline/test_pipeline_package.py @@ -414,7 +414,6 @@ def test_package_pipeline_with_deep_nested_parameters( assert sdist_file.is_file() assert len(list(sdist_location.iterdir())) == 1 - # pylint: disable=consider-using-with with tarfile.open(sdist_file, "r") as tar: sdist_contents = set(tar.getnames()) assert ( diff --git a/tests/framework/cli/test_cli_hooks.py b/tests/framework/cli/test_cli_hooks.py index d9dd0a20ea..afa416d443 100644 --- a/tests/framework/cli/test_cli_hooks.py +++ b/tests/framework/cli/test_cli_hooks.py @@ -61,7 +61,7 @@ def fake_plugin_distribution(mocker): version="0.1", ) mocker.patch( - "pluggy.manager.importlib_metadata.distributions", + "pluggy._manager.importlib_metadata.distributions", return_value=[fake_distribution], ) return fake_distribution diff --git a/tests/framework/context/test_context.py b/tests/framework/context/test_context.py index 2f025ea6f7..827df454d4 100644 --- a/tests/framework/context/test_context.py +++ b/tests/framework/context/test_context.py @@ -20,6 +20,7 @@ _update_nested_dict, _validate_layers_for_transcoding, ) +from kedro.framework.hooks import _create_hook_manager from kedro.framework.project import ( ValidationError, _ProjectSettings, @@ -223,6 +224,7 @@ def dummy_context( MOCK_PACKAGE_NAME, str(tmp_path), config_loader=config_loader, + hook_manager=_create_hook_manager(), env=env, extra_params=extra_params, ) diff --git a/tests/framework/session/conftest.py b/tests/framework/session/conftest.py index 7af110f7c1..8c89f3a46d 100644 --- a/tests/framework/session/conftest.py +++ b/tests/framework/session/conftest.py @@ -12,7 +12,6 @@ from kedro import __version__ as kedro_version from kedro.framework.hooks import hook_impl -from kedro.framework.hooks.manager import get_hook_manager from kedro.framework.project import _ProjectPipelines, _ProjectSettings from kedro.framework.session import KedroSession from kedro.io import DataCatalog @@ -92,15 +91,6 @@ def local_config(tmp_path): } -@pytest.fixture(autouse=True) -def clear_hook_manager(): - yield - hook_manager = get_hook_manager() - plugins = hook_manager.get_plugins() - for plugin in plugins: - hook_manager.unregister(plugin) - - @pytest.fixture(autouse=True) def config_dir(tmp_path, local_config, local_logging_config): catalog = tmp_path / "conf" / "base" / "catalog.yml" @@ -379,6 +369,7 @@ def _mock_imported_settings_paths(mocker, mock_settings): for path in [ "kedro.framework.session.session.settings", "kedro.framework.project.settings", + "kedro.runner.parallel_runner.settings", ]: mocker.patch(path, mock_settings) return mock_settings @@ -396,9 +387,11 @@ class MockSettings(_ProjectSettings): def mock_session( mock_settings, mock_package_name, tmp_path ): # pylint: disable=unused-argument - return KedroSession.create( + session = KedroSession.create( mock_package_name, tmp_path, extra_params={"params:key": "value"} ) + yield session + session.close() @pytest.fixture(autouse=True) diff --git a/tests/framework/session/test_session.py b/tests/framework/session/test_session.py index a57b7ec6b7..62997f63bb 100644 --- a/tests/framework/session/test_session.py +++ b/tests/framework/session/test_session.py @@ -530,7 +530,7 @@ def test_run( """Test running the project via the session""" mock_hook = mocker.patch( - "kedro.framework.session.session.get_hook_manager" + "kedro.framework.session.session._create_hook_manager" ).return_value.hook mock_pipelines = mocker.patch( "kedro.framework.session.session.pipelines", @@ -567,7 +567,7 @@ def test_run( run_params=record_data, pipeline=mock_pipeline, catalog=mock_catalog ) mock_runner.run.assert_called_once_with( - mock_pipeline, mock_catalog, fake_session_id + mock_pipeline, mock_catalog, session._hook_manager, fake_session_id ) mock_hook.after_pipeline_run.assert_called_once_with( run_params=record_data, @@ -602,7 +602,7 @@ def test_run_exception( # pylint: disable=too-many-locals ): """Test exception being raise during the run""" mock_hook = mocker.patch( - "kedro.framework.session.session.get_hook_manager" + "kedro.framework.session.session._create_hook_manager" ).return_value.hook mock_pipelines = mocker.patch( "kedro.framework.session.session.pipelines", diff --git a/tests/framework/session/test_session_extension_hooks.py b/tests/framework/session/test_session_extension_hooks.py index 5f1b8c1473..0fff449e1d 100644 --- a/tests/framework/session/test_session_extension_hooks.py +++ b/tests/framework/session/test_session_extension_hooks.py @@ -13,7 +13,7 @@ from kedro.framework.project import _ProjectPipelines, _ProjectSettings, pipelines from kedro.framework.session import KedroSession from kedro.io import DataCatalog, MemoryDataSet -from kedro.pipeline import Pipeline, node +from kedro.pipeline import node, pipeline from kedro.pipeline.node import Node from kedro.runner import ParallelRunner from kedro.runner.runner import _run_node_async @@ -37,7 +37,7 @@ def broken_node(): @pytest.fixture def broken_pipeline(): - return Pipeline( + return pipeline( [ node(broken_node, None, "A", name="node1"), node(broken_node, None, "B", name="node2"), @@ -59,22 +59,6 @@ def mock_get_pipelines_registry_callable(): return mock_get_pipelines_registry_callable() -@pytest.fixture -def mock_pipelines(mocker, mock_pipeline): - def mock_get_pipelines_registry_callable(): - return { - "__default__": mock_pipeline, - "pipe": mock_pipeline, - } - - mocker.patch.object( - _ProjectPipelines, - "_get_pipelines_registry_callable", - return_value=mock_get_pipelines_registry_callable, - ) - return mock_get_pipelines_registry_callable() - - class TestCatalogHooks: def test_after_catalog_created_hook(self, mocker, mock_session, caplog): context = mock_session.load_context() @@ -569,7 +553,11 @@ def test_after_dataset_load_hook_async( mock_session.load_context() # run the node asynchronously with an instance of `LogCatalog` - _run_node_async(node=sample_node, catalog=memory_catalog) + _run_node_async( + node=sample_node, + catalog=memory_catalog, + hook_manager=mock_session._hook_manager, + ) hooks_log_messages = [r.message for r in logs_listener.logs] diff --git a/tests/framework/session/test_session_hook_manager.py b/tests/framework/session/test_session_hook_manager.py index 9cc31ecf28..0cccc89b84 100644 --- a/tests/framework/session/test_session_hook_manager.py +++ b/tests/framework/session/test_session_hook_manager.py @@ -3,7 +3,7 @@ import pytest from dynaconf.validator import Validator -from kedro.framework.hooks.manager import _register_hooks, get_hook_manager +from kedro.framework.hooks.manager import _register_hooks from kedro.framework.project import _ProjectSettings from kedro.framework.session import KedroSession from tests.framework.session.conftest import _mock_imported_settings_paths @@ -35,19 +35,14 @@ class MockSettings(_ProjectSettings): class TestSessionHookManager: """Test the process of registering hooks with the hook manager in a session.""" - def test_assert_register_hooks(self, request, project_hooks): - hook_manager = get_hook_manager() - assert not hook_manager.is_registered(project_hooks) - - # call the fixture to construct the session - request.getfixturevalue("mock_session") - + def test_assert_register_hooks(self, project_hooks, mock_session): + hook_manager = mock_session._hook_manager assert hook_manager.is_registered(project_hooks) @pytest.mark.usefixtures("mock_session") - def test_calling_register_hooks_twice(self, project_hooks): + def test_calling_register_hooks_twice(self, project_hooks, mock_session): """Calling hook registration multiple times should not raise""" - hook_manager = get_hook_manager() + hook_manager = mock_session._hook_manager assert hook_manager.is_registered(project_hooks) _register_hooks(hook_manager, (project_hooks,)) @@ -58,19 +53,19 @@ def test_calling_register_hooks_twice(self, project_hooks): def test_hooks_registered_when_session_created( self, mocker, request, caplog, project_hooks, num_plugins ): - hook_manager = get_hook_manager() - assert not hook_manager.get_plugins() - load_setuptools_entrypoints = mocker.patch.object( - hook_manager, "load_setuptools_entrypoints", return_value=num_plugins + load_setuptools_entrypoints = mocker.patch( + "pluggy._manager.PluginManager.load_setuptools_entrypoints", + return_value=num_plugins, ) distinfo = [("plugin_obj_1", MockDistInfo("test-project-a", "0.1"))] - list_distinfo_mock = mocker.patch.object( - hook_manager, "list_plugin_distinfo", return_value=distinfo + list_distinfo_mock = mocker.patch( + "pluggy._manager.PluginManager.list_plugin_distinfo", return_value=distinfo ) # call a fixture which creates a session - request.getfixturevalue("mock_session") + session = request.getfixturevalue("mock_session") + hook_manager = session._hook_manager assert hook_manager.is_registered(project_hooks) load_setuptools_entrypoints.assert_called_once_with("kedro.hooks") @@ -94,22 +89,24 @@ def test_disabling_auto_discovered_hooks( naughty_plugin, good_plugin, ): - hook_manager = get_hook_manager() - assert not hook_manager.get_plugins() distinfo = [("plugin_obj_1", naughty_plugin), ("plugin_obj_2", good_plugin)] - list_distinfo_mock = mocker.patch.object( - hook_manager, "list_plugin_distinfo", return_value=distinfo + mocked_distinfo = mocker.patch( + "pluggy._manager.PluginManager.list_plugin_distinfo", return_value=distinfo ) - mocker.patch.object( - hook_manager, "load_setuptools_entrypoints", return_value=len(distinfo) + + mocker.patch( + "pluggy._manager.PluginManager.load_setuptools_entrypoints", + return_value=len(distinfo), ) - unregister_mock = mocker.patch.object(hook_manager, "unregister") + unregister_mock = mocker.patch("pluggy._manager.PluginManager.unregister") + # create a session that will use the mock_settings_with_disabled_hooks from the fixture. KedroSession.create( mock_package_name, tmp_path, extra_params={"params:key": "value"} ) - list_distinfo_mock.assert_called_once_with() + + mocked_distinfo.assert_called_once_with() unregister_mock.assert_called_once_with(plugin=distinfo[0][0]) # check the logs diff --git a/tests/pipeline/test_pipeline_from_missing.py b/tests/pipeline/test_pipeline_from_missing.py index 12cc9d928d..a62b16aeaa 100644 --- a/tests/pipeline/test_pipeline_from_missing.py +++ b/tests/pipeline/test_pipeline_from_missing.py @@ -2,6 +2,7 @@ import pytest +from kedro.framework.hooks import _create_hook_manager from kedro.io import DataCatalog, LambdaDataSet from kedro.pipeline import Pipeline, node from kedro.runner import SequentialRunner @@ -20,6 +21,11 @@ def biconcat(input1: str, input2: str): return input1 + input2 # pragma: no cover +@pytest.fixture +def hook_manager(): + return _create_hook_manager() + + @pytest.fixture def branched_pipeline(): # #### Pipeline execution order #### @@ -87,58 +93,58 @@ def _pipeline_contains(pipe, nodes): return set(nodes) == {n.name for n in pipe.nodes} -def _from_missing(pipeline, catalog): +def _from_missing(pipeline, catalog, hook_manager): """Create a new pipeline based on missing outputs.""" name = "kedro.runner.runner.AbstractRunner.run" with mock.patch(name) as run: - SequentialRunner().run_only_missing(pipeline, catalog) + SequentialRunner().run_only_missing(pipeline, catalog, hook_manager) _, args, _ = run.mock_calls[0] new_pipeline = args[0] return new_pipeline class TestPipelineMissing: - def test_all_missing(self, branched_pipeline): + def test_all_missing(self, branched_pipeline, hook_manager): catalog = _make_catalog(non_existent=["A", "B", "C", "D", "E", "F"]) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipelines_equal(branched_pipeline, new_pipeline) - def test_none_missing(self, branched_pipeline): + def test_none_missing(self, branched_pipeline, hook_manager): catalog = _make_catalog(existent=["A", "B", "C", "D", "E", "F"]) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipeline_contains(new_pipeline, []) - def test_none_missing_feeddict_only(self, branched_pipeline): + def test_none_missing_feeddict_only(self, branched_pipeline, hook_manager): feed_dict = {"A": 1, "B": 2, "C": 3, "D": 4, "E": 5, "F": 6} catalog = _make_catalog(feed_dict=feed_dict) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipeline_contains(new_pipeline, []) - def test_first_missing(self, branched_pipeline): + def test_first_missing(self, branched_pipeline, hook_manager): """combine from B and C is missing.""" catalog = _make_catalog(non_existent=["B", "C"], existent=["A", "D", "E", "F"]) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipelines_equal(branched_pipeline, new_pipeline) - def test_only_left_missing(self, branched_pipeline): + def test_only_left_missing(self, branched_pipeline, hook_manager): catalog = _make_catalog(non_existent=["B"], existent=["A", "C", "D", "E", "F"]) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipeline_contains( new_pipeline, ["left_in", "combine", "split", "right_out"] ) - def test_last_missing(self, branched_pipeline): + def test_last_missing(self, branched_pipeline, hook_manager): """r-out from F is missing.""" catalog = _make_catalog(non_existent=["F"], existent=["A", "B", "C", "D", "E"]) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipeline_contains(new_pipeline, ["split", "right_out"]) - def test_missing_and_no_exists(self, branched_pipeline, caplog): + def test_missing_and_no_exists(self, branched_pipeline, caplog, hook_manager): """If F doesn't have exists(), F is treated as missing.""" catalog = _make_catalog( existent=["A", "B", "C", "D", "E"], no_exists_method=["F"] ) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipeline_contains(new_pipeline, ["split", "right_out"]) log_record = caplog.records[0] @@ -147,9 +153,9 @@ def test_missing_and_no_exists(self, branched_pipeline, caplog): "`exists()` not implemented for `LambdaDataSet`" in log_record.getMessage() ) - def test_all_no_exists_method(self, branched_pipeline, caplog): + def test_all_no_exists_method(self, branched_pipeline, caplog, hook_manager): catalog = _make_catalog(no_exists_method=["A", "B", "C", "D", "E", "F"]) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipelines_equal(branched_pipeline, new_pipeline) log_msgs = [record.getMessage() for record in caplog.records] @@ -159,63 +165,63 @@ def test_all_no_exists_method(self, branched_pipeline, caplog): ) assert expected_msg in log_msgs - def test_catalog_and_feed_dict(self, branched_pipeline): + def test_catalog_and_feed_dict(self, branched_pipeline, hook_manager): """Mix of feed_dict and non-existent F.""" catalog = _make_catalog(non_existent=["F"], existent=["D", "E"]) catalog.add_feed_dict({"A": 1, "B": 2, "C": 3}) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipeline_contains(new_pipeline, ["split", "right_out"]) class TestPipelineUnregistered: - def test_propagate_up(self, branched_pipeline): + def test_propagate_up(self, branched_pipeline, hook_manager): """If a node needs to be rerun and requires unregistered (node-to-node) inputs, all necessary upstream nodes should be added. """ catalog = _make_catalog(existent=["A"], non_existent=["E"]) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipeline_contains( new_pipeline, ["left_in", "right_in", "combine", "split"] ) - def test_propagate_down_then_up(self, branched_pipeline): + def test_propagate_down_then_up(self, branched_pipeline, hook_manager): """Unregistered (node-to-node) inputs for downstream nodes should be included, too. """ catalog = _make_catalog(existent=["A", "D", "E"], non_existent=["C"]) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipelines_equal(branched_pipeline, new_pipeline) - def test_ignore_unneccessary_unreg(self, branched_pipeline): + def test_ignore_unneccessary_unreg(self, branched_pipeline, hook_manager): """Unregistered (node-to-node) data sources should not trigger reruns, unless necessary to recreate registered data sources. """ catalog = _make_catalog(existent=["A", "E", "F"]) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipeline_contains(new_pipeline, []) - def test_partial_propagation(self, branched_pipeline): + def test_partial_propagation(self, branched_pipeline, hook_manager): """Unregistered (node-to-node) data sources should not trigger reruns, unless necessary to recreate registered data sources. """ catalog = _make_catalog(existent=["A", "D"], no_exists_method=["F"]) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipeline_contains(new_pipeline, ["split", "right_out"]) - def test_partial_non_existent_propagation(self, branched_pipeline): + def test_partial_non_existent_propagation(self, branched_pipeline, hook_manager): """A non existent data set whose node has one unregistered input and one existent input should be recalculated correctly. """ catalog = _make_catalog(existent=["A", "C", "E", "F"], non_existent=["D"]) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipeline_contains( new_pipeline, ["left_in", "combine", "split", "right_out"] ) - def test_free_output(self, branched_pipeline): + def test_free_output(self, branched_pipeline, hook_manager): """Free outputs are the only unregistered data sources that should trigger runs. """ catalog = _make_catalog(existent=["A", "B", "C", "F"]) - new_pipeline = _from_missing(branched_pipeline, catalog) + new_pipeline = _from_missing(branched_pipeline, catalog, hook_manager) assert _pipeline_contains(new_pipeline, ["combine", "split"]) diff --git a/tests/pipeline/test_pipeline_integration.py b/tests/pipeline/test_pipeline_integration.py index a26c403b05..10511cd63f 100644 --- a/tests/pipeline/test_pipeline_integration.py +++ b/tests/pipeline/test_pipeline_integration.py @@ -1,3 +1,4 @@ +from kedro.framework.hooks import _create_hook_manager from kedro.io import DataCatalog from kedro.pipeline import Pipeline, node, pipeline from kedro.runner import SequentialRunner @@ -40,7 +41,7 @@ def test_connect_existing_pipelines(self): for pipe in [pipeline1, pipeline2, pipeline3]: catalog = DataCatalog({}, feed_dict={"frozen_meat": "frozen_meat_data"}) - result = SequentialRunner().run(pipe, catalog) + result = SequentialRunner().run(pipe, catalog, _create_hook_manager()) assert result == {"output": "frozen_meat_data_defrosted_grilled_done"} def test_reuse_same_pipeline(self): @@ -77,7 +78,7 @@ def test_reuse_same_pipeline(self): "lunch.frozen_meat": "lunch_frozen_meat", }, ) - result = SequentialRunner().run(pipe, catalog) + result = SequentialRunner().run(pipe, catalog, _create_hook_manager()) assert result == { "breakfast_output": "breakfast_frozen_meat_defrosted_grilled_done", "lunch_output": "lunch_frozen_meat_defrosted_grilled_done", diff --git a/tests/runner/conftest.py b/tests/runner/conftest.py new file mode 100644 index 0000000000..2532a42190 --- /dev/null +++ b/tests/runner/conftest.py @@ -0,0 +1,95 @@ +from random import random + +import pytest + +from kedro.framework.hooks import _create_hook_manager +from kedro.io import DataCatalog +from kedro.pipeline import Pipeline, node + + +def source(): + return "stuff" + + +def identity(arg): + return arg + + +def sink(arg): # pylint: disable=unused-argument + pass + + +def fan_in(*args): + return args + + +def exception_fn(arg): + raise Exception("test exception") + + +def return_none(arg): + arg = None + return arg + + +def return_not_serializable(arg): # pylint: disable=unused-argument + return lambda x: x + + +@pytest.fixture +def catalog(): + return DataCatalog() + + +@pytest.fixture +def hook_manager(): + return _create_hook_manager() + + +@pytest.fixture +def fan_out_fan_in(): + return Pipeline( + [ + node(identity, "A", "B"), + node(identity, "B", "C"), + node(identity, "B", "D"), + node(identity, "B", "E"), + node(fan_in, ["C", "D", "E"], "Z"), + ] + ) + + +@pytest.fixture +def branchless_no_input_pipeline(): + """The pipeline runs in the order A->B->C->D->E.""" + return Pipeline( + [ + node(identity, "D", "E", name="node1"), + node(identity, "C", "D", name="node2"), + node(identity, "A", "B", name="node3"), + node(identity, "B", "C", name="node4"), + node(random, None, "A", name="node5"), + ] + ) + + +@pytest.fixture +def branchless_pipeline(): + return Pipeline( + [ + node(identity, "ds1", "ds2", name="node1"), + node(identity, "ds2", "ds3", name="node2"), + ] + ) + + +@pytest.fixture +def saving_result_pipeline(): + return Pipeline([node(identity, "ds", "dsX")]) + + +@pytest.fixture +def saving_none_pipeline(): + return Pipeline( + [node(random, None, "A"), node(return_none, "A", "B"), node(identity, "B", "C")] + ) diff --git a/tests/runner/test_parallel_runner.py b/tests/runner/test_parallel_runner.py index 9c0ef564f3..ce4ff6f2ca 100644 --- a/tests/runner/test_parallel_runner.py +++ b/tests/runner/test_parallel_runner.py @@ -19,53 +19,14 @@ _run_node_synchronization, _SharedMemoryDataSet, ) - - -def source(): - return "stuff" - - -def identity(arg): - return arg - - -def sink(arg): # pylint: disable=unused-argument - pass - - -def fan_in(*args): - return args - - -def exception_fn(arg): - raise Exception("test exception") - - -def return_none(arg): - arg = None - return arg - - -def return_not_serializable(arg): # pylint: disable=unused-argument - return lambda x: x - - -@pytest.fixture -def catalog(): - return DataCatalog() - - -@pytest.fixture -def fan_out_fan_in(): - return Pipeline( - [ - node(identity, "A", "B"), - node(identity, "B", "C"), - node(identity, "B", "D"), - node(identity, "B", "E"), - node(fan_in, ["C", "D", "E"], "Z"), - ] - ) +from tests.runner.conftest import ( + exception_fn, + identity, + return_none, + return_not_serializable, + sink, + source, +) @pytest.mark.skipif( @@ -78,18 +39,20 @@ def test_create_default_data_set(self): assert isinstance(data_set, _SharedMemoryDataSet) @pytest.mark.parametrize("is_async", [False, True]) - def test_parallel_run(self, is_async, fan_out_fan_in, catalog): + def test_parallel_run(self, is_async, fan_out_fan_in, catalog, hook_manager): catalog.add_feed_dict(dict(A=42)) - result = ParallelRunner(is_async=is_async).run(fan_out_fan_in, catalog) + result = ParallelRunner(is_async=is_async).run( + fan_out_fan_in, catalog, hook_manager + ) assert "Z" in result assert len(result["Z"]) == 3 assert result["Z"] == (42, 42, 42) @pytest.mark.parametrize("is_async", [False, True]) - def test_memory_dataset_input(self, is_async, fan_out_fan_in): + def test_memory_dataset_input(self, is_async, fan_out_fan_in, hook_manager): pipeline = Pipeline([fan_out_fan_in]) catalog = DataCatalog({"A": MemoryDataSet("42")}) - result = ParallelRunner(is_async=is_async).run(pipeline, catalog) + result = ParallelRunner(is_async=is_async).run(pipeline, catalog, hook_manager) assert "Z" in result assert len(result["Z"]) == 3 assert result["Z"] == ("42", "42", "42") @@ -121,6 +84,7 @@ def test_specified_max_workers_bellow_cpu_cores_count( cpu_cores, user_specified_number, expected_number, + hook_manager, ): # pylint: disable=too-many-arguments """ The system has 2 cores, but we initialize the runner with max_workers=4. @@ -137,7 +101,7 @@ def test_specified_max_workers_bellow_cpu_cores_count( catalog.add_feed_dict(dict(A=42)) result = ParallelRunner( max_workers=user_specified_number, is_async=is_async - ).run(fan_out_fan_in, catalog) + ).run(fan_out_fan_in, catalog, hook_manager) assert result == {"Z": (42, 42, 42)} executor_cls_mock.assert_called_once_with(max_workers=expected_number) @@ -159,36 +123,36 @@ def test_max_worker_windows(self, mocker): ) @pytest.mark.parametrize("is_async", [False, True]) class TestInvalidParallelRunner: - def test_task_validation(self, is_async, fan_out_fan_in, catalog): + def test_task_validation(self, is_async, fan_out_fan_in, catalog, hook_manager): """ParallelRunner cannot serialize the lambda function.""" catalog.add_feed_dict(dict(A=42)) pipeline = Pipeline([fan_out_fan_in, node(lambda x: x, "Z", "X")]) with pytest.raises(AttributeError): - ParallelRunner(is_async=is_async).run(pipeline, catalog) + ParallelRunner(is_async=is_async).run(pipeline, catalog, hook_manager) - def test_task_exception(self, is_async, fan_out_fan_in, catalog): + def test_task_exception(self, is_async, fan_out_fan_in, catalog, hook_manager): catalog.add_feed_dict(feed_dict=dict(A=42)) pipeline = Pipeline([fan_out_fan_in, node(exception_fn, "Z", "X")]) with pytest.raises(Exception, match="test exception"): - ParallelRunner(is_async=is_async).run(pipeline, catalog) + ParallelRunner(is_async=is_async).run(pipeline, catalog, hook_manager) - def test_memory_dataset_output(self, is_async, fan_out_fan_in): + def test_memory_dataset_output(self, is_async, fan_out_fan_in, hook_manager): """ParallelRunner does not support output to externally created MemoryDataSets. """ pipeline = Pipeline([fan_out_fan_in]) catalog = DataCatalog({"C": MemoryDataSet()}, dict(A=42)) with pytest.raises(AttributeError, match="['C']"): - ParallelRunner(is_async=is_async).run(pipeline, catalog) + ParallelRunner(is_async=is_async).run(pipeline, catalog, hook_manager) - def test_node_returning_none(self, is_async): + def test_node_returning_none(self, is_async, hook_manager): pipeline = Pipeline([node(identity, "A", "B"), node(return_none, "B", "C")]) catalog = DataCatalog({"A": MemoryDataSet("42")}) pattern = "Saving `None` to a `DataSet` is not allowed" with pytest.raises(DataSetError, match=pattern): - ParallelRunner(is_async=is_async).run(pipeline, catalog) + ParallelRunner(is_async=is_async).run(pipeline, catalog, hook_manager) - def test_data_set_not_serializable(self, is_async, fan_out_fan_in): + def test_data_set_not_serializable(self, is_async, fan_out_fan_in, hook_manager): """Data set A cannot be serializable because _load and _save are not defined in global scope. """ @@ -204,9 +168,9 @@ def _save(arg): pipeline = Pipeline([fan_out_fan_in]) with pytest.raises(AttributeError, match="['A']"): - ParallelRunner(is_async=is_async).run(pipeline, catalog) + ParallelRunner(is_async=is_async).run(pipeline, catalog, hook_manager) - def test_memory_dataset_not_serializable(self, is_async, catalog): + def test_memory_dataset_not_serializable(self, is_async, catalog, hook_manager): """Memory dataset cannot be serializable because of data it stores.""" data = return_not_serializable(None) pipeline = Pipeline([node(return_not_serializable, "A", "B")]) @@ -217,10 +181,10 @@ def test_memory_dataset_not_serializable(self, is_async, catalog): ) with pytest.raises(DataSetError, match=pattern): - ParallelRunner(is_async=is_async).run(pipeline, catalog) + ParallelRunner(is_async=is_async).run(pipeline, catalog, hook_manager) def test_unable_to_schedule_all_nodes( - self, mocker, is_async, fan_out_fan_in, catalog + self, mocker, is_async, fan_out_fan_in, catalog, hook_manager ): """Test the error raised when `futures` variable is empty, but `todo_nodes` is not (can barely happen in real life). @@ -241,7 +205,7 @@ def test_unable_to_schedule_all_nodes( pattern = "Unable to schedule new tasks although some nodes have not been run" with pytest.raises(RuntimeError, match=pattern): - runner.run(fan_out_fan_in, catalog) + runner.run(fan_out_fan_in, catalog, hook_manager) class LoggingDataSet(AbstractDataSet): @@ -276,7 +240,7 @@ def _describe(self) -> Dict[str, Any]: ) @pytest.mark.parametrize("is_async", [False, True]) class TestParallelRunnerRelease: - def test_dont_release_inputs_and_outputs(self, is_async): + def test_dont_release_inputs_and_outputs(self, is_async, hook_manager): runner = ParallelRunner(is_async=is_async) log = runner._manager.list() @@ -291,12 +255,12 @@ def test_dont_release_inputs_and_outputs(self, is_async): "out": runner._manager.LoggingDataSet(log, "out"), } ) - ParallelRunner().run(pipeline, catalog) + ParallelRunner().run(pipeline, catalog, hook_manager) # we don't want to see release in or out in here assert list(log) == [("load", "in"), ("load", "middle"), ("release", "middle")] - def test_release_at_earliest_opportunity(self, is_async): + def test_release_at_earliest_opportunity(self, is_async, hook_manager): runner = ParallelRunner(is_async=is_async) log = runner._manager.list() @@ -314,7 +278,7 @@ def test_release_at_earliest_opportunity(self, is_async): "second": runner._manager.LoggingDataSet(log, "second"), } ) - runner.run(pipeline, catalog) + runner.run(pipeline, catalog, hook_manager) # we want to see "release first" before "load second" assert list(log) == [ @@ -324,7 +288,7 @@ def test_release_at_earliest_opportunity(self, is_async): ("release", "second"), ] - def test_count_multiple_loads(self, is_async): + def test_count_multiple_loads(self, is_async, hook_manager): runner = ParallelRunner(is_async=is_async) log = runner._manager.list() @@ -339,7 +303,7 @@ def test_count_multiple_loads(self, is_async): catalog = DataCatalog( {"dataset": runner._manager.LoggingDataSet(log, "dataset")} ) - runner.run(pipeline, catalog) + runner.run(pipeline, catalog, hook_manager) # we want to the release after both the loads assert list(log) == [ @@ -348,7 +312,7 @@ def test_count_multiple_loads(self, is_async): ("release", "dataset"), ] - def test_release_transcoded(self, is_async): + def test_release_transcoded(self, is_async, hook_manager): runner = ParallelRunner(is_async=is_async) log = runner._manager.list() @@ -362,7 +326,7 @@ def test_release_transcoded(self, is_async): } ) - ParallelRunner().run(pipeline, catalog) + ParallelRunner().run(pipeline, catalog, hook_manager) # we want to see both datasets being released assert list(log) == [("release", "save"), ("load", "load"), ("release", "load")] @@ -411,7 +375,7 @@ def test_package_name_and_logging_provided( package_name=package_name, conf_logging=conf_logging, ) - mock_run_node.assert_called_once_with(node_, catalog, is_async, run_id) + mock_run_node.assert_called_once() mock_logging.assert_called_once_with(conf_logging) mock_configure_project.assert_called_once_with(package_name) @@ -432,7 +396,7 @@ def test_package_name_provided( _run_node_synchronization( node_, catalog, is_async, run_id, package_name=package_name ) - mock_run_node.assert_called_once_with(node_, catalog, is_async, run_id) + mock_run_node.assert_called_once() mock_logging.assert_called_once_with({}) mock_configure_project.assert_called_once_with(package_name) @@ -448,5 +412,5 @@ def test_package_name_not_provided( _run_node_synchronization( node_, catalog, is_async, run_id, package_name=package_name ) - mock_run_node.assert_called_once_with(node_, catalog, is_async, run_id) + mock_run_node.assert_called_once() mock_logging.assert_not_called() diff --git a/tests/runner/test_sequential_runner.py b/tests/runner/test_sequential_runner.py index 0987b7aea3..0c7d6f3ce6 100644 --- a/tests/runner/test_sequential_runner.py +++ b/tests/runner/test_sequential_runner.py @@ -1,5 +1,3 @@ -# pylint: disable=unused-argument -from random import random from typing import Any, Dict import pandas as pd @@ -14,6 +12,7 @@ ) from kedro.pipeline import Pipeline, node from kedro.runner import SequentialRunner +from tests.runner.conftest import identity, sink, source @pytest.fixture @@ -36,90 +35,48 @@ def conflicting_feed_dict(pandas_df_feed_dict): return {"ds1": ds1, "ds3": ds3} -def source(): - return "stuff" - - -def identity(arg): - return arg - - -def sink(arg): - pass - - -def return_none(arg): - return None - - def multi_input_list_output(arg1, arg2): return [arg1, arg2] -@pytest.fixture -def branchless_no_input_pipeline(): - """The pipeline runs in the order A->B->C->D->E.""" - return Pipeline( - [ - node(identity, "D", "E", name="node1"), - node(identity, "C", "D", name="node2"), - node(identity, "A", "B", name="node3"), - node(identity, "B", "C", name="node4"), - node(random, None, "A", name="node5"), - ] - ) - - -@pytest.fixture -def branchless_pipeline(): - return Pipeline( - [ - node(identity, "ds1", "ds2", name="node1"), - node(identity, "ds2", "ds3", name="node2"), - ] - ) - - -@pytest.fixture -def saving_result_pipeline(): - return Pipeline([node(identity, "ds", "dsX")]) - - -@pytest.fixture -def saving_none_pipeline(): - return Pipeline( - [node(random, None, "A"), node(return_none, "A", "B"), node(identity, "B", "C")] - ) - - @pytest.mark.parametrize("is_async", [False, True]) class TestSeqentialRunnerBranchlessPipeline: - def test_no_input_seq(self, is_async, branchless_no_input_pipeline): + def test_no_input_seq( + self, is_async, branchless_no_input_pipeline, catalog, hook_manager + ): outputs = SequentialRunner(is_async=is_async).run( - branchless_no_input_pipeline, DataCatalog() + branchless_no_input_pipeline, catalog, hook_manager ) assert "E" in outputs assert len(outputs) == 1 - def test_no_data_sets(self, is_async, branchless_pipeline): + def test_no_data_sets(self, is_async, branchless_pipeline, hook_manager): catalog = DataCatalog({}, {"ds1": 42}) - outputs = SequentialRunner(is_async=is_async).run(branchless_pipeline, catalog) + outputs = SequentialRunner(is_async=is_async).run( + branchless_pipeline, catalog, hook_manager + ) assert "ds3" in outputs assert outputs["ds3"] == 42 - def test_no_feed(self, is_async, memory_catalog, branchless_pipeline): + def test_no_feed(self, is_async, memory_catalog, branchless_pipeline, hook_manager): outputs = SequentialRunner(is_async=is_async).run( - branchless_pipeline, memory_catalog + branchless_pipeline, memory_catalog, hook_manager ) assert "ds3" in outputs assert outputs["ds3"]["data"] == 42 - def test_node_returning_none(self, is_async, saving_none_pipeline): + def test_node_returning_none( + self, is_async, saving_none_pipeline, catalog, hook_manager + ): pattern = "Saving `None` to a `DataSet` is not allowed" with pytest.raises(DataSetError, match=pattern): - SequentialRunner(is_async=is_async).run(saving_none_pipeline, DataCatalog()) + SequentialRunner(is_async=is_async).run( + saving_none_pipeline, catalog, hook_manager + ) - def test_result_saved_not_returned(self, is_async, saving_result_pipeline): + def test_result_saved_not_returned( + self, is_async, saving_result_pipeline, hook_manager + ): """The pipeline runs ds->dsX but save does not save the output.""" def _load(): @@ -135,7 +92,7 @@ def _save(arg): } ) output = SequentialRunner(is_async=is_async).run( - saving_result_pipeline, catalog + saving_result_pipeline, catalog, hook_manager ) assert output == {} @@ -157,11 +114,16 @@ def unfinished_outputs_pipeline(): @pytest.mark.parametrize("is_async", [False, True]) class TestSeqentialRunnerBranchedPipeline: def test_input_seq( - self, is_async, memory_catalog, unfinished_outputs_pipeline, pandas_df_feed_dict + self, + is_async, + memory_catalog, + unfinished_outputs_pipeline, + pandas_df_feed_dict, + hook_manager, ): memory_catalog.add_feed_dict(pandas_df_feed_dict, replace=True) outputs = SequentialRunner(is_async=is_async).run( - unfinished_outputs_pipeline, memory_catalog + unfinished_outputs_pipeline, memory_catalog, hook_manager ) assert set(outputs.keys()) == {"ds8", "ds5", "ds6"} # the pipeline runs ds2->ds5 @@ -178,21 +140,24 @@ def test_conflict_feed_catalog( memory_catalog, unfinished_outputs_pipeline, conflicting_feed_dict, + hook_manager, ): """ds1 and ds3 will be replaced with new inputs.""" memory_catalog.add_feed_dict(conflicting_feed_dict, replace=True) outputs = SequentialRunner(is_async=is_async).run( - unfinished_outputs_pipeline, memory_catalog + unfinished_outputs_pipeline, memory_catalog, hook_manager ) assert isinstance(outputs["ds8"], dict) assert outputs["ds8"]["data"] == 0 assert isinstance(outputs["ds6"], pd.DataFrame) - def test_unsatisfied_inputs(self, is_async, unfinished_outputs_pipeline): + def test_unsatisfied_inputs( + self, is_async, unfinished_outputs_pipeline, catalog, hook_manager + ): """ds1, ds2 and ds3 were not specified.""" with pytest.raises(ValueError, match=r"not found in the DataCatalog"): SequentialRunner(is_async=is_async).run( - unfinished_outputs_pipeline, DataCatalog() + unfinished_outputs_pipeline, catalog, hook_manager ) @@ -219,7 +184,7 @@ def _describe(self) -> Dict[str, Any]: @pytest.mark.parametrize("is_async", [False, True]) class TestSequentialRunnerRelease: - def test_dont_release_inputs_and_outputs(self, is_async): + def test_dont_release_inputs_and_outputs(self, is_async, hook_manager): log = [] pipeline = Pipeline( [node(identity, "in", "middle"), node(identity, "middle", "out")] @@ -231,12 +196,12 @@ def test_dont_release_inputs_and_outputs(self, is_async): "out": LoggingDataSet(log, "out"), } ) - SequentialRunner(is_async=is_async).run(pipeline, catalog) + SequentialRunner(is_async=is_async).run(pipeline, catalog, hook_manager) # we don't want to see release in or out in here assert log == [("load", "in"), ("load", "middle"), ("release", "middle")] - def test_release_at_earliest_opportunity(self, is_async): + def test_release_at_earliest_opportunity(self, is_async, hook_manager): log = [] pipeline = Pipeline( [ @@ -251,7 +216,7 @@ def test_release_at_earliest_opportunity(self, is_async): "second": LoggingDataSet(log, "second"), } ) - SequentialRunner(is_async=is_async).run(pipeline, catalog) + SequentialRunner(is_async=is_async).run(pipeline, catalog, hook_manager) # we want to see "release first" before "load second" assert log == [ @@ -261,7 +226,7 @@ def test_release_at_earliest_opportunity(self, is_async): ("release", "second"), ] - def test_count_multiple_loads(self, is_async): + def test_count_multiple_loads(self, is_async, hook_manager): log = [] pipeline = Pipeline( [ @@ -271,12 +236,12 @@ def test_count_multiple_loads(self, is_async): ] ) catalog = DataCatalog({"dataset": LoggingDataSet(log, "dataset")}) - SequentialRunner(is_async=is_async).run(pipeline, catalog) + SequentialRunner(is_async=is_async).run(pipeline, catalog, hook_manager) # we want to the release after both the loads assert log == [("load", "dataset"), ("load", "dataset"), ("release", "dataset")] - def test_release_transcoded(self, is_async): + def test_release_transcoded(self, is_async, hook_manager): log = [] pipeline = Pipeline( [node(source, None, "ds@save"), node(sink, "ds@load", None)] @@ -288,7 +253,7 @@ def test_release_transcoded(self, is_async): } ) - SequentialRunner(is_async=is_async).run(pipeline, catalog) + SequentialRunner(is_async=is_async).run(pipeline, catalog, hook_manager) # we want to see both datasets being released assert log == [("release", "save"), ("load", "load"), ("release", "load")] @@ -305,8 +270,8 @@ def test_release_transcoded(self, is_async): ), ], ) - def test_confirms(self, mocker, pipeline, is_async): + def test_confirms(self, mocker, pipeline, is_async, hook_manager): fake_dataset_instance = mocker.Mock() catalog = DataCatalog(data_sets={"ds1": fake_dataset_instance}) - SequentialRunner(is_async=is_async).run(pipeline, catalog) + SequentialRunner(is_async=is_async).run(pipeline, catalog, hook_manager) fake_dataset_instance.confirm.assert_called_once_with() diff --git a/tests/runner/test_thread_runner.py b/tests/runner/test_thread_runner.py index 70a5c51cb4..59a4e5e7d9 100644 --- a/tests/runner/test_thread_runner.py +++ b/tests/runner/test_thread_runner.py @@ -6,49 +6,7 @@ from kedro.io import AbstractDataSet, DataCatalog, DataSetError, MemoryDataSet from kedro.pipeline import Pipeline, node from kedro.runner import ThreadRunner - - -def source(): - return "stuff" - - -def identity(arg): - return arg - - -def sink(arg): # pylint: disable=unused-argument - pass - - -def fan_in(*args): - return args - - -def exception_fn(arg): - raise Exception("test exception") - - -def return_none(arg): - arg = None - return arg - - -@pytest.fixture -def catalog(): - return DataCatalog() - - -@pytest.fixture -def fan_out_fan_in(): - return Pipeline( - [ - node(identity, "A", "B"), - node(identity, "B", "C"), - node(identity, "B", "D"), - node(identity, "B", "E"), - node(fan_in, ["C", "D", "E"], "Z"), - ] - ) +from tests.runner.conftest import exception_fn, identity, return_none, sink, source class TestValidThreadRunner: @@ -56,15 +14,15 @@ def test_create_default_data_set(self): data_set = ThreadRunner().create_default_data_set("") assert isinstance(data_set, MemoryDataSet) - def test_thread_run(self, fan_out_fan_in, catalog): + def test_thread_run(self, fan_out_fan_in, catalog, hook_manager): catalog.add_feed_dict(dict(A=42)) - result = ThreadRunner().run(fan_out_fan_in, catalog) + result = ThreadRunner().run(fan_out_fan_in, catalog, hook_manager) assert "Z" in result assert result["Z"] == (42, 42, 42) - def test_memory_dataset_input(self, fan_out_fan_in): + def test_memory_dataset_input(self, fan_out_fan_in, hook_manager): catalog = DataCatalog({"A": MemoryDataSet("42")}) - result = ThreadRunner().run(fan_out_fan_in, catalog) + result = ThreadRunner().run(fan_out_fan_in, catalog, hook_manager) assert "Z" in result assert result["Z"] == ("42", "42", "42") @@ -85,6 +43,7 @@ def test_specified_max_workers( catalog, user_specified_number, expected_number, + hook_manager, ): # pylint: disable=too-many-arguments """ We initialize the runner with max_workers=4. @@ -98,7 +57,7 @@ def test_specified_max_workers( catalog.add_feed_dict(dict(A=42)) result = ThreadRunner(max_workers=user_specified_number).run( - fan_out_fan_in, catalog + fan_out_fan_in, catalog, hook_manager ) assert result == {"Z": (42, 42, 42)} @@ -111,7 +70,7 @@ def test_init_with_negative_process_count(self): class TestIsAsync: - def test_thread_run(self, fan_out_fan_in, catalog): + def test_thread_run(self, fan_out_fan_in, catalog, hook_manager): catalog.add_feed_dict(dict(A=42)) pattern = ( "`ThreadRunner` doesn't support loading and saving the " @@ -119,24 +78,26 @@ def test_thread_run(self, fan_out_fan_in, catalog): "Setting `is_async` to False." ) with pytest.warns(UserWarning, match=pattern): - result = ThreadRunner(is_async=True).run(fan_out_fan_in, catalog) + result = ThreadRunner(is_async=True).run( + fan_out_fan_in, catalog, hook_manager + ) assert "Z" in result assert result["Z"] == (42, 42, 42) class TestInvalidThreadRunner: - def test_task_exception(self, fan_out_fan_in, catalog): + def test_task_exception(self, fan_out_fan_in, catalog, hook_manager): catalog.add_feed_dict(feed_dict=dict(A=42)) pipeline = Pipeline([fan_out_fan_in, node(exception_fn, "Z", "X")]) with pytest.raises(Exception, match="test exception"): - ThreadRunner().run(pipeline, catalog) + ThreadRunner().run(pipeline, catalog, hook_manager) - def test_node_returning_none(self): + def test_node_returning_none(self, hook_manager): pipeline = Pipeline([node(identity, "A", "B"), node(return_none, "B", "C")]) catalog = DataCatalog({"A": MemoryDataSet("42")}) pattern = "Saving `None` to a `DataSet` is not allowed" with pytest.raises(DataSetError, match=pattern): - ThreadRunner().run(pipeline, catalog) + ThreadRunner().run(pipeline, catalog, hook_manager) class LoggingDataSet(AbstractDataSet): @@ -161,7 +122,7 @@ def _describe(self) -> Dict[str, Any]: class TestThreadRunnerRelease: - def test_dont_release_inputs_and_outputs(self): + def test_dont_release_inputs_and_outputs(self, hook_manager): log = [] pipeline = Pipeline( @@ -174,12 +135,12 @@ def test_dont_release_inputs_and_outputs(self): "out": LoggingDataSet(log, "out"), } ) - ThreadRunner().run(pipeline, catalog) + ThreadRunner().run(pipeline, catalog, hook_manager) # we don't want to see release in or out in here assert list(log) == [("load", "in"), ("load", "middle"), ("release", "middle")] - def test_release_at_earliest_opportunity(self): + def test_release_at_earliest_opportunity(self, hook_manager): runner = ThreadRunner() log = [] @@ -196,7 +157,7 @@ def test_release_at_earliest_opportunity(self): "second": LoggingDataSet(log, "second"), } ) - runner.run(pipeline, catalog) + runner.run(pipeline, catalog, hook_manager) # we want to see "release first" before "load second" assert list(log) == [ @@ -206,7 +167,7 @@ def test_release_at_earliest_opportunity(self): ("release", "second"), ] - def test_count_multiple_loads(self): + def test_count_multiple_loads(self, hook_manager): runner = ThreadRunner() log = [] @@ -218,7 +179,7 @@ def test_count_multiple_loads(self): ] ) catalog = DataCatalog({"dataset": LoggingDataSet(log, "dataset")}) - runner.run(pipeline, catalog) + runner.run(pipeline, catalog, hook_manager) # we want to the release after both the loads assert list(log) == [ @@ -227,7 +188,7 @@ def test_count_multiple_loads(self): ("release", "dataset"), ] - def test_release_transcoded(self): + def test_release_transcoded(self, hook_manager): log = [] pipeline = Pipeline( @@ -240,7 +201,7 @@ def test_release_transcoded(self): } ) - ThreadRunner().run(pipeline, catalog) + ThreadRunner().run(pipeline, catalog, hook_manager) # we want to see both datasets being released assert list(log) == [("release", "save"), ("load", "load"), ("release", "load")]