diff --git a/caravel/source_registry.py b/caravel/source_registry.py
index 04b3f17c28cc6..dc1a170d7c68e 100644
--- a/caravel/source_registry.py
+++ b/caravel/source_registry.py
@@ -12,3 +12,11 @@ def register_sources(cls, datasource_config):
for class_name in class_names:
source_class = getattr(module_obj, class_name)
cls.sources[source_class.type] = source_class
+
+ @classmethod
+ def get_datasource(cls, datasource_type, datasource_id, session):
+ return (
+ session.query(cls.sources[datasource_type])
+ .filter_by(id=datasource_id)
+ .one()
+ )
diff --git a/caravel/templates/caravel/standalone.html b/caravel/templates/caravel/standalone.html
index 700492504321f..ac44f9c62b0c8 100644
--- a/caravel/templates/caravel/standalone.html
+++ b/caravel/templates/caravel/standalone.html
@@ -1,6 +1,6 @@
- {{viz.token}}
+ {{ viz.token }}
diff --git a/caravel/views.py b/caravel/views.py
index c9574d29808d9..a355eacda5a42 100755
--- a/caravel/views.py
+++ b/caravel/views.py
@@ -26,7 +26,6 @@
from flask_appbuilder.models.sqla.filters import BaseFilter
from sqlalchemy import create_engine
-from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.routing import BaseConverter
from wtforms.validators import ValidationError
@@ -244,7 +243,8 @@ def apply(self, query, func): # noqa
druid_datasources = []
for perm in perms:
match = re.search(r'\(id:(\d+)\)', perm)
- druid_datasources.append(match.group(1))
+ if match:
+ druid_datasources.append(match.group(1))
qry = query.filter(self.model.id.in_(druid_datasources))
return qry
@@ -672,6 +672,7 @@ class DruidClusterModelView(CaravelModelView, DeleteMixin): # noqa
'broker_port': _("Broker Port"),
'broker_endpoint': _("Broker Endpoint"),
}
+
def pre_add(self, db):
utils.merge_perm(sm, 'database_access', db.perm)
@@ -699,7 +700,8 @@ class SliceModelView(CaravelModelView, DeleteMixin): # noqa
list_columns = [
'slice_link', 'viz_type', 'datasource_link', 'creator', 'modified']
edit_columns = [
- 'slice_name', 'description', 'viz_type', 'owners', 'dashboards', 'params', 'cache_timeout']
+ 'slice_name', 'description', 'viz_type', 'owners', 'dashboards',
+ 'params', 'cache_timeout']
base_order = ('changed_on', 'desc')
description_columns = {
'description': Markup(
@@ -1099,61 +1101,80 @@ def approve(self):
session.commit()
return redirect('/accessrequestsmodelview/list/')
+ def get_viz(
+ self,
+ slice_id=None,
+ args=None,
+ datasource_type=None,
+ datasource_id=None):
+ if slice_id:
+ slc = db.session.query(models.Slice).filter_by(id=slice_id).one()
+ return slc.get_viz()
+ else:
+ viz_type = args.get('viz_type', 'table')
+ datasource = SourceRegistry.get_datasource(
+ datasource_type, datasource_id, db.session)
+ viz_obj = viz.viz_types[viz_type](datasource, request.args)
+ return viz_obj
+
@has_access
- @expose("/explore////")
- @expose("/explore///")
- @expose("/datasource///") # Legacy url
+ @expose("/slice//")
+ def slice(self, slice_id):
+ viz_obj = self.get_viz(slice_id)
+ return redirect(viz_obj.get_url(**request.args))
+
+ @has_access_api
+ @expose("/explore_json///")
+ def explore_json(self, datasource_type, datasource_id):
+ viz_obj = self.get_viz(
+ datasource_type=datasource_type,
+ datasource_id=datasource_id,
+ args=request.args)
+ if not self.datasource_access(viz_obj.datasource):
+ return Response(
+ json.dumps(
+ {'error': _("You don't have access to this datasource")}),
+ status=404,
+ mimetype="application/json")
+ return Response(
+ viz_obj.get_json(),
+ status=200,
+ mimetype="application/json")
+
@log_this
- def explore(self, datasource_type, datasource_id, slice_id=None):
+ @has_access
+ @expose("/explore///")
+ def explore(self, datasource_type, datasource_id):
+ viz_type = request.args.get("viz_type")
+ slice_id = request.args.get('slice_id')
+ slc = db.session.query(models.Slice).filter_by(id=slice_id).first()
+
error_redirect = '/slicemodelview/list/'
datasource_class = SourceRegistry.sources[datasource_type]
datasources = db.session.query(datasource_class).all()
datasources = sorted(datasources, key=lambda ds: ds.full_name)
- datasource = [ds for ds in datasources if int(datasource_id) == ds.id]
- datasource = datasource[0] if datasource else None
- if not datasource:
+ viz_obj = self.get_viz(
+ datasource_type=datasource_type,
+ datasource_id=datasource_id,
+ args=request.args)
+
+ if not viz_obj.datasource:
flash(DATASOURCE_MISSING_ERR, "alert")
return redirect(error_redirect)
- if not self.datasource_access(datasource):
+ if not self.datasource_access(viz_obj.datasource):
flash(
- __(get_datasource_access_error_msg(datasource.name)), "danger")
+ __(get_datasource_access_error_msg(viz_obj.datasource.name)),
+ "danger")
return redirect(
'caravel/request_access/?'
'datasource_type={datasource_type}&'
'datasource_id={datasource_id}&'
''.format(**locals()))
- request_args_multi_dict = request.args # MultiDict
-
- slice_id = slice_id or request_args_multi_dict.get("slice_id")
- slc = None
- # build viz_obj and get it's params
- if slice_id:
- slc = db.session.query(models.Slice).filter_by(id=slice_id).first()
- try:
- viz_obj = slc.get_viz(
- url_params_multidict=request_args_multi_dict)
- except Exception as e:
- logging.exception(e)
- flash(utils.error_msg_from_exception(e), "danger")
- return redirect(error_redirect)
- else:
- viz_type = request_args_multi_dict.get("viz_type")
- if not viz_type and datasource.default_endpoint:
- return redirect(datasource.default_endpoint)
- # default to table if no default endpoint and no viz_type
- viz_type = viz_type or "table"
- # validate viz params
- try:
- viz_obj = viz.viz_types[viz_type](
- datasource, request_args_multi_dict)
- except Exception as e:
- logging.exception(e)
- flash(utils.error_msg_from_exception(e), "danger")
- return redirect(error_redirect)
- slice_params_multi_dict = ImmutableMultiDict(viz_obj.orig_form_data)
+ if not viz_type and viz_obj.datasource.default_endpoint:
+ return redirect(viz_obj.datasource.default_endpoint)
# slc perms
slice_add_perm = self.can_access('can_add', 'SliceModelView')
@@ -1161,45 +1182,29 @@ def explore(self, datasource_type, datasource_id, slice_id=None):
slice_download_perm = self.can_access('can_download', 'SliceModelView')
# handle save or overwrite
- action = slice_params_multi_dict.get('action')
+ action = request.args.get('action')
if action in ('saveas', 'overwrite'):
return self.save_or_overwrite_slice(
- slice_params_multi_dict, slc, slice_add_perm, slice_edit_perm)
+ request.args, slc, slice_add_perm, slice_edit_perm)
# handle different endpoints
- if slice_params_multi_dict.get("json") == "true":
- if config.get("DEBUG"):
- # Allows for nice debugger stack traces in debug mode
- return Response(
- viz_obj.get_json(),
- status=200,
- mimetype="application/json")
- try:
- return Response(
- viz_obj.get_json(),
- status=200,
- mimetype="application/json")
- except Exception as e:
- logging.exception(e)
- return json_error_response(utils.error_msg_from_exception(e))
-
- elif slice_params_multi_dict.get("csv") == "true":
+ if request.args.get("csv") == "true":
payload = viz_obj.get_csv()
return Response(
payload,
status=200,
headers=generate_download_headers("csv"),
mimetype="application/csv")
+ elif request.args.get("standalone") == "true":
+ return self.render_template("caravel/standalone.html", viz=viz_obj)
else:
- if slice_params_multi_dict.get("standalone") == "true":
- template = "caravel/standalone.html"
- else:
- template = "caravel/explore.html"
return self.render_template(
- template, viz=viz_obj, slice=slc, datasources=datasources,
+ "caravel/explore.html",
+ viz=viz_obj, slice=slc, datasources=datasources,
can_add=slice_add_perm, can_edit=slice_edit_perm,
can_download=slice_download_perm,
- userid=g.user.get_id() if g.user else '')
+ userid=g.user.get_id() if g.user else ''
+ )
@has_access
@expose("/exploreV2////")
@@ -1705,7 +1710,11 @@ def sqllab_viz(self):
data = json.loads(request.args.get('data'))
table_name = data.get('datasourceName')
viz_type = data.get('chartType')
- table = db.session.query(models.SqlaTable).filter_by(table_name=table_name).first()
+ table = (
+ db.session.query(models.SqlaTable)
+ .filter_by(table_name=table_name)
+ .first()
+ )
if not table:
table = models.SqlaTable(table_name=table_name)
table.database_id = data.get('dbId')
diff --git a/caravel/viz.py b/caravel/viz.py
index 524a462acf479..ebb4b4f9c69cf 100755
--- a/caravel/viz.py
+++ b/caravel/viz.py
@@ -104,7 +104,7 @@ def flat_form_fields(cls):
def reassignments(self):
pass
- def get_url(self, for_cache_key=False, **kwargs):
+ def get_url(self, for_cache_key=False, json_endpoint=False, **kwargs):
"""Returns the URL for the viz
:param for_cache_key: when getting the url as the identifier to hash
@@ -140,8 +140,12 @@ def get_url(self, for_cache_key=False, **kwargs):
for item in v:
od.add(key, item)
+ base_endpoint = '/caravel/explore'
+ if json_endpoint:
+ base_endpoint = '/caravel/explore_json'
+
href = Href(
- '/caravel/explore/{self.datasource.type}/'
+ '{base_endpoint}/{self.datasource.type}/'
'{self.datasource.id}/'.format(**locals()))
if for_cache_key and 'force' in od:
del od['force']
@@ -373,7 +377,7 @@ def get_data(self):
@property
def json_endpoint(self):
- return self.get_url(json="true")
+ return self.get_url(json_endpoint=True)
@property
def cache_key(self):
@@ -1261,7 +1265,6 @@ class HistogramViz(BaseViz):
}
}
-
def query_obj(self):
"""Returns the query object for this visualization"""
d = super(HistogramViz, self).query_obj()
@@ -1272,7 +1275,6 @@ def query_obj(self):
d['columns'] = [numeric_column]
return d
-
def get_df(self, query_obj=None):
"""Returns a pandas dataframe based on the query object"""
if not query_obj:
@@ -1289,7 +1291,6 @@ def get_df(self, query_obj=None):
df = df.fillna(0)
return df
-
def get_data(self):
"""Returns the chart data"""
df = self.get_df()
diff --git a/run_specific_test.sh b/run_specific_test.sh
index e78ca5ab6aa9c..c63a459d27aea 100755
--- a/run_specific_test.sh
+++ b/run_specific_test.sh
@@ -5,4 +5,4 @@ export CARAVEL_CONFIG=tests.caravel_test_config
set -e
caravel/bin/caravel version -v
export SOLO_TEST=1
-nosetests tests.core_tests:CoreTests.test_public_user_dashboard_access
+nosetests tests.core_tests:CoreTests.test_slice_endpoint
diff --git a/tests/base_tests.py b/tests/base_tests.py
index 7c1d09aae9464..1e22934e27bf4 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -81,6 +81,12 @@ def __init__(self, *args, **kwargs):
utils.init(caravel)
+ def get_or_create(self, cls, criteria, session):
+ obj = session.query(cls).filter_by(**criteria).first()
+ if not obj:
+ obj = cls(**criteria)
+ return obj
+
def login(self, username='admin', password='general'):
resp = self.client.post(
'/login/',
@@ -104,6 +110,15 @@ def get_latest_query(self, sql):
session.close()
return query
+ def get_slice(self, slice_name, session):
+ slc = (
+ session.query(models.Slice)
+ .filter_by(slice_name=slice_name)
+ .one()
+ )
+ session.expunge_all()
+ return slc
+
def get_resp(self, url):
"""Shortcut to get the parsed results while following redirects"""
resp = self.client.get(url, follow_redirects=True)
@@ -124,11 +139,6 @@ def get_access_requests(self, username, ds_type, ds_id):
def logout(self):
self.client.get('/logout/', follow_redirects=True)
- def test_welcome(self):
- self.login()
- resp = self.client.get('/caravel/welcome')
- assert 'Welcome' in resp.data.decode('utf-8')
-
def setup_public_access_for_dashboard(self, table_name):
public_role = appbuilder.sm.find_role('Public')
perms = db.session.query(ab_models.PermissionView).all()
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 0f16b3a77d2aa..01c502d82c2f3 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -44,6 +44,33 @@ def setUp(self):
def tearDown(self):
pass
+ def test_welcome(self):
+ self.login()
+ resp = self.client.get('/caravel/welcome')
+ assert 'Welcome' in resp.data.decode('utf-8')
+
+ def test_slice_endpoint(self):
+ self.login(username='admin')
+ slc = self.get_slice("Girls", db.session)
+ resp = self.get_resp('/caravel/slice/{}/'.format(slc.id))
+ assert 'Time Column' in resp
+ assert 'List Roles' in resp
+
+ # Testing overrides
+ resp = self.get_resp(
+ '/caravel/slice/{}/?standalone=true'.format(slc.id))
+ assert 'List Roles' not in resp
+
+ def test_endpoints_for_a_slice(self):
+ self.login(username='admin')
+ slc = self.get_slice("Girls", db.session)
+
+ resp = self.get_resp(slc.viz.csv_endpoint)
+ assert 'Jennifer,' in resp
+
+ resp = self.get_resp(slc.viz.json_endpoint)
+ assert '"Jennifer"' in resp
+
def test_admin_only_permissions(self):
def assert_admin_permission_in(role_name, assert_func):
role = sm.find_role(role_name)
@@ -73,13 +100,7 @@ def assert_admin_view_menus_in(role_name, assert_func):
def test_save_slice(self):
self.login(username='admin')
-
- slc = (
- db.session.query(models.Slice.id)
- .filter_by(slice_name="Energy Sankey")
- .first())
- slice_id = slc.id
-
+ slice_id = self.get_slice("Energy Sankey", db.session).id
copy_name = "Test Sankey Save"
tbl_id = self.table_ids.get('energy_usage')
url = (
diff --git a/tests/druid_tests.py b/tests/druid_tests.py
index 9d1857be13b1c..b1e70486e7936 100644
--- a/tests/druid_tests.py
+++ b/tests/druid_tests.py
@@ -14,7 +14,6 @@
from caravel.models import DruidCluster, DruidDatasource
from .base_tests import CaravelTestCase
-from flask_appbuilder.security.sqla import models as ab_models
SEGMENT_METADATA = [{
@@ -118,25 +117,40 @@ def test_client(self, PyDruid):
datasource_id))
assert "[test_cluster].[test_datasource]" in resp.data.decode('utf-8')
- resp = self.client.get(
- '/caravel/explore/druid/{}/?viz_type=table&granularity=one+day&'
+ url = (
+ '/caravel/explore_json/druid/{}/?viz_type=table&granularity=one+day&'
'druid_time_origin=&since=7+days+ago&until=now&row_limit=5000&'
'include_search=false&metrics=count&groupby=name&flt_col_0=dim1&'
'flt_op_0=in&flt_eq_0=&slice_id=&slice_name=&collapsed_fieldsets=&'
'action=&datasource_name=test_datasource&datasource_id={}&'
- 'datasource_type=druid&previous_viz_type=table&json=true&'
+ 'datasource_type=druid&previous_viz_type=table&'
'force=true'.format(datasource_id, datasource_id))
- assert "Canada" in resp.data.decode('utf-8')
+ resp = self.get_resp(url)
+ assert "Canada" in resp
def test_druid_sync_from_config(self):
+ CLUSTER_NAME = 'new_druid'
self.login()
- cluster = DruidCluster(cluster_name="new_druid")
- db.session.add(cluster)
+ cluster = self.get_or_create(
+ DruidCluster,
+ {'cluster_name': CLUSTER_NAME},
+ db.session)
+
+ db.session.merge(cluster)
+ db.session.commit()
+
+ ds = (
+ db.session.query(DruidDatasource)
+ .filter_by(datasource_name='test_click')
+ .first()
+ )
+ if ds:
+ db.session.delete(ds)
db.session.commit()
cfg = {
"user": "admin",
- "cluster": "new_druid",
+ "cluster": CLUSTER_NAME,
"config": {
"name": "test_click",
"dimensions": ["affiliate_id", "campaign", "first_seen"],
@@ -152,30 +166,24 @@ def test_druid_sync_from_config(self):
}
}
}
- resp = self.client.post('/caravel/sync_druid/', data=json.dumps(cfg))
-
- druid_ds = db.session.query(DruidDatasource).filter_by(
- datasource_name="test_click").first()
- assert set([c.column_name for c in druid_ds.columns]) == set(
- ["affiliate_id", "campaign", "first_seen"])
- assert set([m.metric_name for m in druid_ds.metrics]) == set(
- ["count", "sum"])
- assert resp.status_code == 201
-
- # datasource exists, not changes required
- resp = self.client.post('/caravel/sync_druid/', data=json.dumps(cfg))
- druid_ds = db.session.query(DruidDatasource).filter_by(
- datasource_name="test_click").first()
- assert set([c.column_name for c in druid_ds.columns]) == set(
- ["affiliate_id", "campaign", "first_seen"])
- assert set([m.metric_name for m in druid_ds.metrics]) == set(
- ["count", "sum"])
- assert resp.status_code == 201
+ def check():
+ resp = self.client.post('/caravel/sync_druid/', data=json.dumps(cfg))
+ druid_ds = db.session.query(DruidDatasource).filter_by(
+ datasource_name="test_click").first()
+ col_names = set([c.column_name for c in druid_ds.columns])
+ assert {"affiliate_id", "campaign", "first_seen"} == col_names
+ metric_names = {m.metric_name for m in druid_ds.metrics}
+ assert {"count", "sum"} == metric_names
+ assert resp.status_code == 201
+
+ check()
+ # checking twice to make sure a second sync yields the same results
+ check()
# datasource exists, add new metrics and dimentions
cfg = {
"user": "admin",
- "cluster": "new_druid",
+ "cluster": CLUSTER_NAME,
"config": {
"name": "test_click",
"dimensions": ["affiliate_id", "second_seen"],
@@ -200,26 +208,33 @@ def test_druid_sync_from_config(self):
assert resp.status_code == 201
def test_filter_druid_datasource(self):
- gamma_ds = DruidDatasource(
- datasource_name="datasource_for_gamma",
- )
- db.session.add(gamma_ds)
- no_gamma_ds = DruidDatasource(
- datasource_name="datasource_not_for_gamma",
- )
- db.session.add(no_gamma_ds)
- db.session.commit()
+ CLUSTER_NAME = 'new_druid'
+ cluster = self.get_or_create(
+ DruidCluster,
+ {'cluster_name': CLUSTER_NAME},
+ db.session)
+ db.session.merge(cluster)
+
+ gamma_ds = self.get_or_create(
+ DruidDatasource, {'datasource_name': 'datasource_for_gamma'},
+ db.session)
+ gamma_ds.cluster = cluster
+ db.session.merge(gamma_ds)
+
+ no_gamma_ds = self.get_or_create(
+ DruidDatasource, {'datasource_name': 'datasource_not_for_gamma'},
+ db.session)
+ no_gamma_ds.cluster = cluster
+ db.session.merge(no_gamma_ds)
+
utils.merge_perm(sm, 'datasource_access', gamma_ds.perm)
utils.merge_perm(sm, 'datasource_access', no_gamma_ds.perm)
+
db.session.commit()
- gamma_ds_permission_view = (
- db.session.query(ab_models.PermissionView)
- .join(ab_models.ViewMenu)
- .filter(ab_models.ViewMenu.name == gamma_ds.perm)
- .first()
- )
- sm.add_permission_role(sm.find_role('Gamma'), gamma_ds_permission_view)
+ perm = sm.find_permission_view_menu('datasource_access', gamma_ds.perm)
+ sm.add_permission_role(sm.find_role('Gamma'), perm)
+ db.session.commit()
self.login(username='gamma')
url = '/druiddatasourcemodelview/list/'
@@ -227,13 +242,6 @@ def test_filter_druid_datasource(self):
assert 'datasource_for_gamma' in resp
assert 'datasource_not_for_gamma' not in resp
- def test_add_filter(self, username='admin'):
- # navigate to energy_usage slice with "Electricity,heat" in filter values
- data = (
- "/caravel/explore/table/1/?viz_type=table&groupby=source&metric=count&flt_col_1=source&flt_op_1=in&flt_eq_1=%27Electricity%2Cheat%27"
- "&userid=1&datasource_name=energy_usage&datasource_id=1&datasource_type=tablerdo_save=saveas")
- assert "source" in self.get_resp(data)
-
if __name__ == '__main__':
unittest.main()