diff --git a/superset/assets/src/explore/controls.jsx b/superset/assets/src/explore/controls.jsx
index bc5dc4689d964..66d563b343e68 100644
--- a/superset/assets/src/explore/controls.jsx
+++ b/superset/assets/src/explore/controls.jsx
@@ -47,7 +47,6 @@ import {
import * as v from './validators';
import { colorPrimary, ALL_COLOR_SCHEMES, spectrums } from '../modules/colors';
import { defaultViewport } from '../modules/geo';
-import MetricOption from '../components/MetricOption';
import ColumnOption from '../components/ColumnOption';
import OptionDescription from '../components/OptionDescription';
import { t } from '../locales';
@@ -116,6 +115,32 @@ const groupByControl = {
},
};
+const metrics = {
+ type: 'MetricsControl',
+ multi: true,
+ label: t('Metrics'),
+ validators: [v.nonEmpty],
+ default: (c) => {
+ const metric = mainMetric(c.savedMetrics);
+ return metric ? [metric] : null;
+ },
+ mapStateToProps: (state) => {
+ const datasource = state.datasource;
+ return {
+ columns: datasource ? datasource.columns : [],
+ savedMetrics: datasource ? datasource.metrics : [],
+ datasourceType: datasource && datasource.type,
+ };
+ },
+ description: t('One or many metrics to display'),
+};
+const metric = {
+ ...metrics,
+ multi: false,
+ label: t('Metric'),
+ default: props => mainMetric(props.savedMetrics),
+};
+
const sandboxUrl = (
'https://github.com/apache/incubator-superset/' +
'blob/master/superset/assets/src/modules/sandbox.js');
@@ -152,6 +177,11 @@ function jsFunctionControl(label, description, extraDescr = null, height = 100,
}
export const controls = {
+
+ metrics,
+
+ metric,
+
datasource: {
type: 'DatasourceControl',
label: t('Datasource'),
@@ -169,36 +199,11 @@ export const controls = {
description: t('The type of visualization to display'),
},
- metrics: {
- type: 'MetricsControl',
- multi: true,
- label: t('Metrics'),
- validators: [v.nonEmpty],
- default: (c) => {
- const metric = mainMetric(c.savedMetrics);
- return metric ? [metric] : null;
- },
- mapStateToProps: (state) => {
- const datasource = state.datasource;
- return {
- columns: datasource ? datasource.columns : [],
- savedMetrics: datasource ? datasource.metrics : [],
- datasourceType: datasource && datasource.type,
- };
- },
- description: t('One or many metrics to display'),
- },
-
percent_metrics: {
- type: 'SelectControl',
+ ...metrics,
multi: true,
label: t('Percentage Metrics'),
- valueKey: 'metric_name',
- optionRenderer: m => ,
- valueRenderer: m => ,
- mapStateToProps: state => ({
- options: (state.datasource) ? state.datasource.metrics : [],
- }),
+ validators: [],
description: t('Metrics for which percentage of total are to be displayed'),
},
@@ -262,33 +267,11 @@ export const controls = {
renderTrigger: true,
},
- metric: {
- type: 'MetricsControl',
- multi: false,
- label: t('Metric'),
- clearable: false,
- validators: [v.nonEmpty],
- default: props => mainMetric(props.savedMetrics),
- mapStateToProps: state => ({
- columns: state.datasource ? state.datasource.columns : [],
- savedMetrics: state.datasource ? state.datasource.metrics : [],
- datasourceType: state.datasource && state.datasource.type,
- }),
- },
-
metric_2: {
- type: 'SelectControl',
+ ...metric,
label: t('Right Axis Metric'),
- default: null,
- validators: [v.nonEmpty],
clearable: true,
description: t('Choose a metric for right axis'),
- valueKey: 'metric_name',
- optionRenderer: m => ,
- valueRenderer: m => ,
- mapStateToProps: state => ({
- options: (state.datasource) ? state.datasource.metrics : [],
- }),
},
stacked_style: {
@@ -508,13 +491,10 @@ export const controls = {
},
secondary_metric: {
- type: 'SelectControl',
+ ...metric,
label: t('Color Metric'),
default: null,
description: t('A metric to use for color'),
- mapStateToProps: state => ({
- choices: (state.datasource) ? state.datasource.metrics_combo : [],
- }),
},
select_country: {
type: 'SelectControl',
@@ -1105,44 +1085,23 @@ export const controls = {
},
x: {
- type: 'SelectControl',
+ ...metric,
label: t('X Axis'),
description: t('Metric assigned to the [X] axis'),
default: null,
- validators: [v.nonEmpty],
- optionRenderer: m => ,
- valueRenderer: m => ,
- valueKey: 'metric_name',
- mapStateToProps: state => ({
- options: (state.datasource) ? state.datasource.metrics : [],
- }),
},
y: {
- type: 'SelectControl',
+ ...metric,
label: t('Y Axis'),
default: null,
- validators: [v.nonEmpty],
description: t('Metric assigned to the [Y] axis'),
- optionRenderer: m => ,
- valueRenderer: m => ,
- valueKey: 'metric_name',
- mapStateToProps: state => ({
- options: (state.datasource) ? state.datasource.metrics : [],
- }),
},
size: {
- type: 'SelectControl',
+ ...metric,
label: t('Bubble Size'),
default: null,
- validators: [v.nonEmpty],
- optionRenderer: m => ,
- valueRenderer: m => ,
- valueKey: 'metric_name',
- mapStateToProps: state => ({
- options: (state.datasource) ? state.datasource.metrics : [],
- }),
},
url: {
diff --git a/superset/data/__init__.py b/superset/data/__init__.py
index 30b588f020d53..d3d7da86417c0 100644
--- a/superset/data/__init__.py
+++ b/superset/data/__init__.py
@@ -1168,10 +1168,10 @@ def load_multiformat_time_series_data():
obj.fetch_metadata()
tbl = obj
- print("Creating some slices")
+ print("Creating Heatmap charts")
for i, col in enumerate(tbl.columns):
slice_data = {
- "metric": 'count',
+ "metrics": ['count'],
"granularity_sqla": col.column_name,
"granularity_sqla": "day",
"row_limit": config.get("ROW_LIMIT"),
diff --git a/superset/viz.py b/superset/viz.py
index f971ba9dffeac..aa502a1a660ae 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -43,6 +43,11 @@
config = app.config
stats_logger = config.get('STATS_LOGGER')
+METRIC_KEYS = [
+ 'metric', 'metrics', 'percent_metrics', 'metric_2', 'secondary_metric',
+ 'x', 'y', 'size',
+]
+
class BaseViz(object):
@@ -66,13 +71,6 @@ def __init__(self, datasource, form_data, force=False):
self.query = ''
self.token = self.form_data.get(
'token', 'token_' + uuid.uuid4().hex[:8])
- metrics = self.form_data.get('metrics') or []
- self.metrics = []
- for metric in metrics:
- if isinstance(metric, dict):
- self.metrics.append(metric['label'])
- else:
- self.metrics.append(metric)
self.groupby = self.form_data.get('groupby') or []
self.time_shift = timedelta()
@@ -90,6 +88,29 @@ def __init__(self, datasource, form_data, force=False):
self._any_cached_dttm = None
self._extra_chart_data = None
+ self.process_metrics()
+
+ def process_metrics(self):
+ self.metric_dict = {}
+ fd = self.form_data
+ for mkey in METRIC_KEYS:
+ val = fd.get(mkey)
+ if val:
+ if not isinstance(val, list):
+ val = [val]
+ for o in val:
+ self.metric_dict[self.get_metric_label(o)] = o
+
+ # Cast to list needed to return serializable object in py3
+ self.all_metrics = list(self.metric_dict.values())
+ self.metric_labels = list(self.metric_dict.keys())
+
+ def get_metric_label(self, metric):
+ if isinstance(metric, string_types):
+ return metric
+ if isinstance(metric, dict):
+ return metric.get('label')
+
@staticmethod
def handle_js_int_overflow(data):
for d in data.get('records', dict()):
@@ -202,7 +223,7 @@ def query_obj(self):
"""Building a query object"""
form_data = self.form_data
gb = form_data.get('groupby') or []
- metrics = form_data.get('metrics') or []
+ metrics = self.all_metrics or []
columns = form_data.get('columns') or []
groupby = []
for o in gb + columns:
@@ -346,7 +367,7 @@ def cache_key(self, query_obj):
and replace them with the use-provided inputs to bounds, which
may we time-relative (as in "5 days ago" or "now").
"""
- cache_dict = copy.deepcopy(query_obj)
+ cache_dict = copy.copy(query_obj)
for k in ['from_dttm', 'to_dttm']:
del cache_dict[k]
@@ -520,7 +541,7 @@ def query_obj(self):
'Choose either fields to [Group By] and [Metrics] or '
'[Columns], not both'))
- sort_by = fd.get('timeseries_limit_metric')
+ sort_by = fd.get('timeseries_limit_metric') or []
if fd.get('all_columns'):
d['columns'] = fd.get('all_columns')
d['groupby'] = []
@@ -535,7 +556,7 @@ def query_obj(self):
if 'percent_metrics' in fd:
d['metrics'] = d['metrics'] + list(filter(
lambda m: m not in d['metrics'],
- fd['percent_metrics'],
+ fd['percent_metrics'] or [],
))
d['is_timeseries'] = self.should_be_timeseries()
@@ -551,7 +572,8 @@ def get_data(self, df):
del df[DTTM_ALIAS]
# Sum up and compute percentages for all percent metrics
- percent_metrics = fd.get('percent_metrics', [])
+ percent_metrics = fd.get('percent_metrics') or []
+
if len(percent_metrics):
percent_metrics = list(filter(lambda m: m in df, percent_metrics))
metric_sums = {
@@ -611,10 +633,10 @@ def query_obj(self):
def get_data(self, df):
fd = self.form_data
- values = self.metrics
columns = None
+ values = self.metric_labels
if fd.get('groupby'):
- values = self.metrics[0]
+ values = self.metric_labels[0]
columns = fd.get('groupby')
pt = df.pivot_table(
index=DTTM_ALIAS,
@@ -780,7 +802,7 @@ def get_data(self, df):
data = {}
records = df.to_dict('records')
- for metric in self.metrics:
+ for metric in self.metric_labels:
data[metric] = {
str(obj[DTTM_ALIAS].value / 10**9): obj.get(metric)
for obj in records
@@ -1109,7 +1131,7 @@ def to_series(self, df, classed='', title_suffix=''):
if (
isinstance(series_title, (list, tuple)) and
len(series_title) > 1 and
- len(self.metrics) == 1):
+ len(self.metric_labels) == 1):
# Removing metric from series name if only one metric
series_title = series_title[1:]
if title_suffix:
@@ -1393,10 +1415,11 @@ class DistributionPieViz(NVD3Viz):
is_timeseries = False
def get_data(self, df):
+ metric = self.metric_labels[0]
df = df.pivot_table(
index=self.groupby,
- values=[self.metrics[0]])
- df.sort_values(by=self.metrics[0], ascending=False, inplace=True)
+ values=[metric])
+ df.sort_values(by=metric, ascending=False, inplace=True)
df = df.reset_index()
df.columns = ['x', 'y']
return df.to_dict(orient='records')
@@ -1468,14 +1491,15 @@ def query_obj(self):
def get_data(self, df):
fd = self.form_data
+ metrics = self.metric_labels
- row = df.groupby(self.groupby).sum()[self.metrics[0]].copy()
+ row = df.groupby(self.groupby).sum()[metrics[0]].copy()
row.sort_values(ascending=False, inplace=True)
columns = fd.get('columns') or []
pt = df.pivot_table(
index=self.groupby,
columns=columns,
- values=self.metrics)
+ values=metrics)
if fd.get('contribution'):
pt = pt.fillna(0)
pt = pt.T
@@ -1487,7 +1511,7 @@ def get_data(self, df):
continue
if isinstance(name, string_types):
series_title = name
- elif len(self.metrics) > 1:
+ elif len(metrics) > 1:
series_title = ', '.join(name)
else:
l = [str(s) for s in name[1:]] # noqa: E741
@@ -1664,7 +1688,7 @@ def query_obj(self):
def get_data(self, df):
fd = self.form_data
cols = [fd.get('entity')]
- metric = fd.get('metric')
+ metric = self.metric_labels[0]
cols += [metric]
ndf = df[cols]
df = ndf
@@ -1836,7 +1860,7 @@ def get_data(self, df):
fd = self.form_data
x = fd.get('all_columns_x')
y = fd.get('all_columns_y')
- v = fd.get('metric')
+ v = self.metric_labels[0]
if x == y:
df.columns = ['x', 'y', 'v']
else: