Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define modular pipelines in config #3904

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* Added `raise_errors` argument to `find_pipelines`. If `True`, the first pipeline for which autodiscovery fails will cause an error to be raised. The default behaviour is still to raise a warning for each failing pipeline.
* It is now possible to use Kedro without having `rich` installed.
* Updated custom logging behavior: `conf/logging.yml` will be used if it exists and `KEDRO_LOGGING_CONFIG` is not set; otherwise, `default_logging.yml` will be used.
* Can now define modular pipeline in `conf/pipelines.yml`

## Bug fixes and other changes
* User defined catch-all dataset factory patterns now override the default pattern provided by the runner.
Expand Down
1 change: 1 addition & 0 deletions kedro/config/omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__( # noqa: PLR0913
"parameters": ["parameters*", "parameters*/**", "**/parameters*"],
"credentials": ["credentials*", "credentials*/**", "**/credentials*"],
"globals": ["globals.yml"],
"pipelines": ["pipelines.yml"],
}
self.config_patterns.update(config_patterns or {})

Expand Down
98 changes: 81 additions & 17 deletions kedro/framework/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,27 @@ def _create_pipeline(pipeline_module: types.ModuleType) -> Pipeline | None:
return obj


def _get_pipeline_obj(pipeline_name: str, raise_errors: bool = False) -> Pipeline | None:
pipeline_obj = None
pipeline_module_name = f"{PACKAGE_NAME}.pipelines.{pipeline_name}"
try:
pipeline_module = importlib.import_module(pipeline_module_name)
pipeline_obj = _create_pipeline(pipeline_module)
except Exception as exc:
if raise_errors:
raise ImportError(
f"An error occurred while importing the "
f"'{pipeline_module_name}' module."
) from exc

warnings.warn(
IMPORT_ERROR_MESSAGE.format(
module=pipeline_module_name, tb_exc=traceback.format_exc()
)
)
return pipeline_obj


def find_pipelines(raise_errors: bool = False) -> dict[str, Pipeline]: # noqa: PLR0912
"""Automatically find modular pipelines having a ``create_pipeline``
function. By default, projects created using Kedro 0.18.3 and higher
Expand Down Expand Up @@ -419,24 +440,67 @@ def find_pipelines(raise_errors: bool = False) -> dict[str, Pipeline]: # noqa:
if pipeline_name.startswith("."):
continue

pipeline_module_name = f"{PACKAGE_NAME}.pipelines.{pipeline_name}"
try:
pipeline_module = importlib.import_module(pipeline_module_name)
except Exception as exc:
if raise_errors:
raise ImportError(
f"An error occurred while importing the "
f"'{pipeline_module_name}' module."
) from exc

warnings.warn(
IMPORT_ERROR_MESSAGE.format(
module=pipeline_module_name, tb_exc=traceback.format_exc()
)
)
continue
pipeline_obj = _get_pipeline_obj(pipeline_name, raise_errors)

pipeline_obj = _create_pipeline(pipeline_module)
if pipeline_obj is not None:
pipelines_dict[pipeline_name] = pipeline_obj
return pipelines_dict


def from_config(config_entry: dict[str, str], raise_errors: bool = False) -> Pipeline:
"""Create a ``Pipeline`` object from a config entry.

Args:
config_entry: Config entry dictionary.
raise_errors: If ``True``, raise an error upon failed discovery.

Returns:
A generated ``Pipeline`` object. Similar to a modular pipeline
normally defined in the pipeline registry.

Raises:
ImportError: When a module does not expose a ``create_pipeline``
function, the ``create_pipeline`` function does not return a
``Pipeline`` object, or if the module import fails up front.
If ``raise_errors`` is ``False``, see Warns section instead.

Warns:
UserWarning: When a module does not expose a ``create_pipeline``
function, the ``create_pipeline`` function does not return a
``Pipeline`` object, or if the module import fails up front.
If ``raise_errors`` is ``True``, see Raises section instead.

Examples:
pipelines.yml:
processing:
pipe:
- data_processing

pipline_registry.py:
>>> from typing import Dict
>>> from kedro.framework.project import find_pipelines, from_config
>>> from kedro.config.omegaconf_config import OmegaConfigLoader
>>> from kedro.framework.project import settings
>>> def register_pipelines() -> Dict[str, Pipeline]:
>>> conf_path = str(Path(__file__).parents[2] / settings.CONF_SOURCE)
>>> conf_loader = OmegaConfigLoader(conf_path, env="base")
>>> processing_pipeline = from_config(conf_loader["pipelines"].get("processing"))
>>> pipelines["processing"] = processing_pipeline
>>> return pipelines
"""
pipelines_dict = {"__default__": pipeline([])}
for pipeline_name in config_entry["pipe"]:

pipeline_obj = _get_pipeline_obj(pipeline_name, raise_errors)

if pipeline_obj is not None:
pipelines_dict[pipeline_name] = pipeline_obj
config_pipeline = pipeline(
pipe=sum(pipelines_dict.values()),
inputs=config_entry.get("inputs", None),
outputs=config_entry.get("outputs", None),
parameters=config_entry.get("parameters", None),
tags=config_entry.get("tags", None),
namespace=config_entry.get("namespace", None)
)
return config_pipeline
56 changes: 36 additions & 20 deletions tests/framework/project/test_pipeline_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from kedro.framework.project import configure_project, find_pipelines
from kedro.framework.project import configure_project, find_pipelines, from_config


@pytest.fixture
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_find_pipelines(mock_package_name_with_pipelines, pipeline_names):
indirect=True,
)
def test_find_pipelines_skips_modules_without_create_pipelines_function(
mock_package_name_with_pipelines, pipeline_names
mock_package_name_with_pipelines, pipeline_names
):
# Create a module without `create_pipelines` in the `pipelines` dir.
pipelines_dir = Path(sys.path[0]) / mock_package_name_with_pipelines / "pipelines"
Expand All @@ -76,7 +76,7 @@ def test_find_pipelines_skips_modules_without_create_pipelines_function(

configure_project(mock_package_name_with_pipelines)
with pytest.warns(
UserWarning, match="module does not expose a 'create_pipeline' function"
UserWarning, match="module does not expose a 'create_pipeline' function"
):
pipelines = find_pipelines()
assert set(pipelines) == pipeline_names | {"__default__"}
Expand All @@ -89,7 +89,7 @@ def test_find_pipelines_skips_modules_without_create_pipelines_function(
indirect=True,
)
def test_find_pipelines_skips_hidden_modules(
mock_package_name_with_pipelines, pipeline_names
mock_package_name_with_pipelines, pipeline_names
):
pipelines_dir = Path(sys.path[0]) / mock_package_name_with_pipelines / "pipelines"
pipeline_dir = pipelines_dir / ".ipynb_checkpoints"
Expand Down Expand Up @@ -120,7 +120,7 @@ def create_pipeline(**kwargs) -> Pipeline:
indirect=True,
)
def test_find_pipelines_skips_modules_with_unexpected_return_value_type(
mock_package_name_with_pipelines, pipeline_names
mock_package_name_with_pipelines, pipeline_names
):
# Define `create_pipelines` so that it does not return a `Pipeline`.
pipelines_dir = Path(sys.path[0]) / mock_package_name_with_pipelines / "pipelines"
Expand All @@ -145,11 +145,11 @@ def create_pipeline(**kwargs) -> dict[str, Pipeline]:

configure_project(mock_package_name_with_pipelines)
with pytest.warns(
UserWarning,
match=(
r"Expected the 'create_pipeline' function in the '\S+' "
r"module to return a 'Pipeline' object, got 'dict' instead."
),
UserWarning,
match=(
r"Expected the 'create_pipeline' function in the '\S+' "
r"module to return a 'Pipeline' object, got 'dict' instead."
),
):
pipelines = find_pipelines()
assert set(pipelines) == pipeline_names | {"__default__"}
Expand All @@ -162,7 +162,7 @@ def create_pipeline(**kwargs) -> dict[str, Pipeline]:
indirect=True,
)
def test_find_pipelines_skips_regular_files_within_the_pipelines_folder(
mock_package_name_with_pipelines, pipeline_names
mock_package_name_with_pipelines, pipeline_names
):
# Create a regular file (not a subdirectory) in the `pipelines` dir.
pipelines_dir = Path(sys.path[0]) / mock_package_name_with_pipelines / "pipelines"
Expand All @@ -186,7 +186,7 @@ def test_find_pipelines_skips_regular_files_within_the_pipelines_folder(
indirect=["mock_package_name_with_pipelines", "pipeline_names"],
)
def test_find_pipelines_skips_modules_that_cause_exceptions_upon_import(
mock_package_name_with_pipelines, pipeline_names, raise_errors
mock_package_name_with_pipelines, pipeline_names, raise_errors
):
# Create a module that will result in errors when we try to load it.
pipelines_dir = Path(sys.path[0]) / mock_package_name_with_pipelines / "pipelines"
Expand All @@ -196,8 +196,8 @@ def test_find_pipelines_skips_modules_that_cause_exceptions_upon_import(

configure_project(mock_package_name_with_pipelines)
with getattr(pytest, "raises" if raise_errors else "warns")(
ImportError if raise_errors else UserWarning,
match=r"An error occurred while importing the '\S+' module.",
ImportError if raise_errors else UserWarning,
match=r"An error occurred while importing the '\S+' module.",
):
pipelines = find_pipelines(raise_errors=raise_errors)
if not raise_errors:
Expand All @@ -211,7 +211,7 @@ def test_find_pipelines_skips_modules_that_cause_exceptions_upon_import(
indirect=True,
)
def test_find_pipelines_handles_simplified_project_structure(
mock_package_name_with_pipelines, pipeline_names
mock_package_name_with_pipelines, pipeline_names
):
(Path(sys.path[0]) / mock_package_name_with_pipelines / "pipeline.py").write_text(
textwrap.dedent(
Expand Down Expand Up @@ -241,7 +241,7 @@ def create_pipeline(**kwargs) -> Pipeline:
indirect=["mock_package_name_with_pipelines", "pipeline_names"],
)
def test_find_pipelines_skips_unimportable_pipeline_module(
mock_package_name_with_pipelines, pipeline_names, raise_errors
mock_package_name_with_pipelines, pipeline_names, raise_errors
):
(Path(sys.path[0]) / mock_package_name_with_pipelines / "pipeline.py").write_text(
textwrap.dedent(
Expand All @@ -259,8 +259,8 @@ def create_pipeline(**kwargs) -> Pipeline:

configure_project(mock_package_name_with_pipelines)
with getattr(pytest, "raises" if raise_errors else "warns")(
ImportError if raise_errors else UserWarning,
match=r"An error occurred while importing the '\S+' module.",
ImportError if raise_errors else UserWarning,
match=r"An error occurred while importing the '\S+' module.",
):
pipelines = find_pipelines(raise_errors=raise_errors)
if not raise_errors:
Expand All @@ -274,15 +274,15 @@ def create_pipeline(**kwargs) -> Pipeline:
indirect=["mock_package_name_with_pipelines"],
)
def test_find_pipelines_handles_project_structure_without_pipelines_dir(
mock_package_name_with_pipelines, simplified
mock_package_name_with_pipelines, simplified
):
# Delete the `pipelines` directory to simulate a project without it.
pipelines_dir = Path(sys.path[0]) / mock_package_name_with_pipelines / "pipelines"
shutil.rmtree(pipelines_dir)

if simplified:
(
Path(sys.path[0]) / mock_package_name_with_pipelines / "pipeline.py"
Path(sys.path[0]) / mock_package_name_with_pipelines / "pipeline.py"
).write_text(
textwrap.dedent(
"""
Expand All @@ -301,3 +301,19 @@ def create_pipeline(**kwargs) -> Pipeline:
assert sum(pipelines.values()).outputs() == (
{"simple_pipeline"} if simplified else set()
)


@pytest.fixture
def mock_config(pipeline_names):
return dict(pipe=pipeline_names)


@pytest.mark.parametrize(
"mock_package_name_with_pipelines,pipeline_names",
[(x, x) for x in [set(), {"my_pipeline"}]],
indirect=True,
)
def test_from_config(mock_package_name_with_pipelines, mock_config, pipeline_names):
configure_project(mock_package_name_with_pipelines)
pipeline_from_config = from_config(mock_config)
assert pipeline_from_config.outputs() == pipeline_names