Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: sql lab: handling command exceptions #16852

Merged
merged 3 commits into from
Sep 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions superset/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,9 @@ def __post_init__(self) -> None:
]
}
)

def to_dict(self) -> Dict[str, Any]:
rv = {"message": self.message, "error_type": self.error_type}
if self.extra:
rv["extra"] = self.extra # type: ignore
return rv
23 changes: 22 additions & 1 deletion superset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,35 @@ class SupersetException(Exception):
message = ""

def __init__(
self, message: str = "", exception: Optional[Exception] = None,
self,
message: str = "",
exception: Optional[Exception] = None,
error_type: Optional[SupersetErrorType] = None,
) -> None:
if message:
self.message = message
self._exception = exception
self._error_type = error_type
super().__init__(self.message)

@property
def exception(self) -> Optional[Exception]:
return self._exception

@property
def error_type(self) -> Optional[SupersetErrorType]:
return self._error_type

def to_dict(self) -> Dict[str, Any]:
rv = {}
if hasattr(self, "message"):
rv["message"] = self.message
if self.error_type:
rv["error_type"] = self.error_type
if self.exception is not None and hasattr(self.exception, "to_dict"):
rv = {**rv, **self.exception.to_dict()} # type: ignore
return rv


class SupersetErrorException(SupersetException):
"""Exceptions with a single SupersetErrorType associated with them"""
Expand All @@ -49,6 +67,9 @@ def __init__(self, error: SupersetError, status: Optional[int] = None) -> None:
if status is not None:
self.status = status

def to_dict(self) -> Dict[str, Any]:
return self.error.to_dict()


class SupersetGenericErrorException(SupersetErrorException):
"""Exceptions that are too generic to have their own type"""
Expand Down
123 changes: 61 additions & 62 deletions superset/sqllab/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from superset.models.sql_lab import Query
from superset.queries.dao import QueryDAO
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.sqllab.exceptions import SqlLabException
from superset.sqllab.limiting_factor import LimitingFactor
from superset.sqllab.utils import apply_display_max_row_configuration_if_require
from superset.utils import core as utils
Expand All @@ -68,18 +69,18 @@


class ExecuteSqlCommand(BaseCommand):
execution_context: SqlJsonExecutionContext
log_params: Optional[Dict[str, Any]] = None
session: Session
_execution_context: SqlJsonExecutionContext
_log_params: Optional[Dict[str, Any]] = None
_session: Session

def __init__(
self,
execution_context: SqlJsonExecutionContext,
log_params: Optional[Dict[str, Any]] = None,
) -> None:
self.execution_context = execution_context
self.log_params = log_params
self.session = db.session()
self._execution_context = execution_context
self._log_params = log_params
self._session = db.session()

def validate(self) -> None:
pass
Expand All @@ -88,30 +89,29 @@ def run( # pylint: disable=too-many-statements,useless-suppression
self,
) -> CommandResult:
"""Runs arbitrary sql and returns data as json"""
try:
query = self._get_existing_query()
if self.is_query_handled(query):
self._execution_context.set_query(query) # type: ignore
status = SqlJsonExecutionStatus.QUERY_ALREADY_CREATED
else:
status = self._run_sql_json_exec_from_scratch()
return {
"status": status,
"payload": self._create_payload_from_execution_context(status),
}
except (SqlLabException, SupersetErrorsException) as ex:
raise ex
except Exception as ex:
raise SqlLabException(self._execution_context, exception=ex) from ex

query = self._get_existing_query(self.execution_context, self.session)

if self.is_query_handled(query):
self.execution_context.set_query(query) # type: ignore
status = SqlJsonExecutionStatus.QUERY_ALREADY_CREATED
else:
status = self._run_sql_json_exec_from_scratch()

return {
"status": status,
"payload": self._create_payload_from_execution_context(status),
}

@classmethod
def _get_existing_query(
cls, execution_context: SqlJsonExecutionContext, session: Session
) -> Optional[Query]:
def _get_existing_query(self) -> Optional[Query]:
query = (
session.query(Query)
self._session.query(Query)
.filter_by(
client_id=execution_context.client_id,
user_id=execution_context.user_id,
sql_editor_id=execution_context.sql_editor_id,
client_id=self._execution_context.client_id,
user_id=self._execution_context.user_id,
sql_editor_id=self._execution_context.sql_editor_id,
)
.one_or_none()
)
Expand All @@ -126,25 +126,24 @@ def is_query_handled(cls, query: Optional[Query]) -> bool:
]

