(
+ state => state.explore?.slice?.slice_id,
+ );
+
+ // state needed for refreshing dashboard
+ const chartList = useSelector(state => Object.keys(state.charts));
+ const dashboardId = useSelector(state => state.dashboardInfo.id);
+
+ const dispatch = useDispatch();
+
+ useEffect(() => {
+ /* Listen for messages from the OAuth2 tab.
+ *
+ * After OAuth2 is successfull the opened tab will send a message before
+ * closing itself. Once we receive the message we can retrigger the
+ * original query in SQL Lab, explore, or in a dashboard.
+ */
+ const redirectUrl = new URL(extra.redirect_uri);
+ const handleMessage = (event: MessageEvent) => {
+ if (
+ event.origin === redirectUrl.origin &&
+ event.data.tabId === extra.tab_id &&
+ event.source === oAuthTab.current
+ ) {
+ if (source === 'sqllab' && query) {
+ dispatch(reRunQuery(query));
+ } else if (source === 'explore' && chartId) {
+ dispatch(triggerQuery(true, chartId));
+ } else if (source === 'dashboard') {
+ dispatch(onRefresh(chartList, true, 0, dashboardId));
+ }
+ }
+ };
+ window.addEventListener('message', handleMessage);
+
+ return () => {
+ window.removeEventListener('message', handleMessage);
+ };
+ }, [
+ source,
+ extra.redirect_uri,
+ extra.tab_id,
+ dispatch,
+ query,
+ chartId,
+ chartList,
+ dashboardId,
+ ]);
+
+ const body = (
+
+ This database uses OAuth2 for authentication. Please click the link above
+ to grant Apache Superset permission to access the data. Your personal
+ access token will be stored encrypted and used only for queries run by
+ you.
+
+ );
+ const subtitle = (
+ <>
+ You need to{' '}
+
+ provide authorization
+ {' '}
+ in order to run this query.
+ >
+ );
+
+ return (
+
+ );
+}
+
+export default OAuth2RedirectMessage;
diff --git a/superset-frontend/src/components/ErrorMessage/types.ts b/superset-frontend/src/components/ErrorMessage/types.ts
index 7c4c3fe94a68e..a27c4aff45da9 100644
--- a/superset-frontend/src/components/ErrorMessage/types.ts
+++ b/superset-frontend/src/components/ErrorMessage/types.ts
@@ -56,6 +56,8 @@ export const ErrorTypeEnum = {
QUERY_SECURITY_ACCESS_ERROR: 'QUERY_SECURITY_ACCESS_ERROR',
MISSING_OWNERSHIP_ERROR: 'MISSING_OWNERSHIP_ERROR',
DASHBOARD_SECURITY_ACCESS_ERROR: 'DASHBOARD_SECURITY_ACCESS_ERROR',
+ OAUTH2_REDIRECT: 'OAUTH2_REDIRECT',
+ OAUTH2_REDIRECT_ERROR: 'OAUTH2_REDIRECT_ERROR',
// Other errors
BACKEND_TIMEOUT_ERROR: 'BACKEND_TIMEOUT_ERROR',
diff --git a/superset-frontend/src/setup/setupErrorMessages.ts b/superset-frontend/src/setup/setupErrorMessages.ts
index 59842f190adae..f393a36f1cce9 100644
--- a/superset-frontend/src/setup/setupErrorMessages.ts
+++ b/superset-frontend/src/setup/setupErrorMessages.ts
@@ -23,6 +23,7 @@ import DatabaseErrorMessage from 'src/components/ErrorMessage/DatabaseErrorMessa
import MarshmallowErrorMessage from 'src/components/ErrorMessage/MarshmallowErrorMessage';
import ParameterErrorMessage from 'src/components/ErrorMessage/ParameterErrorMessage';
import DatasetNotFoundErrorMessage from 'src/components/ErrorMessage/DatasetNotFoundErrorMessage';
+import OAuth2RedirectMessage from 'src/components/ErrorMessage/OAuth2RedirectMessage';
import setupErrorMessagesExtra from './setupErrorMessagesExtra';
@@ -149,5 +150,9 @@ export default function setupErrorMessages() {
ErrorTypeEnum.MARSHMALLOW_ERROR,
MarshmallowErrorMessage,
);
+ errorMessageComponentRegistry.registerValue(
+ ErrorTypeEnum.OAUTH2_REDIRECT,
+ OAuth2RedirectMessage,
+ );
setupErrorMessagesExtra();
}
diff --git a/superset/commands/chart/data/get_data_command.py b/superset/commands/chart/data/get_data_command.py
index 971c343cba4e8..ad53a03f285d0 100644
--- a/superset/commands/chart/data/get_data_command.py
+++ b/superset/commands/chart/data/get_data_command.py
@@ -48,7 +48,6 @@ def run(self, **kwargs: Any) -> dict[str, Any]:
except CacheLoadError as ex:
raise ChartDataCacheLoadError(ex.message) from ex
- # TODO: QueryContext should support SIP-40 style errors
for query in payload["queries"]:
if query.get("error"):
raise ChartDataQueryFailedError(
diff --git a/superset/config.py b/superset/config.py
index 197e4bac4296e..b95505451f986 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -1392,6 +1392,24 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
# one here.
TEST_DATABASE_CONNECTION_TIMEOUT = timedelta(seconds=30)
+# Details needed for databases that allows user to authenticate using personal
+# OAuth2 tokens. See https://github.com/apache/superset/issues/20300 for more
+# information
+DATABASE_OAUTH2_CREDENTIALS = {
+ # "Google Sheets": {
+ # "CLIENT_ID": "XXX.apps.googleusercontent.com",
+ # "CLIENT_SECRET": "GOCSPX-YYY",
+ # },
+}
+# OAuth2 state is encoded in a JWT using the alogorithm below.
+DATABASE_OAUTH2_JWT_ALGORITHM = "HS256"
+# By default the redirect URI points to /api/v1/database/oauth2/ and doesn't have to be
+# specified. If you're running multiple Superset instances you might want to have a
+# proxy handling the redirects, since redirect URIs need to be registered in the OAuth2
+# applications. In that case, the proxy can forward the request to the correct instance
+# by looking at the `default_redirect_uri` attribute in the OAuth2 state object.
+# DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/"
+
# Enable/disable CSP warning
CONTENT_SECURITY_POLICY_WARNING = True
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index dd9334d9d06ec..0dfd830b64ab4 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -87,6 +87,8 @@
DatasetInvalidPermissionEvaluationException,
QueryClauseValidationException,
QueryObjectValidationError,
+ SupersetErrorException,
+ SupersetErrorsException,
SupersetGenericDBErrorException,
SupersetSecurityException,
)
@@ -1744,6 +1746,11 @@ def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None:
try:
df = self.database.get_df(sql, self.schema, mutator=assign_column_label)
+ except (SupersetErrorException, SupersetErrorsException) as ex:
+ # SupersetError(s) exception should not be captured; instead, they should
+ # bubble up to the Flask error handler so they are returned as proper SIP-40
+ # errors. This is particularly important for database OAuth2, see SIP-85.
+ raise ex
except Exception as ex: # pylint: disable=broad-except
df = pd.DataFrame()
status = QueryStatus.FAILED
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 58a90e6ecaed3..d0922e40f3c4f 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -145,7 +145,7 @@ def get_columns_description(
cursor = conn.cursor()
query = database.apply_limit_to_sql(query, limit=1)
cursor.execute(query)
- db_engine_spec.execute(cursor, query)
+ db_engine_spec.execute(cursor, query, database.id)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
return result_set.columns
diff --git a/superset/databases/api.py b/superset/databases/api.py
index 1e44a52106f9f..770be61b09ad9 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -17,19 +17,27 @@
# pylint: disable=too-many-lines
import json
import logging
-from datetime import datetime
+from datetime import datetime, timedelta
from io import BytesIO
from typing import Any, cast, Optional
from zipfile import is_zipfile, ZipFile
+import jwt
from deprecation import deprecated
-from flask import request, Response, send_file
+from flask import (
+ current_app,
+ make_response,
+ render_template,
+ request,
+ Response,
+ send_file,
+)
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from marshmallow import ValidationError
from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
-from superset import app, event_logger
+from superset import app, db, event_logger
from superset.commands.database.create import CreateDatabaseCommand
from superset.commands.database.delete import DeleteDatabaseCommand
from superset.commands.database.exceptions import (
@@ -88,10 +96,11 @@
)
from superset.databases.utils import get_table_metadata
from superset.db_engine_specs import get_available_engine_specs
+from superset.db_engine_specs.base import OAuth2State
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
-from superset.exceptions import SupersetErrorsException, SupersetException
+from superset.exceptions import OAuth2Error, SupersetErrorsException, SupersetException
from superset.extensions import security_manager
-from superset.models.core import Database
+from superset.models.core import Database, DatabaseUserOAuth2Tokens
from superset.superset_typing import FlaskResponse
from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item
from superset.utils.ssh_tunnel import mask_password_info
@@ -106,6 +115,7 @@
logger = logging.getLogger(__name__)
+# pylint: disable=too-many-public-methods
class DatabaseRestApi(BaseSupersetModelRestApi):
datamodel = SQLAInterface(Database)
@@ -127,7 +137,9 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"delete_ssh_tunnel",
"schemas_access_for_file_upload",
"get_connection",
+ "oauth2",
}
+
resource_name = "database"
class_permission_name = "Database"
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
@@ -1049,6 +1061,79 @@ def validate_sql(self, pk: int) -> FlaskResponse:
except DatabaseNotFoundError:
return self.response_404()
+ @expose("/oauth2/", methods=["GET"])
+ @event_logger.log_this_with_context(
+ action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.oauth2",
+ log_to_statsd=True,
+ )
+ def oauth2(self) -> FlaskResponse:
+ """
+ ---
+ get:
+ summary: >-
+ Receive personal access tokens from OAuth2
+ description: ->
+ Receive and store personal access tokens from OAuth for user-level
+ authorization
+ parameters:
+ - in: query
+ name: state
+ - in: query
+ name: code
+ - in: query
+ name: scope
+ - in: query
+ name: error
+ responses:
+ 200:
+ description: A dummy self-closing HTML page
+ content:
+ text/html:
+ schema:
+ type: string
+ 400:
+ $ref: '#/components/responses/400'
+ 500:
+ $ref: '#/components/responses/500'
+ """
+ parameters = request.args.to_dict()
+ if "error" in parameters:
+ raise OAuth2Error(parameters["error"])
+
+ state = cast(
+ OAuth2State,
+ jwt.decode(
+ parameters["state"].replace("%2E", "."),
+ current_app.config["SECRET_KEY"],
+ algorithms=[current_app.config["DATABASE_OAUTH2_JWT_ALGORITHM"]],
+ ),
+ )
+
+ # exchange code for access/refresh tokens
+ database = db.session.query(Database).filter_by(id=state["database_id"]).one()
+ token_response = database.db_engine_spec.get_oauth2_token(
+ parameters["code"],
+ state,
+ )
+
+ # store tokens
+ token = DatabaseUserOAuth2Tokens(
+ user_id=state["user_id"],
+ database_id=database.id,
+ access_token=token_response["access_token"],
+ access_token_expiration=datetime.now()
+ + timedelta(seconds=token_response["expires_in"]),
+ refresh_token=token_response.get("refresh_token"),
+ )
+ db.session.add(token)
+ db.session.commit()
+
+ # return blank page that closes itself
+ return make_response(
+ render_template("superset/oauth2.html", tab_id=state["tab_id"]),
+ 200,
+ )
+
@expose("/export/", methods=("GET",))
@protect()
@safe
diff --git a/superset/db_engine_specs/README.md b/superset/db_engine_specs/README.md
index ee4c4ce9e5b25..40d9d3bbf4ccc 100644
--- a/superset/db_engine_specs/README.md
+++ b/superset/db_engine_specs/README.md
@@ -529,6 +529,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
url: URL,
impersonate_user: bool,
username: str | None,
+ access_token: str | None,
) -> URL:
if impersonate_user and username is not None:
user = security_manager.find_user(username=username)
@@ -542,6 +543,70 @@ The method `get_url_for_impersonation` updates the SQLAlchemy URI before every q
Alternatively, it's also possible to impersonate users by implementing the `update_impersonation_config`. This is a class method which modifies `connect_args` in place. You can use either method, and ideally they [should be consolidated in a single one](https://github.com/apache/superset/issues/24910).
+### OAuth2
+
+Support for authenticating to a database using personal OAuth2 access tokens was introduced in [SIP-85](https://github.com/apache/superset/issues/20300). The Google Sheets DB engine spec is the reference implementation.
+
+To add support for OAuth2 to a DB engine spec, the following attribute and methods are needed:
+
+```python
+class BaseEngineSpec:
+
+ oauth2_exception = OAuth2RedirectError
+
+ @staticmethod
+ def is_oauth2_enabled() -> bool:
+ return False
+
+ @staticmethod
+ def get_oauth2_authorization_uri(state: OAuth2State) -> str:
+ raise NotImplementedError()
+
+ @staticmethod
+ def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse:
+ raise NotImplementedError()
+
+ @staticmethod
+ def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse:
+ raise NotImplementedError()
+```
+
+The `oauth2_exception` is an exception that is raised by `cursor.execute` when OAuth2 is needed. This will start the OAuth2 dance when `BaseEngineSpec.execute` is called, by returning the custom error `OAUTH2_REDIRECT` to the frontend. If the database driver doesn't have a specific exception, it might be necessary to overload the `execute` method in the DB engine spec, so that the `BaseEngineSpec.start_oauth2_dance` method gets called whenever OAuth2 is needed.
+
+The first method, `is_oauth2_enabled`, is used to inform if the database supports OAuth2. This can be dynamic; for example, the Google Sheets DB engine spec checks if the Superset configuration has the necessary section:
+
+```python
+from flask import current_app
+
+
+class GSheetsEngineSpec(ShillelaghEngineSpec):
+ @staticmethod
+ def is_oauth2_enabled() -> bool:
+ return "Google Sheets" in current_app.config["DATABASE_OAUTH2_CREDENTIALS"]
+```
+
+Where the configuration for OAuth2 would look like this:
+
+```python
+# superset_config.py
+DATABASE_OAUTH2_CREDENTIALS = {
+ "Google Sheets": {
+ "CLIENT_ID": "XXX.apps.googleusercontent.com",
+ "CLIENT_SECRET": "GOCSPX-YYY",
+ },
+}
+DATABASE_OAUTH2_JWT_ALGORITHM = "HS256"
+DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/"
+```
+
+The second method, `get_oauth2_authorization_uri`, is responsible for building the URL where the user is sent to initiate OAuth2. This method receives a `state`. The state is an encoded JWT that is passed to the OAuth2 provider, and is received unmodified when the user is redirected back to Superset. The default state contains the user ID and the database ID, so that Superset can know where to store the received OAuth2 tokens.
+
+Additionally, the state also contains a `tab_id`, which is a random UUID4 used as a shared secret for communication between browser tabs. When OAuth2 starts, Superset will open a new browser tab, where the user will grant permissions to Superset. When authentication is complete and successfull this opened tab will send a message to the original tab, so that the original query can be re-run. The `tab_id` is sent by the opened tab and verified by the original tab to prevent malicious messages from other sites. As an additional security measure the origin of the message should match the OAuth2 redirect URL.
+
+State also contains a `defaul_redirect_uri`, which is the enpoint in Supeset that receives the tokens from the OAuth2 provider (`/api/v1/database/oauth2/`). The redirect URL can be overwritten in the config file via the `DATABASE_OAUTH2_REDIRECT_URI` parameter. This might be useful where you have multiple Superset instances. Since the OAuth2 provider requires the redirect URL to be registered a priori, it might be easier (or needed) to register a single URL for a proxy service; the proxy service can then inspect the JWT and redirect the request to `defaul_redirect_uri`.
+
+Finally, `get_oauth2_token` and `get_oauth2_fresh_token` are used to actually retrieve a token and refresh an expired token, respectively.
+
### File upload
When a DB engine spec supports file upload it declares so via the `supports_file_upload` class attribute. The base class implementation is very generic and should work for any database that has support for `CREATE TABLE`. It leverages Pandas and the [`df_to_sql`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html) method.
@@ -615,7 +680,7 @@ SELECT * FROM my_table
The table `my_table` should live in the `dev` schema. In order to do that, it's necessary to modify the SQLAlchemy URI before running the query. Since different databases have different ways of doing that, this functionality is implemented via the `adjust_engine_params` class method. The method receives the SQLAlchemy URI and `connect_args`, as well as the schema in which the query should run. It then returns a potentially modified URI and `connect_args` to ensure that the query runs in the specified schema.
-When a DB engine specs implements `adjust_engine_params` it should have the class attribute `supports_dynamic_schema` set to true. This is critical for security, since **it allows Superset to know to which schema any unqualified table names belong to**. For example, in the query above, if the database supports dynamic schema, Superset would check to see if the user running the query has access to `dev.my_table`. On the other hand, if the database doesn't support dynamic schema, Superset would sue the default database schema instead of `dev`.
+When a DB engine specs implements `adjust_engine_params` it should have the class attribute `supports_dynamic_schema` set to true. This is critical for security, since **it allows Superset to know to which schema any unqualified table names belong to**. For example, in the query above, if the database supports dynamic schema, Superset would check to see if the user running the query has access to `dev.my_table`. On the other hand, if the database doesn't support dynamic schema, Superset would use the default database schema instead of `dev`.
Implementing this method is also important for usability. When the method is not implemented selecting the schema in SQL Lab has no effect on the schema in which the query runs, resulting in a confusing results when using unqualified table names.
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index e8790bdcd4f77..d380bfe958602 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -33,13 +33,14 @@
TypedDict,
Union,
)
+from uuid import uuid4
import pandas as pd
import sqlparse
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from deprecation import deprecated
-from flask import current_app
+from flask import current_app, g, url_for
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __, lazy_gettext as _
from marshmallow import fields, Schema
@@ -59,6 +60,7 @@
from superset.constants import TimeGrain as TimeGrainConstants
from superset.databases.utils import make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.exceptions import OAuth2RedirectError
from superset.sql_parse import ParsedQuery, SQLScript, Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
from superset.utils import core as utils
@@ -170,6 +172,31 @@ class MetricType(TypedDict, total=False):
extra: str | None
+class OAuth2TokenResponse(TypedDict, total=False):
+ """
+ Type for an OAuth2 response when exchanging or refreshing tokens.
+ """
+
+ access_token: str
+ expires_in: int
+ scope: str
+ token_type: str
+
+ # only present when exchanging code for refresh/access tokens
+ refresh_token: str
+
+
+class OAuth2State(TypedDict):
+ """
+ Type for the state passed during OAuth2.
+ """
+
+ database_id: int
+ user_id: int
+ default_redirect_uri: str
+ tab_id: str
+
+
class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""Abstract class for database engine specific configurations
@@ -397,6 +424,79 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# Can the catalog be changed on a per-query basis?
supports_dynamic_catalog = False
+ # Driver-specific exception that should be mapped to OAuth2RedirectError
+ oauth2_exception = OAuth2RedirectError
+
+ @staticmethod
+ def is_oauth2_enabled() -> bool:
+ return False
+
+ @classmethod
+ def start_oauth2_dance(cls, database_id: int) -> None:
+ """
+ Start the OAuth2 dance.
+
+ This method will raise a custom exception that is captured by the frontend to
+ start the OAuth2 authentication. The frontend will open a new tab where the user
+ can authorize Superset to access the database. Once the user has authorized, the
+ tab sends a message to the original tab informing that authorization was
+ successfull (or not), and then closes. The original tab will automatically
+ re-run the query after authorization.
+ """
+ tab_id = str(uuid4())
+ default_redirect_uri = url_for("DatabaseRestApi.oauth2", _external=True)
+ redirect_uri = current_app.config.get(
+ "DATABASE_OAUTH2_REDIRECT_URI",
+ default_redirect_uri,
+ )
+
+ # The state is passed to the OAuth2 provider, and sent back to Superset after
+ # the user authorizes the access. The redirect endpoint in Superset can then
+ # inspect the state to figure out to which user/database the access token
+ # belongs to.
+ state: OAuth2State = {
+ # Database ID and user ID are the primary key associated with the token.
+ "database_id": database_id,
+ "user_id": g.user.id,
+ # In multi-instance deployments there might be a single proxy handling
+ # redirects, with a custom `DATABASE_OAUTH2_REDIRECT_URI`. Since the OAuth2
+ # application requires every redirect URL to be registered a priori, this
+ # allows OAuth2 to be used where new instances are being constantly
+ # deployed. The proxy can extract `default_redirect_uri` from the state and
+ # then forward the token to the instance that initiated the authentication.
+ "default_redirect_uri": default_redirect_uri,
+ # When OAuth2 is complete the browser tab where OAuth2 happened will send a
+ # message to the original browser tab informing that the process was
+ # successfull. To allow cross-tab commmunication in a safe way we assign a
+ # UUID to the original tab, and the second tab will use it when sending the
+ # message.
+ "tab_id": tab_id,
+ }
+ oauth_url = cls.get_oauth2_authorization_uri(state)
+
+ raise OAuth2RedirectError(oauth_url, tab_id, redirect_uri)
+
+ @staticmethod
+ def get_oauth2_authorization_uri(state: OAuth2State) -> str:
+ """
+ Return URI for initial OAuth2 request.
+ """
+ raise NotImplementedError()
+
+ @staticmethod
+ def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse:
+ """
+ Exchange authorization code for refresh/access tokens.
+ """
+ raise NotImplementedError()
+
+ @staticmethod
+ def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse:
+ """
+ Refresh an access token that has expired.
+ """
+ raise NotImplementedError()
+
@classmethod
def get_allows_alias_in_select(
cls, database: Database # pylint: disable=unused-argument
@@ -1079,7 +1179,12 @@ def handle_cursor(cls, cursor: Any, query: Query) -> None:
# TODO: Fix circular import error caused by importing sql_lab.Query
@classmethod
- def execute_with_cursor(cls, cursor: Any, sql: str, query: Query) -> None:
+ def execute_with_cursor(
+ cls,
+ cursor: Any,
+ sql: str,
+ query: Query,
+ ) -> None:
"""
Trigger execution of a query and handle the resulting cursor.
@@ -1090,7 +1195,7 @@ def execute_with_cursor(cls, cursor: Any, sql: str, query: Query) -> None:
in a timely manner and facilitate operations such as query stop
"""
logger.debug("Query %d: Running query: %s", query.id, sql)
- cls.execute(cursor, sql, async_=True)
+ cls.execute(cursor, sql, query.database.id, async_=True)
logger.debug("Query %d: Handling cursor", query.id)
cls.handle_cursor(cursor, query)
@@ -1536,7 +1641,11 @@ def estimate_query_cost(
@classmethod
def get_url_for_impersonation(
- cls, url: URL, impersonate_user: bool, username: str | None
+ cls,
+ url: URL,
+ impersonate_user: bool,
+ username: str | None,
+ access_token: str | None, # pylint: disable=unused-argument
) -> URL:
"""
Return a modified URL with the username set.
@@ -1544,6 +1653,7 @@ def get_url_for_impersonation(
:param url: SQLAlchemy URL object
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
+ :param access_token: Personal access token
"""
if impersonate_user and username is not None:
url = url.set(username=username)
@@ -1572,6 +1682,7 @@ def execute( # pylint: disable=unused-argument
cls,
cursor: Any,
query: str,
+ database_id: int,
**kwargs: Any,
) -> None:
"""
@@ -1579,6 +1690,7 @@ def execute( # pylint: disable=unused-argument
:param cursor: Cursor instance
:param query: Query to execute
+ :param database_id: ID of the database where the query will run
:param kwargs: kwargs to be passed to cursor.execute()
:return:
"""
@@ -1589,6 +1701,10 @@ def execute( # pylint: disable=unused-argument
cursor.arraysize = cls.arraysize
try:
cursor.execute(query)
+ except cls.oauth2_exception as ex:
+ if cls.is_oauth2_enabled() and g.user:
+ cls.start_oauth2_dance(database_id)
+ raise cls.get_dbapi_mapped_exception(ex) from ex
except Exception as ex:
raise cls.get_dbapi_mapped_exception(ex) from ex
diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py
index 276ff5b185448..e99d4a27f4a6e 100644
--- a/superset/db_engine_specs/drill.py
+++ b/superset/db_engine_specs/drill.py
@@ -100,7 +100,11 @@ def get_schema_from_engine_params(
@classmethod
def get_url_for_impersonation(
- cls, url: URL, impersonate_user: bool, username: str | None
+ cls,
+ url: URL,
+ impersonate_user: bool,
+ username: str | None,
+ access_token: str | None,
) -> URL:
"""
Return a modified URL with the username set.
diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py
index 18349f4314910..db8c1f221d3cf 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -23,21 +23,27 @@
import re
from re import Pattern
from typing import Any, TYPE_CHECKING, TypedDict
+from urllib.parse import urlencode
+import jwt
import pandas as pd
+import urllib3
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
-from flask import g
+from flask import current_app, g
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.exceptions import ValidationError
from requests import Session
+from shillelagh.adapters.api.gsheets.lib import SCOPES
+from shillelagh.exceptions import UnauthenticatedError
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.url import URL
from superset import db, security_manager
from superset.constants import PASSWORD_MASK
from superset.databases.schemas import encrypted_field_properties, EncryptedString
+from superset.db_engine_specs.base import OAuth2State, OAuth2TokenResponse
from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
@@ -56,6 +62,7 @@
SYNTAX_ERROR_REGEX = re.compile('SQLError: near "(?P.*?)": syntax error')
ma_plugin = MarshmallowPlugin()
+http = urllib3.PoolManager()
class GSheetsParametersSchema(Schema):
@@ -104,18 +111,28 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
supports_file_upload = True
+ # exception raised by shillelagh that should trigger OAuth2
+ oauth2_exception = UnauthenticatedError
+
@classmethod
def get_url_for_impersonation(
cls,
url: URL,
impersonate_user: bool,
username: str | None,
+ access_token: str | None,
) -> URL:
- if impersonate_user and username is not None:
+ if not impersonate_user:
+ return url
+
+ if username is not None:
user = security_manager.find_user(username=username)
if user and user.email:
url = url.update_query_dict({"subject": user.email})
+ if access_token:
+ url = url.update_query_dict({"access_token": access_token})
+
return url
@classmethod
@@ -136,6 +153,89 @@ def extra_table_metadata(
return {"metadata": metadata["extra"]}
+ @staticmethod
+ def is_oauth2_enabled() -> bool:
+ """
+ Return if OAuth2 is enabled for GSheets.
+ """
+ return "Google Sheets" in current_app.config["DATABASE_OAUTH2_CREDENTIALS"]
+
+ @staticmethod
+ def get_oauth2_authorization_uri(state: OAuth2State) -> str:
+ """
+ Return URI for initial OAuth2 request.
+
+ https://developers.google.com/identity/protocols/oauth2/web-server#creatingclient
+ """
+ config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"]
+ redirect_uri = current_app.config.get(
+ "DATABASE_OAUTH2_REDIRECT_URI",
+ state["default_redirect_uri"],
+ )
+
+ encoded_state = jwt.encode(
+ payload=state,
+ key=current_app.config["SECRET_KEY"],
+ algorithm=current_app.config["DATABASE_OAUTH2_JWT_ALGORITHM"],
+ )
+ # periods in the state break Google OAuth2 for some reason
+ encoded_state = encoded_state.replace(".", "%2E")
+
+ params = {
+ "scope": " ".join(SCOPES),
+ "access_type": "offline",
+ "include_granted_scopes": "false",
+ "response_type": "code",
+ "state": encoded_state,
+ "redirect_uri": redirect_uri,
+ "client_id": config["CLIENT_ID"],
+ "prompt": "consent",
+ }
+ return "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(params)
+
+ @staticmethod
+ def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse:
+ """
+ Exchange authorization code for refresh/access tokens.
+ """
+ config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"]
+ redirect_uri = current_app.config.get(
+ "DATABASE_OAUTH2_REDIRECT_URI",
+ state["default_redirect_uri"],
+ )
+
+ response = http.request(
+ "POST",
+ "https://oauth2.googleapis.com/token",
+ fields={
+ "code": code,
+ "client_id": config["CLIENT_ID"],
+ "client_secret": config["CLIENT_SECRET"],
+ "redirect_uri": redirect_uri,
+ "grant_type": "authorization_code",
+ },
+ )
+ return json.loads(response.data.decode("utf-8"))
+
+ @staticmethod
+ def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse:
+ """
+ Refresh an access token that has expired.
+ """
+ config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"]
+
+ response = http.request(
+ "POST",
+ "https://oauth2.googleapis.com/token",
+ fields={
+ "client_id": config["CLIENT_ID"],
+ "client_secret": config["CLIENT_SECRET"],
+ "refresh_token": refresh_token,
+ "grant_type": "refresh_token",
+ },
+ )
+ return json.loads(response.data.decode("utf-8"))
+
@classmethod
def build_sqlalchemy_uri(
cls,
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 9222d55db0171..a97dd88aefdda 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -505,7 +505,11 @@ def select_star( # pylint: disable=too-many-arguments
@classmethod
def get_url_for_impersonation(
- cls, url: URL, impersonate_user: bool, username: str | None
+ cls,
+ url: URL,
+ impersonate_user: bool,
+ username: str | None,
+ access_token: str | None,
) -> URL:
"""
Return a modified URL with the username set.
@@ -547,7 +551,10 @@ def update_impersonation_config(
@staticmethod
def execute( # type: ignore
- cursor, query: str, async_: bool = False
+ cursor,
+ query: str,
+ database_id: int,
+ async_: bool = False,
): # pylint: disable=arguments-differ
kwargs = {"async": async_}
cursor.execute(query, **kwargs)
diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py
index 9e5f728a6f84a..8cda5b5861836 100644
--- a/superset/db_engine_specs/impala.py
+++ b/superset/db_engine_specs/impala.py
@@ -93,6 +93,7 @@ def execute(
cls,
cursor: Any,
query: str,
+ database_id: int,
**kwargs: Any,
) -> None:
try:
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 44f8f9668a224..0df8d53f4f236 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -1271,7 +1271,7 @@ def get_create_view(
cursor = conn.cursor()
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
- cls.execute(cursor, sql)
+ cls.execute(cursor, sql, database.id)
rows = cls.fetch_data(cursor, 1)
return rows[0][0]
diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index 6d95f9589e49a..4513d63c606b1 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -132,7 +132,11 @@ def update_impersonation_config(
@classmethod
def get_url_for_impersonation(
- cls, url: URL, impersonate_user: bool, username: str | None
+ cls,
+ url: URL,
+ impersonate_user: bool,
+ username: str | None,
+ access_token: str | None,
) -> URL:
"""
Return a modified URL with the username set.
@@ -191,7 +195,12 @@ def handle_cursor(cls, cursor: Cursor, query: Query) -> None:
super().handle_cursor(cursor=cursor, query=query)
@classmethod
- def execute_with_cursor(cls, cursor: Cursor, sql: str, query: Query) -> None:
+ def execute_with_cursor(
+ cls,
+ cursor: Cursor,
+ sql: str,
+ query: Query,
+ ) -> None:
"""
Trigger execution of a query and handle the resulting cursor.
@@ -210,7 +219,7 @@ def _execute(results: dict[str, Any], event: threading.Event) -> None:
logger.debug("Query %d: Running query: %s", query_id, sql)
try:
- cls.execute(cursor, sql)
+ cls.execute(cursor, sql, query.database.id)
except Exception as ex: # pylint: disable=broad-except
results["error"] = ex
finally:
diff --git a/superset/errors.py b/superset/errors.py
index 7c383891676db..b792a120da5ae 100644
--- a/superset/errors.py
+++ b/superset/errors.py
@@ -67,6 +67,8 @@ class SupersetErrorType(StrEnum):
USER_ACTIVITY_SECURITY_ACCESS_ERROR = "USER_ACTIVITY_SECURITY_ACCESS_ERROR"
DASHBOARD_SECURITY_ACCESS_ERROR = "DASHBOARD_SECURITY_ACCESS_ERROR"
CHART_SECURITY_ACCESS_ERROR = "CHART_SECURITY_ACCESS_ERROR"
+ OAUTH2_REDIRECT = "OAUTH2_REDIRECT"
+ OAUTH2_REDIRECT_ERROR = "OAUTH2_REDIRECT_ERROR"
# Other errors
BACKEND_TIMEOUT_ERROR = "BACKEND_TIMEOUT_ERROR"
diff --git a/superset/exceptions.py b/superset/exceptions.py
index 0ce72e2e1a6f1..91a4656595cd5 100644
--- a/superset/exceptions.py
+++ b/superset/exceptions.py
@@ -312,3 +312,47 @@ def __init__(self, sql: str, engine: Optional[str] = None):
extra={"sql": sql, "engine": engine},
)
super().__init__(error)
+
+
+class OAuth2RedirectError(SupersetErrorException):
+ """
+ Exception used to start OAuth2 dance for personal tokens.
+
+ The exception requires 3 parameters:
+
+ - The URL that starts the OAuth2 dance.
+ - The UUID of the browser tab where OAuth2 started, so that the newly opened tab
+ where OAuth2 happens can communicate with the original tab to inform that OAuth2
+ was successfull (or not).
+ - The redirect URL, so that the original tab can validate that the message from the
+ second tab is coming from a valid origin.
+
+ See the `OAuth2RedirectMessage.tsx` component for more details of how this
+ information is handled.
+ """
+
+ def __init__(self, url: str, tab_id: str, redirect_uri: str):
+ super().__init__(
+ SupersetError(
+ message="You don't have permission to access the data.",
+ error_type=SupersetErrorType.OAUTH2_REDIRECT,
+ level=ErrorLevel.WARNING,
+ extra={"url": url, "tab_id": tab_id, "redirect_uri": redirect_uri},
+ )
+ )
+
+
+class OAuth2Error(SupersetErrorException):
+ """
+ Exception for when OAuth2 goes wrong.
+ """
+
+ def __init__(self, error: str):
+ super().__init__(
+ SupersetError(
+ message="Something went wrong while doing OAuth2",
+ error_type=SupersetErrorType.OAUTH2_REDIRECT_ERROR,
+ level=ErrorLevel.ERROR,
+ extra={"error": error},
+ )
+ )
diff --git a/superset/migrations/versions/2024-03-20_16-02_678eefb4ab44_add_access_token_table.py b/superset/migrations/versions/2024-03-20_16-02_678eefb4ab44_add_access_token_table.py
new file mode 100644
index 0000000000000..9ee87fc2231f0
--- /dev/null
+++ b/superset/migrations/versions/2024-03-20_16-02_678eefb4ab44_add_access_token_table.py
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Add access token table
+
+Revision ID: 678eefb4ab44
+Revises: be1b217cd8cd
+Create Date: 2024-03-20 16:02:58.515915
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = "678eefb4ab44"
+down_revision = "be1b217cd8cd"
+
+import sqlalchemy as sa
+from alembic import op
+from sqlalchemy_utils import EncryptedType
+
+
+def upgrade():
+ op.create_table(
+ "database_user_oauth2_tokens",
+ sa.Column("created_on", sa.DateTime(), nullable=True),
+ sa.Column("changed_on", sa.DateTime(), nullable=True),
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column("database_id", sa.Integer(), nullable=False),
+ sa.Column(
+ "access_token",
+ EncryptedType(),
+ nullable=True,
+ ),
+ sa.Column("access_token_expiration", sa.DateTime(), nullable=True),
+ sa.Column(
+ "refresh_token",
+ EncryptedType(),
+ nullable=True,
+ ),
+ sa.Column("created_by_fk", sa.Integer(), nullable=True),
+ sa.Column("changed_by_fk", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["changed_by_fk"],
+ ["ab_user.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["created_by_fk"],
+ ["ab_user.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["database_id"],
+ ["dbs.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"],
+ ["ab_user.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+
+
+def downgrade():
+ op.drop_table("database_user_oauth2_tokens")
diff --git a/superset/models/core.py b/superset/models/core.py
index 71a6e9d04237b..54780cac74e82 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -75,6 +75,7 @@
from superset.utils import cache as cache_util, core as utils
from superset.utils.backports import StrEnum
from superset.utils.core import get_username
+from superset.utils.oauth2 import get_oauth2_access_token
config = app.config
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
@@ -461,6 +462,11 @@ def _get_sqla_engine(
)
effective_username = self.get_effective_user(sqlalchemy_url)
+ access_token = (
+ get_oauth2_access_token(self.id, g.user.id, self.db_engine_spec)
+ if hasattr(g, "user") and hasattr(g.user, "id")
+ else None
+ )
# If using MySQL or Presto for example, will set url.username
# If using Hive, will not do anything yet since that relies on a
# configuration parameter instead.
@@ -468,6 +474,7 @@ def _get_sqla_engine(
sqlalchemy_url,
self.impersonate_user,
effective_username,
+ access_token,
)
masked_url = self.get_password_masked_url(sqlalchemy_url)
@@ -588,7 +595,7 @@ def _log_query(sql: str) -> None:
database=None,
)
_log_query(sql_)
- self.db_engine_spec.execute(cursor, sql_)
+ self.db_engine_spec.execute(cursor, sql_, self.id)
cursor.fetchall()
if mutate_after_split:
@@ -598,10 +605,10 @@ def _log_query(sql: str) -> None:
database=None,
)
_log_query(last_sql)
- self.db_engine_spec.execute(cursor, last_sql)
+ self.db_engine_spec.execute(cursor, last_sql, self.id)
else:
_log_query(sqls[-1])
- self.db_engine_spec.execute(cursor, sqls[-1])
+ self.db_engine_spec.execute(cursor, sqls[-1], self.id)
data = self.db_engine_spec.fetch_data(cursor)
result_set = SupersetResultSet(
@@ -978,6 +985,26 @@ def make_sqla_column_compatible(
sqla.event.listen(Database, "after_delete", security_manager.database_after_delete)
+class DatabaseUserOAuth2Tokens(Model, AuditMixinNullable):
+ """
+ Store OAuth2 tokens, for authenticating to DBs using user personal tokens.
+ """
+
+ __tablename__ = "database_user_oauth2_tokens"
+
+ id = Column(Integer, primary_key=True)
+
+ user_id = Column(Integer, ForeignKey("ab_user.id"), nullable=False)
+ user = relationship(security_manager.user_model, foreign_keys=[user_id])
+
+ database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False)
+ database = relationship("Database", foreign_keys=[database_id])
+
+ access_token = Column(encrypted_field_factory.create(Text), nullable=True)
+ access_token_expiration = Column(DateTime, nullable=True)
+ refresh_token = Column(encrypted_field_factory.create(Text), nullable=True)
+
+
class Log(Model): # pylint: disable=too-few-public-methods
"""ORM object used to log Superset actions to the database"""
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 1b883a77cfbbc..3fa80fbc45179 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -41,7 +41,11 @@
from superset.dataframe import df_to_records
from superset.db_engine_specs import BaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
-from superset.exceptions import SupersetErrorException, SupersetErrorsException
+from superset.exceptions import (
+ OAuth2RedirectError,
+ SupersetErrorException,
+ SupersetErrorsException,
+)
from superset.extensions import celery_app
from superset.models.core import Database
from superset.models.sql_lab import Query
@@ -188,7 +192,7 @@ def get_sql_results( # pylint: disable=too-many-arguments
return handle_query_error(ex, query)
-def execute_sql_statement(
+def execute_sql_statement( # pylint: disable=too-many-statements
sql_statement: str,
query: Query,
cursor: Any,
@@ -308,6 +312,9 @@ def execute_sql_statement(
level=ErrorLevel.ERROR,
)
) from ex
+ except OAuth2RedirectError as ex:
+ # user needs to authenticate with OAuth2 in order to run query
+ raise ex
except Exception as ex:
# query is stopped in another thread/worker
# stopping raises expected exceptions which we should skip
diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py
index fed1ff3bfae62..8c815ad63ed34 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -73,7 +73,7 @@ def validate_statement(
from pyhive.exc import DatabaseError
try:
- db_engine_spec.execute(cursor, sql)
+ db_engine_spec.execute(cursor, sql, database.id)
polled = cursor.poll()
while polled:
logger.info("polling presto for validation progress")
diff --git a/superset/templates/superset/oauth2.html b/superset/templates/superset/oauth2.html
new file mode 100644
index 0000000000000..e3562758db876
--- /dev/null
+++ b/superset/templates/superset/oauth2.html
@@ -0,0 +1,31 @@
+{#
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-#}
+
+
+
+
+
+
+
+ You can close this window and re-run the query.
+
+
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
new file mode 100644
index 0000000000000..e391d527a3cc6
--- /dev/null
+++ b/superset/utils/oauth2.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from datetime import datetime, timedelta
+
+from superset import db
+from superset.db_engine_specs.base import BaseEngineSpec
+
+
+def get_oauth2_access_token(
+ database_id: int,
+ user_id: int,
+ db_engine_spec: type[BaseEngineSpec],
+) -> str | None:
+ """
+ Return a valid OAuth2 access token.
+
+ If the token exists but is expired and a refresh token is available the function will
+ return a fresh token and store it in the database for further requests.
+ """
+ # pylint: disable=import-outside-toplevel
+ from superset.models.core import DatabaseUserOAuth2Tokens
+
+ token = (
+ db.session.query(DatabaseUserOAuth2Tokens)
+ .filter_by(user_id=user_id, database_id=database_id)
+ .one_or_none()
+ )
+ if token is None:
+ return None
+
+ if token.access_token and token.access_token_expiration < datetime.now():
+ return token.access_token
+
+ if token.refresh_token:
+ # refresh access token
+ token_response = db_engine_spec.get_oauth2_fresh_token(token.refresh_token)
+
+ # store new access token; note that the refresh token might be revoked, in which
+ # case there would be no access token in the response
+ if "access_token" in token_response:
+ token.access_token = token_response["access_token"]
+ token.access_token_expiration = datetime.now() + timedelta(
+ seconds=token_response["expires_in"]
+ )
+ db.session.add(token)
+
+ return token.access_token
+
+ # since the access token is expired and there's no refresh token, delete the entry
+ db.session.delete(token)
+
+ return None
diff --git a/tests/unit_tests/connectors/__init__.py b/tests/unit_tests/connectors/__init__.py
new file mode 100644
index 0000000000000..13a83393a9124
--- /dev/null
+++ b/tests/unit_tests/connectors/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/connectors/sqla/__init__.py b/tests/unit_tests/connectors/sqla/__init__.py
new file mode 100644
index 0000000000000..13a83393a9124
--- /dev/null
+++ b/tests/unit_tests/connectors/sqla/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/connectors/sqla/models_test.py b/tests/unit_tests/connectors/sqla/models_test.py
new file mode 100644
index 0000000000000..00b4b0a31545b
--- /dev/null
+++ b/tests/unit_tests/connectors/sqla/models_test.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+from pytest_mock import MockerFixture
+
+from superset.connectors.sqla.models import SqlaTable
+from superset.exceptions import OAuth2RedirectError
+from superset.superset_typing import QueryObjectDict
+
+
+def test_query_bubbles_errors(mocker: MockerFixture) -> None:
+ """
+ Test that the `query` method bubbles exceptions correctly.
+
+ When a user needs to authenticate via OAuth2 to access data, a custom exception is
+ raised. The exception needs to bubble up all the way to the frontend as a SIP-40
+ compliant payload with the error type `DATABASE_OAUTH2_REDIRECT_URI` so that the
+ frontend can initiate the OAuth2 authentication.
+
+ This tests verifies that the method does not capture these exceptions; otherwise the
+ user will be never be prompted to authenticate via OAuth2.
+ """
+ database = mocker.MagicMock()
+ database.get_df.side_effect = OAuth2RedirectError(
+ url="http://example.com",
+ tab_id="1234",
+ redirect_uri="http://redirect.example.com",
+ )
+
+ sqla_table = SqlaTable(
+ table_name="my_sqla_table",
+ columns=[],
+ metrics=[],
+ database=database,
+ )
+ mocker.patch.object(
+ sqla_table,
+ "get_query_str_extended",
+ return_value=mocker.MagicMock(sql="SELECT * FROM my_sqla_table"),
+ )
+ query_obj: QueryObjectDict = {
+ "granularity": None,
+ "from_dttm": None,
+ "to_dttm": None,
+ "groupby": ["id", "username", "email"],
+ "metrics": [],
+ "is_timeseries": False,
+ "filter": [],
+ }
+ with pytest.raises(OAuth2RedirectError):
+ sqla_table.query(query_obj)
diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py
index c0e1723fd8361..cba5ca7f4acc1 100644
--- a/tests/unit_tests/databases/api_test.py
+++ b/tests/unit_tests/databases/api_test.py
@@ -18,6 +18,7 @@
# pylint: disable=unused-argument, import-outside-toplevel, line-too-long
import json
+from datetime import datetime
from io import BytesIO
from typing import Any
from unittest.mock import Mock
@@ -25,10 +26,12 @@
import pytest
from flask import current_app
+from freezegun import freeze_time
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset import db
+from superset.db_engine_specs.sqlite import SqliteEngineSpec
def test_filter_by_uuid(
@@ -638,3 +641,70 @@ def _base_filter(query):
# Ensure that the filter has been called once
assert base_filter_mock.call_count == 1
+
+
+def test_oauth2_happy_path(
+ mocker: MockFixture,
+ session: Session,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test the OAuth2 endpoint.
+ """
+ from superset.databases.api import DatabaseRestApi, DatabaseUserOAuth2Tokens
+ from superset.models.core import Database
+
+ DatabaseRestApi.datamodel.session = session
+
+ # create table for databases
+ Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
+ db.session.add(
+ Database(
+ database_name="my_db",
+ sqlalchemy_uri="sqlite://",
+ uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
+ )
+ )
+ db.session.commit()
+
+ get_oauth2_token = mocker.patch.object(SqliteEngineSpec, "get_oauth2_token")
+ get_oauth2_token.return_value = {
+ "access_token": "YYY",
+ "expires_in": 3600,
+ "refresh_token": "ZZZ",
+ }
+
+ state = {
+ "user_id": 1,
+ "database_id": 1,
+ "tab_id": 42,
+ }
+ jwt = mocker.patch("superset.databases.api.jwt")
+ jwt.decode.return_value = state
+
+ mocker.patch("superset.databases.api.render_template", return_value="OK")
+
+ with freeze_time("2024-01-01T00:00:00Z"):
+ response = client.get(
+ "/api/v1/database/oauth2/",
+ query_string={
+ "state": "some%2Estate",
+ "code": "XXX",
+ },
+ )
+
+ assert response.status_code == 200
+ jwt.decode.assert_called_with(
+ "some.state",
+ "not-a-secret",
+ algorithms=["HS256"],
+ )
+ get_oauth2_token.assert_called_with("XXX", state)
+
+ token = db.session.query(DatabaseUserOAuth2Tokens).one()
+ assert token.user_id == 1
+ assert token.database_id == 1
+ assert token.access_token == "YYY"
+ assert token.access_token_expiration == datetime(2024, 1, 1, 1, 0)
+ assert token.refresh_token == "ZZZ"
diff --git a/tests/unit_tests/db_engine_specs/test_clickhouse.py b/tests/unit_tests/db_engine_specs/test_clickhouse.py
index 3f28341f2643d..94b70ba5264ec 100644
--- a/tests/unit_tests/db_engine_specs/test_clickhouse.py
+++ b/tests/unit_tests/db_engine_specs/test_clickhouse.py
@@ -65,8 +65,9 @@ def test_execute_connection_error() -> None:
cursor.execute.side_effect = NewConnectionError(
HTTPConnection("localhost"), "Exception with sensitive data"
)
- with pytest.raises(SupersetDBAPIDatabaseError) as ex:
- ClickHouseEngineSpec.execute(cursor, "SELECT col1 from table1")
+ with pytest.raises(SupersetDBAPIDatabaseError) as excinfo:
+ ClickHouseEngineSpec.execute(cursor, "SELECT col1 from table1", 1)
+ assert str(excinfo.value) == "Connection failed"
@pytest.mark.parametrize(
diff --git a/tests/unit_tests/db_engine_specs/test_databend.py b/tests/unit_tests/db_engine_specs/test_databend.py
index 9c494492d9ad1..06fab791884e9 100644
--- a/tests/unit_tests/db_engine_specs/test_databend.py
+++ b/tests/unit_tests/db_engine_specs/test_databend.py
@@ -66,8 +66,9 @@ def test_execute_connection_error() -> None:
cursor.execute.side_effect = NewConnectionError(
HTTPConnection("Dummypool"), "Exception with sensitive data"
)
- with pytest.raises(SupersetDBAPIDatabaseError) as ex:
- DatabendEngineSpec.execute(cursor, "SELECT col1 from table1")
+ with pytest.raises(SupersetDBAPIDatabaseError) as excinfo:
+ DatabendEngineSpec.execute(cursor, "SELECT col1 from table1", 1)
+ assert str(excinfo.value) == "Connection failed"
@pytest.mark.parametrize(
diff --git a/tests/unit_tests/db_engine_specs/test_drill.py b/tests/unit_tests/db_engine_specs/test_drill.py
index c7463dcf1faa8..c0d2601006bd5 100644
--- a/tests/unit_tests/db_engine_specs/test_drill.py
+++ b/tests/unit_tests/db_engine_specs/test_drill.py
@@ -38,7 +38,7 @@ def test_odbc_impersonation() -> None:
url = URL.create("drill+odbc")
username = "DoAsUser"
- url = DrillEngineSpec.get_url_for_impersonation(url, True, username)
+ url = DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
assert url.query["DelegationUID"] == username
@@ -54,7 +54,7 @@ def test_jdbc_impersonation() -> None:
url = URL.create("drill+jdbc")
username = "DoAsUser"
- url = DrillEngineSpec.get_url_for_impersonation(url, True, username)
+ url = DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
assert url.query["impersonation_target"] == username
@@ -70,7 +70,7 @@ def test_sadrill_impersonation() -> None:
url = URL.create("drill+sadrill")
username = "DoAsUser"
- url = DrillEngineSpec.get_url_for_impersonation(url, True, username)
+ url = DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
assert url.query["impersonation_target"] == username
@@ -90,7 +90,7 @@ def test_invalid_impersonation() -> None:
username = "DoAsUser"
with pytest.raises(SupersetDBAPIProgrammingError):
- DrillEngineSpec.get_url_for_impersonation(url, True, username)
+ DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
@pytest.mark.parametrize(
diff --git a/tests/unit_tests/db_engine_specs/test_elasticsearch.py b/tests/unit_tests/db_engine_specs/test_elasticsearch.py
index 0c1597766948b..ed80454d3c699 100644
--- a/tests/unit_tests/db_engine_specs/test_elasticsearch.py
+++ b/tests/unit_tests/db_engine_specs/test_elasticsearch.py
@@ -101,6 +101,8 @@ def test_opendistro_strip_comments() -> None:
mock_cursor.execute.return_value = []
OpenDistroEngineSpec.execute(
- mock_cursor, "-- some comment \nSELECT 1\n --other comment"
+ mock_cursor,
+ "-- some comment \nSELECT 1\n --other comment",
+ 1,
)
mock_cursor.execute.assert_called_once_with("SELECT 1\n")
diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py
index 83e7c373c8e43..3e5a80881549a 100644
--- a/tests/unit_tests/sql_lab_test.py
+++ b/tests/unit_tests/sql_lab_test.py
@@ -55,7 +55,9 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True)
db_engine_spec.execute_with_cursor.assert_called_with(
- cursor, "SELECT 42 AS answer LIMIT 2", query
+ cursor,
+ "SELECT 42 AS answer LIMIT 2",
+ query,
)
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
@@ -104,7 +106,9 @@ def test_execute_sql_statement_with_rls(
force=True,
)
db_engine_spec.execute_with_cursor.assert_called_with(
- cursor, "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", query
+ cursor,
+ "SELECT * FROM sales WHERE organization_id=42 LIMIT 101",
+ query,
)
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)