Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include model tags in parsers as well as a list parameter to filter them #69

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cosmos/providers/dbt/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def DbtDag(
dbt_args: dict = None,
emit_datasets: bool = True,
dbt_root_path: str = "/usr/local/airflow/dbt",
include_tags: list = None,
**kwargs,
):
"""
Expand All @@ -23,6 +24,8 @@ def DbtDag(
:type emit_datasets: bool
:param dbt_root_path: The path to the dbt root directory
:type dbt_root_path: str
:param include_tags: A list of tags to filter the dbt project
:type include_tags: list
:param kwargs: Additional kwargs to pass to the DAG
:type kwargs: dict
:return: The rendered DAG
Expand All @@ -35,6 +38,7 @@ def DbtDag(
dbt_args=dbt_args,
emit_datasets=emit_datasets,
dbt_root_path=dbt_root_path,
include_tags=include_tags,
)
group = parser.parse()

Expand Down
61 changes: 44 additions & 17 deletions cosmos/providers/dbt/parser/jinja.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Dict

import os
from typing import Dict, List
from pathlib import Path

import jinja2
import yaml


def extract_deps_from_models(project_path: str) -> Dict[str, List[str]]:
def extract_model_data(project_path: str) -> Dict[str, Dict[str, list]]:
"""
Takes a path to a dbt project and returns a list of all the
models with their dependencies. Return value is a dictionary
Expand All @@ -30,36 +33,60 @@ def extract_deps_from_models(project_path: str) -> Dict[str, List[str]]:
raise ValueError(f"models_path {models_path} is not a directory.")

# get all the files recursively
all_files = [
os.path.join(root, f)
for root, _, files in os.walk(models_path)
for f in files
]

# get all the model files
model_files = [
path for path in all_files
if os.path.isfile(path) and path.endswith(".sql")
]
all_files = [os.path.join(root, f) for root, _, files in os.walk(models_path) for f in files]

model_data = {}

dependencies = {}
# MODELS (SQL Files)
model_files = [path for path in all_files if os.path.isfile(path) and path.endswith(".sql")]

for model_file in model_files:
with open(model_file, "r", encoding="utf-8") as f:
with open(model_file, encoding="utf-8") as f:
model = f.read()

env = jinja2.Environment()
ast = env.parse(model)

# if there's a dependency, add it to the list
model_deps = []
tags = []
for base_node in ast.find_all(jinja2.nodes.Call):
if hasattr(base_node.node, "name"):
if base_node.node.name == "ref":
model_deps.append(base_node.args[0].value)
if base_node.node.name == "config" and base_node.kwargs[0].key == "tags":
lst = base_node.kwargs[0].value
tags = [x.value for x in lst.items if lst.items]

# add the dependencies to the dictionary, without the .sql extension
dependencies[Path(model_file).stem] = model_deps
# add the tags to the dictionary
model_data[Path(model_file).stem] = {"dependencies": model_deps, "tags": tags}

# PROPERTIES.YML
properties_files = [path for path in all_files if os.path.isfile(path) and path.endswith(".yml")]

# get tags for models from property files
for properties_file in properties_files:
with open(properties_file) as f:
properties = yaml.safe_load(f)

models = properties["models"]
for model in models:
properties_tags = []

if "config" in model:
config = model["config"]
[properties_tags.append(item) for item in config.get("tags", [])]
[properties_tags.append(item) for item in config.get("+tags", [])]

if model.get("name") in model_data:
# add tags to model_data without deleting the tags collected from models
exiting_tags = model_data[model.get("name")]["tags"]
model_data[model.get("name")]["tags"] = list(set(exiting_tags).union(set(properties_tags)))

# TODO: Decide if we want to include tags from the project file

# DBT_PROJECT.YML
# project_file = os.path.join(project_path, "dbt_project.yml")

return dependencies
return model_data
112 changes: 68 additions & 44 deletions cosmos/providers/dbt/parser/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from cosmos.core.graph.entities import Group, Task
from cosmos.core.parse.base_parser import BaseParser

from .jinja import extract_deps_from_models
from .jinja import extract_model_data

# from .jinja import extract_model_tags
from .utils import validate_directory

logger = logging.getLogger(__name__)
Expand All @@ -24,6 +26,7 @@ def __init__(
dbt_args: dict = {},
dbt_root_path: str = "/usr/local/airflow/dbt",
emit_datasets: bool = True,
include_tags: list = None,
):
"""
Initializes the parser.
Expand All @@ -38,6 +41,8 @@ def __init__(
:type dbt_root_path: str
:param emit_datasets: If enabled test nodes emit Airflow Datasets for downstream cross-DAG dependencies
:type emit_datasets: bool
:param include_tags: A list of tags to filter the dbt project
:type include_tags: list
"""
# validate the dbt root path
self.project_name = project_name
Expand All @@ -59,61 +64,80 @@ def __init__(
# emit datasets from test tasks unless false
self.emit_datasets = emit_datasets

# list of tags to filter down to
self.include_tags = include_tags

def parse(self):
"""
Parses the dbt project in the project_path into `cosmos` entities.
"""
models = extract_deps_from_models(project_path=self.project_path)
models = extract_model_data(project_path=self.project_path)

project_name = os.path.basename(self.project_path)

base_group = Group(id=project_name)

entities = {}

for model, deps in models.items():
args = {
"conn_id": self.conn_id,
"project_dir": self.project_path,
"models": model,
**self.dbt_args,
}
# make the run task
run_task = Task(
id=f"{model}_run",
operator_class="cosmos.providers.dbt.core.operators.DbtRunOperator",
arguments=args,
)
entities[run_task.id] = run_task

# make the test task
if self.emit_datasets:
args["outlets"] = [Dataset(f"DBT://{self.conn_id.upper()}/{self.project_name.upper()}/{model.upper()}")]
test_task = Task(
id=f"{model}_test",
operator_class="cosmos.providers.dbt.core.operators.DbtTestOperator",
upstream_entity_ids=[run_task.id],
arguments=args,
)
entities[test_task.id] = test_task

# make the group
model_group = Group(
id=model,
entities=[run_task, test_task],
)
entities[model_group.id] = model_group

# just add to base group for now
base_group.add_entity(entity=model_group)
for model, attr in models.items():
included = True
if self.include_tags:
for item in self.include_tags:
if item not in attr["tags"]:
included = False

if included:
args = {
"conn_id": self.conn_id,
"project_dir": self.project_path,
"models": model,
**self.dbt_args,
}
# make the run task
run_task = Task(
id=f"{model}_run",
operator_class="cosmos.providers.dbt.core.operators.DbtRunOperator",
arguments=args,
)
entities[run_task.id] = run_task

# make the test task
if self.emit_datasets:
args["outlets"] = [
Dataset(f"DBT://{self.conn_id.upper()}/{self.project_name.upper()}/{model.upper()}")
]
test_task = Task(
id=f"{model}_test",
operator_class="cosmos.providers.dbt.core.operators.DbtTestOperator",
upstream_entity_ids=[run_task.id],
arguments=args,
)
entities[test_task.id] = test_task

# make the group
model_group = Group(
id=model,
entities=[run_task, test_task],
)
entities[model_group.id] = model_group

# just add to base group for now
base_group.add_entity(entity=model_group)

# add dependencies
for model, deps in models.items():
for dep in deps:
try:
dep_task = entities[dep]
entities[model].add_upstream(dep_task)
except KeyError:
logger.error(f"Dependency {dep} not found for model {model}")
for model, attr in models.items():
included = True
if self.include_tags:
for item in self.include_tags:
if item not in attr["tags"]:
included = False

if included:
for dep in attr["dependencies"]:
try:
dep_task = entities[dep]
entities[model].add_upstream(dep_task)
except KeyError:
logger.error(f"Dependency {dep} not found for model {model}")

return base_group
4 changes: 4 additions & 0 deletions cosmos/providers/dbt/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def DbtTaskGroup(
dbt_args: dict = None,
emit_datasets: bool = True,
dbt_root_path: str = "/usr/local/airflow/dbt",
include_tags: list = None,
**kwargs,
):
"""
Expand All @@ -27,6 +28,8 @@ def DbtTaskGroup(
:param kwargs: Additional kwargs to pass to the DAG
:param dbt_root_path: The path to the dbt root directory
:type dbt_root_path: str
:param include_tags: A list of tags to filter the dbt project
:type include_tags: list
:type kwargs: dict
:return: The rendered Task Group
:rtype: airflow.utils.task_group.TaskGroup
Expand All @@ -38,6 +41,7 @@ def DbtTaskGroup(
dbt_args=dbt_args,
emit_datasets=emit_datasets,
dbt_root_path=dbt_root_path,
include_tags=include_tags,
)
group = parser.parse()

Expand Down