From e1a836415af9b88eb703c0bbeb6a4163756e97db Mon Sep 17 00:00:00 2001 From: Micah Victoria Date: Sun, 11 Aug 2024 22:28:55 -0400 Subject: [PATCH] update GraphSelector and SelectorConfig updated GraphSelector to take into account path and tag dbt selector methods update SelectorConfig to use regex to parse selection statement to handle graph and path statements --- cosmos/dbt/selector.py | 67 ++++++++++++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 18 deletions(-) diff --git a/cosmos/dbt/selector.py b/cosmos/dbt/selector.py index 257d60721..9015b256d 100644 --- a/cosmos/dbt/selector.py +++ b/cosmos/dbt/selector.py @@ -84,6 +84,8 @@ def parse(text: str) -> GraphSelector | None: regex_match = re.search(GRAPH_SELECTOR_REGEX, text) if regex_match: precursors, node_name, descendants = regex_match.groups() + if "/" in node_name and not node_name.startswith(PATH_SELECTOR): + node_name = f"{PATH_SELECTOR}{node_name}" return GraphSelector(node_name, precursors, descendants) return None @@ -148,22 +150,43 @@ def filter_nodes(self, nodes: dict[str, DbtNode]) -> set[str]: :return: set of node ids that matches current graph selector """ selected_nodes: set[str] = set() + root_nodes: set[str] = set() # Index nodes by name, we can improve performance by doing this once # for multiple GraphSelectors - node_by_name = {} - for node_id, node in nodes.items(): - node_by_name[node.name] = node_id + if PATH_SELECTOR in self.node_name: + path_selection = self.node_name[len(PATH_SELECTOR):] + + for node_id, node in nodes.items(): + if path_selection in str(node.file_path): + root_nodes.add(node_id) + + elif TAG_SELECTOR in self.node_name: + tag_selection = self.node_name[len(TAG_SELECTOR):] + + for node_id, node in nodes.items(): + if tag_selection in node.tags: + root_nodes.add(node_id) + + elif CONFIG_SELECTOR in self.node_name: ... - if self.node_name in node_by_name: - root_id = node_by_name[self.node_name] else: - logger.warn(f"Selector {self.node_name} not found.") - return selected_nodes + node_by_name = {} + for node_id, node in nodes.items(): + node_by_name[node.name] = node_id + + if self.node_name in node_by_name: + root_id = node_by_name[self.node_name] + root_nodes.add(root_id) + else: + logger.warn(f"Selector {self.node_name} not found.") + return selected_nodes - selected_nodes.add(root_id) - self.select_node_precursors(nodes, root_id, selected_nodes) - self.select_node_descendants(nodes, root_id, selected_nodes) + selected_nodes.update(root_nodes) + + for root_id in root_nodes: + self.select_node_precursors(nodes, root_id, selected_nodes) + self.select_node_descendants(nodes, root_id, selected_nodes) return selected_nodes @@ -210,14 +233,22 @@ def load_from_statement(self, statement: str) -> None: items = statement.split(",") for item in items: - if item.startswith(PATH_SELECTOR): - self._parse_path_selector(item) - elif item.startswith(TAG_SELECTOR): - self._parse_tag_selector(item) - elif item.startswith(CONFIG_SELECTOR): - self._parse_config_selector(item) - else: - self._parse_unknown_selector(item) + regex_match = re.search(GRAPH_SELECTOR_REGEX, item) + if regex_match: + precursors, node_name, descendants = regex_match.groups() + + if precursors or descendants: + self._parse_unknown_selector(item) + elif node_name.startswith(PATH_SELECTOR): + self._parse_path_selector(item) + elif "/" in node_name: + self._parse_path_selector(f"{PATH_SELECTOR}{node_name}") + elif node_name.startswith(TAG_SELECTOR): + self._parse_tag_selector(item) + elif node_name.startswith(CONFIG_SELECTOR): + self._parse_config_selector(item) + else: + self._parse_unknown_selector(item) def _parse_unknown_selector(self, item: str) -> None: if item: