From 2a4ef941cf967940e0ebf68f8d1956b8673ad94e Mon Sep 17 00:00:00 2001 From: kgopal Date: Tue, 19 Sep 2023 16:08:28 -0400 Subject: [PATCH 1/6] Add column name suggestions to presto validator --- .../server/lib/elasticsearch/search_table.py | 42 ++++ .../validation/base_query_validator.py | 1 + .../validators/base_sqlglot_validator.py | 25 ++- .../validators/presto_explain_validator.py | 1 + .../validators/presto_optimizing_validator.py | 211 +++++++++++++----- 5 files changed, 220 insertions(+), 60 deletions(-) diff --git a/querybook/server/lib/elasticsearch/search_table.py b/querybook/server/lib/elasticsearch/search_table.py index ad1d73160..71448a38f 100644 --- a/querybook/server/lib/elasticsearch/search_table.py +++ b/querybook/server/lib/elasticsearch/search_table.py @@ -1,9 +1,14 @@ +from typing import Dict, List, Tuple from lib.elasticsearch.query_utils import ( match_filters, highlight_fields, order_by_fields, combine_keyword_and_filter_query, ) +from lib.elasticsearch.search_utils import ( + ES_CONFIG, + get_matching_objects, +) FILTERS_TO_AND = ["tags", "data_elements"] @@ -173,3 +178,40 @@ def construct_tables_query_by_table_names( } return query + + +def get_column_name_suggestion( + fuzzy_column_name: str, full_table_names: List[str] +) -> Tuple[Dict, int]: + """Given an invalid column name and a list of tables to search from, uses fuzzy search to search + the correctly-spelled column name""" + should_clause = [] + for full_table_name in full_table_names: + schema_name, table_name = full_table_name.split(".") + should_clause.append( + { + "bool": { + "must": [ + {"match": {"name": table_name}}, + {"match": {"schema": schema_name}}, + ] + } + } + ) + + search_query = { + "query": { + "bool": { + "must": { + "match": { + "columns": {"query": fuzzy_column_name, "fuzziness": "AUTO"} + } + }, + "should": should_clause, + "minimum_should_match": 1, + }, + }, + "highlight": {"pre_tags": [""], "post_tags": [""], "fields": {"columns": {}}}, + } + + return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True) diff --git a/querybook/server/lib/query_analysis/validation/base_query_validator.py b/querybook/server/lib/query_analysis/validation/base_query_validator.py index d2ed3a3d2..6014c89db 100644 --- a/querybook/server/lib/query_analysis/validation/base_query_validator.py +++ b/querybook/server/lib/query_analysis/validation/base_query_validator.py @@ -67,6 +67,7 @@ def validate( query: str, uid: int, # who is doing the syntax check engine_id: int, # which engine they are checking against + **kwargs, ) -> List[QueryValidationResult]: raise NotImplementedError() diff --git a/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py b/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py index 067f1eef0..073d13ad8 100644 --- a/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py +++ b/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py @@ -1,5 +1,5 @@ -from abc import ABCMeta, abstractmethod -from typing import List, Tuple +from abc import abstractmethod +from typing import Any, Dict, List, Tuple from sqlglot import Tokenizer from sqlglot.tokens import Token @@ -8,9 +8,13 @@ QueryValidationResultObjectType, QueryValidationSeverity, ) +from lib.query_analysis.validation.base_query_validator import BaseQueryValidator -class BaseSQLGlotValidator(metaclass=ABCMeta): +class BaseSQLGlotValidator(BaseQueryValidator): + def __init__(self, name: str = "", config: Dict[str, Any] = {}): + super(BaseSQLGlotValidator, self).__init__(name, config) + @property @abstractmethod def message(self) -> str: @@ -33,6 +37,12 @@ def _get_query_coordinate_by_index(self, query: str, index: int) -> Tuple[int, i rows = query[: index + 1].splitlines(keepends=False) return len(rows) - 1, len(rows[-1]) - 1 + def _get_query_index_by_coordinate( + self, query: str, start_line: int, start_ch: int + ) -> int: + rows = query.splitlines(keepends=True)[:start_line] + return sum([len(row) for row in rows]) + start_ch + def _get_query_validation_result( self, query: str, @@ -56,7 +66,12 @@ def _get_query_validation_result( ) @abstractmethod - def get_query_validation_results( - self, query: str, raw_tokens: List[Token] = None + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[Token] = None, + **kwargs ) -> List[QueryValidationResult]: raise NotImplementedError() diff --git a/querybook/server/lib/query_analysis/validation/validators/presto_explain_validator.py b/querybook/server/lib/query_analysis/validation/validators/presto_explain_validator.py index 3a0736b9f..a307921d5 100644 --- a/querybook/server/lib/query_analysis/validation/validators/presto_explain_validator.py +++ b/querybook/server/lib/query_analysis/validation/validators/presto_explain_validator.py @@ -72,6 +72,7 @@ def validate( query: str, uid: int, # who is doing the syntax check engine_id: int, # which engine they are checking against + **kwargs, ) -> List[QueryValidationResult]: validation_errors = [] ( diff --git a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py index 5d726a4ff..0aee31210 100644 --- a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py +++ b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py @@ -1,8 +1,12 @@ -from typing import List +import re +from itertools import chain from sqlglot import TokenType, Tokenizer from sqlglot.dialects import Trino from sqlglot.tokens import Token +from typing import List +from lib.elasticsearch.search_table import get_column_name_suggestion +from lib.query_analysis.lineage import process_query from lib.query_analysis.validation.base_query_validator import ( BaseQueryValidator, QueryValidationResult, @@ -14,15 +18,33 @@ from lib.query_analysis.validation.validators.base_sqlglot_validator import ( BaseSQLGlotValidator, ) +from logic.admin import get_query_engine_by_id + +class BasePrestoSQLGlotDecorator(BaseSQLGlotValidator): + def __init__(self, validator: BaseQueryValidator): + self._validator = validator + + def languages(self): + return ["presto", "trino"] -class BasePrestoSQLGlotValidator(BaseSQLGlotValidator): @property def tokenizer(self) -> Tokenizer: return Trino.Tokenizer() + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[Token] = None, + **kwargs, + ): + """Override this method to add suggestions to validation results""" + return self._validator.validate(query, uid, engine_id, **kwargs) + -class UnionAllValidator(BasePrestoSQLGlotValidator): +class UnionAllValidator(BasePrestoSQLGlotDecorator): @property def message(self): return "Using UNION ALL instead of UNION will execute faster" @@ -31,27 +53,34 @@ def message(self): def severity(self) -> str: return QueryValidationSeverity.WARNING - def get_query_validation_results( - self, query: str, raw_tokens: List[Token] = None + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[Token] = None, + **kwargs, ) -> List[QueryValidationResult]: if raw_tokens is None: raw_tokens = self._tokenize_query(query) - validation_errors = [] + validation_results = self._validator.validate( + query, uid, engine_id, raw_tokens=raw_tokens + ) for i, token in enumerate(raw_tokens): if token.token_type == TokenType.UNION: if ( i < len(raw_tokens) - 1 and raw_tokens[i + 1].token_type != TokenType.ALL ): - validation_errors.append( + validation_results.append( self._get_query_validation_result( query, token.start, token.end, "UNION ALL" ) ) - return validation_errors + return validation_results -class ApproxDistinctValidator(BasePrestoSQLGlotValidator): +class ApproxDistinctValidator(BasePrestoSQLGlotDecorator): @property def message(self): return ( @@ -62,13 +91,20 @@ def message(self): def severity(self) -> str: return QueryValidationSeverity.WARNING - def get_query_validation_results( - self, query: str, raw_tokens: List[Token] = None + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[Token] = None, + **kwargs, ) -> List[QueryValidationResult]: if raw_tokens is None: raw_tokens = self._tokenize_query(query) - validation_errors = [] + validation_results = self._validator.validate( + query, uid, engine_id, raw_tokens=raw_tokens + ) for i, token in enumerate(raw_tokens): if ( i < len(raw_tokens) - 2 @@ -77,7 +113,7 @@ def get_query_validation_results( and raw_tokens[i + 1].token_type == TokenType.L_PAREN and raw_tokens[i + 2].token_type == TokenType.DISTINCT ): - validation_errors.append( + validation_results.append( self._get_query_validation_result( query, token.start, @@ -85,10 +121,10 @@ def get_query_validation_results( "APPROX_DISTINCT(", ) ) - return validation_errors + return validation_results -class RegexpLikeValidator(BasePrestoSQLGlotValidator): +class RegexpLikeValidator(BasePrestoSQLGlotDecorator): @property def message(self): return "Combining multiple LIKEs into one REGEXP_LIKE will execute faster" @@ -103,13 +139,20 @@ def _get_regexp_like_suggestion(self, column_name: str, like_strings: List[str]) ] return f"REGEXP_LIKE({column_name}, '{'|'.join(sanitized_like_strings)}')" - def get_query_validation_results( - self, query: str, raw_tokens: List[Token] = None + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[Token] = None, + **kwargs, ) -> List[QueryValidationResult]: if raw_tokens is None: raw_tokens = self._tokenize_query(query) - validation_errors = [] + validation_results = self._validator.validate( + query, uid, engine_id, raw_tokens=raw_tokens + ) start_column_token = None like_strings = [] @@ -139,7 +182,7 @@ def get_query_validation_results( ): # No "OR" token following the phrase, so we cannot combine additional phrases # Check if there are multiple phrases that can be combined if len(like_strings) > 1: - validation_errors.append( + validation_results.append( self._get_query_validation_result( query, start_column_token.start, @@ -157,7 +200,7 @@ def get_query_validation_results( if ( len(like_strings) > 1 ): # Check if a validation suggestion can be created - validation_errors.append( + validation_results.append( self._get_query_validation_result( query, start_column_token.start, @@ -171,53 +214,111 @@ def get_query_validation_results( like_strings = [] token_idx += 1 - return validation_errors + return validation_results -class PrestoOptimizingValidator(BaseQueryValidator): - def languages(self): - return ["presto", "trino"] +class ColumnNameSuggester(BasePrestoSQLGlotDecorator): + @property + def message(self): + return "" # Unused, message is not changed - def _get_explain_validator(self): - return PrestoExplainValidator("") + @property + def severity(self): + return QueryValidationSeverity.WARNING # Unused, severity is not changed - def _get_sqlglot_validators(self) -> List[BaseSQLGlotValidator]: - return [ - UnionAllValidator(), - ApproxDistinctValidator(), - RegexpLikeValidator(), - ] + def _get_tables_in_query(self, query: str, engine_id: int) -> List[str]: + engine = get_query_engine_by_id(engine_id) + tables_per_statement, _ = process_query(query, language=engine.language) + return list(chain.from_iterable(tables_per_statement)) - def _get_sql_glot_validation_results( - self, query: str - ) -> List[QueryValidationResult]: - validation_suggestions = [] - - query_raw_tokens = None - for validator in self._get_sqlglot_validators(): - if query_raw_tokens is None: - query_raw_tokens = validator._tokenize_query(query) - validation_suggestions.extend( - validator.get_query_validation_results( - query, raw_tokens=query_raw_tokens - ) - ) + def _is_column_name_error(self, validation_result: QueryValidationResult) -> bool: + return bool( + re.match(r"Column .* cannot be resolved", validation_result.message) + ) - return validation_suggestions + def _get_column_name_from_position( + self, tokens: List[Token], query: str, start_line: int, start_ch: int + ) -> str: + column_error_start_index = self._get_query_index_by_coordinate( + query, start_line, start_ch + ) + for token in tokens: + if token.start == column_error_start_index: + return token.text + return None - def _get_presto_explain_validation_results( - self, query: str, uid: int, engine_id: int - ) -> List[QueryValidationResult]: - return self._get_explain_validator().validate(query, uid, engine_id) + def _search_columns_for_suggestion(self, columns: List[str], suggestion: str): + """Return the case-sensitive column name by searching the table's columns for the suggestion text""" + for col in columns: + if col.lower() == suggestion.lower(): + return col + return suggestion + + def _get_column_name_suggestion( + self, + validation_result: QueryValidationResult, + query: str, + tables_in_query: List[str], + raw_tokens: List[Token], + ): + fuzzy_column_name = self._get_column_name_from_position( + raw_tokens, query, validation_result.start_line, validation_result.start_ch + ) + if not fuzzy_column_name: + return None + results, count = get_column_name_suggestion(fuzzy_column_name, tables_in_query) + if count == 1: # Only return suggestion if there's a single match + table_result = results[0] + highlights = table_result.get("highlight", {}).get("columns", []) + if len(highlights) == 1: + column_suggestion = self._search_columns_for_suggestion( + table_result.get("columns"), highlights[0] + ) + return column_suggestion + + return None def validate( self, query: str, uid: int, engine_id: int, + raw_tokens: List[QueryValidationResult] = None, + **kwargs, ) -> List[QueryValidationResult]: - validation_results = [ - *self._get_presto_explain_validation_results(query, uid, engine_id), - *self._get_sql_glot_validation_results(query), - ] + if raw_tokens is None: + raw_tokens = self._tokenize_query(query) + validation_results = self._validator.validate( + query, uid, engine_id, raw_tokens=raw_tokens + ) + tables_in_query = self._get_tables_in_query(query, engine_id) + for result in validation_results: + # "Column .* cannot be resolved" -> to get all name errors + if self._is_column_name_error(result): + column_suggestion = self._get_column_name_suggestion( + result, query, tables_in_query, raw_tokens + ) + if column_suggestion: + result.suggestion = column_suggestion return validation_results + + +class PrestoOptimizingValidator(BaseQueryValidator): + def languages(self): + return ["presto", "trino"] + + def _get_explain_validator(self): + return PrestoExplainValidator("") + + def _get_decorated_validator(self) -> BaseQueryValidator: + return UnionAllValidator( + ApproxDistinctValidator( + RegexpLikeValidator(ColumnNameSuggester(self._get_explain_validator())) + ) + ) + + def validate( + self, query: str, uid: int, engine_id: int, **kwargs + ) -> List[QueryValidationResult]: + validator = self._get_decorated_validator() + return validator.validate(query, uid, engine_id) From bc4e5f96292f76ac566e7ae600cb3865f6694100 Mon Sep 17 00:00:00 2001 From: kgopal Date: Wed, 20 Sep 2023 13:54:20 -0400 Subject: [PATCH 2/6] fix regex --- .../validation/validators/presto_optimizing_validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py index 0aee31210..9f50cd189 100644 --- a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py +++ b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py @@ -233,7 +233,7 @@ def _get_tables_in_query(self, query: str, engine_id: int) -> List[str]: def _is_column_name_error(self, validation_result: QueryValidationResult) -> bool: return bool( - re.match(r"Column .* cannot be resolved", validation_result.message) + re.search(r"Column .* cannot be resolved", validation_result.message) ) def _get_column_name_from_position( From c326d5b330a87721fbadfb139a74df98df38cc88 Mon Sep 17 00:00:00 2001 From: kgopal Date: Wed, 20 Sep 2023 14:24:00 -0400 Subject: [PATCH 3/6] remove comment --- .../validation/validators/presto_optimizing_validator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py index 9f50cd189..e564370fb 100644 --- a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py +++ b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py @@ -293,7 +293,6 @@ def validate( ) tables_in_query = self._get_tables_in_query(query, engine_id) for result in validation_results: - # "Column .* cannot be resolved" -> to get all name errors if self._is_column_name_error(result): column_suggestion = self._get_column_name_suggestion( result, query, tables_in_query, raw_tokens From 32e18421367d7b5e2e60722e1fe77c3adbdb0f15 Mon Sep 17 00:00:00 2001 From: kgopal Date: Thu, 21 Sep 2023 16:56:47 -0400 Subject: [PATCH 4/6] fix tests --- .../validators/presto_optimizing_validator.py | 19 +- .../test_presto_optimizing_validator.py | 204 +++++++++++++++--- 2 files changed, 186 insertions(+), 37 deletions(-) diff --git a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py index e564370fb..e8cb205eb 100644 --- a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py +++ b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py @@ -5,7 +5,7 @@ from sqlglot.tokens import Token from typing import List -from lib.elasticsearch.search_table import get_column_name_suggestion +from lib.elasticsearch import search_table from lib.query_analysis.lineage import process_query from lib.query_analysis.validation.base_query_validator import ( BaseQueryValidator, @@ -237,12 +237,14 @@ def _is_column_name_error(self, validation_result: QueryValidationResult) -> boo ) def _get_column_name_from_position( - self, tokens: List[Token], query: str, start_line: int, start_ch: int + self, query: str, start_line: int, start_ch: int, raw_tokens: List[Token] ) -> str: + if raw_tokens is None: + raw_tokens = self._tokenize_query(query) column_error_start_index = self._get_query_index_by_coordinate( query, start_line, start_ch ) - for token in tokens: + for token in raw_tokens: if token.start == column_error_start_index: return token.text return None @@ -259,14 +261,19 @@ def _get_column_name_suggestion( validation_result: QueryValidationResult, query: str, tables_in_query: List[str], - raw_tokens: List[Token], + raw_tokens: List[Token] = None, ): fuzzy_column_name = self._get_column_name_from_position( - raw_tokens, query, validation_result.start_line, validation_result.start_ch + query, + validation_result.start_line, + validation_result.start_ch, + raw_tokens=raw_tokens, ) if not fuzzy_column_name: return None - results, count = get_column_name_suggestion(fuzzy_column_name, tables_in_query) + results, count = search_table.get_column_name_suggestion( + fuzzy_column_name, tables_in_query + ) if count == 1: # Only return suggestion if there's a single match table_result = results[0] highlights = table_result.get("highlight", {}).get("columns", []) diff --git a/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py b/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py index 0abb7aa51..5679d8d14 100644 --- a/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py +++ b/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py @@ -1,5 +1,6 @@ from typing import List from unittest import TestCase +from unittest.mock import patch, MagicMock from lib.query_analysis.validation.base_query_validator import ( QueryValidationResult, @@ -8,6 +9,7 @@ ) from lib.query_analysis.validation.validators.presto_optimizing_validator import ( ApproxDistinctValidator, + ColumnNameSuggester, RegexpLikeValidator, UnionAllValidator, PrestoOptimizingValidator, @@ -15,6 +17,11 @@ class BaseValidatorTestCase(TestCase): + def _get_explain_validator_mock(self): + explain_validator_mock = MagicMock() + explain_validator_mock.validate.return_value = [] + return explain_validator_mock + def _verify_query_validation_results( self, validation_results: List[QueryValidationResult], @@ -75,12 +82,12 @@ def _get_approx_distinct_validation_result( class UnionAllValidatorTestCase(BaseValidatorTestCase): def setUp(self): - self._validator = UnionAllValidator() + self._validator = UnionAllValidator(self._get_explain_validator_mock()) def test_basic_union(self): query = "SELECT * FROM a \nUNION SELECT * FROM b" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_union_all_validation_result( 1, @@ -94,7 +101,7 @@ def test_basic_union(self): def test_multiple_unions(self): query = "SELECT * FROM a \nUNION SELECT * FROM b \nUNION SELECT * FROM c" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_union_all_validation_result( 1, @@ -113,26 +120,24 @@ def test_multiple_unions(self): def test_union_all(self): query = "SELECT * FROM a UNION ALL SELECT * FROM b" - self._verify_query_validation_results( - self._validator.get_query_validation_results(query), [] - ) + self._verify_query_validation_results(self._validator.validate(query, 0, 0), []) class ApproxDistinctValidatorTestCase(BaseValidatorTestCase): def setUp(self): - self._validator = ApproxDistinctValidator() + self._validator = ApproxDistinctValidator(self._get_explain_validator_mock()) def test_basic_count_distinct(self): query = "SELECT COUNT(DISTINCT x) FROM a" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [self._get_approx_distinct_validation_result(0, 7, 0, 20)], ) def test_count_not_followed_by_distinct(self): query = "SELECT \nCOUNT * FROM a" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [], ) @@ -141,7 +146,7 @@ def test_multiple_count_distincts(self): "SELECT \nCOUNT(DISTINCT y) FROM a UNION SELECT \nCOUNT(DISTINCT x) FROM b" ) self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_approx_distinct_validation_result(1, 0, 1, 13), self._get_approx_distinct_validation_result(2, 0, 2, 13), @@ -153,7 +158,7 @@ def test_count_distinct_in_where_clause(self): "SELECT \nCOUNT(DISTINCT a), b FROM table_a WHERE \nCOUNT(DISTINCT a) > 10" ) self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_approx_distinct_validation_result(1, 0, 1, 13), self._get_approx_distinct_validation_result(2, 0, 2, 13), @@ -163,12 +168,12 @@ def test_count_distinct_in_where_clause(self): class RegexpLikeValidatorTestCase(BaseValidatorTestCase): def setUp(self): - self._validator = RegexpLikeValidator() + self._validator = RegexpLikeValidator(self._get_explain_validator_mock()) def test_basic_combine_case(self): query = "SELECT * from a WHERE \nx LIKE 'foo' OR x LIKE \n'bar'" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_regexp_like_validation_result( 1, 0, 2, 4, "REGEXP_LIKE(x, 'foo|bar')" @@ -179,14 +184,14 @@ def test_basic_combine_case(self): def test_and_clause(self): query = "SELECT * from a WHERE \nx LIKE 'foo%' AND x LIKE \n'%bar'" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [], ) def test_more_than_two_phrases(self): query = "SELECT * from a WHERE \nx LIKE 'foo' OR x LIKE 'bar' OR x LIKE \n'baz'" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_regexp_like_validation_result( 1, 0, 2, 4, "REGEXP_LIKE(x, 'foo|bar|baz')" @@ -197,7 +202,7 @@ def test_more_than_two_phrases(self): def test_different_column_names(self): query = "SELECT * from a WHERE \nx LIKE 'foo' OR y LIKE 'bar'" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [], ) @@ -206,7 +211,7 @@ def test_both_or_and(self): "SELECT * from a WHERE \nx LIKE 'foo' OR x LIKE \n'bar' AND y LIKE 'foo'" ) self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_regexp_like_validation_result( 1, 0, 2, 4, "REGEXP_LIKE(x, 'foo|bar')" @@ -217,7 +222,7 @@ def test_both_or_and(self): def test_multiple_suggestions(self): query = "SELECT * from a WHERE \nx LIKE 'foo' OR x LIKE \n'bar' AND \ny LIKE 'foo' OR y LIKE \n'bar'" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_regexp_like_validation_result( 1, 0, 2, 4, "REGEXP_LIKE(x, 'foo|bar')" @@ -231,72 +236,209 @@ def test_multiple_suggestions(self): def test_phrase_not_match(self): query = "SELECT * from a WHERE x LIKE 'foo' OR x = 'bar'" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), + [], + ) + + +class ColumnNameSuggesterTestCase(BaseValidatorTestCase): + def setUp(self): + self._validator = ColumnNameSuggester(MagicMock()) + + def test__is_column_name_error(self): + self.assertEqual( + self._validator._is_column_name_error( + QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "Line 0:1 Column 'happyness' cannot be resolved", + ) + ), + True, + ) + self.assertEqual( + self._validator._is_column_name_error( + QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "Line 0:1 Table 'world_happiness_rank' does not exist", + ) + ), + False, + ) + + def test_search_columns_for_suggestion(self): + self.assertEqual( + self._validator._search_columns_for_suggestion( + ["HappinessRank", "Country", "Region"], "country" + ), + "Country", + ) + self.assertEqual( + self._validator._search_columns_for_suggestion( + ["HappinessRank, Region"], "country" + ), + "country", + ) + + @patch( + "lib.elasticsearch.search_table.get_column_name_suggestion", + ) + def test__get_column_name_suggestion(self, mock_get_column_name_suggestion): + validation_result = QueryValidationResult( + 0, + 7, + QueryValidationSeverity.WARNING, + "Line 0:1 Column 'happynessrank' cannot be resolved", + ) + query = "select happynessrank from main.world_happiness_report;" + # Test too many tables matched + mock_get_column_name_suggestion.return_value = [ + [ + { + "columns": ["HappinessRank"], + "highlight": {"columns": ["happinessrank"]}, + }, + { + "columns": ["HappinessRank"], + "highlight": {"columns": ["happinessrank1"]}, + }, + ], + 2, + ] + self.assertEqual( + self._validator._get_column_name_suggestion( + validation_result, + query, + ["main.world_happiness_report"], + ), + None, + ) + + # Test too many columns in a table matched + mock_get_column_name_suggestion.return_value = [ + [ + { + "columns": ["HappinessRank", "HappinessRank1"], + "highlight": {"columns": ["happinessrank", "happinessrank1"]}, + }, + ], + 1, + ] + self.assertEqual( + self._validator._get_column_name_suggestion( + validation_result, + query, + ["main.world_happiness_report"], + ), + None, + ) + + # Test single column matched + mock_get_column_name_suggestion.return_value = [ + [ + { + "columns": ["HappinessRank", "HappinessRank1"], + "highlight": {"columns": ["happinessrank"]}, + }, + ], + 1, + ] + self.assertEqual( + self._validator._get_column_name_suggestion( + validation_result, + query, + ["main.world_happiness_report"], + ), + "HappinessRank", + ) + + # Test no search results + mock_get_column_name_suggestion.return_value = [ [], + 0, + ] + self.assertEqual( + self._validator._get_column_name_suggestion( + validation_result, + query, + ["main.world_happiness_report"], + ), + None, ) class PrestoOptimizingValidatorTestCase(BaseValidatorTestCase): def setUp(self): + super(PrestoOptimizingValidatorTestCase, self).setUp() + patch_validator = patch.object( + ColumnNameSuggester, + "validate", + return_value=[], + ) + patch_validator.start() + self.addCleanup(patch_validator.stop) self._validator = PrestoOptimizingValidator("") def test_union_and_count_distinct(self): query = "SELECT \nCOUNT( DISTINCT x) from a \nUNION select \ncount(distinct y) from b" self._verify_query_validation_results( - self._validator._get_sql_glot_validation_results(query), + self._validator.validate(query, 0, 0), [ - self._get_union_all_validation_result(2, 0, 2, 4), self._get_approx_distinct_validation_result(1, 0, 1, 14), self._get_approx_distinct_validation_result(3, 0, 3, 13), + self._get_union_all_validation_result(2, 0, 2, 4), ], ) def test_union_and_regexp_like(self): query = "SELECT * from a WHERE \nx like 'foo' or x like \n'bar' \nUNION select * from b where y like 'foo' AND x like 'bar'" self._verify_query_validation_results( - self._validator._get_sql_glot_validation_results(query), + self._validator.validate(query, 0, 0), [ - self._get_union_all_validation_result(3, 0, 3, 4), self._get_regexp_like_validation_result( 1, 0, 2, 4, "REGEXP_LIKE(x, 'foo|bar')" ), + self._get_union_all_validation_result(3, 0, 3, 4), ], ) def test_count_distinct_and_regexp_like(self): query = "SELECT \nCOUNT( DISTINCT x) from a WHERE \nx LIKE 'foo' or x like \n'bar' and y like 'foo'" self._verify_query_validation_results( - self._validator._get_sql_glot_validation_results(query), + self._validator.validate(query, 0, 0), [ - self._get_approx_distinct_validation_result(1, 0, 1, 14), self._get_regexp_like_validation_result( 2, 0, 3, 4, "REGEXP_LIKE(x, 'foo|bar')" ), + self._get_approx_distinct_validation_result(1, 0, 1, 14), ], ) def test_all_errors(self): query = "SELECT \nCOUNT( DISTINCT x) from a WHERE \nx LIKE 'foo' or x like \n'bar' and y like 'foo' \nUNION select * from b" self._verify_query_validation_results( - self._validator._get_sql_glot_validation_results(query), + self._validator.validate(query, 0, 0), [ - self._get_union_all_validation_result(4, 0, 4, 4), - self._get_approx_distinct_validation_result(1, 0, 1, 14), self._get_regexp_like_validation_result( 2, 0, 3, 4, "REGEXP_LIKE(x, 'foo|bar')" ), + self._get_approx_distinct_validation_result(1, 0, 1, 14), + self._get_union_all_validation_result(4, 0, 4, 4), ], ) def test_extra_whitespace(self): query = "SELECT \n COUNT( DISTINCT x) from a WHERE \n\t x LIKE 'foo' or x like \n'bar' and y like 'foo' \n UNION select * from b" self._verify_query_validation_results( - self._validator._get_sql_glot_validation_results(query), + self._validator.validate(query, 0, 0), [ - self._get_union_all_validation_result(4, 5, 4, 9), - self._get_approx_distinct_validation_result(1, 2, 1, 16), self._get_regexp_like_validation_result( 2, 3, 3, 4, "REGEXP_LIKE(x, 'foo|bar')" ), + self._get_approx_distinct_validation_result(1, 2, 1, 16), + self._get_union_all_validation_result(4, 5, 4, 9), ], ) From 54272fbb3c58f0cc4b83bce6b6a2058a0d759ef3 Mon Sep 17 00:00:00 2001 From: kgopal Date: Mon, 25 Sep 2023 18:30:50 -0500 Subject: [PATCH 5/6] address pr comments, add table suggester --- .../server/lib/elasticsearch/search_table.py | 22 +++ .../validators/base_sqlglot_validator.py | 18 ++- .../validators/metadata_suggesters.py | 150 ++++++++++++++++++ .../validators/presto_optimizing_validator.py | 134 ++++------------ .../test_presto_optimizing_validator.py | 75 +++++---- 5 files changed, 254 insertions(+), 145 deletions(-) create mode 100644 querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py diff --git a/querybook/server/lib/elasticsearch/search_table.py b/querybook/server/lib/elasticsearch/search_table.py index 71448a38f..76dcffbb9 100644 --- a/querybook/server/lib/elasticsearch/search_table.py +++ b/querybook/server/lib/elasticsearch/search_table.py @@ -215,3 +215,25 @@ def get_column_name_suggestion( } return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True) + + +def get_table_name_suggestion(fuzzy_full_table_name: str) -> Tuple[Dict, int]: + """Given an invalid table name use fuzzy search to search the correctly-spelled table name""" + schema_name, fuzzy_table_name = fuzzy_full_table_name.split(".") + + search_query = { + "query": { + "bool": { + "must": [ + {"match": {"schema": schema_name}}, + { + "match": { + "name": {"query": fuzzy_table_name, "fuzziness": "AUTO"}, + } + }, + ] + } + }, + } + + return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True) diff --git a/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py b/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py index 073d13ad8..16dab444d 100644 --- a/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py +++ b/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py @@ -72,6 +72,22 @@ def validate( uid: int, engine_id: int, raw_tokens: List[Token] = None, - **kwargs + **kwargs, ) -> List[QueryValidationResult]: raise NotImplementedError() + + +class BaseSQLGlotDecorator(BaseSQLGlotValidator): + def __init__(self, validator: BaseQueryValidator): + self._validator = validator + + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[Token] = None, + **kwargs, + ): + """Override this method to add suggestions to validation results""" + return self._validator.validate(query, uid, engine_id, **kwargs) diff --git a/querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py b/querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py new file mode 100644 index 000000000..6841391d0 --- /dev/null +++ b/querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py @@ -0,0 +1,150 @@ +from abc import abstractmethod +import re +from itertools import chain +from typing import List + +from lib.elasticsearch import search_table +from lib.query_analysis.lineage import process_query +from lib.query_analysis.validation.base_query_validator import ( + QueryValidationResult, + QueryValidationSeverity, +) +from lib.query_analysis.validation.validators.base_sqlglot_validator import ( + BaseSQLGlotDecorator, +) +from logic.admin import get_query_engine_by_id + + +class BaseColumnNameSuggester(BaseSQLGlotDecorator): + @property + def severity(self): + return QueryValidationSeverity.WARNING # Unused, severity is not changed + + @property + def message(self): + return "" # Unused, message is not changed + + @property + @abstractmethod + def column_name_error_regex(self): + raise NotImplementedError() + + @abstractmethod + def get_column_name_from_error(self, validation_result: QueryValidationResult): + raise NotImplementedError() + + def _is_column_name_error(self, validation_result: QueryValidationResult) -> bool: + return bool(re.match(self.column_name_error_regex, validation_result.message)) + + def _get_tables_in_query(self, query: str, engine_id: int) -> List[str]: + engine = get_query_engine_by_id(engine_id) + tables_per_statement, _ = process_query(query, language=engine.language) + return list(chain.from_iterable(tables_per_statement)) + + def _search_columns_for_suggestion(self, columns: List[str], suggestion: str): + """Return the case-sensitive column name by searching the table's columns for the suggestion text""" + for col in columns: + if col.lower() == suggestion.lower(): + return col + return suggestion + + def _suggest_column_name( + self, + validation_result: QueryValidationResult, + tables_in_query: List[str], + ): + """Takes validation result and tables in query to update validation result to provide column + name suggestion""" + fuzzy_column_name = self.get_column_name_from_error(validation_result) + if not fuzzy_column_name: + return None + results, count = search_table.get_column_name_suggestion( + fuzzy_column_name, tables_in_query + ) + if count == 1: # Only suggest column if there's a single match + table_result = results[0] + highlights = table_result.get("highlight", {}).get("columns", []) + if len(highlights) == 1: + column_suggestion = self._search_columns_for_suggestion( + table_result.get("columns"), highlights[0] + ) + validation_result.suggestion = column_suggestion + validation_result.end_line = validation_result.start_line + validation_result.end_ch = ( + validation_result.start_ch + len(fuzzy_column_name) - 1 + ) + + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[QueryValidationResult] = None, + **kwargs, + ) -> List[QueryValidationResult]: + if raw_tokens is None: + raw_tokens = self._tokenize_query(query) + validation_results = self._validator.validate( + query, uid, engine_id, raw_tokens=raw_tokens + ) + tables_in_query = self._get_tables_in_query(query, engine_id) + for result in validation_results: + if self._is_column_name_error(result): + self._suggest_column_name(result, tables_in_query) + return validation_results + + +class BaseTableNameSuggester(BaseSQLGlotDecorator): + @property + def severity(self): + return QueryValidationSeverity.WARNING # Unused, severity is not changed + + @property + def message(self): + return "" # Unused, message is not changed + + @property + @abstractmethod + def table_name_error_regex(self): + raise NotImplementedError() + + @abstractmethod + def get_full_table_name_from_error(self, validation_result: QueryValidationResult): + raise NotImplementedError() + + def _is_table_name_error(self, validation_result: QueryValidationResult) -> bool: + return bool(re.match(self.table_name_error_regex, validation_result.message)) + + def _suggest_table_name(self, validation_result: QueryValidationResult): + """Takes validation result and tables in query to update validation result to provide table + name suggestion""" + fuzzy_table_name = self.get_full_table_name_from_error(validation_result) + if not fuzzy_table_name: + return None + results, count = search_table.get_table_name_suggestion(fuzzy_table_name) + if count > 0: + table_result = results[0] # Get top match + table_suggestion = f"{table_result['schema']}.{table_result['name']}" + validation_result.suggestion = table_suggestion + validation_result.end_line = validation_result.start_line + validation_result.end_ch = ( + validation_result.start_ch + len(fuzzy_table_name) - 1 + ) + + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[QueryValidationResult] = None, + **kwargs, + ) -> List[QueryValidationResult]: + if raw_tokens is None: + raw_tokens = self._tokenize_query(query) + validation_results = self._validator.validate( + query, uid, engine_id, raw_tokens=raw_tokens + ) + for result in validation_results: + if self._is_table_name_error(result): + self._suggest_table_name(result) + return validation_results diff --git a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py index e8cb205eb..21cedc85d 100644 --- a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py +++ b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py @@ -1,12 +1,9 @@ import re -from itertools import chain from sqlglot import TokenType, Tokenizer from sqlglot.dialects import Trino from sqlglot.tokens import Token from typing import List -from lib.elasticsearch import search_table -from lib.query_analysis.lineage import process_query from lib.query_analysis.validation.base_query_validator import ( BaseQueryValidator, QueryValidationResult, @@ -16,15 +13,15 @@ PrestoExplainValidator, ) from lib.query_analysis.validation.validators.base_sqlglot_validator import ( - BaseSQLGlotValidator, + BaseSQLGlotDecorator, +) +from lib.query_analysis.validation.validators.metadata_suggesters import ( + BaseColumnNameSuggester, + BaseTableNameSuggester, ) -from logic.admin import get_query_engine_by_id - -class BasePrestoSQLGlotDecorator(BaseSQLGlotValidator): - def __init__(self, validator: BaseQueryValidator): - self._validator = validator +class BasePrestoSQLGlotDecorator(BaseSQLGlotDecorator): def languages(self): return ["presto", "trino"] @@ -32,17 +29,6 @@ def languages(self): def tokenizer(self) -> Tokenizer: return Trino.Tokenizer() - def validate( - self, - query: str, - uid: int, - engine_id: int, - raw_tokens: List[Token] = None, - **kwargs, - ): - """Override this method to add suggestions to validation results""" - return self._validator.validate(query, uid, engine_id, **kwargs) - class UnionAllValidator(BasePrestoSQLGlotDecorator): @property @@ -217,109 +203,45 @@ def validate( return validation_results -class ColumnNameSuggester(BasePrestoSQLGlotDecorator): +class PrestoColumnNameSuggester(BasePrestoSQLGlotDecorator, BaseColumnNameSuggester): @property - def message(self): - return "" # Unused, message is not changed + def column_name_error_regex(self): + return r"line \d+:\d+: Column '(.*)' cannot be resolved" - @property - def severity(self): - return QueryValidationSeverity.WARNING # Unused, severity is not changed + def get_column_name_from_error(self, validation_result: QueryValidationResult): + regex_result = re.match(self.column_name_error_regex, validation_result.message) + return regex_result.groups()[0] - def _get_tables_in_query(self, query: str, engine_id: int) -> List[str]: - engine = get_query_engine_by_id(engine_id) - tables_per_statement, _ = process_query(query, language=engine.language) - return list(chain.from_iterable(tables_per_statement)) - def _is_column_name_error(self, validation_result: QueryValidationResult) -> bool: - return bool( - re.search(r"Column .* cannot be resolved", validation_result.message) - ) - - def _get_column_name_from_position( - self, query: str, start_line: int, start_ch: int, raw_tokens: List[Token] - ) -> str: - if raw_tokens is None: - raw_tokens = self._tokenize_query(query) - column_error_start_index = self._get_query_index_by_coordinate( - query, start_line, start_ch - ) - for token in raw_tokens: - if token.start == column_error_start_index: - return token.text - return None - - def _search_columns_for_suggestion(self, columns: List[str], suggestion: str): - """Return the case-sensitive column name by searching the table's columns for the suggestion text""" - for col in columns: - if col.lower() == suggestion.lower(): - return col - return suggestion - - def _get_column_name_suggestion( - self, - validation_result: QueryValidationResult, - query: str, - tables_in_query: List[str], - raw_tokens: List[Token] = None, - ): - fuzzy_column_name = self._get_column_name_from_position( - query, - validation_result.start_line, - validation_result.start_ch, - raw_tokens=raw_tokens, - ) - if not fuzzy_column_name: - return None - results, count = search_table.get_column_name_suggestion( - fuzzy_column_name, tables_in_query - ) - if count == 1: # Only return suggestion if there's a single match - table_result = results[0] - highlights = table_result.get("highlight", {}).get("columns", []) - if len(highlights) == 1: - column_suggestion = self._search_columns_for_suggestion( - table_result.get("columns"), highlights[0] - ) - return column_suggestion - - return None +class PrestoTableNameSuggester(BasePrestoSQLGlotDecorator, BaseTableNameSuggester): + @property + def table_name_error_regex(self): + return r"line \d+:\d+: Table '(.*)' does not exist" - def validate( - self, - query: str, - uid: int, - engine_id: int, - raw_tokens: List[QueryValidationResult] = None, - **kwargs, - ) -> List[QueryValidationResult]: - if raw_tokens is None: - raw_tokens = self._tokenize_query(query) - validation_results = self._validator.validate( - query, uid, engine_id, raw_tokens=raw_tokens - ) - tables_in_query = self._get_tables_in_query(query, engine_id) - for result in validation_results: - if self._is_column_name_error(result): - column_suggestion = self._get_column_name_suggestion( - result, query, tables_in_query, raw_tokens - ) - if column_suggestion: - result.suggestion = column_suggestion - return validation_results + def get_full_table_name_from_error(self, validation_result: QueryValidationResult): + regex_result = re.match(self.table_name_error_regex, validation_result.message) + return regex_result.groups()[0] class PrestoOptimizingValidator(BaseQueryValidator): def languages(self): return ["presto", "trino"] + @property + def tokenizer(self) -> Tokenizer: + return Trino.Tokenizer() + def _get_explain_validator(self): return PrestoExplainValidator("") def _get_decorated_validator(self) -> BaseQueryValidator: return UnionAllValidator( ApproxDistinctValidator( - RegexpLikeValidator(ColumnNameSuggester(self._get_explain_validator())) + RegexpLikeValidator( + PrestoTableNameSuggester( + PrestoColumnNameSuggester(self._get_explain_validator()) + ) + ) ) ) diff --git a/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py b/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py index 5679d8d14..afce1fefb 100644 --- a/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py +++ b/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py @@ -9,7 +9,7 @@ ) from lib.query_analysis.validation.validators.presto_optimizing_validator import ( ApproxDistinctValidator, - ColumnNameSuggester, + PrestoColumnNameSuggester, RegexpLikeValidator, UnionAllValidator, PrestoOptimizingValidator, @@ -241,9 +241,9 @@ def test_phrase_not_match(self): ) -class ColumnNameSuggesterTestCase(BaseValidatorTestCase): +class PrestoColumnNameSuggesterTestCase(BaseValidatorTestCase): def setUp(self): - self._validator = ColumnNameSuggester(MagicMock()) + self._validator = PrestoColumnNameSuggester(MagicMock()) def test__is_column_name_error(self): self.assertEqual( @@ -252,7 +252,7 @@ def test__is_column_name_error(self): 0, 0, QueryValidationSeverity.WARNING, - "Line 0:1 Column 'happyness' cannot be resolved", + "line 0:1: Column 'happyness' cannot be resolved", ) ), True, @@ -263,7 +263,7 @@ def test__is_column_name_error(self): 0, 0, QueryValidationSeverity.WARNING, - "Line 0:1 Table 'world_happiness_rank' does not exist", + "line 0:1: Table 'world_happiness_rank' does not exist", ) ), False, @@ -283,18 +283,20 @@ def test_search_columns_for_suggestion(self): "country", ) - @patch( - "lib.elasticsearch.search_table.get_column_name_suggestion", - ) - def test__get_column_name_suggestion(self, mock_get_column_name_suggestion): - validation_result = QueryValidationResult( + def _get_new_validation_result_obj(self): + return QueryValidationResult( 0, 7, QueryValidationSeverity.WARNING, - "Line 0:1 Column 'happynessrank' cannot be resolved", + "line 0:1: Column 'happynessrank' cannot be resolved", ) - query = "select happynessrank from main.world_happiness_report;" + + @patch( + "lib.elasticsearch.search_table.get_column_name_suggestion", + ) + def test__get_column_name_suggestion(self, mock_get_column_name_suggestion): # Test too many tables matched + validation_result = self._get_new_validation_result_obj() mock_get_column_name_suggestion.return_value = [ [ { @@ -308,16 +310,14 @@ def test__get_column_name_suggestion(self, mock_get_column_name_suggestion): ], 2, ] - self.assertEqual( - self._validator._get_column_name_suggestion( - validation_result, - query, - ["main.world_happiness_report"], - ), - None, + self._validator._suggest_column_name( + validation_result, + ["main.world_happiness_report"], ) + self.assertEqual(validation_result.suggestion, None) # Test too many columns in a table matched + validation_result = self._get_new_validation_result_obj() mock_get_column_name_suggestion.return_value = [ [ { @@ -327,16 +327,17 @@ def test__get_column_name_suggestion(self, mock_get_column_name_suggestion): ], 1, ] + self._validator._suggest_column_name( + validation_result, + ["main.world_happiness_report"], + ), self.assertEqual( - self._validator._get_column_name_suggestion( - validation_result, - query, - ["main.world_happiness_report"], - ), + validation_result.suggestion, None, ) # Test single column matched + validation_result = self._get_new_validation_result_obj() mock_get_column_name_suggestion.return_value = [ [ { @@ -346,26 +347,24 @@ def test__get_column_name_suggestion(self, mock_get_column_name_suggestion): ], 1, ] - self.assertEqual( - self._validator._get_column_name_suggestion( - validation_result, - query, - ["main.world_happiness_report"], - ), - "HappinessRank", - ) + self._validator._suggest_column_name( + validation_result, + ["main.world_happiness_report"], + ), + self.assertEqual(validation_result.suggestion, "HappinessRank") # Test no search results + validation_result = self._get_new_validation_result_obj() mock_get_column_name_suggestion.return_value = [ [], 0, ] + self._validator._suggest_column_name( + validation_result, + ["main.world_happiness_report"], + ), self.assertEqual( - self._validator._get_column_name_suggestion( - validation_result, - query, - ["main.world_happiness_report"], - ), + validation_result.suggestion, None, ) @@ -374,7 +373,7 @@ class PrestoOptimizingValidatorTestCase(BaseValidatorTestCase): def setUp(self): super(PrestoOptimizingValidatorTestCase, self).setUp() patch_validator = patch.object( - ColumnNameSuggester, + PrestoColumnNameSuggester, "validate", return_value=[], ) From ed06df8df5fad663316891117e8d8bb337dfc7f5 Mon Sep 17 00:00:00 2001 From: kgopal Date: Tue, 26 Sep 2023 11:03:28 -0500 Subject: [PATCH 6/6] fix pr comments, add test case --- .../server/lib/elasticsearch/search_table.py | 29 +++--- .../validators/metadata_suggesters.py | 43 +++----- .../validators/presto_optimizing_validator.py | 20 ++-- .../test_presto_optimizing_validator.py | 99 +++++++++++++++++-- 4 files changed, 130 insertions(+), 61 deletions(-) diff --git a/querybook/server/lib/elasticsearch/search_table.py b/querybook/server/lib/elasticsearch/search_table.py index 76dcffbb9..55ac8ee66 100644 --- a/querybook/server/lib/elasticsearch/search_table.py +++ b/querybook/server/lib/elasticsearch/search_table.py @@ -217,23 +217,26 @@ def get_column_name_suggestion( return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True) -def get_table_name_suggestion(fuzzy_full_table_name: str) -> Tuple[Dict, int]: +def get_table_name_suggestion(fuzzy_table_name: str) -> Tuple[Dict, int]: """Given an invalid table name use fuzzy search to search the correctly-spelled table name""" - schema_name, fuzzy_table_name = fuzzy_full_table_name.split(".") - search_query = { - "query": { - "bool": { - "must": [ - {"match": {"schema": schema_name}}, - { - "match": { - "name": {"query": fuzzy_table_name, "fuzziness": "AUTO"}, - } - }, - ] + schema_name, fuzzy_name = None, fuzzy_table_name + fuzzy_table_name_parts = fuzzy_table_name.split(".") + if len(fuzzy_table_name_parts) == 2: + schema_name, fuzzy_name = fuzzy_table_name_parts + + must_clause = [ + { + "match": { + "name": {"query": fuzzy_name, "fuzziness": "AUTO"}, } }, + ] + if schema_name: + must_clause.append({"match": {"schema": schema_name}}) + + search_query = { + "query": {"bool": {"must": must_clause}}, } return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True) diff --git a/querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py b/querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py index 6841391d0..726e2ce9d 100644 --- a/querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py +++ b/querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py @@ -1,7 +1,6 @@ from abc import abstractmethod -import re from itertools import chain -from typing import List +from typing import List, Optional from lib.elasticsearch import search_table from lib.query_analysis.lineage import process_query @@ -24,18 +23,14 @@ def severity(self): def message(self): return "" # Unused, message is not changed - @property @abstractmethod - def column_name_error_regex(self): + def get_column_name_from_error( + self, validation_result: QueryValidationResult + ) -> Optional[str]: + """Returns invalid column name if the validation result is a column name error, otherwise + returns None""" raise NotImplementedError() - @abstractmethod - def get_column_name_from_error(self, validation_result: QueryValidationResult): - raise NotImplementedError() - - def _is_column_name_error(self, validation_result: QueryValidationResult) -> bool: - return bool(re.match(self.column_name_error_regex, validation_result.message)) - def _get_tables_in_query(self, query: str, engine_id: int) -> List[str]: engine = get_query_engine_by_id(engine_id) tables_per_statement, _ = process_query(query, language=engine.language) @@ -48,7 +43,7 @@ def _search_columns_for_suggestion(self, columns: List[str], suggestion: str): return col return suggestion - def _suggest_column_name( + def _suggest_column_name_if_needed( self, validation_result: QueryValidationResult, tables_in_query: List[str], @@ -57,7 +52,7 @@ def _suggest_column_name( name suggestion""" fuzzy_column_name = self.get_column_name_from_error(validation_result) if not fuzzy_column_name: - return None + return results, count = search_table.get_column_name_suggestion( fuzzy_column_name, tables_in_query ) @@ -89,8 +84,7 @@ def validate( ) tables_in_query = self._get_tables_in_query(query, engine_id) for result in validation_results: - if self._is_column_name_error(result): - self._suggest_column_name(result, tables_in_query) + self._suggest_column_name_if_needed(result, tables_in_query) return validation_results @@ -103,24 +97,20 @@ def severity(self): def message(self): return "" # Unused, message is not changed - @property - @abstractmethod - def table_name_error_regex(self): - raise NotImplementedError() - @abstractmethod def get_full_table_name_from_error(self, validation_result: QueryValidationResult): + """Returns invalid table name if the validation result is a table name error, otherwise + returns None""" raise NotImplementedError() - def _is_table_name_error(self, validation_result: QueryValidationResult) -> bool: - return bool(re.match(self.table_name_error_regex, validation_result.message)) - - def _suggest_table_name(self, validation_result: QueryValidationResult): + def _suggest_table_name_if_needed( + self, validation_result: QueryValidationResult + ) -> Optional[str]: """Takes validation result and tables in query to update validation result to provide table name suggestion""" fuzzy_table_name = self.get_full_table_name_from_error(validation_result) if not fuzzy_table_name: - return None + return results, count = search_table.get_table_name_suggestion(fuzzy_table_name) if count > 0: table_result = results[0] # Get top match @@ -145,6 +135,5 @@ def validate( query, uid, engine_id, raw_tokens=raw_tokens ) for result in validation_results: - if self._is_table_name_error(result): - self._suggest_table_name(result) + self._suggest_table_name_if_needed(result) return validation_results diff --git a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py index 21cedc85d..6cf87de45 100644 --- a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py +++ b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py @@ -204,23 +204,19 @@ def validate( class PrestoColumnNameSuggester(BasePrestoSQLGlotDecorator, BaseColumnNameSuggester): - @property - def column_name_error_regex(self): - return r"line \d+:\d+: Column '(.*)' cannot be resolved" - def get_column_name_from_error(self, validation_result: QueryValidationResult): - regex_result = re.match(self.column_name_error_regex, validation_result.message) - return regex_result.groups()[0] + regex_result = re.match( + r"line \d+:\d+: Column '(.*)' cannot be resolved", validation_result.message + ) + return regex_result.groups()[0] if regex_result else None class PrestoTableNameSuggester(BasePrestoSQLGlotDecorator, BaseTableNameSuggester): - @property - def table_name_error_regex(self): - return r"line \d+:\d+: Table '(.*)' does not exist" - def get_full_table_name_from_error(self, validation_result: QueryValidationResult): - regex_result = re.match(self.table_name_error_regex, validation_result.message) - return regex_result.groups()[0] + regex_result = re.match( + r"line \d+:\d+: Table '(.*)' does not exist", validation_result.message + ) + return regex_result.groups()[0] if regex_result else None class PrestoOptimizingValidator(BaseQueryValidator): diff --git a/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py b/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py index afce1fefb..93a98bbb6 100644 --- a/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py +++ b/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py @@ -10,6 +10,7 @@ from lib.query_analysis.validation.validators.presto_optimizing_validator import ( ApproxDistinctValidator, PrestoColumnNameSuggester, + PrestoTableNameSuggester, RegexpLikeValidator, UnionAllValidator, PrestoOptimizingValidator, @@ -245,9 +246,9 @@ class PrestoColumnNameSuggesterTestCase(BaseValidatorTestCase): def setUp(self): self._validator = PrestoColumnNameSuggester(MagicMock()) - def test__is_column_name_error(self): + def test_get_column_name_from_error(self): self.assertEqual( - self._validator._is_column_name_error( + self._validator.get_column_name_from_error( QueryValidationResult( 0, 0, @@ -255,10 +256,10 @@ def test__is_column_name_error(self): "line 0:1: Column 'happyness' cannot be resolved", ) ), - True, + "happyness", ) self.assertEqual( - self._validator._is_column_name_error( + self._validator.get_column_name_from_error( QueryValidationResult( 0, 0, @@ -266,7 +267,7 @@ def test__is_column_name_error(self): "line 0:1: Table 'world_happiness_rank' does not exist", ) ), - False, + None, ) def test_search_columns_for_suggestion(self): @@ -310,7 +311,7 @@ def test__get_column_name_suggestion(self, mock_get_column_name_suggestion): ], 2, ] - self._validator._suggest_column_name( + self._validator._suggest_column_name_if_needed( validation_result, ["main.world_happiness_report"], ) @@ -327,7 +328,7 @@ def test__get_column_name_suggestion(self, mock_get_column_name_suggestion): ], 1, ] - self._validator._suggest_column_name( + self._validator._suggest_column_name_if_needed( validation_result, ["main.world_happiness_report"], ), @@ -347,7 +348,7 @@ def test__get_column_name_suggestion(self, mock_get_column_name_suggestion): ], 1, ] - self._validator._suggest_column_name( + self._validator._suggest_column_name_if_needed( validation_result, ["main.world_happiness_report"], ), @@ -359,7 +360,7 @@ def test__get_column_name_suggestion(self, mock_get_column_name_suggestion): [], 0, ] - self._validator._suggest_column_name( + self._validator._suggest_column_name_if_needed( validation_result, ["main.world_happiness_report"], ), @@ -369,6 +370,86 @@ def test__get_column_name_suggestion(self, mock_get_column_name_suggestion): ) +class PrestoTableNameSuggesterTestCase(BaseValidatorTestCase): + def setUp(self): + self._validator = PrestoTableNameSuggester(MagicMock()) + + def test_get_full_table_name_from_error(self): + self.assertEquals( + self._validator.get_full_table_name_from_error( + QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "line 0:1: Table 'world_happiness_15' does not exist", + ) + ), + "world_happiness_15", + ) + self.assertEquals( + self._validator.get_full_table_name_from_error( + QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "line 0:1: column 'happiness_rank' cannot be resolved", + ) + ), + None, + ) + + @patch( + "lib.elasticsearch.search_table.get_table_name_suggestion", + ) + def test__suggest_table_name_if_needed_single_hit(self, mock_table_suggestion): + validation_result = QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "line 0:1: Table 'world_happiness_15' does not exist", + ) + mock_table_suggestion.return_value = [ + {"schema": "main", "name": "world_happiness_rank_2015"} + ], 1 + self._validator._suggest_table_name_if_needed(validation_result) + self.assertEquals( + validation_result.suggestion, "main.world_happiness_rank_2015" + ) + + @patch( + "lib.elasticsearch.search_table.get_table_name_suggestion", + ) + def test__suggest_table_name_if_needed_multiple_hits(self, mock_table_suggestion): + validation_result = QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "line 0:1: Table 'world_happiness_15' does not exist", + ) + mock_table_suggestion.return_value = [ + {"schema": "main", "name": "world_happiness_rank_2015"}, + {"schema": "main", "name": "world_happiness_rank_2016"}, + ], 2 + self._validator._suggest_table_name_if_needed(validation_result) + self.assertEquals( + validation_result.suggestion, "main.world_happiness_rank_2015" + ) + + @patch( + "lib.elasticsearch.search_table.get_table_name_suggestion", + ) + def test__suggest_table_name_if_needed_no_hits(self, mock_table_suggestion): + validation_result = QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "line 0:1: Table 'world_happiness_15' does not exist", + ) + mock_table_suggestion.return_value = [], 0 + self._validator._suggest_table_name_if_needed(validation_result) + self.assertEquals(validation_result.suggestion, None) + + class PrestoOptimizingValidatorTestCase(BaseValidatorTestCase): def setUp(self): super(PrestoOptimizingValidatorTestCase, self).setUp()