Skip to content

Commit

Permalink
feat: improve embedded data table in text reports (apache#16335)
Browse files Browse the repository at this point in the history
* feat: improve HTML table in text reports

* Remove unused import

* Update tests

* Fix test
  • Loading branch information
betodealmeida authored Aug 19, 2021
1 parent d895134 commit 90b67b4
Show file tree
Hide file tree
Showing 11 changed files with 295 additions and 428 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytest_mock,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,tabulate,typing_extensions,urllib3,werkzeug,wtforms,wtforms_json,yaml
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,cron_descriptor,croniter,cryptography,dateutil,deprecation,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pgsanity,pkg_resources,polyline,prison,progress,pyarrow,pyhive,pyparsing,pytest,pytest_mock,pytz,redis,requests,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,urllib3,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false

Expand Down
15 changes: 6 additions & 9 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,15 +499,6 @@ def send_chart_response(
result_type = result["query_context"].result_type
result_format = result["query_context"].result_format

# Post-process the data so it matches the data presented in the chart.
# This is needed for sending reports based on text charts that do the
# post-processing of data, eg, the pivot table.
if (
result_type == ChartDataResultType.POST_PROCESSED
and result_format == ChartDataResultFormat.CSV
):
result = apply_post_process(result, form_data)

if result_format == ChartDataResultFormat.CSV:
# Verify user has permission to export CSV file
if not security_manager.can_access("can_csv", "Superset"):
Expand All @@ -518,6 +509,12 @@ def send_chart_response(
return CsvResponse(data, headers=generate_download_headers("csv"))

if result_format == ChartDataResultFormat.JSON:
# Post-process the data so it matches the data presented in the chart.
# This is needed for sending reports based on text charts that do the
# post-processing of data, eg, the pivot table.
if result_type == ChartDataResultType.POST_PROCESSED:
result = apply_post_process(result, form_data)

response_data = simplejson.dumps(
{"result": result["queries"]},
default=json_int_dttm_ser,
Expand Down
39 changes: 20 additions & 19 deletions superset/charts/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
for these chart types.
"""

from io import StringIO
from typing import Any, Dict, List, Optional, Tuple

import pandas as pd
Expand Down Expand Up @@ -126,11 +125,13 @@ def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-s
total = df.sum(axis=axis["columns"])
df = df.astype(total.dtypes).div(total, axis=axis["rows"])

if show_rows_total:
# convert to a MultiIndex to simplify logic
if not isinstance(df.columns, pd.MultiIndex):
df.columns = pd.MultiIndex.from_tuples([(str(i),) for i in df.columns])
# convert to a MultiIndex to simplify logic
if not isinstance(df.index, pd.MultiIndex):
df.index = pd.MultiIndex.from_tuples([(str(i),) for i in df.index])
if not isinstance(df.columns, pd.MultiIndex):
df.columns = pd.MultiIndex.from_tuples([(str(i),) for i in df.columns])

if show_rows_total:
# add subtotal for each group and overall total; we start from the
# overall group, and iterate deeper into subgroups
groups = df.columns
Expand All @@ -146,10 +147,6 @@ def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-s
df.insert(int(slice_.stop), subtotal_name, subtotal)

if rows and show_columns_total:
# convert to a MultiIndex to simplify logic
if not isinstance(df.index, pd.MultiIndex):
df.index = pd.MultiIndex.from_tuples([(str(i),) for i in df.index])

# add subtotal for each group and overall total; we start from the
# overall group, and iterate deeper into subgroups
groups = df.index
Expand Down Expand Up @@ -279,24 +276,28 @@ def apply_post_process(
post_processor = post_processors[viz_type]

for query in result["queries"]:
df = pd.read_csv(StringIO(query["data"]))
df = pd.DataFrame.from_dict(query["data"])
processed_df = post_processor(df, form_data)

# flatten column names
query["colnames"] = list(processed_df.columns)
query["indexnames"] = list(processed_df.index)
query["coltypes"] = extract_dataframe_dtypes(processed_df)
query["rowcount"] = len(processed_df.index)

# flatten columns/index so we can encode data as JSON
processed_df.columns = [
" ".join(str(name) for name in column).strip()
if isinstance(column, tuple)
else column
for column in processed_df.columns
]
processed_df.index = [
" ".join(str(name) for name in index).strip()
if isinstance(index, tuple)
else index
for index in processed_df.index
]

buf = StringIO()
processed_df.to_csv(buf)
buf.seek(0)

query["data"] = buf.getvalue()
query["colnames"] = list(processed_df.columns)
query["coltypes"] = extract_dataframe_dtypes(processed_df)
query["rowcount"] = len(processed_df.index)
query["data"] = processed_df.to_dict()

return result
1 change: 1 addition & 0 deletions superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _get_full(
status = payload["status"]
if status != QueryStatus.FAILED:
payload["colnames"] = list(df.columns)
payload["indexnames"] = list(df.index)
payload["coltypes"] = extract_dataframe_dtypes(df)
payload["data"] = query_context.get_data(df)
del payload["df"]
Expand Down
8 changes: 8 additions & 0 deletions superset/reports/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ class ReportScheduleCsvFailedError(CommandException):
message = _("Report Schedule execution failed when generating a csv.")


class ReportScheduleDataFrameFailedError(CommandException):
message = _("Report Schedule execution failed when generating a dataframe.")


class ReportScheduleExecuteUnexpectedError(CommandException):
message = _("Report Schedule execution got an unexpected error.")

Expand Down Expand Up @@ -171,6 +175,10 @@ class ReportScheduleCsvTimeout(CommandException):
message = _("A timeout occurred while generating a csv.")


class ReportScheduleDataFrameTimeout(CommandException):
message = _("A timeout occurred while generating a dataframe.")


class ReportScheduleAlertGracePeriodError(CommandException):
message = _("Alert fired during grace period.")

Expand Down
82 changes: 57 additions & 25 deletions superset/reports/commands/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import json
import logging
from datetime import datetime, timedelta
from io import BytesIO
from typing import Any, List, Optional
from uuid import UUID

Expand Down Expand Up @@ -45,6 +44,8 @@
ReportScheduleAlertGracePeriodError,
ReportScheduleCsvFailedError,
ReportScheduleCsvTimeout,
ReportScheduleDataFrameFailedError,
ReportScheduleDataFrameTimeout,
ReportScheduleExecuteUnexpectedError,
ReportScheduleNotFoundError,
ReportScheduleNotificationError,
Expand All @@ -65,7 +66,7 @@
from superset.reports.notifications.exceptions import NotificationError
from superset.utils.celery import session_scope
from superset.utils.core import ChartDataResultFormat, ChartDataResultType
from superset.utils.csv import get_chart_csv_data
from superset.utils.csv import get_chart_csv_data, get_chart_dataframe
from superset.utils.screenshots import (
BaseScreenshot,
ChartScreenshot,
Expand Down Expand Up @@ -137,17 +138,23 @@ def create_log( # pylint: disable=too-many-arguments
self._session.commit()

def _get_url(
self, user_friendly: bool = False, csv: bool = False, **kwargs: Any
self,
user_friendly: bool = False,
result_format: Optional[ChartDataResultFormat] = None,
**kwargs: Any,
) -> str:
"""
Get the url for this report schedule: chart or dashboard
"""
if self._report_schedule.chart:
if csv:
if result_format in {
ChartDataResultFormat.CSV,
ChartDataResultFormat.JSON,
}:
return get_url_path(
"ChartRestApi.get_data",
pk=self._report_schedule.chart_id,
format=ChartDataResultFormat.CSV.value,
format=result_format.value,
type=ChartDataResultType.POST_PROCESSED.value,
)
return get_url_path(
Expand Down Expand Up @@ -213,28 +220,14 @@ def _get_screenshot(self) -> bytes:
return image_data

def _get_csv_data(self) -> bytes:
url = self._get_url(csv=True)
url = self._get_url(result_format=ChartDataResultFormat.CSV)
auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(
self._get_user()
)

# To load CSV data from the endpoint the chart must have been saved
# with its query context. For charts without saved query context we
# get a screenshot to force the chart to produce and save the query
# context.
if self._report_schedule.chart.query_context is None:
logger.warning("No query context found, taking a screenshot to generate it")
try:
self._get_screenshot()
except (
ReportScheduleScreenshotFailedError,
ReportScheduleScreenshotTimeout,
) as ex:
raise ReportScheduleCsvFailedError(
"Unable to fetch CSV data because the chart has no query context "
"saved, and an error occurred when fetching it via a screenshot. "
"Please try loading the chart and saving it again."
) from ex
self._update_query_context()

try:
logger.info("Getting chart from %s", url)
Expand All @@ -251,11 +244,50 @@ def _get_csv_data(self) -> bytes:

def _get_embedded_data(self) -> pd.DataFrame:
"""
Return data as an HTML table, to embed in the email.
Return data as a Pandas dataframe, to embed in notifications as a table.
"""
url = self._get_url(result_format=ChartDataResultFormat.JSON)
auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(
self._get_user()
)

if self._report_schedule.chart.query_context is None:
logger.warning("No query context found, taking a screenshot to generate it")
self._update_query_context()

try:
logger.info("Getting chart from %s", url)
dataframe = get_chart_dataframe(url, auth_cookies)
except SoftTimeLimitExceeded as ex:
raise ReportScheduleDataFrameTimeout() from ex
except Exception as ex:
raise ReportScheduleDataFrameFailedError(
f"Failed generating dataframe {str(ex)}"
) from ex
if dataframe is None:
raise ReportScheduleCsvFailedError()
return dataframe

def _update_query_context(self) -> None:
"""
Update chart query context.
To load CSV data from the endpoint the chart must have been saved
with its query context. For charts without saved query context we
get a screenshot to force the chart to produce and save the query
context.
"""
buf = BytesIO(self._get_csv_data())
df = pd.read_csv(buf)
return df
try:
self._get_screenshot()
except (
ReportScheduleScreenshotFailedError,
ReportScheduleScreenshotTimeout,
) as ex:
raise ReportScheduleCsvFailedError(
"Unable to fetch data because the chart has no query context "
"saved, and an error occurred when fetching it via a screenshot. "
"Please try loading the chart and saving it again."
) from ex

def _get_notification_content(self) -> NotificationContent:
"""
Expand Down
44 changes: 31 additions & 13 deletions superset/reports/notifications/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
import json
import logging
import textwrap
from dataclasses import dataclass
from email.utils import make_msgid, parseaddr
from typing import Any, Dict, Optional
Expand All @@ -33,6 +34,7 @@
logger = logging.getLogger(__name__)

TABLE_TAGS = ["table", "th", "tr", "td", "thead", "tbody", "tfoot"]
TABLE_ATTRIBUTES = ["colspan", "rowspan", "halign", "border", "class"]


@dataclass
Expand Down Expand Up @@ -79,24 +81,40 @@ def _get_content(self) -> EmailContent:
if self._content.embedded_data is not None:
df = self._content.embedded_data
html_table = bleach.clean(
df.to_html(na_rep="", index=False), tags=TABLE_TAGS
df.to_html(na_rep="", index=True),
tags=TABLE_TAGS,
attributes=TABLE_ATTRIBUTES,
)
else:
html_table = ""

body = __(
"""
<p>%(description)s</p>
<b><a href="%(url)s">Explore in Superset</a></b><p></p>
%(html_table)s
%(img_tag)s
""",
description=description,
url=self._content.url,
html_table=html_table,
img_tag='<img width="1000px" src="cid:{}">'.format(msgid)
call_to_action = __("Explore in Superset")
img_tag = (
f'<img width="1000px" src="cid:{msgid}">'
if self._content.screenshot
else "",
else ""
)
body = textwrap.dedent(
f"""
<html>
<head>
<style type="text/css">
table, th, td {{
border-collapse: collapse;
border-color: rgb(200, 212, 227);
color: rgb(42, 63, 95);
padding: 4px 8px;
}}
</style>
</head>
<body>
<p>{description}</p>
<b><a href="{self._content.url}">{call_to_action}</a></b><p></p>
{html_table}
{img_tag}
</body>
</html>
"""
)
if self._content.screenshot:
image = {msgid: self._content.screenshot}
Expand Down
Loading

0 comments on commit 90b67b4

Please sign in to comment.