def _run_sql_json_exec_from_scratch(self) -> SqlJsonExecutionStatus:
self.execution_context.set_database(self._get_the_query_db())
query = self.execution_context.create_query()
self._execution_context.set_database(self._get_the_query_db())
query = self._execution_context.create_query()
self._save_new_query(query)
try:
self._save_new_query(query)
logger.info("Triggering query_id: %i", query.id)
self._validate_access(query)
self.execution_context.set_query(query)
self._execution_context.set_query(query)
rendered_query = self._render_query()

self._set_query_limit_if_required(rendered_query)

return self._execute_query(rendered_query)
except Exception as ex:
query.status = QueryStatus.FAILED
self.session.commit()
self._session.commit()
raise ex

def _get_the_query_db(self) -> Database:
mydb = self.session.query(Database).get(self.execution_context.database_id)
mydb = self._session.query(Database).get(self._execution_context.database_id)
self._validate_query_db(mydb)
return mydb

Expand All @@ -160,12 +159,12 @@ def _validate_query_db(cls, database: Optional[Database]) -> None:

def _save_new_query(self, query: Query) -> None:
try:
self.session.add(query)
self.session.flush()
self.session.commit() # shouldn't be necessary
self._session.add(query)
self._session.flush()
self._session.commit() # shouldn't be necessary
except SQLAlchemyError as ex:
logger.error("Errors saving query details %s", str(ex), exc_info=True)
self.session.rollback()
self._session.rollback()
if not query.id:
raise SupersetGenericErrorException(
__(
Expand All @@ -181,7 +180,7 @@ def _validate_access(self, query: Query) -> None:
query.set_extra_json_key("errors", [dataclasses.asdict(ex.error)])
query.status = QueryStatus.FAILED
query.error_message = ex.error.message
self.session.commit()
self._session.commit()
raise SupersetErrorException(ex.error, status=403) from ex

def _render_query(self) -> str:
Expand All @@ -205,18 +204,18 @@ def validate(
error=SupersetErrorType.MISSING_TEMPLATE_PARAMS_ERROR,
extra={
"undefined_parameters": list(undefined_parameters),
"template_parameters": self.execution_context.template_params,
"template_parameters": self._execution_context.template_params,
},
)

query = self.execution_context.query
query = self._execution_context.query

try:
template_processor = get_template_processor(
database=query.database, query=query
)
rendered_query = template_processor.process_template(
query.sql, **self.execution_context.template_params
query.sql, **self._execution_context.template_params
)
validate(rendered_query, template_processor)
except TemplateError as ex:
Expand All @@ -235,32 +234,32 @@ def _set_query_limit_if_required(self, rendered_query: str,) -> None:

def _is_required_to_set_limit(self) -> bool:
return not (
config.get("SQLLAB_CTAS_NO_LIMIT") and self.execution_context.select_as_cta
config.get("SQLLAB_CTAS_NO_LIMIT") and self._execution_context.select_as_cta
)

def _set_query_limit(self, rendered_query: str) -> None:
db_engine_spec = self.execution_context.database.db_engine_spec # type: ignore
db_engine_spec = self._execution_context.database.db_engine_spec # type: ignore
limits = [
db_engine_spec.get_limit_from_sql(rendered_query),
self.execution_context.limit,
self._execution_context.limit,
]
if limits[0] is None or limits[0] > limits[1]: # type: ignore
self.execution_context.query.limiting_factor = LimitingFactor.DROPDOWN
self._execution_context.query.limiting_factor = LimitingFactor.DROPDOWN
elif limits[1] > limits[0]: # type: ignore
self.execution_context.query.limiting_factor = LimitingFactor.QUERY
self._execution_context.query.limiting_factor = LimitingFactor.QUERY
else: # limits[0] == limits[1]
self.execution_context.query.limiting_factor = (
self._execution_context.query.limiting_factor = (
LimitingFactor.QUERY_AND_DROPDOWN
)
self.execution_context.query.limit = min(
self._execution_context.query.limit = min(
lim for lim in limits if lim is not None
)

def _execute_query(self, rendered_query: str,) -> SqlJsonExecutionStatus:
# Flag for whether or not to expand data
# (feature that will expand Presto row objects and arrays)
# Async request.
if self.execution_context.is_run_asynchronous():
if self._execution_context.is_run_asynchronous():
return self._sql_json_async(rendered_query)

return self._sql_json_sync(rendered_query)
Expand All @@ -271,7 +270,7 @@ def _sql_json_async(self, rendered_query: str,) -> SqlJsonExecutionStatus:
:param rendered_query: the rendered query to perform by workers
:return: A Flask Response
"""
query = self.execution_context.query
query = self._execution_context.query
logger.info("Query %i: Running query on a Celery worker", query.id)
# Ignore the celery future object and the request may time out.
query_id = query.id
Expand All @@ -285,8 +284,8 @@ def _sql_json_async(self, rendered_query: str,) -> SqlJsonExecutionStatus:
if g.user and hasattr(g.user, "username")
else None,
start_time=now_as_float(),
expand_data=self.execution_context.expand_data,
log_params=self.log_params,
expand_data=self._execution_context.expand_data,
log_params=self._log_params,
)

# Explicitly forget the task to ensure the task metadata is removed from the
Expand All @@ -312,14 +311,14 @@ def _sql_json_async(self, rendered_query: str,) -> SqlJsonExecutionStatus:
query.set_extra_json_key("errors", [error_payload])
query.status = QueryStatus.FAILED
query.error_message = message
self.session.commit()
self._session.commit()

raise SupersetErrorException(error) from ex

# Update saved query with execution info from the query execution
QueryDAO.update_saved_query_exec_info(query_id)

self.session.commit()
self._session.commit()
return SqlJsonExecutionStatus.QUERY_IS_RUNNING

def _sql_json_sync(self, rendered_query: str) -> SqlJsonExecutionStatus:
Expand All @@ -329,7 +328,7 @@ def _sql_json_sync(self, rendered_query: str) -> SqlJsonExecutionStatus:
:param rendered_query: The rendered query (included templates)
:raises: SupersetTimeoutException
"""
query = self.execution_context.query
query = self._execution_context.query
try:
timeout = config["SQLLAB_TIMEOUT"]
timeout_msg = f"The query exceeded the {timeout} seconds timeout."
Expand All @@ -339,7 +338,7 @@ def _sql_json_sync(self, rendered_query: str) -> SqlJsonExecutionStatus:
)
# Update saved query if needed
QueryDAO.update_saved_query_exec_info(query_id)
self.execution_context.set_execution_result(data)
self._execution_context.set_execution_result(data)
except SupersetTimeoutException as ex:
# re-raise exception for api exception handler
raise ex
Expand All @@ -362,7 +361,7 @@ def _sql_json_sync(self, rendered_query: str) -> SqlJsonExecutionStatus:
def _get_sql_results_with_timeout(
self, timeout: int, rendered_query: str, timeout_msg: str,
) -> Optional[SqlResults]:
query = self.execution_context.query
query = self._execution_context.query
with utils.timeout(seconds=timeout, error_message=timeout_msg):
# pylint: disable=no-value-for-parameter
return sql_lab.get_sql_results(
Expand All @@ -373,8 +372,8 @@ def _get_sql_results_with_timeout(
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
expand_data=self.execution_context.expand_data,
log_params=self.log_params,
expand_data=self._execution_context.expand_data,
log_params=self._log_params,
)

@classmethod
Expand All @@ -389,9 +388,9 @@ def _create_payload_from_execution_context( # pylint: disable=invalid-name

if status == SqlJsonExecutionStatus.HAS_RESULTS:
return self._to_payload_results_based(
self.execution_context.get_execution_result() or {}
self._execution_context.get_execution_result() or {}
)
return self._to_payload_query_based(self.execution_context.query)
return self._to_payload_query_based(self._execution_context.query)

def _to_payload_results_based( # pylint: disable=no-self-use
self, execution_result: SqlResults
Expand Down
Loading