diff --git a/cosmos/providers/dbt/dag.py b/cosmos/providers/dbt/dag.py index 255912fed..0362efd43 100644 --- a/cosmos/providers/dbt/dag.py +++ b/cosmos/providers/dbt/dag.py @@ -8,6 +8,7 @@ def DbtDag( dbt_args: dict = None, emit_datasets: bool = True, dbt_root_path: str = "/usr/local/airflow/dbt", + include_tags: list = None, **kwargs, ): """ @@ -23,6 +24,8 @@ def DbtDag( :type emit_datasets: bool :param dbt_root_path: The path to the dbt root directory :type dbt_root_path: str + :param include_tags: A list of tags to filter the dbt project + :type include_tags: list :param kwargs: Additional kwargs to pass to the DAG :type kwargs: dict :return: The rendered DAG @@ -35,6 +38,7 @@ def DbtDag( dbt_args=dbt_args, emit_datasets=emit_datasets, dbt_root_path=dbt_root_path, + include_tags=include_tags, ) group = parser.parse() diff --git a/cosmos/providers/dbt/parser/jinja.py b/cosmos/providers/dbt/parser/jinja.py index 4b6d27a97..05c06d395 100644 --- a/cosmos/providers/dbt/parser/jinja.py +++ b/cosmos/providers/dbt/parser/jinja.py @@ -1,10 +1,13 @@ +from typing import Dict + import os -from typing import Dict, List from pathlib import Path import jinja2 +import yaml + -def extract_deps_from_models(project_path: str) -> Dict[str, List[str]]: +def extract_model_data(project_path: str) -> Dict[str, Dict[str, list]]: """ Takes a path to a dbt project and returns a list of all the models with their dependencies. Return value is a dictionary @@ -30,23 +33,15 @@ def extract_deps_from_models(project_path: str) -> Dict[str, List[str]]: raise ValueError(f"models_path {models_path} is not a directory.") # get all the files recursively - all_files = [ - os.path.join(root, f) - for root, _, files in os.walk(models_path) - for f in files - ] - - # get all the model files - model_files = [ - path for path in all_files - if os.path.isfile(path) and path.endswith(".sql") - ] + all_files = [os.path.join(root, f) for root, _, files in os.walk(models_path) for f in files] + model_data = {} - dependencies = {} + # MODELS (SQL Files) + model_files = [path for path in all_files if os.path.isfile(path) and path.endswith(".sql")] for model_file in model_files: - with open(model_file, "r", encoding="utf-8") as f: + with open(model_file, encoding="utf-8") as f: model = f.read() env = jinja2.Environment() @@ -54,12 +49,44 @@ def extract_deps_from_models(project_path: str) -> Dict[str, List[str]]: # if there's a dependency, add it to the list model_deps = [] + tags = [] for base_node in ast.find_all(jinja2.nodes.Call): if hasattr(base_node.node, "name"): if base_node.node.name == "ref": model_deps.append(base_node.args[0].value) + if base_node.node.name == "config" and base_node.kwargs[0].key == "tags": + lst = base_node.kwargs[0].value + tags = [x.value for x in lst.items if lst.items] # add the dependencies to the dictionary, without the .sql extension - dependencies[Path(model_file).stem] = model_deps + # add the tags to the dictionary + model_data[Path(model_file).stem] = {"dependencies": model_deps, "tags": tags} + + # PROPERTIES.YML + properties_files = [path for path in all_files if os.path.isfile(path) and path.endswith(".yml")] + + # get tags for models from property files + for properties_file in properties_files: + with open(properties_file) as f: + properties = yaml.safe_load(f) + + models = properties["models"] + for model in models: + properties_tags = [] + + if "config" in model: + config = model["config"] + [properties_tags.append(item) for item in config.get("tags", [])] + [properties_tags.append(item) for item in config.get("+tags", [])] + + if model.get("name") in model_data: + # add tags to model_data without deleting the tags collected from models + exiting_tags = model_data[model.get("name")]["tags"] + model_data[model.get("name")]["tags"] = list(set(exiting_tags).union(set(properties_tags))) + + # TODO: Decide if we want to include tags from the project file + + # DBT_PROJECT.YML + # project_file = os.path.join(project_path, "dbt_project.yml") - return dependencies \ No newline at end of file + return model_data diff --git a/cosmos/providers/dbt/parser/project.py b/cosmos/providers/dbt/parser/project.py index 724c29c4b..0a5c6e5ab 100644 --- a/cosmos/providers/dbt/parser/project.py +++ b/cosmos/providers/dbt/parser/project.py @@ -6,7 +6,9 @@ from cosmos.core.graph.entities import Group, Task from cosmos.core.parse.base_parser import BaseParser -from .jinja import extract_deps_from_models +from .jinja import extract_model_data + +# from .jinja import extract_model_tags from .utils import validate_directory logger = logging.getLogger(__name__) @@ -24,6 +26,7 @@ def __init__( dbt_args: dict = {}, dbt_root_path: str = "/usr/local/airflow/dbt", emit_datasets: bool = True, + include_tags: list = None, ): """ Initializes the parser. @@ -38,6 +41,8 @@ def __init__( :type dbt_root_path: str :param emit_datasets: If enabled test nodes emit Airflow Datasets for downstream cross-DAG dependencies :type emit_datasets: bool + :param include_tags: A list of tags to filter the dbt project + :type include_tags: list """ # validate the dbt root path self.project_name = project_name @@ -59,11 +64,14 @@ def __init__( # emit datasets from test tasks unless false self.emit_datasets = emit_datasets + # list of tags to filter down to + self.include_tags = include_tags + def parse(self): """ Parses the dbt project in the project_path into `cosmos` entities. """ - models = extract_deps_from_models(project_path=self.project_path) + models = extract_model_data(project_path=self.project_path) project_name = os.path.basename(self.project_path) @@ -71,49 +79,65 @@ def parse(self): entities = {} - for model, deps in models.items(): - args = { - "conn_id": self.conn_id, - "project_dir": self.project_path, - "models": model, - **self.dbt_args, - } - # make the run task - run_task = Task( - id=f"{model}_run", - operator_class="cosmos.providers.dbt.core.operators.DbtRunOperator", - arguments=args, - ) - entities[run_task.id] = run_task - - # make the test task - if self.emit_datasets: - args["outlets"] = [Dataset(f"DBT://{self.conn_id.upper()}/{self.project_name.upper()}/{model.upper()}")] - test_task = Task( - id=f"{model}_test", - operator_class="cosmos.providers.dbt.core.operators.DbtTestOperator", - upstream_entity_ids=[run_task.id], - arguments=args, - ) - entities[test_task.id] = test_task - - # make the group - model_group = Group( - id=model, - entities=[run_task, test_task], - ) - entities[model_group.id] = model_group - - # just add to base group for now - base_group.add_entity(entity=model_group) + for model, attr in models.items(): + included = True + if self.include_tags: + for item in self.include_tags: + if item not in attr["tags"]: + included = False + + if included: + args = { + "conn_id": self.conn_id, + "project_dir": self.project_path, + "models": model, + **self.dbt_args, + } + # make the run task + run_task = Task( + id=f"{model}_run", + operator_class="cosmos.providers.dbt.core.operators.DbtRunOperator", + arguments=args, + ) + entities[run_task.id] = run_task + + # make the test task + if self.emit_datasets: + args["outlets"] = [ + Dataset(f"DBT://{self.conn_id.upper()}/{self.project_name.upper()}/{model.upper()}") + ] + test_task = Task( + id=f"{model}_test", + operator_class="cosmos.providers.dbt.core.operators.DbtTestOperator", + upstream_entity_ids=[run_task.id], + arguments=args, + ) + entities[test_task.id] = test_task + + # make the group + model_group = Group( + id=model, + entities=[run_task, test_task], + ) + entities[model_group.id] = model_group + + # just add to base group for now + base_group.add_entity(entity=model_group) # add dependencies - for model, deps in models.items(): - for dep in deps: - try: - dep_task = entities[dep] - entities[model].add_upstream(dep_task) - except KeyError: - logger.error(f"Dependency {dep} not found for model {model}") + for model, attr in models.items(): + included = True + if self.include_tags: + for item in self.include_tags: + if item not in attr["tags"]: + included = False + + if included: + for dep in attr["dependencies"]: + try: + dep_task = entities[dep] + entities[model].add_upstream(dep_task) + except KeyError: + logger.error(f"Dependency {dep} not found for model {model}") return base_group diff --git a/cosmos/providers/dbt/task_group.py b/cosmos/providers/dbt/task_group.py index 4a9b84db2..e08b2518c 100644 --- a/cosmos/providers/dbt/task_group.py +++ b/cosmos/providers/dbt/task_group.py @@ -11,6 +11,7 @@ def DbtTaskGroup( dbt_args: dict = None, emit_datasets: bool = True, dbt_root_path: str = "/usr/local/airflow/dbt", + include_tags: list = None, **kwargs, ): """ @@ -27,6 +28,8 @@ def DbtTaskGroup( :param kwargs: Additional kwargs to pass to the DAG :param dbt_root_path: The path to the dbt root directory :type dbt_root_path: str + :param include_tags: A list of tags to filter the dbt project + :type include_tags: list :type kwargs: dict :return: The rendered Task Group :rtype: airflow.utils.task_group.TaskGroup @@ -38,6 +41,7 @@ def DbtTaskGroup( dbt_args=dbt_args, emit_datasets=emit_datasets, dbt_root_path=dbt_root_path, + include_tags=include_tags, ) group = parser.parse()