diff --git a/superset/models/core.py b/superset/models/core.py index 8448c7ba54e49..ebce5fcb3103c 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -31,6 +31,7 @@ from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint from sqlalchemy_utils import EncryptedType +import sqlparse from superset import app, db, db_engine_specs, security_manager, utils from superset.connectors.connector_registry import ConnectorRegistry @@ -690,9 +691,13 @@ def get_quoter(self): return self.get_dialect().identifier_preparer.quote def get_df(self, sql, schema): - sql = sql.strip().strip(';') + sqls = [str(s).strip().strip(';') for s in sqlparse.parse(sql)] eng = self.get_sqla_engine(schema=schema) - df = pd.read_sql_query(sql, eng) + + for i in range(len(sqls) - 1): + eng.execute(sqls[i]) + + df = pd.read_sql_query(sqls[-1], eng) def needs_conversion(df_series): if df_series.empty: diff --git a/tests/model_tests.py b/tests/model_tests.py index 8af104f57c9d0..411fdd5ed38ae 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -106,6 +106,26 @@ def test_grains_dict(self): self.assertEquals(d.get('P1D').function, 'DATE({col})') self.assertEquals(d.get('Time Column').function, '{col}') + def test_single_statement(self): + main_db = self.get_main_database(db.session) + + if main_db.backend == 'mysql': + df = main_db.get_df('SELECT 1', None) + self.assertEquals(df.iat[0, 0], 1) + + df = main_db.get_df('SELECT 1;', None) + self.assertEquals(df.iat[0, 0], 1) + + def test_multi_statement(self): + main_db = self.get_main_database(db.session) + + if main_db.backend == 'mysql': + df = main_db.get_df('USE superset; SELECT 1', None) + self.assertEquals(df.iat[0, 0], 1) + + df = main_db.get_df("USE superset; SELECT ';';", None) + self.assertEquals(df.iat[0, 0], ';') + class SqlaTableModelTestCase(SupersetTestCase):