Skip to content

Commit

Permalink
Address @feluelle's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti committed Nov 18, 2022
1 parent d8ec6fe commit 2213b22
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions sql-cli/sql_cli/workflow_directory_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from abc import abstractmethod
from pathlib import Path
from typing import Any, Iterable
from typing import Iterable

import airflow
import frontmatter
import yaml
from airflow.models import BaseOperator
from yaml.scanner import ScannerError

from astro.sql import LoadFileOperator
from astro.sql.operators.transform import TransformOperator
Expand Down Expand Up @@ -147,7 +149,7 @@ def write_raw_content_to_target_path(self) -> None:
target_path.write_text(self.raw_content)

@abstractmethod
def to_operator(self) -> Any:
def to_operator(self) -> BaseOperator:
raise NotImplementedError()


Expand Down Expand Up @@ -185,12 +187,18 @@ class YamlFile(WorkflowFile):

def __init__(self, root_directory: Path, path: Path, target_directory: Path) -> None:
super().__init__(root_directory, path, target_directory, YAML_FILE_TYPE)
with open(path) as yaml_file:
self.yaml_content = yaml.load(yaml_file, Loader=yaml.Loader)
with path.open() as yaml_file:
self.yaml_content = yaml.safe_load(yaml_file)
self.operator = self._get_operator()
self.yaml_file_name = self.path.stem

def _get_operator(self) -> str:
top_level_keys = list(self.yaml_content.keys())
top_level_keys_count = len(top_level_keys)
if top_level_keys_count > 1:
raise ScannerError(
f"Only one top level operator expected. Got {top_level_keys_count} operators: {top_level_keys}"
)
operator = list(self.yaml_content.keys())[0]
if operator not in YamlFile.operator_instance_builder_callable_map:
raise NotImplementedError(f"Operator support for {operator} not available")
Expand All @@ -205,7 +213,7 @@ def to_operator(self) -> LoadFileOperator:
)


FILE_TYPE_CLASS_MAP = {SQL_FILE_TYPE: SqlFile, YAML_FILE_TYPE: YamlFile}
SUPPORTED_FILES = {"*.sql": SqlFile, "*.yaml": YamlFile, "*.yml": YamlFile}


def get_files_by_type(
Expand All @@ -221,15 +229,17 @@ def get_files_by_type(
:returns: the workflow files found in the directory.
"""
return {
FILE_TYPE_CLASS_MAP[file_type](
SUPPORTED_FILES[file_type](
root_directory=directory, path=child, target_directory=target_directory # type: ignore
)
for child in directory.rglob(f"*.{file_type}")
for child in directory.rglob(file_type)
if child.is_file() and not child.is_symlink()
}


def get_workflow_files(directory: Path, target_directory: Path | None) -> set[WorkflowFile]:
sql_files = get_files_by_type(directory, SQL_FILE_TYPE, target_directory=target_directory)
yaml_files = get_files_by_type(directory, YAML_FILE_TYPE, target_directory=target_directory)
return sql_files.union(yaml_files)
workflow_files: set[WorkflowFile] = set()
for file_type in SUPPORTED_FILES:
files = get_files_by_type(directory, file_type, target_directory=target_directory)
workflow_files = workflow_files.union(files)
return workflow_files

0 comments on commit 2213b22

Please sign in to comment.