diff --git a/cosmos/dbt/selector.py b/cosmos/dbt/selector.py index 257d60721..9015b256d 100644 --- a/cosmos/dbt/selector.py +++ b/cosmos/dbt/selector.py @@ -84,6 +84,8 @@ def parse(text: str) -> GraphSelector | None: regex_match = re.search(GRAPH_SELECTOR_REGEX, text) if regex_match: precursors, node_name, descendants = regex_match.groups() + if "/" in node_name and not node_name.startswith(PATH_SELECTOR): + node_name = f"{PATH_SELECTOR}{node_name}" return GraphSelector(node_name, precursors, descendants) return None @@ -148,22 +150,43 @@ def filter_nodes(self, nodes: dict[str, DbtNode]) -> set[str]: :return: set of node ids that matches current graph selector """ selected_nodes: set[str] = set() + root_nodes: set[str] = set() # Index nodes by name, we can improve performance by doing this once # for multiple GraphSelectors - node_by_name = {} - for node_id, node in nodes.items(): - node_by_name[node.name] = node_id + if PATH_SELECTOR in self.node_name: + path_selection = self.node_name[len(PATH_SELECTOR):] + + for node_id, node in nodes.items(): + if path_selection in str(node.file_path): + root_nodes.add(node_id) + + elif TAG_SELECTOR in self.node_name: + tag_selection = self.node_name[len(TAG_SELECTOR):] + + for node_id, node in nodes.items(): + if tag_selection in node.tags: + root_nodes.add(node_id) + + elif CONFIG_SELECTOR in self.node_name: ... - if self.node_name in node_by_name: - root_id = node_by_name[self.node_name] else: - logger.warn(f"Selector {self.node_name} not found.") - return selected_nodes + node_by_name = {} + for node_id, node in nodes.items(): + node_by_name[node.name] = node_id + + if self.node_name in node_by_name: + root_id = node_by_name[self.node_name] + root_nodes.add(root_id) + else: + logger.warn(f"Selector {self.node_name} not found.") + return selected_nodes - selected_nodes.add(root_id) - self.select_node_precursors(nodes, root_id, selected_nodes) - self.select_node_descendants(nodes, root_id, selected_nodes) + selected_nodes.update(root_nodes) + + for root_id in root_nodes: + self.select_node_precursors(nodes, root_id, selected_nodes) + self.select_node_descendants(nodes, root_id, selected_nodes) return selected_nodes @@ -210,14 +233,22 @@ def load_from_statement(self, statement: str) -> None: items = statement.split(",") for item in items: - if item.startswith(PATH_SELECTOR): - self._parse_path_selector(item) - elif item.startswith(TAG_SELECTOR): - self._parse_tag_selector(item) - elif item.startswith(CONFIG_SELECTOR): - self._parse_config_selector(item) - else: - self._parse_unknown_selector(item) + regex_match = re.search(GRAPH_SELECTOR_REGEX, item) + if regex_match: + precursors, node_name, descendants = regex_match.groups() + + if precursors or descendants: + self._parse_unknown_selector(item) + elif node_name.startswith(PATH_SELECTOR): + self._parse_path_selector(item) + elif "/" in node_name: + self._parse_path_selector(f"{PATH_SELECTOR}{node_name}") + elif node_name.startswith(TAG_SELECTOR): + self._parse_tag_selector(item) + elif node_name.startswith(CONFIG_SELECTOR): + self._parse_config_selector(item) + else: + self._parse_unknown_selector(item) def _parse_unknown_selector(self, item: str) -> None: if item: