diff --git a/superset/assets/src/explore/controls.jsx b/superset/assets/src/explore/controls.jsx index bc5dc4689d964..66d563b343e68 100644 --- a/superset/assets/src/explore/controls.jsx +++ b/superset/assets/src/explore/controls.jsx @@ -47,7 +47,6 @@ import { import * as v from './validators'; import { colorPrimary, ALL_COLOR_SCHEMES, spectrums } from '../modules/colors'; import { defaultViewport } from '../modules/geo'; -import MetricOption from '../components/MetricOption'; import ColumnOption from '../components/ColumnOption'; import OptionDescription from '../components/OptionDescription'; import { t } from '../locales'; @@ -116,6 +115,32 @@ const groupByControl = { }, }; +const metrics = { + type: 'MetricsControl', + multi: true, + label: t('Metrics'), + validators: [v.nonEmpty], + default: (c) => { + const metric = mainMetric(c.savedMetrics); + return metric ? [metric] : null; + }, + mapStateToProps: (state) => { + const datasource = state.datasource; + return { + columns: datasource ? datasource.columns : [], + savedMetrics: datasource ? datasource.metrics : [], + datasourceType: datasource && datasource.type, + }; + }, + description: t('One or many metrics to display'), +}; +const metric = { + ...metrics, + multi: false, + label: t('Metric'), + default: props => mainMetric(props.savedMetrics), +}; + const sandboxUrl = ( 'https://github.com/apache/incubator-superset/' + 'blob/master/superset/assets/src/modules/sandbox.js'); @@ -152,6 +177,11 @@ function jsFunctionControl(label, description, extraDescr = null, height = 100, } export const controls = { + + metrics, + + metric, + datasource: { type: 'DatasourceControl', label: t('Datasource'), @@ -169,36 +199,11 @@ export const controls = { description: t('The type of visualization to display'), }, - metrics: { - type: 'MetricsControl', - multi: true, - label: t('Metrics'), - validators: [v.nonEmpty], - default: (c) => { - const metric = mainMetric(c.savedMetrics); - return metric ? [metric] : null; - }, - mapStateToProps: (state) => { - const datasource = state.datasource; - return { - columns: datasource ? datasource.columns : [], - savedMetrics: datasource ? datasource.metrics : [], - datasourceType: datasource && datasource.type, - }; - }, - description: t('One or many metrics to display'), - }, - percent_metrics: { - type: 'SelectControl', + ...metrics, multi: true, label: t('Percentage Metrics'), - valueKey: 'metric_name', - optionRenderer: m => , - valueRenderer: m => , - mapStateToProps: state => ({ - options: (state.datasource) ? state.datasource.metrics : [], - }), + validators: [], description: t('Metrics for which percentage of total are to be displayed'), }, @@ -262,33 +267,11 @@ export const controls = { renderTrigger: true, }, - metric: { - type: 'MetricsControl', - multi: false, - label: t('Metric'), - clearable: false, - validators: [v.nonEmpty], - default: props => mainMetric(props.savedMetrics), - mapStateToProps: state => ({ - columns: state.datasource ? state.datasource.columns : [], - savedMetrics: state.datasource ? state.datasource.metrics : [], - datasourceType: state.datasource && state.datasource.type, - }), - }, - metric_2: { - type: 'SelectControl', + ...metric, label: t('Right Axis Metric'), - default: null, - validators: [v.nonEmpty], clearable: true, description: t('Choose a metric for right axis'), - valueKey: 'metric_name', - optionRenderer: m => , - valueRenderer: m => , - mapStateToProps: state => ({ - options: (state.datasource) ? state.datasource.metrics : [], - }), }, stacked_style: { @@ -508,13 +491,10 @@ export const controls = { }, secondary_metric: { - type: 'SelectControl', + ...metric, label: t('Color Metric'), default: null, description: t('A metric to use for color'), - mapStateToProps: state => ({ - choices: (state.datasource) ? state.datasource.metrics_combo : [], - }), }, select_country: { type: 'SelectControl', @@ -1105,44 +1085,23 @@ export const controls = { }, x: { - type: 'SelectControl', + ...metric, label: t('X Axis'), description: t('Metric assigned to the [X] axis'), default: null, - validators: [v.nonEmpty], - optionRenderer: m => , - valueRenderer: m => , - valueKey: 'metric_name', - mapStateToProps: state => ({ - options: (state.datasource) ? state.datasource.metrics : [], - }), }, y: { - type: 'SelectControl', + ...metric, label: t('Y Axis'), default: null, - validators: [v.nonEmpty], description: t('Metric assigned to the [Y] axis'), - optionRenderer: m => , - valueRenderer: m => , - valueKey: 'metric_name', - mapStateToProps: state => ({ - options: (state.datasource) ? state.datasource.metrics : [], - }), }, size: { - type: 'SelectControl', + ...metric, label: t('Bubble Size'), default: null, - validators: [v.nonEmpty], - optionRenderer: m => , - valueRenderer: m => , - valueKey: 'metric_name', - mapStateToProps: state => ({ - options: (state.datasource) ? state.datasource.metrics : [], - }), }, url: { diff --git a/superset/data/__init__.py b/superset/data/__init__.py index 30b588f020d53..d3d7da86417c0 100644 --- a/superset/data/__init__.py +++ b/superset/data/__init__.py @@ -1168,10 +1168,10 @@ def load_multiformat_time_series_data(): obj.fetch_metadata() tbl = obj - print("Creating some slices") + print("Creating Heatmap charts") for i, col in enumerate(tbl.columns): slice_data = { - "metric": 'count', + "metrics": ['count'], "granularity_sqla": col.column_name, "granularity_sqla": "day", "row_limit": config.get("ROW_LIMIT"), diff --git a/superset/viz.py b/superset/viz.py index f971ba9dffeac..aa502a1a660ae 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -43,6 +43,11 @@ config = app.config stats_logger = config.get('STATS_LOGGER') +METRIC_KEYS = [ + 'metric', 'metrics', 'percent_metrics', 'metric_2', 'secondary_metric', + 'x', 'y', 'size', +] + class BaseViz(object): @@ -66,13 +71,6 @@ def __init__(self, datasource, form_data, force=False): self.query = '' self.token = self.form_data.get( 'token', 'token_' + uuid.uuid4().hex[:8]) - metrics = self.form_data.get('metrics') or [] - self.metrics = [] - for metric in metrics: - if isinstance(metric, dict): - self.metrics.append(metric['label']) - else: - self.metrics.append(metric) self.groupby = self.form_data.get('groupby') or [] self.time_shift = timedelta() @@ -90,6 +88,29 @@ def __init__(self, datasource, form_data, force=False): self._any_cached_dttm = None self._extra_chart_data = None + self.process_metrics() + + def process_metrics(self): + self.metric_dict = {} + fd = self.form_data + for mkey in METRIC_KEYS: + val = fd.get(mkey) + if val: + if not isinstance(val, list): + val = [val] + for o in val: + self.metric_dict[self.get_metric_label(o)] = o + + # Cast to list needed to return serializable object in py3 + self.all_metrics = list(self.metric_dict.values()) + self.metric_labels = list(self.metric_dict.keys()) + + def get_metric_label(self, metric): + if isinstance(metric, string_types): + return metric + if isinstance(metric, dict): + return metric.get('label') + @staticmethod def handle_js_int_overflow(data): for d in data.get('records', dict()): @@ -202,7 +223,7 @@ def query_obj(self): """Building a query object""" form_data = self.form_data gb = form_data.get('groupby') or [] - metrics = form_data.get('metrics') or [] + metrics = self.all_metrics or [] columns = form_data.get('columns') or [] groupby = [] for o in gb + columns: @@ -346,7 +367,7 @@ def cache_key(self, query_obj): and replace them with the use-provided inputs to bounds, which may we time-relative (as in "5 days ago" or "now"). """ - cache_dict = copy.deepcopy(query_obj) + cache_dict = copy.copy(query_obj) for k in ['from_dttm', 'to_dttm']: del cache_dict[k] @@ -520,7 +541,7 @@ def query_obj(self): 'Choose either fields to [Group By] and [Metrics] or ' '[Columns], not both')) - sort_by = fd.get('timeseries_limit_metric') + sort_by = fd.get('timeseries_limit_metric') or [] if fd.get('all_columns'): d['columns'] = fd.get('all_columns') d['groupby'] = [] @@ -535,7 +556,7 @@ def query_obj(self): if 'percent_metrics' in fd: d['metrics'] = d['metrics'] + list(filter( lambda m: m not in d['metrics'], - fd['percent_metrics'], + fd['percent_metrics'] or [], )) d['is_timeseries'] = self.should_be_timeseries() @@ -551,7 +572,8 @@ def get_data(self, df): del df[DTTM_ALIAS] # Sum up and compute percentages for all percent metrics - percent_metrics = fd.get('percent_metrics', []) + percent_metrics = fd.get('percent_metrics') or [] + if len(percent_metrics): percent_metrics = list(filter(lambda m: m in df, percent_metrics)) metric_sums = { @@ -611,10 +633,10 @@ def query_obj(self): def get_data(self, df): fd = self.form_data - values = self.metrics columns = None + values = self.metric_labels if fd.get('groupby'): - values = self.metrics[0] + values = self.metric_labels[0] columns = fd.get('groupby') pt = df.pivot_table( index=DTTM_ALIAS, @@ -780,7 +802,7 @@ def get_data(self, df): data = {} records = df.to_dict('records') - for metric in self.metrics: + for metric in self.metric_labels: data[metric] = { str(obj[DTTM_ALIAS].value / 10**9): obj.get(metric) for obj in records @@ -1109,7 +1131,7 @@ def to_series(self, df, classed='', title_suffix=''): if ( isinstance(series_title, (list, tuple)) and len(series_title) > 1 and - len(self.metrics) == 1): + len(self.metric_labels) == 1): # Removing metric from series name if only one metric series_title = series_title[1:] if title_suffix: @@ -1393,10 +1415,11 @@ class DistributionPieViz(NVD3Viz): is_timeseries = False def get_data(self, df): + metric = self.metric_labels[0] df = df.pivot_table( index=self.groupby, - values=[self.metrics[0]]) - df.sort_values(by=self.metrics[0], ascending=False, inplace=True) + values=[metric]) + df.sort_values(by=metric, ascending=False, inplace=True) df = df.reset_index() df.columns = ['x', 'y'] return df.to_dict(orient='records') @@ -1468,14 +1491,15 @@ def query_obj(self): def get_data(self, df): fd = self.form_data + metrics = self.metric_labels - row = df.groupby(self.groupby).sum()[self.metrics[0]].copy() + row = df.groupby(self.groupby).sum()[metrics[0]].copy() row.sort_values(ascending=False, inplace=True) columns = fd.get('columns') or [] pt = df.pivot_table( index=self.groupby, columns=columns, - values=self.metrics) + values=metrics) if fd.get('contribution'): pt = pt.fillna(0) pt = pt.T @@ -1487,7 +1511,7 @@ def get_data(self, df): continue if isinstance(name, string_types): series_title = name - elif len(self.metrics) > 1: + elif len(metrics) > 1: series_title = ', '.join(name) else: l = [str(s) for s in name[1:]] # noqa: E741 @@ -1664,7 +1688,7 @@ def query_obj(self): def get_data(self, df): fd = self.form_data cols = [fd.get('entity')] - metric = fd.get('metric') + metric = self.metric_labels[0] cols += [metric] ndf = df[cols] df = ndf @@ -1836,7 +1860,7 @@ def get_data(self, df): fd = self.form_data x = fd.get('all_columns_x') y = fd.get('all_columns_y') - v = fd.get('metric') + v = self.metric_labels[0] if x == y: df.columns = ['x', 'y', 'v'] else: