-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathnode_hook.py
66 lines (54 loc) · 2.11 KB
/
node_hook.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import Any, Dict
import mlflow
from kedro.framework.hooks import hook_impl
from kedro.io import DataCatalog
from kedro.pipeline.node import Node
from kedro_mlflow.framework.context import get_mlflow_config
class MlflowNodeHook:
def __init__(
self, flatten_dict_params: bool = False, recursive: bool = True, sep: str = "."
):
config = get_mlflow_config()
self.flatten = config.node_hook_opts["flatten_dict_params"]
self.recursive = config.node_hook_opts["recursive"]
self.sep = config.node_hook_opts["sep"]
@hook_impl
def before_node_run(
self,
node: Node,
catalog: DataCatalog,
inputs: Dict[str, Any],
is_async: bool,
run_id: str,
) -> None:
"""Hook to be invoked before a node runs.
This hook logs all the paramters of the nodes in mlflow.
Args:
node: The ``Node`` to run.
catalog: A ``DataCatalog`` containing the node's inputs and outputs.
inputs: The dictionary of inputs dataset.
is_async: Whether the node was run in ``async`` mode.
run_id: The id of the run.
"""
# only parameters will be logged. Artifacts must be declared manually in the catalog
params_inputs = {}
for k, v in inputs.items():
if k.startswith("params:"):
params_inputs[k[7:]] = v
elif k == "parameters":
params_inputs[k] = v
# dictionary parameters may be flattened for readibility
if self.flatten:
params_inputs = flatten_dict(
d=params_inputs, recursive=self.recursive, sep=self.sep
)
mlflow.log_params(params_inputs)
def flatten_dict(d, recursive: bool = True, sep="."):
def expand(key, value):
if isinstance(value, dict):
new_value = flatten_dict(value) if recursive else value
return [(key + sep + k, v) for k, v in new_value.items()]
else:
return [(key, value)]
items = [item for k, v in d.items() for item in expand(k, v)]
return dict(items)