Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve embedded data table in text reports #16335

Merged
merged 4 commits into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytest_mock,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,tabulate,typing_extensions,urllib3,werkzeug,wtforms,wtforms_json,yaml
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytest_mock,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,urllib3,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false

Expand Down
15 changes: 6 additions & 9 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,15 +499,6 @@ def send_chart_response(
result_type = result["query_context"].result_type
result_format = result["query_context"].result_format

# Post-process the data so it matches the data presented in the chart.
# This is needed for sending reports based on text charts that do the
# post-processing of data, eg, the pivot table.
if (
result_type == ChartDataResultType.POST_PROCESSED
and result_format == ChartDataResultFormat.CSV
):
result = apply_post_process(result, form_data)

if result_format == ChartDataResultFormat.CSV:
# Verify user has permission to export CSV file
if not security_manager.can_access("can_csv", "Superset"):
Expand All @@ -518,6 +509,12 @@ def send_chart_response(
return CsvResponse(data, headers=generate_download_headers("csv"))

if result_format == ChartDataResultFormat.JSON:
# Post-process the data so it matches the data presented in the chart.
# This is needed for sending reports based on text charts that do the
# post-processing of data, eg, the pivot table.
if result_type == ChartDataResultType.POST_PROCESSED:
result = apply_post_process(result, form_data)

response_data = simplejson.dumps(
{"result": result["queries"]},
default=json_int_dttm_ser,
Expand Down
38 changes: 20 additions & 18 deletions superset/charts/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,13 @@ def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-s
total = df.sum(axis=axis["columns"])
df = df.astype(total.dtypes).div(total, axis=axis["rows"])

if show_rows_total:
# convert to a MultiIndex to simplify logic
if not isinstance(df.columns, pd.MultiIndex):
df.columns = pd.MultiIndex.from_tuples([(str(i),) for i in df.columns])
# convert to a MultiIndex to simplify logic
if not isinstance(df.index, pd.MultiIndex):
df.index = pd.MultiIndex.from_tuples([(str(i),) for i in df.index])
if not isinstance(df.columns, pd.MultiIndex):
df.columns = pd.MultiIndex.from_tuples([(str(i),) for i in df.columns])

if rows and show_rows_total:
# add subtotal for each group and overall total; we start from the
# overall group, and iterate deeper into subgroups
groups = df.columns
Expand All @@ -146,10 +148,6 @@ def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-s
df.insert(int(slice_.stop), subtotal_name, subtotal)

if rows and show_columns_total:
# convert to a MultiIndex to simplify logic
if not isinstance(df.index, pd.MultiIndex):
df.index = pd.MultiIndex.from_tuples([(str(i),) for i in df.index])

# add subtotal for each group and overall total; we start from the
# overall group, and iterate deeper into subgroups
groups = df.index
Expand Down Expand Up @@ -279,24 +277,28 @@ def apply_post_process(
post_processor = post_processors[viz_type]

for query in result["queries"]:
df = pd.read_csv(StringIO(query["data"]))
df = pd.DataFrame.from_dict(query["data"])
processed_df = post_processor(df, form_data)

# flatten column names
query["colnames"] = list(processed_df.columns)
query["indexnames"] = list(processed_df.index)
query["coltypes"] = extract_dataframe_dtypes(processed_df)
query["rowcount"] = len(processed_df.index)

# flatten columns/index so we can encode data as JSON
processed_df.columns = [
" ".join(str(name) for name in column).strip()
if isinstance(column, tuple)
else column
for column in processed_df.columns
]
processed_df.index = [
" ".join(str(name) for name in index).strip()
if isinstance(index, tuple)
else index
for index in processed_df.index
]

buf = StringIO()
processed_df.to_csv(buf)
buf.seek(0)

query["data"] = buf.getvalue()
query["colnames"] = list(processed_df.columns)
query["coltypes"] = extract_dataframe_dtypes(processed_df)
query["rowcount"] = len(processed_df.index)
query["data"] = processed_df.to_dict()

return result
1 change: 1 addition & 0 deletions superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _get_full(
status = payload["status"]
if status != QueryStatus.FAILED:
payload["colnames"] = list(df.columns)
payload["indexnames"] = list(df.index)
payload["coltypes"] = extract_dataframe_dtypes(df)
payload["data"] = query_context.get_data(df)
del payload["df"]
Expand Down
8 changes: 8 additions & 0 deletions superset/reports/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ class ReportScheduleCsvFailedError(CommandException):
message = _("Report Schedule execution failed when generating a csv.")


class ReportScheduleDataFrameFailedError(CommandException):
message = _("Report Schedule execution failed when generating a dataframe.")


class ReportScheduleExecuteUnexpectedError(CommandException):
message = _("Report Schedule execution got an unexpected error.")

Expand Down Expand Up @@ -171,6 +175,10 @@ class ReportScheduleCsvTimeout(CommandException):
message = _("A timeout occurred while generating a csv.")


class ReportScheduleDataFrameTimeout(CommandException):
message = _("A timeout occurred while generating a dataframe.")


class ReportScheduleAlertGracePeriodError(CommandException):
message = _("Alert fired during grace period.")

Expand Down
82 changes: 57 additions & 25 deletions superset/reports/commands/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import json
import logging
from datetime import datetime, timedelta
from io import BytesIO
from typing import Any, List, Optional
from uuid import UUID

Expand Down Expand Up @@ -45,6 +44,8 @@
ReportScheduleAlertGracePeriodError,
ReportScheduleCsvFailedError,
ReportScheduleCsvTimeout,
ReportScheduleDataFrameFailedError,
ReportScheduleDataFrameTimeout,
ReportScheduleExecuteUnexpectedError,
ReportScheduleNotFoundError,
ReportScheduleNotificationError,
Expand All @@ -65,7 +66,7 @@
from superset.reports.notifications.exceptions import NotificationError
from superset.utils.celery import session_scope
from superset.utils.core import ChartDataResultFormat, ChartDataResultType
from superset.utils.csv import get_chart_csv_data
from superset.utils.csv import get_chart_csv_data, get_chart_dataframe
from superset.utils.screenshots import (
BaseScreenshot,
ChartScreenshot,
Expand Down Expand Up @@ -137,17 +138,23 @@ def create_log( # pylint: disable=too-many-arguments
self._session.commit()

def _get_url(
self, user_friendly: bool = False, csv: bool = False, **kwargs: Any
self,
user_friendly: bool = False,
result_format: Optional[ChartDataResultFormat] = None,
**kwargs: Any,
) -> str:
"""
Get the url for this report schedule: chart or dashboard
"""
if self._report_schedule.chart:
if csv:
if result_format in {
ChartDataResultFormat.CSV,
ChartDataResultFormat.JSON,
}:
return get_url_path(
"ChartRestApi.get_data",
pk=self._report_schedule.chart_id,
format=ChartDataResultFormat.CSV.value,
format=result_format.value,
type=ChartDataResultType.POST_PROCESSED.value,
)
return get_url_path(
Expand Down Expand Up @@ -213,28 +220,14 @@ def _get_screenshot(self) -> bytes:
return image_data

def _get_csv_data(self) -> bytes:
url = self._get_url(csv=True)
url = self._get_url(result_format=ChartDataResultFormat.CSV)
auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(
self._get_user()
)

# To load CSV data from the endpoint the chart must have been saved
# with its query context. For charts without saved query context we
# get a screenshot to force the chart to produce and save the query
# context.
if self._report_schedule.chart.query_context is None:
logger.warning("No query context found, taking a screenshot to generate it")
try:
self._get_screenshot()
except (
ReportScheduleScreenshotFailedError,
ReportScheduleScreenshotTimeout,
) as ex:
raise ReportScheduleCsvFailedError(
"Unable to fetch CSV data because the chart has no query context "
"saved, and an error occurred when fetching it via a screenshot. "
"Please try loading the chart and saving it again."
) from ex
self._update_query_context()

try:
logger.info("Getting chart from %s", url)
Expand All @@ -251,11 +244,50 @@ def _get_csv_data(self) -> bytes:

def _get_embedded_data(self) -> pd.DataFrame:
"""
Return data as an HTML table, to embed in the email.
Return data as a Pandas dataframe, to embed in notifications as a table.
"""
url = self._get_url(result_format=ChartDataResultFormat.JSON)
auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(
self._get_user()
)

if self._report_schedule.chart.query_context is None:
logger.warning("No query context found, taking a screenshot to generate it")
self._update_query_context()

try:
logger.info("Getting chart from %s", url)
dataframe = get_chart_dataframe(url, auth_cookies)
except SoftTimeLimitExceeded as ex:
raise ReportScheduleDataFrameTimeout() from ex
except Exception as ex:
raise ReportScheduleDataFrameFailedError(
f"Failed generating dataframe {str(ex)}"
) from ex
if dataframe is None:
raise ReportScheduleCsvFailedError()
return dataframe

def _update_query_context(self) -> None:
"""
Update chart query context.

To load CSV data from the endpoint the chart must have been saved
with its query context. For charts without saved query context we
get a screenshot to force the chart to produce and save the query
context.
"""
buf = BytesIO(self._get_csv_data())
df = pd.read_csv(buf)
return df
try:
self._get_screenshot()
except (
ReportScheduleScreenshotFailedError,
ReportScheduleScreenshotTimeout,
) as ex:
raise ReportScheduleCsvFailedError(
"Unable to fetch data because the chart has no query context "
"saved, and an error occurred when fetching it via a screenshot. "
"Please try loading the chart and saving it again."
) from ex

def _get_notification_content(self) -> NotificationContent:
"""
Expand Down
44 changes: 31 additions & 13 deletions superset/reports/notifications/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
import json
import logging
import textwrap
from dataclasses import dataclass
from email.utils import make_msgid, parseaddr
from typing import Any, Dict, Optional
Expand All @@ -33,6 +34,7 @@
logger = logging.getLogger(__name__)

TABLE_TAGS = ["table", "th", "tr", "td", "thead", "tbody", "tfoot"]
TABLE_ATTRIBUTES = ["colspan", "rowspan", "halign", "border", "class"]


@dataclass
Expand Down Expand Up @@ -79,24 +81,40 @@ def _get_content(self) -> EmailContent:
if self._content.embedded_data is not None:
df = self._content.embedded_data
html_table = bleach.clean(
df.to_html(na_rep="", index=False), tags=TABLE_TAGS
df.to_html(na_rep="", index=True),
tags=TABLE_TAGS,
attributes=TABLE_ATTRIBUTES,
)
else:
html_table = ""

body = __(
"""
<p>%(description)s</p>
<b><a href="%(url)s">Explore in Superset</a></b><p></p>
%(html_table)s
%(img_tag)s
""",
description=description,
url=self._content.url,
html_table=html_table,
img_tag='<img width="1000px" src="cid:{}">'.format(msgid)
call_to_action = __("Explore in Superset")
img_tag = (
f'<img width="1000px" src="cid:{msgid}">'
if self._content.screenshot
else "",
else ""
)
body = textwrap.dedent(
f"""
<html>
<head>
<style type="text/css">
table, th, td {{
border-collapse: collapse;
border-color: rgb(200, 212, 227);
color: rgb(42, 63, 95);
padding: 4px 8px;
}}
</style>
</head>
<body>
<p>{description}</p>
<b><a href="{self._content.url}">{call_to_action}</a></b><p></p>
{html_table}
{img_tag}
</body>
</html>
"""
)
if self._content.screenshot:
image = {msgid: self._content.screenshot}
Expand Down
Loading