Skip to content

Commit

Permalink
[sql] Adding lighweight Table class (#9649)
Browse files Browse the repository at this point in the history
Co-authored-by: John Bodley <[email protected]>
  • Loading branch information
john-bodley and John Bodley authored Apr 30, 2020
1 parent f7f60cc commit 3b0f8e9
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 169 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ colorama==0.4.3 # via apache-superset (setup.py), flask-appbuilder
contextlib2==0.6.0.post1 # via apache-superset (setup.py)
croniter==0.3.31 # via apache-superset (setup.py)
cryptography==2.8 # via apache-superset (setup.py)
dataclasses==0.6 # via apache-superset (setup.py)
decorator==4.4.1 # via retry
defusedxml==0.6.0 # via python3-openid
flask-appbuilder==2.3.2 # via apache-superset (setup.py)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dataclasses,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def get_git_sha():
"contextlib2",
"croniter>=0.3.28",
"cryptography>=2.4.2",
"dataclasses<0.7",
"flask>=1.1.0, <2.0.0",
"flask-appbuilder>=2.3.2, <2.4.0",
"flask-caching",
Expand Down
97 changes: 25 additions & 72 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from superset.common.query_context import QueryContext
from superset.connectors.base.models import BaseDatasource
from superset.models.core import Database
from superset.sql_parse import Table
from superset.viz import BaseViz

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -290,26 +291,23 @@ def get_datasource_access_link(self, datasource: "BaseDatasource") -> Optional[s

return conf.get("PERMISSION_INSTRUCTIONS_LINK")

def get_table_access_error_msg(self, tables: List[str]) -> str:
def get_table_access_error_msg(self, tables: Set["Table"]) -> str:
"""
Return the error message for the denied SQL tables.
Note the table names conform to the [[cluster.]schema.]table construct.
:param tables: The list of denied SQL table names
:param tables: The set of denied SQL tables
:returns: The error message
"""
quoted_tables = [f"`{t}`" for t in tables]

quoted_tables = [f"`{table}`" for table in tables]
return f"""You need access to the following tables: {", ".join(quoted_tables)},
`all_database_access` or `all_datasource_access` permission"""

def get_table_access_link(self, tables: List[str]) -> Optional[str]:
def get_table_access_link(self, tables: Set["Table"]) -> Optional[str]:
"""
Return the access link for the denied SQL tables.
Note the table names conform to the [[cluster.]schema.]table construct.
:param tables: The list of denied SQL table names
:param tables: The set of denied SQL tables
:returns: The access URL
"""

Expand All @@ -318,23 +316,19 @@ def get_table_access_link(self, tables: List[str]) -> Optional[str]:
return conf.get("PERMISSION_INSTRUCTIONS_LINK")

def can_access_datasource(
self, database: "Database", table_name: str, schema: Optional[str] = None
) -> bool:
return self._datasource_access_by_name(database, table_name, schema=schema)

def _datasource_access_by_name(
self, database: "Database", table_name: str, schema: Optional[str] = None
self, database: "Database", table: "Table", schema: Optional[str] = None
) -> bool:
"""
Return True if the user can access the SQL table, False otherwise.
:param database: The SQL database
:param table_name: The SQL table name
:param schema: The Superset schema
:param table: The SQL table
:param schema: The fallback SQL schema if not present in the table
:returns: Whether the use can access the SQL table
"""

from superset import db
from superset.connectors.sqla.models import SqlaTable

if self.database_access(database) or self.all_datasource_access():
return True
Expand All @@ -343,74 +337,33 @@ def _datasource_access_by_name(
if schema_perm and self.can_access("schema_access", schema_perm):
return True

datasources = ConnectorRegistry.query_datasources_by_name(
db.session, database, table_name, schema=schema
datasources = SqlaTable.query_datasources_by_name(
db.session, database, table.table, schema=table.schema or schema
)
for datasource in datasources:
if self.can_access("datasource_access", datasource.perm):
return True
return False

def _get_schema_and_table(
self, table_in_query: str, schema: str
) -> Tuple[str, str]:
def rejected_tables(
self, sql: str, database: "Database", schema: str
) -> Set["Table"]:
"""
Return the SQL schema/table tuple associated with the table extracted from the
SQL query.
Note the table name conforms to the [[cluster.]schema.]table construct.
:param table_in_query: The SQL table name
:param schema: The fallback SQL schema if not present in the table name
:returns: The SQL schema/table tuple
"""

table_name_pieces = table_in_query.split(".")
if len(table_name_pieces) == 3:
return tuple(table_name_pieces[1:]) # type: ignore
elif len(table_name_pieces) == 2:
return tuple(table_name_pieces) # type: ignore
return (schema, table_name_pieces[0])

def _datasource_access_by_fullname(
self, database: "Database", table_in_query: str, schema: str
) -> bool:
"""
Return True if the user can access the table extracted from the SQL query, False
otherwise.
Note the table name conforms to the [[cluster.]schema.]table construct.
:param database: The Superset database
:param table_in_query: The SQL table name
:param schema: The fallback SQL schema, i.e., if not present in the table name
:returns: Whether the user can access the SQL table
"""

table_schema, table_name = self._get_schema_and_table(table_in_query, schema)
return self._datasource_access_by_name(
database, table_name, schema=table_schema
)

def rejected_tables(self, sql: str, database: "Database", schema: str) -> List[str]:
"""
Return the list of rejected SQL table names.
Note the rejected table names conform to the [[cluster.]schema.]table construct.
Return the list of rejected SQL tables.
:param sql: The SQL statement
:param database: The SQL database
:param schema: The SQL database schema
:returns: The rejected table names
:returns: The rejected tables
"""

superset_query = sql_parse.ParsedQuery(sql)
query = sql_parse.ParsedQuery(sql)

return [
t
for t in superset_query.tables
if not self._datasource_access_by_fullname(database, t, schema)
]
return {
table
for table in query.tables
if not self.can_access_datasource(database, table, schema)
}

def get_public_role(self) -> Optional[Any]: # Optional[self.role_model]
from superset import conf
Expand Down Expand Up @@ -493,7 +446,7 @@ def schemas_accessible_by_user(
.filter(or_(SqlaTable.perm.in_(perms)))
.distinct()
)
accessible_schemas.update([t.schema for t in tables])
accessible_schemas.update([table.schema for table in tables])

return [s for s in schemas if s in accessible_schemas]

Expand Down
79 changes: 53 additions & 26 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# under the License.
import logging
from typing import List, Optional, Set
from urllib import parse

import sqlparse
from dataclasses import dataclass
from sqlparse.sql import (
Function,
Identifier,
Expand Down Expand Up @@ -57,10 +59,32 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
return None


@dataclass(eq=True, frozen=True)
class Table: # pylint: disable=too-few-public-methods
"""
A fully qualified SQL table conforming to [[catalog.]schema.]table.
"""

table: str
schema: Optional[str] = None
catalog: Optional[str] = None

def __str__(self) -> str:
"""
Return the fully qualified SQL table name.
"""

return ".".join(
parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)


class ParsedQuery:
def __init__(self, sql_statement: str):
self.sql: str = sql_statement
self._table_names: Set[str] = set()
self._tables: Set[Table] = set()
self._alias_names: Set[str] = set()
self._limit: Optional[int] = None

Expand All @@ -70,12 +94,15 @@ def __init__(self, sql_statement: str):
self._limit = _extract_limit_from_query(statement)

@property
def tables(self) -> Set[str]:
if not self._table_names:
def tables(self) -> Set[Table]:
if not self._tables:
for statement in self._parsed:
self.__extract_from_token(statement)
self._table_names = self._table_names - self._alias_names
return self._table_names
self._extract_from_token(statement)

self._tables = {
table for table in self._tables if str(table) not in self._alias_names
}
return self._tables

@property
def limit(self) -> Optional[int]:
Expand Down Expand Up @@ -105,13 +132,13 @@ def get_statements(self) -> List[str]:
return statements

@staticmethod
def __get_full_name(tlist: TokenList) -> Optional[str]:
def _get_table(tlist: TokenList) -> Optional[Table]:
"""
Return the full unquoted table name if valid, i.e., conforms to the following
[[cluster.]schema.]table construct.
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.
:param tlist: The SQL tokens
:returns: The valid full table name
:returns: The table if the name conforms
"""

# Strip the alias if present.
Expand All @@ -127,28 +154,28 @@ def __get_full_name(tlist: TokenList) -> Optional[str]:

if (
len(tokens) in (1, 3, 5)
and all(imt(token, t=[Name, String]) for token in tokens[0::2])
and all(imt(token, t=[Name, String]) for token in tokens[::2])
and all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2])
):
return ".".join([remove_quotes(token.value) for token in tokens[0::2]])
return Table(*[remove_quotes(token.value) for token in tokens[::-2]])

return None

@staticmethod
def __is_identifier(token: Token) -> bool:
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))

def __process_tokenlist(self, token_list: TokenList):
def _process_tokenlist(self, token_list: TokenList):
"""
Add table names to table set
:param token_list: TokenList to be processed
"""
# exclude subselects
if "(" not in str(token_list):
table_name = self.__get_full_name(token_list)
if table_name and not table_name.startswith(CTE_PREFIX):
self._table_names.add(table_name)
table = self._get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return

# store aliases
Expand All @@ -158,7 +185,7 @@ def __process_tokenlist(self, token_list: TokenList):
# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
self.__extract_from_token(token_list)
self._extract_from_token(token_list)

def as_create_table(
self,
Expand All @@ -184,9 +211,9 @@ def as_create_table(
exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}"
return exec_sql

def __extract_from_token(self, token: Token): # pylint: disable=too-many-branches
def _extract_from_token(self, token: Token): # pylint: disable=too-many-branches
"""
Populate self._table_names from token
Populate self._tables from token
:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
Expand All @@ -196,8 +223,8 @@ def __extract_from_token(self, token: Token): # pylint: disable=too-many-branch
table_name_preceding_token = False

for item in token.tokens:
if item.is_group and not self.__is_identifier(item):
self.__extract_from_token(item)
if item.is_group and not self._is_identifier(item):
self._extract_from_token(item)

if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME
Expand All @@ -212,15 +239,15 @@ def __extract_from_token(self, token: Token): # pylint: disable=too-many-branch

if table_name_preceding_token:
if isinstance(item, Identifier):
self.__process_tokenlist(item)
self._process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token2 in item.get_identifiers():
if isinstance(token2, TokenList):
self.__process_tokenlist(token2)
self._process_tokenlist(token2)
elif isinstance(item, IdentifierList):
for token2 in item.tokens:
if not self.__is_identifier(token2):
self.__extract_from_token(item)
if not self._is_identifier(token2):
self._extract_from_token(item)

def set_or_update_query_limit(self, new_limit: int) -> str:
"""Returns the query with the specified limit.
Expand Down
6 changes: 4 additions & 2 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
check_sqlalchemy_uri,
DBSecurityException,
)
from superset.sql_parse import ParsedQuery
from superset.sql_parse import ParsedQuery, Table
from superset.sql_validators import get_validator_by_name
from superset.utils import core as utils, dashboard_import_export
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
Expand Down Expand Up @@ -2083,7 +2083,9 @@ def select_star(self, database_id, table_name, schema=None):
schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
table_name = utils.parse_js_uri_path_item(table_name)
# Check that the user can access the datasource
if not self.appbuilder.sm.can_access_datasource(database, table_name, schema):
if not self.appbuilder.sm.can_access_datasource(
database, Table(table_name, schema), schema
):
stats_logger.incr(
f"deprecated.{self.__class__.__name__}.select_star.permission_denied"
)
Expand Down
3 changes: 2 additions & 1 deletion superset/views/database/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from flask_babel import lazy_gettext as _

from superset.models.core import Database
from superset.sql_parse import Table
from superset.utils.core import parse_js_uri_path_item

logger = logging.getLogger(__name__)
Expand All @@ -45,7 +46,7 @@ def wraps(self, pk: int, table_name: str, schema_name: Optional[str] = None):
return self.response_404()
# Check that the user can access the datasource
if not self.appbuilder.sm.can_access_datasource(
database, table_name_parsed, schema_name_parsed
database, Table(table_name_parsed, schema_name_parsed), schema_name_parsed
):
self.stats_logger.incr(
f"permisssion_denied_{self.__class__.__name__}.select_star"
Expand Down
Loading

0 comments on commit 3b0f8e9

Please sign in to comment.