From dd1c7d2bb41f42c703af509084f50fe4aeb9c1f1 Mon Sep 17 00:00:00 2001 From: Elizabeth Thompson Date: Mon, 6 Nov 2023 16:59:41 -0800 Subject: [PATCH 1/2] move execution_id to global context --- superset/reports/commands/execute.py | 58 +++++++++++-- superset/reports/notifications/email.py | 8 +- superset/reports/notifications/slack.py | 10 ++- superset/utils/decorators.py | 38 +++++++- superset/utils/webdriver.py | 5 +- tests/unit_tests/notifications/email_tests.py | 83 ++++++++++++++++++ tests/unit_tests/notifications/slack_tests.py | 87 +++++++++++++++++++ tests/unit_tests/utils/test_decorators.py | 49 +++++++++++ 8 files changed, 327 insertions(+), 11 deletions(-) create mode 100644 tests/unit_tests/notifications/slack_tests.py diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py index 301bac4531575..8cb15bb1d2fcb 100644 --- a/superset/reports/commands/execute.py +++ b/superset/reports/commands/execute.py @@ -73,6 +73,7 @@ from superset.utils.celery import session_scope from superset.utils.core import HeaderDataType, override_user from superset.utils.csv import get_chart_csv_data, get_chart_dataframe +from superset.utils.decorators import context from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot from superset.utils.urls import get_url_path @@ -83,6 +84,7 @@ class BaseReportState: current_states: list[ReportState] = [] initial: bool = False + @context() def __init__( self, session: Session, @@ -234,7 +236,12 @@ def _get_screenshots(self) -> list[bytes]: try: image = screenshot.get_screenshot(user=user) except SoftTimeLimitExceeded as ex: - logger.warning("A timeout occurred while taking a screenshot.") + logger.warning( + "A timeout occurred while taking a screenshot.", + extra={ + "execution_id": self._execution_id, + }, + ) raise ReportScheduleScreenshotTimeout() from ex except Exception as ex: raise ReportScheduleScreenshotFailedError( @@ -254,11 +261,23 @@ def _get_csv_data(self) -> bytes: auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(user) if self._report_schedule.chart.query_context is None: - logger.warning("No query context found, taking a screenshot to generate it") + logger.warning( + "No query context found, taking a screenshot to generate it", + extra={ + "execution_id": self._execution_id, + }, + ) self._update_query_context() try: - logger.info("Getting chart from %s as user %s", url, user.username) + logger.info( + "Getting chart from %s as user %s", + url, + user.username, + extra={ + "execution_id": self._execution_id, + }, + ) csv_data = get_chart_csv_data(chart_url=url, auth_cookies=auth_cookies) except SoftTimeLimitExceeded as ex: raise ReportScheduleCsvTimeout() from ex @@ -283,11 +302,23 @@ def _get_embedded_data(self) -> pd.DataFrame: auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(user) if self._report_schedule.chart.query_context is None: - logger.warning("No query context found, taking a screenshot to generate it") + logger.warning( + "No query context found, taking a screenshot to generate it", + extra={ + "execution_id": self._execution_id, + }, + ) self._update_query_context() try: - logger.info("Getting chart from %s as user %s", url, user.username) + logger.info( + "Getting chart from %s as user %s", + url, + user.username, + extra={ + "execution_id": self._execution_id, + }, + ) dataframe = get_chart_dataframe(url, auth_cookies) except SoftTimeLimitExceeded as ex: raise ReportScheduleDataFrameTimeout() from ex @@ -440,7 +471,12 @@ def _send( if notification_errors: # log all errors but raise based on the most severe for error in notification_errors: - logger.warning(str(error)) + logger.warning( + str(error), + extra={ + "execution_id": self._execution_id, + }, + ) if any(error.level == ErrorLevel.ERROR for error in notification_errors): raise ReportScheduleSystemErrorsException(errors=notification_errors) @@ -466,7 +502,9 @@ def send_error(self, name: str, message: str) -> None: logger.info( "header_data in notifications for alerts and reports %s, taskid, %s", header_data, - self._execution_id, + extra={ + "execution_id": self._execution_id, + }, ) notification_content = NotificationContent( name=name, text=message, header_data=header_data @@ -725,6 +763,9 @@ def run(self) -> None: "Running report schedule %s as user %s", self._execution_id, username, + extra={ + "execution_id": self._execution_id, + }, ) ReportScheduleStateMachine( session, self._execution_id, self._model, self._scheduled_dttm @@ -740,6 +781,9 @@ def validate(self, session: Session = None) -> None: "session is validated: id %s, executionid: %s", self._model_id, self._execution_id, + extra={ + "execution_id": self._execution_id, + }, ) self._model = ( session.query(ReportSchedule).filter_by(id=self._model_id).one_or_none() diff --git a/superset/reports/notifications/email.py b/superset/reports/notifications/email.py index 1b9e4ade72f01..78afd54b8c827 100644 --- a/superset/reports/notifications/email.py +++ b/superset/reports/notifications/email.py @@ -22,6 +22,7 @@ from typing import Any, Optional import nh3 +from flask import g from flask_babel import gettext as __ from superset import app @@ -190,6 +191,7 @@ def send(self) -> None: subject = self._get_subject() content = self._get_content() to = self._get_to() + global_context = getattr(g, "context", {}) or {} try: send_email_smtp( to, @@ -205,7 +207,11 @@ def send(self) -> None: header_data=content.header_data, ) logger.info( - "Report sent to email, notification content is %s", content.header_data + "Report sent to email, notification content is %s", + content.header_data, + extra={ + "execution_id": global_context.get("execution_id"), + }, ) except SupersetErrorsException as ex: raise NotificationError( diff --git a/superset/reports/notifications/slack.py b/superset/reports/notifications/slack.py index a769622b57640..f227e72432df0 100644 --- a/superset/reports/notifications/slack.py +++ b/superset/reports/notifications/slack.py @@ -22,6 +22,7 @@ import backoff import pandas as pd +from flask import g from flask_babel import gettext as __ from slack_sdk import WebClient from slack_sdk.errors import ( @@ -166,6 +167,8 @@ def send(self) -> None: channel = self._get_channel() body = self._get_body() file_type = "csv" if self._content.csv else "png" + global_context = getattr(g, "context", {}) or {} + try: token = app.config["SLACK_API_TOKEN"] if callable(token): @@ -183,7 +186,12 @@ def send(self) -> None: ) else: client.chat_postMessage(channel=channel, text=body) - logger.info("Report sent to slack") + logger.info( + "Report sent to slack", + extra={ + "execution_id": global_context.get("execution_id"), + }, + ) except ( BotUserAccessError, SlackRequestError, diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index 9c21e3b5ec5a5..9b4ae0d5e6e6e 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -20,8 +20,9 @@ from collections.abc import Iterator from contextlib import contextmanager from typing import Any, Callable, TYPE_CHECKING +from uuid import UUID -from flask import current_app, Response +from flask import current_app, g, Response from superset.utils import core as utils from superset.utils.dates import now_as_float @@ -111,3 +112,38 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: def on_security_exception(self: Any, ex: Exception) -> Response: return self.response(403, **{"message": utils.error_msg_from_exception(ex)}) + + +def context( + slice_id: int | None = None, + dashboard_id: int | None = None, + execution_id: str | UUID | None = None, +) -> Callable[..., Any]: + """ + Takes arguments and adds them to the global context. + This is for logging purposes only and values should not be relied on or mutated + """ + + def decorate(f: Callable[..., Any]) -> Callable[..., Any]: + def wrapped(*args: Any, **kwargs: Any) -> Any: + if not hasattr(g, "context"): + g.context = {} + available_context_values = ["slice_id", "dashboard_id", "execution_id"] + context_data = { + key: val + for key, val in kwargs.items() + if key in available_context_values + } + + # if values are passed in to decorator directly, add them to context + # by overriding values from kwargs + for val in available_context_values: + if locals().get(val) is not None: + context_data[val] = locals()[val] + + g.context.update(context_data) + return f(*args, **kwargs) + + return wrapped + + return decorate diff --git a/superset/utils/webdriver.py b/superset/utils/webdriver.py index 4353319072287..579e444800818 100644 --- a/superset/utils/webdriver.py +++ b/superset/utils/webdriver.py @@ -23,7 +23,7 @@ from time import sleep from typing import Any, TYPE_CHECKING -from flask import current_app +from flask import current_app, g from selenium.common.exceptions import ( StaleElementReferenceException, TimeoutException, @@ -227,6 +227,9 @@ def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | Non "Taking a PNG screenshot of url %s as user %s", url, user.username, + extra={ + "execution_id": getattr(g, "context", {}).get("execution_id"), + }, ) if current_app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]: unexpected_errors = WebDriverPlaywright.find_unexpected_errors(page) diff --git a/tests/unit_tests/notifications/email_tests.py b/tests/unit_tests/notifications/email_tests.py index 697a9bac40c86..0af04a333ffa1 100644 --- a/tests/unit_tests/notifications/email_tests.py +++ b/tests/unit_tests/notifications/email_tests.py @@ -14,7 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import uuid +from unittest.mock import MagicMock, patch, PropertyMock + import pandas as pd +from flask import g def test_render_description_with_html() -> None: @@ -55,3 +59,82 @@ def test_render_description_with_html() -> None: in email_body ) assert '<a href="http://www.example.com">333</a>' in email_body + + +@patch("superset.reports.notifications.email.send_email_smtp") +@patch("superset.reports.notifications.email.logger") +def test_send_email( + logger_mock: MagicMock, + send_email_mock: MagicMock, +) -> None: + # `superset.models.helpers`, a dependency of following imports, + # requires app context + from superset.reports.models import ReportRecipients, ReportRecipientType + from superset.reports.notifications.base import NotificationContent + from superset.reports.notifications.email import EmailNotification + + execution_id = uuid.uuid4() + g.context = {"execution_id": execution_id} + content = NotificationContent( + name="test alert", + embedded_data=pd.DataFrame( + { + "A": [1, 2, 3], + "B": [4, 5, 6], + "C": ["111", "222", '333'], + } + ), + description='

This is a test alert


', + header_data={ + "notification_format": "PNG", + "notification_type": "Alert", + "owners": [1], + "notification_source": None, + "chart_id": None, + "dashboard_id": None, + }, + ) + EmailNotification( + recipient=ReportRecipients( + type=ReportRecipientType.EMAIL, + recipient_config_json='{"target": "foo@bar.com"}', + ), + content=content, + ).send() + + logger_mock.info.assert_called_with( + "Report sent to email, notification content is %s", + { + "notification_format": "PNG", + "notification_type": "Alert", + "owners": [1], + "notification_source": None, + "chart_id": None, + "dashboard_id": None, + }, + extra={"execution_id": execution_id}, + ) + # clear logger_mock and g.context + logger_mock.reset_mock() + g.context = None + + # test with no execution_id + EmailNotification( + recipient=ReportRecipients( + type=ReportRecipientType.EMAIL, + recipient_config_json='{"target": "foo@bar.com"}', + ), + content=content, + ).send() + logger_mock.info.assert_called_with( + "Report sent to email, notification content is %s", + { + "notification_format": "PNG", + "notification_type": "Alert", + "owners": [1], + "notification_source": None, + "chart_id": None, + "dashboard_id": None, + }, + extra={"execution_id": None}, + ) diff --git a/tests/unit_tests/notifications/slack_tests.py b/tests/unit_tests/notifications/slack_tests.py new file mode 100644 index 0000000000000..6a6b90a2db7b4 --- /dev/null +++ b/tests/unit_tests/notifications/slack_tests.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import uuid +from unittest.mock import MagicMock, patch + +import pandas as pd +from flask import g + + +@patch("superset.reports.notifications.slack.logger") +def test_send_slack( + logger_mock: MagicMock, +) -> None: + # `superset.models.helpers`, a dependency of following imports, + # requires app context + from superset.reports.models import ReportRecipients, ReportRecipientType + from superset.reports.notifications.base import NotificationContent + from superset.reports.notifications.slack import SlackNotification, WebClient + + execution_id = uuid.uuid4() + g.context = {"execution_id": execution_id} + content = NotificationContent( + name="test alert", + embedded_data=pd.DataFrame( + { + "A": [1, 2, 3], + "B": [4, 5, 6], + "C": ["111", "222", '333'], + } + ), + description='

This is a test alert


', + header_data={ + "notification_format": "PNG", + "notification_type": "Alert", + "owners": [1], + "notification_source": None, + "chart_id": None, + "dashboard_id": None, + }, + ) + with patch.object( + WebClient, "chat_postMessage", return_value=True + ) as chat_post_message_mock: + SlackNotification( + recipient=ReportRecipients( + type=ReportRecipientType.SLACK, + recipient_config_json='{"target": "some_channel"}', + ), + content=content, + ).send() + logger_mock.info.assert_called_with( + "Report sent to slack", extra={"execution_id": execution_id} + ) + chat_post_message_mock.assert_called_with( + channel="some_channel", + text="""*test alert* + +

This is a test alert


+ + + +``` +| | A | B | C | +|---:|----:|----:|:-----------------------------------------| +| 0 | 1 | 4 | 111 | +| 1 | 2 | 5 | 222 | +| 2 | 3 | 6 | 333 | +``` +""", + ) + + # reset g.context + g.context = None diff --git a/tests/unit_tests/utils/test_decorators.py b/tests/unit_tests/utils/test_decorators.py index f600334414eeb..4d0ad0acc8e9c 100644 --- a/tests/unit_tests/utils/test_decorators.py +++ b/tests/unit_tests/utils/test_decorators.py @@ -16,12 +16,14 @@ # under the License. +import uuid from contextlib import nullcontext from inspect import isclass from typing import Any, Optional from unittest.mock import call, Mock, patch import pytest +from flask import g from superset import app from superset.utils import decorators @@ -85,3 +87,50 @@ def my_func(response: ResponseValues, *args: Any, **kwargs: Any) -> str: with cm: my_func(response_value, 1, 2) mock.assert_called_once_with(expected_result, 1) + + +# write a TDD test for a decorator that will log context to a global scope. It will write dashboard_id, slice_id, and execution_id to a global scope. +def test_context_decorator() -> None: + @decorators.context() + def myfunc(*args, **kwargs) -> str: + return "test" + + # should be able to add values to the decorator function directly + @decorators.context(slice_id=1, dashboard_id=1, execution_id=uuid.uuid4()) + def myfunc_with_kwargs(*args, **kwargs) -> str: + return "test" + + # should not add any data to the global.context scope + myfunc(1, 1) + assert g.context == {} + + # should add dashboard_id to the global.context scope + myfunc(1, 1, dashboard_id=1) + assert g.context == {"dashboard_id": 1} + g.context = {} + + # should add slice_id to the global.context scope + myfunc(1, 1, slice_id=1) + assert g.context == {"slice_id": 1} + g.context = {} + + # should add execution_id to the global.context scope + myfunc(1, 1, execution_id=1) + assert g.context == {"execution_id": 1} + g.context = {} + + # should add all three to the global.context scope + myfunc(1, 1, dashboard_id=1, slice_id=1, execution_id=1) + assert g.context == {"dashboard_id": 1, "slice_id": 1, "execution_id": 1} + g.context = {} + + # should overwrite existing values in the global.context scope + g.context = {"dashboard_id": 2, "slice_id": 2, "execution_id": 2} + myfunc(1, 1, dashboard_id=3, slice_id=3, execution_id=3) + assert g.context == {"dashboard_id": 3, "slice_id": 3, "execution_id": 3} + g.context = {} + + # should be able to add values to the decorator function directly + myfunc_with_kwargs(slice_id=1, dashboard_id=1, execution_id=1) + assert g.context == {"dashboard_id": 1, "slice_id": 1, "execution_id": 1} + g.context = {} From 6ae796779d038808b25727bc8e8ddf845b7d634f Mon Sep 17 00:00:00 2001 From: Elizabeth Thompson Date: Mon, 6 Nov 2023 17:47:48 -0800 Subject: [PATCH 2/2] move notification header to global context --- superset/reports/commands/execute.py | 34 +-- superset/reports/notifications/base.py | 2 - superset/reports/notifications/email.py | 10 +- superset/reports/notifications/utils.py | 200 ++++++++++++++++++ superset/reports/types.py | 11 +- superset/tasks/scheduler.py | 8 +- superset/utils/core.py | 134 ------------ superset/utils/decorators.py | 15 +- tests/integration_tests/email_tests.py | 2 +- tests/integration_tests/utils_tests.py | 22 +- tests/unit_tests/notifications/email_tests.py | 36 +--- tests/unit_tests/notifications/slack_tests.py | 8 - tests/unit_tests/notifications/utils_tests.py | 33 +++ tests/unit_tests/utils/test_decorators.py | 1 - 14 files changed, 279 insertions(+), 237 deletions(-) create mode 100644 superset/reports/notifications/utils.py create mode 100644 tests/unit_tests/notifications/utils_tests.py diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py index 8cb15bb1d2fcb..a660db5edccfa 100644 --- a/superset/reports/commands/execute.py +++ b/superset/reports/commands/execute.py @@ -63,7 +63,6 @@ ReportRecipientType, ReportSchedule, ReportScheduleType, - ReportSourceFormat, ReportState, ) from superset.reports.notifications import create_notification @@ -71,7 +70,7 @@ from superset.reports.notifications.exceptions import NotificationError from superset.tasks.utils import get_executor from superset.utils.celery import session_scope -from superset.utils.core import HeaderDataType, override_user +from superset.utils.core import override_user from superset.utils.csv import get_chart_csv_data, get_chart_dataframe from superset.utils.decorators import context from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot @@ -351,27 +350,6 @@ def _update_query_context(self) -> None: "Please try loading the chart and saving it again." ) from ex - def _get_log_data(self) -> HeaderDataType: - chart_id = None - dashboard_id = None - report_source = None - if self._report_schedule.chart: - report_source = ReportSourceFormat.CHART - chart_id = self._report_schedule.chart_id - else: - report_source = ReportSourceFormat.DASHBOARD - dashboard_id = self._report_schedule.dashboard_id - - log_data: HeaderDataType = { - "notification_type": self._report_schedule.type, - "notification_source": report_source, - "notification_format": self._report_schedule.report_format, - "chart_id": chart_id, - "dashboard_id": dashboard_id, - "owners": self._report_schedule.owners, - } - return log_data - def _get_notification_content(self) -> NotificationContent: """ Gets a notification content, this is composed by a title and a screenshot @@ -382,7 +360,6 @@ def _get_notification_content(self) -> NotificationContent: embedded_data = None error_text = None screenshot_data = [] - header_data = self._get_log_data() url = self._get_url(user_friendly=True) if ( feature_flag_manager.is_feature_enabled("ALERTS_ATTACH_REPORTS") @@ -403,7 +380,6 @@ def _get_notification_content(self) -> NotificationContent: return NotificationContent( name=self._report_schedule.name, text=error_text, - header_data=header_data, ) if ( @@ -430,7 +406,6 @@ def _get_notification_content(self) -> NotificationContent: description=self._report_schedule.description, csv=csv_data, embedded_data=embedded_data, - header_data=header_data, ) def _send( @@ -498,16 +473,15 @@ def send_error(self, name: str, message: str) -> None: :raises: CommandException """ - header_data = self._get_log_data() logger.info( - "header_data in notifications for alerts and reports %s, taskid, %s", - header_data, + "An error for a notification occurred, sending error notification", extra={ "execution_id": self._execution_id, }, ) notification_content = NotificationContent( - name=name, text=message, header_data=header_data + name=name, + text=message, ) # filter recipients to recipients who are also owners diff --git a/superset/reports/notifications/base.py b/superset/reports/notifications/base.py index 640b326fc53d8..2e014e029eae9 100644 --- a/superset/reports/notifications/base.py +++ b/superset/reports/notifications/base.py @@ -20,13 +20,11 @@ import pandas as pd from superset.reports.models import ReportRecipients, ReportRecipientType -from superset.utils.core import HeaderDataType @dataclass class NotificationContent: name: str - header_data: HeaderDataType # this is optional to account for error states csv: Optional[bytes] = None # bytes for csv file screenshots: Optional[list[bytes]] = None # bytes for a list of screenshots text: Optional[str] = None diff --git a/superset/reports/notifications/email.py b/superset/reports/notifications/email.py index 78afd54b8c827..a022aca1db400 100644 --- a/superset/reports/notifications/email.py +++ b/superset/reports/notifications/email.py @@ -30,7 +30,7 @@ from superset.reports.models import ReportRecipientType from superset.reports.notifications.base import BaseNotification from superset.reports.notifications.exceptions import NotificationError -from superset.utils.core import HeaderDataType, send_email_smtp +from superset.reports.notifications.utils import send_email_smtp from superset.utils.decorators import statsd_gauge logger = logging.getLogger(__name__) @@ -68,7 +68,6 @@ @dataclass class EmailContent: body: str - header_data: Optional[HeaderDataType] = None data: Optional[dict[str, Any]] = None images: Optional[dict[str, bytes]] = None @@ -173,7 +172,6 @@ def _get_content(self) -> EmailContent: body=body, images=images, data=csv_data, - header_data=self._content.header_data, ) def _get_subject(self) -> str: @@ -204,13 +202,13 @@ def send(self) -> None: bcc="", mime_subtype="related", dryrun=False, - header_data=content.header_data, ) logger.info( - "Report sent to email, notification content is %s", - content.header_data, + "Report sent to email", extra={ "execution_id": global_context.get("execution_id"), + "dashboard_id": global_context.get("dashboard_id"), + "chart_id": global_context.get("chart_id"), }, ) except SupersetErrorsException as ex: diff --git a/superset/reports/notifications/utils.py b/superset/reports/notifications/utils.py new file mode 100644 index 0000000000000..6761f50bdb560 --- /dev/null +++ b/superset/reports/notifications/utils.py @@ -0,0 +1,200 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import logging +import os +import re +import smtplib +import ssl +from email.mime.application import MIMEApplication +from email.mime.image import MIMEImage +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText +from email.utils import formatdate +from typing import Any, Optional + +from flask import g + +from superset import db +from superset.reports.models import ReportSchedule, ReportSourceFormat +from superset.reports.types import HeaderDataType + +logger = logging.getLogger(__name__) + + +def get_email_address_list(address_string: str) -> list[str]: + address_string_list: list[str] = [] + if isinstance(address_string, str): + address_string_list = re.split(r",|\s|;", address_string) + return [x.strip() for x in address_string_list if x.strip()] + + +def _get_log_data() -> HeaderDataType: + global_context = getattr(g, "context", {}) or {} + chart_id = global_context.get("chart_id") + dashboard_id = global_context.get("dashboard_id") + report_schedule_id = global_context.get("report_schedule_id") + report_source: str = "" + report_format: str = "" + report_type: str = "" + owners: list[int] = [] + + # intentionally creating a new session to + # keep the logging session separate from the + # main session + session = db.create_scoped_session() + report_schedule = ( + session.query(ReportSchedule).filter_by(id=report_schedule_id).one_or_none() + ) + session.close() + + if report_schedule is not None: + report_type = report_schedule.type + report_format = report_schedule.report_format + owners = report_schedule.owners + report_source = ( + ReportSourceFormat.DASHBOARD + if report_schedule.dashboard_id + else ReportSourceFormat.CHART + ) + + log_data: HeaderDataType = { + "notification_type": report_type, + "notification_source": report_source, + "notification_format": report_format, + "chart_id": chart_id, + "dashboard_id": dashboard_id, + "owners": owners, + } + return log_data + + +def send_email_smtp( + to: str, + subject: str, + html_content: str, + config: dict[str, Any], + files: Optional[list[str]] = None, + data: Optional[dict[str, str]] = None, + images: Optional[dict[str, bytes]] = None, + dryrun: bool = False, + cc: Optional[str] = None, + bcc: Optional[str] = None, + mime_subtype: str = "mixed", +) -> None: + """ + Send an email with html content, eg: + send_email_smtp( + 'test@example.com', 'foo', 'Foo bar',['/dev/null'], dryrun=True) + """ + smtp_mail_from = config["SMTP_MAIL_FROM"] + smtp_mail_to = get_email_address_list(to) + + msg = MIMEMultipart(mime_subtype) + msg["Subject"] = subject + msg["From"] = smtp_mail_from + msg["To"] = ", ".join(smtp_mail_to) + + msg.preamble = "This is a multi-part message in MIME format." + + recipients = smtp_mail_to + if cc: + smtp_mail_cc = get_email_address_list(cc) + msg["CC"] = ", ".join(smtp_mail_cc) + recipients = recipients + smtp_mail_cc + + if bcc: + # don't add bcc in header + smtp_mail_bcc = get_email_address_list(bcc) + recipients = recipients + smtp_mail_bcc + + msg["Date"] = formatdate(localtime=True) + mime_text = MIMEText(html_content, "html") + msg.attach(mime_text) + + # Attach files by reading them from disk + for fname in files or []: + basename = os.path.basename(fname) + with open(fname, "rb") as f: + msg.attach( + MIMEApplication( + f.read(), + Content_Disposition=f"attachment; filename='{basename}'", + Name=basename, + ) + ) + + # Attach any files passed directly + for name, body in (data or {}).items(): + msg.attach( + MIMEApplication( + body, Content_Disposition=f"attachment; filename='{name}'", Name=name + ) + ) + + # Attach any inline images, which may be required for display in + # HTML content (inline) + for msgid, imgdata in (images or {}).items(): + formatted_time = formatdate(localtime=True) + file_name = f"{subject} {formatted_time}" + image = MIMEImage(imgdata, name=file_name) + image.add_header("Content-ID", f"<{msgid}>") + image.add_header("Content-Disposition", "inline") + msg.attach(image) + msg_mutator = config["EMAIL_HEADER_MUTATOR"] + # the base notification returns the message without any editing. + header_data = _get_log_data() + new_msg = msg_mutator(msg, **(header_data or {})) + send_mime_email(smtp_mail_from, recipients, new_msg, config, dryrun=dryrun) + + +def send_mime_email( + e_from: str, + e_to: list[str], + mime_msg: MIMEMultipart, + config: dict[str, Any], + dryrun: bool = False, +) -> None: + smtp_host = config["SMTP_HOST"] + smtp_port = config["SMTP_PORT"] + smtp_user = config["SMTP_USER"] + smtp_password = config["SMTP_PASSWORD"] + smtp_starttls = config["SMTP_STARTTLS"] + smtp_ssl = config["SMTP_SSL"] + smtp_ssl_server_auth = config["SMTP_SSL_SERVER_AUTH"] + + if dryrun: + logger.info("Dryrun enabled, email notification content is below:") + logger.info(mime_msg.as_string()) + return + + # Default ssl context is SERVER_AUTH using the default system + # root CA certificates + ssl_context = ssl.create_default_context() if smtp_ssl_server_auth else None + smtp = ( + smtplib.SMTP_SSL(smtp_host, smtp_port, context=ssl_context) + if smtp_ssl + else smtplib.SMTP(smtp_host, smtp_port) + ) + if smtp_starttls: + smtp.starttls(context=ssl_context) + if smtp_user and smtp_password: + smtp.login(smtp_user, smtp_password) + logger.debug("Sent an email to %s", str(e_to)) + smtp.sendmail(e_from, e_to, mime_msg.as_string()) + smtp.quit() diff --git a/superset/reports/types.py b/superset/reports/types.py index d487e3ad23766..d71897379dc56 100644 --- a/superset/reports/types.py +++ b/superset/reports/types.py @@ -14,10 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TypedDict +from typing import Optional, TypedDict from superset.dashboards.permalink.types import DashboardPermalinkState class ReportScheduleExtra(TypedDict): dashboard: DashboardPermalinkState + + +class HeaderDataType(TypedDict): + notification_format: str + owners: list[int] + notification_type: str + notification_source: Optional[str] + chart_id: Optional[int] + dashboard_id: Optional[int] diff --git a/superset/tasks/scheduler.py b/superset/tasks/scheduler.py index f3cc270b86347..cc4b9de54409c 100644 --- a/superset/tasks/scheduler.py +++ b/superset/tasks/scheduler.py @@ -31,6 +31,7 @@ from superset.tasks.cron_util import cron_schedule_window from superset.utils.celery import session_scope from superset.utils.core import LoggerLevel +from superset.utils.decorators import context from superset.utils.log import get_logger_from_status logger = logging.getLogger(__name__) @@ -78,6 +79,7 @@ def scheduler() -> None: @celery_app.task(name="reports.execute", bind=True) +@context() def execute(self: Celery.task, report_schedule_id: int) -> None: stats_logger: BaseStatsLogger = app.config["STATS_LOGGER"] stats_logger.incr("reports.execute") @@ -90,6 +92,7 @@ def execute(self: Celery.task, report_schedule_id: int) -> None: "Executing alert/report, task id: %s, scheduled_dttm: %s", task_id, scheduled_dttm, + extra={"execution_id": task_id}, ) AsyncExecuteReportScheduleCommand( task_id, @@ -98,7 +101,9 @@ def execute(self: Celery.task, report_schedule_id: int) -> None: ).run() except ReportScheduleUnexpectedError: logger.exception( - "An unexpected occurred while executing the report: %s", task_id + "An unexpected occurred while executing the report: %s", + task_id, + extra={"execution_id": task_id}, ) self.update_state(state="FAILURE") except CommandException as ex: @@ -107,6 +112,7 @@ def execute(self: Celery.task, report_schedule_id: int) -> None: f"A downstream {level} occurred " f"while generating a report: {task_id}. {ex.message}", exc_info=True, + extra={"execution_id": task_id}, ) if level == LoggerLevel.EXCEPTION: self.update_state(state="FAILURE") diff --git a/superset/utils/core.py b/superset/utils/core.py index 67edabe626d6f..310c086acfcf6 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -28,9 +28,7 @@ import platform import re import signal -import smtplib import sqlite3 -import ssl import tempfile import threading import traceback @@ -164,15 +162,6 @@ class LoggerLevel(StrEnum): EXCEPTION = "exception" -class HeaderDataType(TypedDict): - notification_format: str - owners: list[int] - notification_type: str - notification_source: str | None - chart_id: int | None - dashboard_id: int | None - - class DatasourceDict(TypedDict): type: str # todo(hugh): update this to be DatasourceType id: int @@ -815,129 +804,6 @@ def set_sqlite_pragma( # pylint: disable=unused-argument cursor.execute("PRAGMA foreign_keys=ON") -def send_email_smtp( # pylint: disable=invalid-name,too-many-arguments,too-many-locals - to: str, - subject: str, - html_content: str, - config: dict[str, Any], - files: list[str] | None = None, - data: dict[str, str] | None = None, - images: dict[str, bytes] | None = None, - dryrun: bool = False, - cc: str | None = None, - bcc: str | None = None, - mime_subtype: str = "mixed", - header_data: HeaderDataType | None = None, -) -> None: - """ - Send an email with html content, eg: - send_email_smtp( - 'test@example.com', 'foo', 'Foo bar',['/dev/null'], dryrun=True) - """ - smtp_mail_from = config["SMTP_MAIL_FROM"] - smtp_mail_to = get_email_address_list(to) - - msg = MIMEMultipart(mime_subtype) - msg["Subject"] = subject - msg["From"] = smtp_mail_from - msg["To"] = ", ".join(smtp_mail_to) - - msg.preamble = "This is a multi-part message in MIME format." - - recipients = smtp_mail_to - if cc: - smtp_mail_cc = get_email_address_list(cc) - msg["CC"] = ", ".join(smtp_mail_cc) - recipients = recipients + smtp_mail_cc - - if bcc: - # don't add bcc in header - smtp_mail_bcc = get_email_address_list(bcc) - recipients = recipients + smtp_mail_bcc - - msg["Date"] = formatdate(localtime=True) - mime_text = MIMEText(html_content, "html") - msg.attach(mime_text) - - # Attach files by reading them from disk - for fname in files or []: - basename = os.path.basename(fname) - with open(fname, "rb") as f: - msg.attach( - MIMEApplication( - f.read(), - Content_Disposition=f"attachment; filename='{basename}'", - Name=basename, - ) - ) - - # Attach any files passed directly - for name, body in (data or {}).items(): - msg.attach( - MIMEApplication( - body, Content_Disposition=f"attachment; filename='{name}'", Name=name - ) - ) - - # Attach any inline images, which may be required for display in - # HTML content (inline) - for msgid, imgdata in (images or {}).items(): - formatted_time = formatdate(localtime=True) - file_name = f"{subject} {formatted_time}" - image = MIMEImage(imgdata, name=file_name) - image.add_header("Content-ID", f"<{msgid}>") - image.add_header("Content-Disposition", "inline") - msg.attach(image) - msg_mutator = config["EMAIL_HEADER_MUTATOR"] - # the base notification returns the message without any editing. - new_msg = msg_mutator(msg, **(header_data or {})) - send_mime_email(smtp_mail_from, recipients, new_msg, config, dryrun=dryrun) - - -def send_mime_email( - e_from: str, - e_to: list[str], - mime_msg: MIMEMultipart, - config: dict[str, Any], - dryrun: bool = False, -) -> None: - smtp_host = config["SMTP_HOST"] - smtp_port = config["SMTP_PORT"] - smtp_user = config["SMTP_USER"] - smtp_password = config["SMTP_PASSWORD"] - smtp_starttls = config["SMTP_STARTTLS"] - smtp_ssl = config["SMTP_SSL"] - smtp_ssl_server_auth = config["SMTP_SSL_SERVER_AUTH"] - - if dryrun: - logger.info("Dryrun enabled, email notification content is below:") - logger.info(mime_msg.as_string()) - return - - # Default ssl context is SERVER_AUTH using the default system - # root CA certificates - ssl_context = ssl.create_default_context() if smtp_ssl_server_auth else None - smtp = ( - smtplib.SMTP_SSL(smtp_host, smtp_port, context=ssl_context) - if smtp_ssl - else smtplib.SMTP(smtp_host, smtp_port) - ) - if smtp_starttls: - smtp.starttls(context=ssl_context) - if smtp_user and smtp_password: - smtp.login(smtp_user, smtp_password) - logger.debug("Sent an email to %s", str(e_to)) - smtp.sendmail(e_from, e_to, mime_msg.as_string()) - smtp.quit() - - -def get_email_address_list(address_string: str) -> list[str]: - address_string_list: list[str] = [] - if isinstance(address_string, str): - address_string_list = re.split(r",|\s|;", address_string) - return [x.strip() for x in address_string_list if x.strip()] - - def choicify(values: Iterable[Any]) -> list[tuple[Any, Any]]: """Takes an iterable and makes an iterable of tuples with it""" return [(v, v) for v in values] diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index 9b4ae0d5e6e6e..b29851fe55d9f 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -114,11 +114,7 @@ def on_security_exception(self: Any, ex: Exception) -> Response: return self.response(403, **{"message": utils.error_msg_from_exception(ex)}) -def context( - slice_id: int | None = None, - dashboard_id: int | None = None, - execution_id: str | UUID | None = None, -) -> Callable[..., Any]: +def context(**ctx_kwargs: int | str | UUID | None) -> Callable[..., Any]: """ Takes arguments and adds them to the global context. This is for logging purposes only and values should not be relied on or mutated @@ -128,7 +124,12 @@ def decorate(f: Callable[..., Any]) -> Callable[..., Any]: def wrapped(*args: Any, **kwargs: Any) -> Any: if not hasattr(g, "context"): g.context = {} - available_context_values = ["slice_id", "dashboard_id", "execution_id"] + available_context_values = [ + "slice_id", + "dashboard_id", + "execution_id", + "report_schedule_id", + ] context_data = { key: val for key, val in kwargs.items() @@ -138,7 +139,7 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: # if values are passed in to decorator directly, add them to context # by overriding values from kwargs for val in available_context_values: - if locals().get(val) is not None: + if ctx_kwargs.get(val) is not None: context_data[val] = locals()[val] g.context.update(context_data) diff --git a/tests/integration_tests/email_tests.py b/tests/integration_tests/email_tests.py index 7c7cc1683089f..140d90a1c6509 100644 --- a/tests/integration_tests/email_tests.py +++ b/tests/integration_tests/email_tests.py @@ -25,7 +25,7 @@ from unittest import mock from superset import app -from superset.utils import core as utils +from superset.reports.notifications import utils from tests.integration_tests.base_tests import SupersetTestCase from .utils import read_fixture diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index 6f8a7ed457628..a78da73fc34eb 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -45,17 +45,17 @@ from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.utils.core import ( + DTTM_ALIAS, + DateColumn, + GenericDataType, + as_list, base_json_conv, cast_to_num, convert_legacy_filters_into_adhoc, create_ssl_cert_file, - DTTM_ALIAS, extract_dataframe_dtypes, format_timedelta, - GenericDataType, get_form_data_token, - as_list, - get_email_address_list, get_stacktrace, json_int_dttm_ser, json_iso_dttm_ser, @@ -63,14 +63,14 @@ merge_extra_form_data, merge_request_params, normalize_dttm_col, - parse_ssl_cert, parse_js_uri_path_item, + parse_ssl_cert, split, validate_json, zlib_compress, zlib_decompress, - DateColumn, ) + from superset.utils.database import get_or_create_db from superset.utils import schema from superset.utils.hashing import md5_sha_from_str @@ -893,16 +893,6 @@ def test_ssl_certificate_file_creation(self): self.assertIn(expected_filename, path) self.assertTrue(os.path.exists(path)) - def test_get_email_address_list(self): - self.assertEqual(get_email_address_list("a@a"), ["a@a"]) - self.assertEqual(get_email_address_list(" a@a "), ["a@a"]) - self.assertEqual(get_email_address_list("a@a\n"), ["a@a"]) - self.assertEqual(get_email_address_list(",a@a;"), ["a@a"]) - self.assertEqual( - get_email_address_list(",a@a; b@b c@c a-c@c; d@d, f@f"), - ["a@a", "b@b", "c@c", "a-c@c", "d@d", "f@f"], - ) - def test_get_form_data_default(self) -> None: with app.test_request_context(): form_data, slc = get_form_data() diff --git a/tests/unit_tests/notifications/email_tests.py b/tests/unit_tests/notifications/email_tests.py index 0af04a333ffa1..80055374639c2 100644 --- a/tests/unit_tests/notifications/email_tests.py +++ b/tests/unit_tests/notifications/email_tests.py @@ -38,14 +38,6 @@ def test_render_description_with_html() -> None: } ), description='

This is a test alert


', - header_data={ - "notification_format": "PNG", - "notification_type": "Alert", - "owners": [1], - "notification_source": None, - "chart_id": None, - "dashboard_id": None, - }, ) email_body = ( EmailNotification( @@ -85,14 +77,6 @@ def test_send_email( } ), description='

This is a test alert


', - header_data={ - "notification_format": "PNG", - "notification_type": "Alert", - "owners": [1], - "notification_source": None, - "chart_id": None, - "dashboard_id": None, - }, ) EmailNotification( recipient=ReportRecipients( @@ -103,16 +87,12 @@ def test_send_email( ).send() logger_mock.info.assert_called_with( - "Report sent to email, notification content is %s", - { - "notification_format": "PNG", - "notification_type": "Alert", - "owners": [1], - "notification_source": None, + "Report sent to email", + extra={ + "execution_id": execution_id, "chart_id": None, "dashboard_id": None, }, - extra={"execution_id": execution_id}, ) # clear logger_mock and g.context logger_mock.reset_mock() @@ -127,14 +107,10 @@ def test_send_email( content=content, ).send() logger_mock.info.assert_called_with( - "Report sent to email, notification content is %s", - { - "notification_format": "PNG", - "notification_type": "Alert", - "owners": [1], - "notification_source": None, + "Report sent to email", + extra={ + "execution_id": None, "chart_id": None, "dashboard_id": None, }, - extra={"execution_id": None}, ) diff --git a/tests/unit_tests/notifications/slack_tests.py b/tests/unit_tests/notifications/slack_tests.py index 6a6b90a2db7b4..4a496e988fc6c 100644 --- a/tests/unit_tests/notifications/slack_tests.py +++ b/tests/unit_tests/notifications/slack_tests.py @@ -43,14 +43,6 @@ def test_send_slack( } ), description='

This is a test alert


', - header_data={ - "notification_format": "PNG", - "notification_type": "Alert", - "owners": [1], - "notification_source": None, - "chart_id": None, - "dashboard_id": None, - }, ) with patch.object( WebClient, "chat_postMessage", return_value=True diff --git a/tests/unit_tests/notifications/utils_tests.py b/tests/unit_tests/notifications/utils_tests.py new file mode 100644 index 0000000000000..b8cf4b2024ddf --- /dev/null +++ b/tests/unit_tests/notifications/utils_tests.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from superset.reports.notifications.utils import get_email_address_list + + +def test_get_email_address_list(): + assert get_email_address_list("a@a") == ["a@a"] + assert get_email_address_list(" a@a ") == ["a@a"] + assert get_email_address_list("a@a\n") == ["a@a"] + assert get_email_address_list(",a@a;") == ["a@a"] + assert get_email_address_list(",a@a; b@b c@c a-c@c; d@d, f@f") == [ + "a@a", + "b@b", + "c@c", + "a-c@c", + "d@d", + "f@f", + ] diff --git a/tests/unit_tests/utils/test_decorators.py b/tests/unit_tests/utils/test_decorators.py index 4d0ad0acc8e9c..48341bedcb8d7 100644 --- a/tests/unit_tests/utils/test_decorators.py +++ b/tests/unit_tests/utils/test_decorators.py @@ -89,7 +89,6 @@ def my_func(response: ResponseValues, *args: Any, **kwargs: Any) -> str: mock.assert_called_once_with(expected_result, 1) -# write a TDD test for a decorator that will log context to a global scope. It will write dashboard_id, slice_id, and execution_id to a global scope. def test_context_decorator() -> None: @decorators.context() def myfunc(*args, **kwargs) -> str: