From 2213b22a6d69f41da1dde4520f58cf41c049213b Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Fri, 18 Nov 2022 13:22:06 +0530 Subject: [PATCH] Address @feluelle's comments --- sql-cli/sql_cli/workflow_directory_parser.py | 30 +++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/sql-cli/sql_cli/workflow_directory_parser.py b/sql-cli/sql_cli/workflow_directory_parser.py index b84c69daa..d58ff4fd3 100644 --- a/sql-cli/sql_cli/workflow_directory_parser.py +++ b/sql-cli/sql_cli/workflow_directory_parser.py @@ -2,11 +2,13 @@ from abc import abstractmethod from pathlib import Path -from typing import Any, Iterable +from typing import Iterable import airflow import frontmatter import yaml +from airflow.models import BaseOperator +from yaml.scanner import ScannerError from astro.sql import LoadFileOperator from astro.sql.operators.transform import TransformOperator @@ -147,7 +149,7 @@ def write_raw_content_to_target_path(self) -> None: target_path.write_text(self.raw_content) @abstractmethod - def to_operator(self) -> Any: + def to_operator(self) -> BaseOperator: raise NotImplementedError() @@ -185,12 +187,18 @@ class YamlFile(WorkflowFile): def __init__(self, root_directory: Path, path: Path, target_directory: Path) -> None: super().__init__(root_directory, path, target_directory, YAML_FILE_TYPE) - with open(path) as yaml_file: - self.yaml_content = yaml.load(yaml_file, Loader=yaml.Loader) + with path.open() as yaml_file: + self.yaml_content = yaml.safe_load(yaml_file) self.operator = self._get_operator() self.yaml_file_name = self.path.stem def _get_operator(self) -> str: + top_level_keys = list(self.yaml_content.keys()) + top_level_keys_count = len(top_level_keys) + if top_level_keys_count > 1: + raise ScannerError( + f"Only one top level operator expected. Got {top_level_keys_count} operators: {top_level_keys}" + ) operator = list(self.yaml_content.keys())[0] if operator not in YamlFile.operator_instance_builder_callable_map: raise NotImplementedError(f"Operator support for {operator} not available") @@ -205,7 +213,7 @@ def to_operator(self) -> LoadFileOperator: ) -FILE_TYPE_CLASS_MAP = {SQL_FILE_TYPE: SqlFile, YAML_FILE_TYPE: YamlFile} +SUPPORTED_FILES = {"*.sql": SqlFile, "*.yaml": YamlFile, "*.yml": YamlFile} def get_files_by_type( @@ -221,15 +229,17 @@ def get_files_by_type( :returns: the workflow files found in the directory. """ return { - FILE_TYPE_CLASS_MAP[file_type]( + SUPPORTED_FILES[file_type]( root_directory=directory, path=child, target_directory=target_directory # type: ignore ) - for child in directory.rglob(f"*.{file_type}") + for child in directory.rglob(file_type) if child.is_file() and not child.is_symlink() } def get_workflow_files(directory: Path, target_directory: Path | None) -> set[WorkflowFile]: - sql_files = get_files_by_type(directory, SQL_FILE_TYPE, target_directory=target_directory) - yaml_files = get_files_by_type(directory, YAML_FILE_TYPE, target_directory=target_directory) - return sql_files.union(yaml_files) + workflow_files: set[WorkflowFile] = set() + for file_type in SUPPORTED_FILES: + files = get_files_by_type(directory, file_type, target_directory=target_directory) + workflow_files = workflow_files.union(files) + return workflow_files