Skip to content

Commit

Permalink
treating floats like doubles for druid versions lower than 11.0.0 (#5030
Browse files Browse the repository at this point in the history
)
  • Loading branch information
Gabe Lyons authored and mistercrunch committed May 21, 2018
1 parent 9f66dae commit 1c9474b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ ignore-mixin-members=yes
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=numpy,pandas,alembic.op,sqlalchemy,alembic.context,flask_appbuilder.security.sqla.PermissionView.role,flask_appbuilder.Model.metadata,flask_appbuilder.Base.metadata
ignored-modules=numpy,pandas,alembic.op,sqlalchemy,alembic.context,flask_appbuilder.security.sqla.PermissionView.role,flask_appbuilder.Model.metadata,flask_appbuilder.Base.metadata,distutils

# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
Expand Down
37 changes: 27 additions & 10 deletions superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import OrderedDict
from copy import deepcopy
from datetime import datetime, timedelta
from distutils.version import LooseVersion
import json
import logging
from multiprocessing.pool import ThreadPool
Expand Down Expand Up @@ -899,8 +900,8 @@ def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dic
missing_postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj)

@classmethod
def metrics_and_post_aggs(cls, metrics, metrics_dict):
@staticmethod
def metrics_and_post_aggs(metrics, metrics_dict, druid_version=None):
# Separate metrics into those that are aggregations
# and those that are post aggregations
saved_agg_names = set()
Expand All @@ -920,9 +921,13 @@ def metrics_and_post_aggs(cls, metrics, metrics_dict):
for postagg_name in postagg_names:
postagg = metrics_dict[postagg_name]
visited_postaggs.add(postagg_name)
cls.resolve_postagg(
DruidDatasource.resolve_postagg(
postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict)
aggs = cls.get_aggregations(metrics_dict, saved_agg_names, adhoc_agg_configs)
aggs = DruidDatasource.get_aggregations(
metrics_dict,
saved_agg_names,
adhoc_agg_configs,
)
return aggs, post_aggs

def values_for_column(self,
Expand Down Expand Up @@ -997,11 +1002,12 @@ def _add_filter_from_pre_query_data(self, df, dimensions, dim_filter):

@staticmethod
def druid_type_from_adhoc_metric(adhoc_metric):
column_type = adhoc_metric.get('column').get('type').lower()
aggregate = adhoc_metric.get('aggregate').lower()
if (aggregate == 'count'):
column_type = adhoc_metric['column']['type'].lower()
aggregate = adhoc_metric['aggregate'].lower()

if aggregate == 'count':
return 'count'
if (aggregate == 'count_distinct'):
if aggregate == 'count_distinct':
return 'cardinality'
else:
return column_type + aggregate.capitalize()
Expand Down Expand Up @@ -1132,6 +1138,17 @@ def run_query( # noqa / druid
metrics_dict = {m.metric_name: m for m in self.metrics}
columns_dict = {c.column_name: c for c in self.columns}

if (
self.cluster and
LooseVersion(self.cluster.get_druid_version()) < LooseVersion('0.11.0')
):
for metric in metrics:
if (
utils.is_adhoc_metric(metric) and
metric['column']['type'].upper() == 'FLOAT'
):
metric['column']['type'] = 'DOUBLE'

aggregations, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics,
metrics_dict)
Expand Down Expand Up @@ -1187,7 +1204,7 @@ def run_query( # noqa / druid
pre_qry = deepcopy(qry)
if timeseries_limit_metric:
order_by = timeseries_limit_metric
aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
aggs_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs(
[timeseries_limit_metric],
metrics_dict)
if phase == 1:
Expand Down Expand Up @@ -1256,7 +1273,7 @@ def run_query( # noqa / druid

if timeseries_limit_metric:
order_by = timeseries_limit_metric
aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
aggs_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs(
[timeseries_limit_metric],
metrics_dict)
if phase == 1:
Expand Down
4 changes: 2 additions & 2 deletions tests/druid_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def __reduce__(self):
},
]

DruidCluster.get_druid_version = lambda _: '0.9.1'


class DruidTests(SupersetTestCase):

Expand Down Expand Up @@ -114,7 +116,6 @@ def get_cluster(self, PyDruid):

db.session.add(cluster)
cluster.get_datasources = PickableMock(return_value=['test_datasource'])
cluster.get_druid_version = PickableMock(return_value='0.9.1')

return cluster

Expand Down Expand Up @@ -324,7 +325,6 @@ def test_sync_druid_perm(self, PyDruid):
cluster.get_datasources = PickableMock(
return_value=['test_datasource'],
)
cluster.get_druid_version = PickableMock(return_value='0.9.1')

cluster.refresh_datasources()
cluster.datasources[0].merge_flag = True
Expand Down

0 comments on commit 1c9474b

Please sign in to comment.