Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add column name suggestions to presto validator #1330

Merged
merged 6 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions querybook/server/lib/elasticsearch/search_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,25 @@ def get_column_name_suggestion(
}

return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True)


def get_table_name_suggestion(fuzzy_full_table_name: str) -> Tuple[Dict, int]:
"""Given an invalid table name use fuzzy search to search the correctly-spelled table name"""
schema_name, fuzzy_table_name = fuzzy_full_table_name.split(".")
kgopal492 marked this conversation as resolved.
Show resolved Hide resolved

search_query = {
"query": {
"bool": {
"must": [
{"match": {"schema": schema_name}},
{
"match": {
"name": {"query": fuzzy_table_name, "fuzziness": "AUTO"},
}
},
]
}
},
}

return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True)
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,22 @@ def validate(
uid: int,
engine_id: int,
raw_tokens: List[Token] = None,
**kwargs
**kwargs,
) -> List[QueryValidationResult]:
raise NotImplementedError()


class BaseSQLGlotDecorator(BaseSQLGlotValidator):
def __init__(self, validator: BaseQueryValidator):
self._validator = validator

def validate(
self,
query: str,
uid: int,
engine_id: int,
raw_tokens: List[Token] = None,
**kwargs,
):
"""Override this method to add suggestions to validation results"""
return self._validator.validate(query, uid, engine_id, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from abc import abstractmethod
import re
from itertools import chain
from typing import List

from lib.elasticsearch import search_table
from lib.query_analysis.lineage import process_query
from lib.query_analysis.validation.base_query_validator import (
QueryValidationResult,
QueryValidationSeverity,
)
from lib.query_analysis.validation.validators.base_sqlglot_validator import (
BaseSQLGlotDecorator,
)
from logic.admin import get_query_engine_by_id


class BaseColumnNameSuggester(BaseSQLGlotDecorator):
@property
def severity(self):
return QueryValidationSeverity.WARNING # Unused, severity is not changed

@property
def message(self):
return "" # Unused, message is not changed

@property
@abstractmethod
def column_name_error_regex(self):
raise NotImplementedError()

@abstractmethod
def get_column_name_from_error(self, validation_result: QueryValidationResult):
raise NotImplementedError()

def _is_column_name_error(self, validation_result: QueryValidationResult) -> bool:
return bool(re.match(self.column_name_error_regex, validation_result.message))

def _get_tables_in_query(self, query: str, engine_id: int) -> List[str]:
engine = get_query_engine_by_id(engine_id)
tables_per_statement, _ = process_query(query, language=engine.language)
return list(chain.from_iterable(tables_per_statement))

def _search_columns_for_suggestion(self, columns: List[str], suggestion: str):
"""Return the case-sensitive column name by searching the table's columns for the suggestion text"""
for col in columns:
if col.lower() == suggestion.lower():
return col
return suggestion

def _suggest_column_name(
self,
validation_result: QueryValidationResult,
tables_in_query: List[str],
):
"""Takes validation result and tables in query to update validation result to provide column
name suggestion"""
fuzzy_column_name = self.get_column_name_from_error(validation_result)
if not fuzzy_column_name:
return None
results, count = search_table.get_column_name_suggestion(
fuzzy_column_name, tables_in_query
)
if count == 1: # Only suggest column if there's a single match
table_result = results[0]
highlights = table_result.get("highlight", {}).get("columns", [])
if len(highlights) == 1:
column_suggestion = self._search_columns_for_suggestion(
table_result.get("columns"), highlights[0]
)
validation_result.suggestion = column_suggestion
validation_result.end_line = validation_result.start_line
validation_result.end_ch = (
validation_result.start_ch + len(fuzzy_column_name) - 1
)

def validate(
self,
query: str,
uid: int,
engine_id: int,
raw_tokens: List[QueryValidationResult] = None,
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)
tables_in_query = self._get_tables_in_query(query, engine_id)
for result in validation_results:
if self._is_column_name_error(result):
self._suggest_column_name(result, tables_in_query)
return validation_results


class BaseTableNameSuggester(BaseSQLGlotDecorator):
@property
def severity(self):
return QueryValidationSeverity.WARNING # Unused, severity is not changed

@property
def message(self):
return "" # Unused, message is not changed

@property
@abstractmethod
def table_name_error_regex(self):
raise NotImplementedError()

@abstractmethod
def get_full_table_name_from_error(self, validation_result: QueryValidationResult):
raise NotImplementedError()

def _is_table_name_error(self, validation_result: QueryValidationResult) -> bool:
return bool(re.match(self.table_name_error_regex, validation_result.message))
kgopal492 marked this conversation as resolved.
Show resolved Hide resolved

def _suggest_table_name(self, validation_result: QueryValidationResult):
"""Takes validation result and tables in query to update validation result to provide table
name suggestion"""
fuzzy_table_name = self.get_full_table_name_from_error(validation_result)
if not fuzzy_table_name:
return None
results, count = search_table.get_table_name_suggestion(fuzzy_table_name)
if count > 0:
table_result = results[0] # Get top match
table_suggestion = f"{table_result['schema']}.{table_result['name']}"
validation_result.suggestion = table_suggestion
validation_result.end_line = validation_result.start_line
validation_result.end_ch = (
validation_result.start_ch + len(fuzzy_table_name) - 1
)

def validate(
self,
query: str,
uid: int,
engine_id: int,
raw_tokens: List[QueryValidationResult] = None,
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)
for result in validation_results:
if self._is_table_name_error(result):
self._suggest_table_name(result)
return validation_results
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import re
from itertools import chain
from sqlglot import TokenType, Tokenizer
from sqlglot.dialects import Trino
from sqlglot.tokens import Token
from typing import List

from lib.elasticsearch import search_table
from lib.query_analysis.lineage import process_query
from lib.query_analysis.validation.base_query_validator import (
BaseQueryValidator,
QueryValidationResult,
Expand All @@ -16,33 +13,22 @@
PrestoExplainValidator,
)
from lib.query_analysis.validation.validators.base_sqlglot_validator import (
BaseSQLGlotValidator,
BaseSQLGlotDecorator,
)
from lib.query_analysis.validation.validators.metadata_suggesters import (
BaseColumnNameSuggester,
BaseTableNameSuggester,
)
from logic.admin import get_query_engine_by_id


class BasePrestoSQLGlotDecorator(BaseSQLGlotValidator):
def __init__(self, validator: BaseQueryValidator):
self._validator = validator

class BasePrestoSQLGlotDecorator(BaseSQLGlotDecorator):
def languages(self):
return ["presto", "trino"]

@property
def tokenizer(self) -> Tokenizer:
return Trino.Tokenizer()

def validate(
self,
query: str,
uid: int,
engine_id: int,
raw_tokens: List[Token] = None,
**kwargs,
):
"""Override this method to add suggestions to validation results"""
return self._validator.validate(query, uid, engine_id, **kwargs)


class UnionAllValidator(BasePrestoSQLGlotDecorator):
@property
Expand Down Expand Up @@ -217,109 +203,45 @@ def validate(
return validation_results


class ColumnNameSuggester(BasePrestoSQLGlotDecorator):
class PrestoColumnNameSuggester(BasePrestoSQLGlotDecorator, BaseColumnNameSuggester):
@property
def message(self):
return "" # Unused, message is not changed
def column_name_error_regex(self):
return r"line \d+:\d+: Column '(.*)' cannot be resolved"

@property
def severity(self):
return QueryValidationSeverity.WARNING # Unused, severity is not changed
def get_column_name_from_error(self, validation_result: QueryValidationResult):
regex_result = re.match(self.column_name_error_regex, validation_result.message)
return regex_result.groups()[0]

def _get_tables_in_query(self, query: str, engine_id: int) -> List[str]:
engine = get_query_engine_by_id(engine_id)
tables_per_statement, _ = process_query(query, language=engine.language)
return list(chain.from_iterable(tables_per_statement))

def _is_column_name_error(self, validation_result: QueryValidationResult) -> bool:
return bool(
re.search(r"Column .* cannot be resolved", validation_result.message)
)

def _get_column_name_from_position(
self, query: str, start_line: int, start_ch: int, raw_tokens: List[Token]
) -> str:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
column_error_start_index = self._get_query_index_by_coordinate(
query, start_line, start_ch
)
for token in raw_tokens:
if token.start == column_error_start_index:
return token.text
return None

def _search_columns_for_suggestion(self, columns: List[str], suggestion: str):
"""Return the case-sensitive column name by searching the table's columns for the suggestion text"""
for col in columns:
if col.lower() == suggestion.lower():
return col
return suggestion

def _get_column_name_suggestion(
self,
validation_result: QueryValidationResult,
query: str,
tables_in_query: List[str],
raw_tokens: List[Token] = None,
):
fuzzy_column_name = self._get_column_name_from_position(
query,
validation_result.start_line,
validation_result.start_ch,
raw_tokens=raw_tokens,
)
if not fuzzy_column_name:
return None
results, count = search_table.get_column_name_suggestion(
fuzzy_column_name, tables_in_query
)
if count == 1: # Only return suggestion if there's a single match
table_result = results[0]
highlights = table_result.get("highlight", {}).get("columns", [])
if len(highlights) == 1:
column_suggestion = self._search_columns_for_suggestion(
table_result.get("columns"), highlights[0]
)
return column_suggestion

return None
class PrestoTableNameSuggester(BasePrestoSQLGlotDecorator, BaseTableNameSuggester):
@property
def table_name_error_regex(self):
return r"line \d+:\d+: Table '(.*)' does not exist"

def validate(
self,
query: str,
uid: int,
engine_id: int,
raw_tokens: List[QueryValidationResult] = None,
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)
tables_in_query = self._get_tables_in_query(query, engine_id)
for result in validation_results:
if self._is_column_name_error(result):
column_suggestion = self._get_column_name_suggestion(
result, query, tables_in_query, raw_tokens
)
if column_suggestion:
result.suggestion = column_suggestion
return validation_results
def get_full_table_name_from_error(self, validation_result: QueryValidationResult):
regex_result = re.match(self.table_name_error_regex, validation_result.message)
return regex_result.groups()[0]


class PrestoOptimizingValidator(BaseQueryValidator):
def languages(self):
return ["presto", "trino"]

@property
def tokenizer(self) -> Tokenizer:
return Trino.Tokenizer()

def _get_explain_validator(self):
return PrestoExplainValidator("")

def _get_decorated_validator(self) -> BaseQueryValidator:
return UnionAllValidator(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the reason of changing from a list of validators to a chain of validators?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the decorator pattern so that we can add suggestions on top of the validation messages

ApproxDistinctValidator(
RegexpLikeValidator(ColumnNameSuggester(self._get_explain_validator()))
RegexpLikeValidator(
PrestoTableNameSuggester(
PrestoColumnNameSuggester(self._get_explain_validator())
)
)
)
)

Expand Down
Loading