From f70d301f0d37fe6c4f94cd7d6026a72a56087d34 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Fri, 7 Oct 2016 16:24:39 -0700 Subject: [PATCH] Refactor the explore view (#1252) * Refactor the explore view * Fixing the tests * Addressing comments --- caravel/source_registry.py | 8 ++ caravel/templates/caravel/standalone.html | 2 +- caravel/views.py | 143 ++++++++++++---------- caravel/viz.py | 13 +- run_specific_test.sh | 2 +- tests/base_tests.py | 20 ++- tests/core_tests.py | 35 ++++-- tests/druid_tests.py | 110 +++++++++-------- 8 files changed, 195 insertions(+), 138 deletions(-) diff --git a/caravel/source_registry.py b/caravel/source_registry.py index 04b3f17c28cc6..dc1a170d7c68e 100644 --- a/caravel/source_registry.py +++ b/caravel/source_registry.py @@ -12,3 +12,11 @@ def register_sources(cls, datasource_config): for class_name in class_names: source_class = getattr(module_obj, class_name) cls.sources[source_class.type] = source_class + + @classmethod + def get_datasource(cls, datasource_type, datasource_id, session): + return ( + session.query(cls.sources[datasource_type]) + .filter_by(id=datasource_id) + .one() + ) diff --git a/caravel/templates/caravel/standalone.html b/caravel/templates/caravel/standalone.html index 700492504321f..ac44f9c62b0c8 100644 --- a/caravel/templates/caravel/standalone.html +++ b/caravel/templates/caravel/standalone.html @@ -1,6 +1,6 @@ - {{viz.token}} + {{ viz.token }} diff --git a/caravel/views.py b/caravel/views.py index c9574d29808d9..a355eacda5a42 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -26,7 +26,6 @@ from flask_appbuilder.models.sqla.filters import BaseFilter from sqlalchemy import create_engine -from werkzeug.datastructures import ImmutableMultiDict from werkzeug.routing import BaseConverter from wtforms.validators import ValidationError @@ -244,7 +243,8 @@ def apply(self, query, func): # noqa druid_datasources = [] for perm in perms: match = re.search(r'\(id:(\d+)\)', perm) - druid_datasources.append(match.group(1)) + if match: + druid_datasources.append(match.group(1)) qry = query.filter(self.model.id.in_(druid_datasources)) return qry @@ -672,6 +672,7 @@ class DruidClusterModelView(CaravelModelView, DeleteMixin): # noqa 'broker_port': _("Broker Port"), 'broker_endpoint': _("Broker Endpoint"), } + def pre_add(self, db): utils.merge_perm(sm, 'database_access', db.perm) @@ -699,7 +700,8 @@ class SliceModelView(CaravelModelView, DeleteMixin): # noqa list_columns = [ 'slice_link', 'viz_type', 'datasource_link', 'creator', 'modified'] edit_columns = [ - 'slice_name', 'description', 'viz_type', 'owners', 'dashboards', 'params', 'cache_timeout'] + 'slice_name', 'description', 'viz_type', 'owners', 'dashboards', + 'params', 'cache_timeout'] base_order = ('changed_on', 'desc') description_columns = { 'description': Markup( @@ -1099,61 +1101,80 @@ def approve(self): session.commit() return redirect('/accessrequestsmodelview/list/') + def get_viz( + self, + slice_id=None, + args=None, + datasource_type=None, + datasource_id=None): + if slice_id: + slc = db.session.query(models.Slice).filter_by(id=slice_id).one() + return slc.get_viz() + else: + viz_type = args.get('viz_type', 'table') + datasource = SourceRegistry.get_datasource( + datasource_type, datasource_id, db.session) + viz_obj = viz.viz_types[viz_type](datasource, request.args) + return viz_obj + @has_access - @expose("/explore////") - @expose("/explore///") - @expose("/datasource///") # Legacy url + @expose("/slice//") + def slice(self, slice_id): + viz_obj = self.get_viz(slice_id) + return redirect(viz_obj.get_url(**request.args)) + + @has_access_api + @expose("/explore_json///") + def explore_json(self, datasource_type, datasource_id): + viz_obj = self.get_viz( + datasource_type=datasource_type, + datasource_id=datasource_id, + args=request.args) + if not self.datasource_access(viz_obj.datasource): + return Response( + json.dumps( + {'error': _("You don't have access to this datasource")}), + status=404, + mimetype="application/json") + return Response( + viz_obj.get_json(), + status=200, + mimetype="application/json") + @log_this - def explore(self, datasource_type, datasource_id, slice_id=None): + @has_access + @expose("/explore///") + def explore(self, datasource_type, datasource_id): + viz_type = request.args.get("viz_type") + slice_id = request.args.get('slice_id') + slc = db.session.query(models.Slice).filter_by(id=slice_id).first() + error_redirect = '/slicemodelview/list/' datasource_class = SourceRegistry.sources[datasource_type] datasources = db.session.query(datasource_class).all() datasources = sorted(datasources, key=lambda ds: ds.full_name) - datasource = [ds for ds in datasources if int(datasource_id) == ds.id] - datasource = datasource[0] if datasource else None - if not datasource: + viz_obj = self.get_viz( + datasource_type=datasource_type, + datasource_id=datasource_id, + args=request.args) + + if not viz_obj.datasource: flash(DATASOURCE_MISSING_ERR, "alert") return redirect(error_redirect) - if not self.datasource_access(datasource): + if not self.datasource_access(viz_obj.datasource): flash( - __(get_datasource_access_error_msg(datasource.name)), "danger") + __(get_datasource_access_error_msg(viz_obj.datasource.name)), + "danger") return redirect( 'caravel/request_access/?' 'datasource_type={datasource_type}&' 'datasource_id={datasource_id}&' ''.format(**locals())) - request_args_multi_dict = request.args # MultiDict - - slice_id = slice_id or request_args_multi_dict.get("slice_id") - slc = None - # build viz_obj and get it's params - if slice_id: - slc = db.session.query(models.Slice).filter_by(id=slice_id).first() - try: - viz_obj = slc.get_viz( - url_params_multidict=request_args_multi_dict) - except Exception as e: - logging.exception(e) - flash(utils.error_msg_from_exception(e), "danger") - return redirect(error_redirect) - else: - viz_type = request_args_multi_dict.get("viz_type") - if not viz_type and datasource.default_endpoint: - return redirect(datasource.default_endpoint) - # default to table if no default endpoint and no viz_type - viz_type = viz_type or "table" - # validate viz params - try: - viz_obj = viz.viz_types[viz_type]( - datasource, request_args_multi_dict) - except Exception as e: - logging.exception(e) - flash(utils.error_msg_from_exception(e), "danger") - return redirect(error_redirect) - slice_params_multi_dict = ImmutableMultiDict(viz_obj.orig_form_data) + if not viz_type and viz_obj.datasource.default_endpoint: + return redirect(viz_obj.datasource.default_endpoint) # slc perms slice_add_perm = self.can_access('can_add', 'SliceModelView') @@ -1161,45 +1182,29 @@ def explore(self, datasource_type, datasource_id, slice_id=None): slice_download_perm = self.can_access('can_download', 'SliceModelView') # handle save or overwrite - action = slice_params_multi_dict.get('action') + action = request.args.get('action') if action in ('saveas', 'overwrite'): return self.save_or_overwrite_slice( - slice_params_multi_dict, slc, slice_add_perm, slice_edit_perm) + request.args, slc, slice_add_perm, slice_edit_perm) # handle different endpoints - if slice_params_multi_dict.get("json") == "true": - if config.get("DEBUG"): - # Allows for nice debugger stack traces in debug mode - return Response( - viz_obj.get_json(), - status=200, - mimetype="application/json") - try: - return Response( - viz_obj.get_json(), - status=200, - mimetype="application/json") - except Exception as e: - logging.exception(e) - return json_error_response(utils.error_msg_from_exception(e)) - - elif slice_params_multi_dict.get("csv") == "true": + if request.args.get("csv") == "true": payload = viz_obj.get_csv() return Response( payload, status=200, headers=generate_download_headers("csv"), mimetype="application/csv") + elif request.args.get("standalone") == "true": + return self.render_template("caravel/standalone.html", viz=viz_obj) else: - if slice_params_multi_dict.get("standalone") == "true": - template = "caravel/standalone.html" - else: - template = "caravel/explore.html" return self.render_template( - template, viz=viz_obj, slice=slc, datasources=datasources, + "caravel/explore.html", + viz=viz_obj, slice=slc, datasources=datasources, can_add=slice_add_perm, can_edit=slice_edit_perm, can_download=slice_download_perm, - userid=g.user.get_id() if g.user else '') + userid=g.user.get_id() if g.user else '' + ) @has_access @expose("/exploreV2////") @@ -1705,7 +1710,11 @@ def sqllab_viz(self): data = json.loads(request.args.get('data')) table_name = data.get('datasourceName') viz_type = data.get('chartType') - table = db.session.query(models.SqlaTable).filter_by(table_name=table_name).first() + table = ( + db.session.query(models.SqlaTable) + .filter_by(table_name=table_name) + .first() + ) if not table: table = models.SqlaTable(table_name=table_name) table.database_id = data.get('dbId') diff --git a/caravel/viz.py b/caravel/viz.py index 524a462acf479..ebb4b4f9c69cf 100755 --- a/caravel/viz.py +++ b/caravel/viz.py @@ -104,7 +104,7 @@ def flat_form_fields(cls): def reassignments(self): pass - def get_url(self, for_cache_key=False, **kwargs): + def get_url(self, for_cache_key=False, json_endpoint=False, **kwargs): """Returns the URL for the viz :param for_cache_key: when getting the url as the identifier to hash @@ -140,8 +140,12 @@ def get_url(self, for_cache_key=False, **kwargs): for item in v: od.add(key, item) + base_endpoint = '/caravel/explore' + if json_endpoint: + base_endpoint = '/caravel/explore_json' + href = Href( - '/caravel/explore/{self.datasource.type}/' + '{base_endpoint}/{self.datasource.type}/' '{self.datasource.id}/'.format(**locals())) if for_cache_key and 'force' in od: del od['force'] @@ -373,7 +377,7 @@ def get_data(self): @property def json_endpoint(self): - return self.get_url(json="true") + return self.get_url(json_endpoint=True) @property def cache_key(self): @@ -1261,7 +1265,6 @@ class HistogramViz(BaseViz): } } - def query_obj(self): """Returns the query object for this visualization""" d = super(HistogramViz, self).query_obj() @@ -1272,7 +1275,6 @@ def query_obj(self): d['columns'] = [numeric_column] return d - def get_df(self, query_obj=None): """Returns a pandas dataframe based on the query object""" if not query_obj: @@ -1289,7 +1291,6 @@ def get_df(self, query_obj=None): df = df.fillna(0) return df - def get_data(self): """Returns the chart data""" df = self.get_df() diff --git a/run_specific_test.sh b/run_specific_test.sh index e78ca5ab6aa9c..c63a459d27aea 100755 --- a/run_specific_test.sh +++ b/run_specific_test.sh @@ -5,4 +5,4 @@ export CARAVEL_CONFIG=tests.caravel_test_config set -e caravel/bin/caravel version -v export SOLO_TEST=1 -nosetests tests.core_tests:CoreTests.test_public_user_dashboard_access +nosetests tests.core_tests:CoreTests.test_slice_endpoint diff --git a/tests/base_tests.py b/tests/base_tests.py index 7c1d09aae9464..1e22934e27bf4 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -81,6 +81,12 @@ def __init__(self, *args, **kwargs): utils.init(caravel) + def get_or_create(self, cls, criteria, session): + obj = session.query(cls).filter_by(**criteria).first() + if not obj: + obj = cls(**criteria) + return obj + def login(self, username='admin', password='general'): resp = self.client.post( '/login/', @@ -104,6 +110,15 @@ def get_latest_query(self, sql): session.close() return query + def get_slice(self, slice_name, session): + slc = ( + session.query(models.Slice) + .filter_by(slice_name=slice_name) + .one() + ) + session.expunge_all() + return slc + def get_resp(self, url): """Shortcut to get the parsed results while following redirects""" resp = self.client.get(url, follow_redirects=True) @@ -124,11 +139,6 @@ def get_access_requests(self, username, ds_type, ds_id): def logout(self): self.client.get('/logout/', follow_redirects=True) - def test_welcome(self): - self.login() - resp = self.client.get('/caravel/welcome') - assert 'Welcome' in resp.data.decode('utf-8') - def setup_public_access_for_dashboard(self, table_name): public_role = appbuilder.sm.find_role('Public') perms = db.session.query(ab_models.PermissionView).all() diff --git a/tests/core_tests.py b/tests/core_tests.py index 0f16b3a77d2aa..01c502d82c2f3 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -44,6 +44,33 @@ def setUp(self): def tearDown(self): pass + def test_welcome(self): + self.login() + resp = self.client.get('/caravel/welcome') + assert 'Welcome' in resp.data.decode('utf-8') + + def test_slice_endpoint(self): + self.login(username='admin') + slc = self.get_slice("Girls", db.session) + resp = self.get_resp('/caravel/slice/{}/'.format(slc.id)) + assert 'Time Column' in resp + assert 'List Roles' in resp + + # Testing overrides + resp = self.get_resp( + '/caravel/slice/{}/?standalone=true'.format(slc.id)) + assert 'List Roles' not in resp + + def test_endpoints_for_a_slice(self): + self.login(username='admin') + slc = self.get_slice("Girls", db.session) + + resp = self.get_resp(slc.viz.csv_endpoint) + assert 'Jennifer,' in resp + + resp = self.get_resp(slc.viz.json_endpoint) + assert '"Jennifer"' in resp + def test_admin_only_permissions(self): def assert_admin_permission_in(role_name, assert_func): role = sm.find_role(role_name) @@ -73,13 +100,7 @@ def assert_admin_view_menus_in(role_name, assert_func): def test_save_slice(self): self.login(username='admin') - - slc = ( - db.session.query(models.Slice.id) - .filter_by(slice_name="Energy Sankey") - .first()) - slice_id = slc.id - + slice_id = self.get_slice("Energy Sankey", db.session).id copy_name = "Test Sankey Save" tbl_id = self.table_ids.get('energy_usage') url = ( diff --git a/tests/druid_tests.py b/tests/druid_tests.py index 9d1857be13b1c..b1e70486e7936 100644 --- a/tests/druid_tests.py +++ b/tests/druid_tests.py @@ -14,7 +14,6 @@ from caravel.models import DruidCluster, DruidDatasource from .base_tests import CaravelTestCase -from flask_appbuilder.security.sqla import models as ab_models SEGMENT_METADATA = [{ @@ -118,25 +117,40 @@ def test_client(self, PyDruid): datasource_id)) assert "[test_cluster].[test_datasource]" in resp.data.decode('utf-8') - resp = self.client.get( - '/caravel/explore/druid/{}/?viz_type=table&granularity=one+day&' + url = ( + '/caravel/explore_json/druid/{}/?viz_type=table&granularity=one+day&' 'druid_time_origin=&since=7+days+ago&until=now&row_limit=5000&' 'include_search=false&metrics=count&groupby=name&flt_col_0=dim1&' 'flt_op_0=in&flt_eq_0=&slice_id=&slice_name=&collapsed_fieldsets=&' 'action=&datasource_name=test_datasource&datasource_id={}&' - 'datasource_type=druid&previous_viz_type=table&json=true&' + 'datasource_type=druid&previous_viz_type=table&' 'force=true'.format(datasource_id, datasource_id)) - assert "Canada" in resp.data.decode('utf-8') + resp = self.get_resp(url) + assert "Canada" in resp def test_druid_sync_from_config(self): + CLUSTER_NAME = 'new_druid' self.login() - cluster = DruidCluster(cluster_name="new_druid") - db.session.add(cluster) + cluster = self.get_or_create( + DruidCluster, + {'cluster_name': CLUSTER_NAME}, + db.session) + + db.session.merge(cluster) + db.session.commit() + + ds = ( + db.session.query(DruidDatasource) + .filter_by(datasource_name='test_click') + .first() + ) + if ds: + db.session.delete(ds) db.session.commit() cfg = { "user": "admin", - "cluster": "new_druid", + "cluster": CLUSTER_NAME, "config": { "name": "test_click", "dimensions": ["affiliate_id", "campaign", "first_seen"], @@ -152,30 +166,24 @@ def test_druid_sync_from_config(self): } } } - resp = self.client.post('/caravel/sync_druid/', data=json.dumps(cfg)) - - druid_ds = db.session.query(DruidDatasource).filter_by( - datasource_name="test_click").first() - assert set([c.column_name for c in druid_ds.columns]) == set( - ["affiliate_id", "campaign", "first_seen"]) - assert set([m.metric_name for m in druid_ds.metrics]) == set( - ["count", "sum"]) - assert resp.status_code == 201 - - # datasource exists, not changes required - resp = self.client.post('/caravel/sync_druid/', data=json.dumps(cfg)) - druid_ds = db.session.query(DruidDatasource).filter_by( - datasource_name="test_click").first() - assert set([c.column_name for c in druid_ds.columns]) == set( - ["affiliate_id", "campaign", "first_seen"]) - assert set([m.metric_name for m in druid_ds.metrics]) == set( - ["count", "sum"]) - assert resp.status_code == 201 + def check(): + resp = self.client.post('/caravel/sync_druid/', data=json.dumps(cfg)) + druid_ds = db.session.query(DruidDatasource).filter_by( + datasource_name="test_click").first() + col_names = set([c.column_name for c in druid_ds.columns]) + assert {"affiliate_id", "campaign", "first_seen"} == col_names + metric_names = {m.metric_name for m in druid_ds.metrics} + assert {"count", "sum"} == metric_names + assert resp.status_code == 201 + + check() + # checking twice to make sure a second sync yields the same results + check() # datasource exists, add new metrics and dimentions cfg = { "user": "admin", - "cluster": "new_druid", + "cluster": CLUSTER_NAME, "config": { "name": "test_click", "dimensions": ["affiliate_id", "second_seen"], @@ -200,26 +208,33 @@ def test_druid_sync_from_config(self): assert resp.status_code == 201 def test_filter_druid_datasource(self): - gamma_ds = DruidDatasource( - datasource_name="datasource_for_gamma", - ) - db.session.add(gamma_ds) - no_gamma_ds = DruidDatasource( - datasource_name="datasource_not_for_gamma", - ) - db.session.add(no_gamma_ds) - db.session.commit() + CLUSTER_NAME = 'new_druid' + cluster = self.get_or_create( + DruidCluster, + {'cluster_name': CLUSTER_NAME}, + db.session) + db.session.merge(cluster) + + gamma_ds = self.get_or_create( + DruidDatasource, {'datasource_name': 'datasource_for_gamma'}, + db.session) + gamma_ds.cluster = cluster + db.session.merge(gamma_ds) + + no_gamma_ds = self.get_or_create( + DruidDatasource, {'datasource_name': 'datasource_not_for_gamma'}, + db.session) + no_gamma_ds.cluster = cluster + db.session.merge(no_gamma_ds) + utils.merge_perm(sm, 'datasource_access', gamma_ds.perm) utils.merge_perm(sm, 'datasource_access', no_gamma_ds.perm) + db.session.commit() - gamma_ds_permission_view = ( - db.session.query(ab_models.PermissionView) - .join(ab_models.ViewMenu) - .filter(ab_models.ViewMenu.name == gamma_ds.perm) - .first() - ) - sm.add_permission_role(sm.find_role('Gamma'), gamma_ds_permission_view) + perm = sm.find_permission_view_menu('datasource_access', gamma_ds.perm) + sm.add_permission_role(sm.find_role('Gamma'), perm) + db.session.commit() self.login(username='gamma') url = '/druiddatasourcemodelview/list/' @@ -227,13 +242,6 @@ def test_filter_druid_datasource(self): assert 'datasource_for_gamma' in resp assert 'datasource_not_for_gamma' not in resp - def test_add_filter(self, username='admin'): - # navigate to energy_usage slice with "Electricity,heat" in filter values - data = ( - "/caravel/explore/table/1/?viz_type=table&groupby=source&metric=count&flt_col_1=source&flt_op_1=in&flt_eq_1=%27Electricity%2Cheat%27" - "&userid=1&datasource_name=energy_usage&datasource_id=1&datasource_type=tablerdo_save=saveas") - assert "source" in self.get_resp(data) - if __name__ == '__main__': unittest.main()