Skip to content

Commit

Permalink
fix(viz): improve dtype inference logic (apache#12933)
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro committed Feb 4, 2021
1 parent 58392f4 commit 87dce0c
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 40 deletions.
28 changes: 7 additions & 21 deletions superset-frontend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions superset-frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@
"@superset-ui/legacy-preset-chart-big-number": "^0.17.5",
"@superset-ui/legacy-preset-chart-deckgl": "^0.4.1",
"@superset-ui/legacy-preset-chart-nvd3": "^0.17.5",
"@superset-ui/plugin-chart-echarts": "^0.17.5",
"@superset-ui/plugin-chart-table": "^0.17.5",
"@superset-ui/plugin-chart-echarts": "^0.17.6",
"@superset-ui/plugin-chart-table": "^0.17.6",
"@superset-ui/plugin-chart-word-cloud": "^0.17.5",
"@superset-ui/preset-chart-xy": "^0.17.5",
"@vx/responsive": "^0.0.195",
Expand Down
2 changes: 1 addition & 1 deletion superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def get_single_payload(
status = payload["status"]
if status != utils.QueryStatus.FAILED:
payload["colnames"] = list(df.columns)
payload["coltypes"] = utils.serialize_pandas_dtypes(df.dtypes)
payload["coltypes"] = utils.extract_dataframe_dtypes(df)
payload["data"] = self.get_data(df)
del payload["df"]

Expand Down
35 changes: 23 additions & 12 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from flask_appbuilder.security.sqla.models import Role, User
from flask_babel import gettext as __
from flask_babel.speaklater import LazyString
from pandas.api.types import infer_dtype
from sqlalchemy import event, exc, select, Text
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.engine import Connection, Engine
Expand Down Expand Up @@ -1401,19 +1402,29 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
return columns


def serialize_pandas_dtypes(dtypes: List[np.dtype]) -> List[GenericDataType]:
"""Serialize pandas/numpy dtypes to JavaScript types"""
mapping = {
"object": GenericDataType.STRING,
"category": GenericDataType.STRING,
"datetime64[ns]": GenericDataType.TEMPORAL,
"int64": GenericDataType.NUMERIC,
"in32": GenericDataType.NUMERIC,
"float64": GenericDataType.NUMERIC,
"float32": GenericDataType.NUMERIC,
"bool": GenericDataType.BOOLEAN,
def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]:
"""Serialize pandas/numpy dtypes to generic types"""

# omitting string types as those will be the default type
inferred_type_map: Dict[str, GenericDataType] = {
"floating": GenericDataType.NUMERIC,
"integer": GenericDataType.NUMERIC,
"mixed-integer-float": GenericDataType.NUMERIC,
"decimal": GenericDataType.NUMERIC,
"boolean": GenericDataType.BOOLEAN,
"datetime64": GenericDataType.TEMPORAL,
"datetime": GenericDataType.TEMPORAL,
"date": GenericDataType.TEMPORAL,
}
return [mapping.get(str(x), GenericDataType.STRING) for x in dtypes]

generic_types: List[GenericDataType] = []
for column in df.columns:
series = df[column]
inferred_type = infer_dtype(series)
generic_type = inferred_type_map.get(inferred_type, GenericDataType.STRING)
generic_types.append(generic_type)

return generic_types


def indexed(
Expand Down
43 changes: 39 additions & 4 deletions tests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
import json
import os
import re
from typing import Any, Tuple, List
from unittest.mock import Mock, patch
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices

import numpy
import numpy as np
import pandas as pd
import pytest
from flask import Flask, g
import marshmallow
Expand All @@ -44,6 +46,7 @@
convert_legacy_filters_into_adhoc,
create_ssl_cert_file,
format_timedelta,
GenericDataType,
get_form_data_token,
get_iterable,
get_email_address_list,
Expand All @@ -57,6 +60,7 @@
merge_request_params,
parse_ssl_cert,
parse_js_uri_path_item,
extract_dataframe_dtypes,
split,
TimeRangeEndpoint,
validate_json,
Expand Down Expand Up @@ -113,9 +117,9 @@ def test_json_iso_dttm_ser(self):
json_iso_dttm_ser("this is not a date")

def test_base_json_conv(self):
assert isinstance(base_json_conv(numpy.bool_(1)), bool) is True
assert isinstance(base_json_conv(numpy.int64(1)), int) is True
assert isinstance(base_json_conv(numpy.array([1, 2, 3])), list) is True
assert isinstance(base_json_conv(np.bool_(1)), bool) is True
assert isinstance(base_json_conv(np.int64(1)), int) is True
assert isinstance(base_json_conv(np.array([1, 2, 3])), list) is True
assert isinstance(base_json_conv(set([1])), list) is True
assert isinstance(base_json_conv(Decimal("1.0")), float) is True
assert isinstance(base_json_conv(uuid.uuid4()), str) is True
Expand Down Expand Up @@ -1066,3 +1070,34 @@ def test_get_form_data_token(self):
assert get_form_data_token({"token": "token_abcdefg1"}) == "token_abcdefg1"
generated_token = get_form_data_token({})
assert re.match(r"^token_[a-z0-9]{8}$", generated_token) is not None

def test_extract_dataframe_dtypes(self):
cols: Tuple[Tuple[str, GenericDataType, List[Any]], ...] = (
("dt", GenericDataType.TEMPORAL, [date(2021, 2, 4), date(2021, 2, 4)]),
(
"dttm",
GenericDataType.TEMPORAL,
[datetime(2021, 2, 4, 1, 1, 1), datetime(2021, 2, 4, 1, 1, 1)],
),
("str", GenericDataType.STRING, ["foo", "foo"]),
("int", GenericDataType.NUMERIC, [1, 1]),
("float", GenericDataType.NUMERIC, [0.5, 0.5]),
("mixed-int-float", GenericDataType.NUMERIC, [0.5, 1.0]),
("bool", GenericDataType.BOOLEAN, [True, False]),
("mixed-str-int", GenericDataType.STRING, ["abc", 1.0]),
("obj", GenericDataType.STRING, [{"a": 1}, {"a": 1}]),
("dt_null", GenericDataType.TEMPORAL, [None, date(2021, 2, 4)]),
(
"dttm_null",
GenericDataType.TEMPORAL,
[None, datetime(2021, 2, 4, 1, 1, 1)],
),
("str_null", GenericDataType.STRING, [None, "foo"]),
("int_null", GenericDataType.NUMERIC, [None, 1]),
("float_null", GenericDataType.NUMERIC, [None, 0.5]),
("bool_null", GenericDataType.BOOLEAN, [None, False]),
("obj_null", GenericDataType.STRING, [None, {"a": 1}]),
)

df = pd.DataFrame(data={col[0]: col[2] for col in cols})
assert extract_dataframe_dtypes(df) == [col[1] for col in cols]

0 comments on commit 87dce0c

Please sign in to comment.