diff --git a/UPDATING.md b/UPDATING.md index 22674e49a3f2e..a911b4b541291 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -62,7 +62,6 @@ flag for the legacy datasource editor (DISABLE_LEGACY_DATASOURCE_EDITOR) in conf ### Deprecations -- [19078](https://github.com/apache/superset/pull/19078): Creation of old shorturl links has been deprecated in favor of a new permalink feature that solves the long url problem (old shorturls will still work, though!). By default, new permalinks use UUID4 as the key. However, to use serial ids similar to the old shorturls, add the following to your `superset_config.py`: `PERMALINK_KEY_TYPE = "id"`. - [18960](https://github.com/apache/superset/pull/18960): Persisting URL params in chart metadata is no longer supported. To set a default value for URL params in Jinja code, use the optional second argument: `url_param("my-param", "my-default-value")`. ### Other diff --git a/requirements/base.txt b/requirements/base.txt index fb16470fd399a..6162303b05e1b 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -118,6 +118,8 @@ graphlib-backport==1.0.3 # via apache-superset gunicorn==20.1.0 # via apache-superset +hashids==1.3.1 + # via apache-superset holidays==0.10.3 # via apache-superset humanize==3.11.0 diff --git a/setup.py b/setup.py index 977d1b3e7ae63..81068e1f860da 100644 --- a/setup.py +++ b/setup.py @@ -88,6 +88,7 @@ def get_git_sha() -> str: "geopy", "graphlib-backport", "gunicorn>=20.1.0", + "hashids>=1.3.1, <2", "holidays==0.10.3", # PINNED! https://github.com/dr-prodigy/python-holidays/issues/406 "humanize", "isodate", diff --git a/superset/config.py b/superset/config.py index ee7a4d16e1383..b578c99354a5b 100644 --- a/superset/config.py +++ b/superset/config.py @@ -43,7 +43,6 @@ from superset.constants import CHANGE_ME_SECRET_KEY from superset.jinja_context import BaseTemplateProcessor -from superset.key_value.types import KeyType from superset.stats_logger import DummyStatsLogger from superset.superset_typing import CacheConfig from superset.utils.core import is_test, parse_boolean_string @@ -600,8 +599,6 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: # store cache keys by datasource UID (via CacheKey) for custom processing/invalidation STORE_CACHE_KEYS_IN_METADATA_DB = False -PERMALINK_KEY_TYPE: KeyType = "uuid" - # CORS Options ENABLE_CORS = False CORS_OPTIONS: Dict[Any, Any] = {} diff --git a/superset/dashboards/filter_state/commands/create.py b/superset/dashboards/filter_state/commands/create.py index a37ee072fa75e..137623027a1fc 100644 --- a/superset/dashboards/filter_state/commands/create.py +++ b/superset/dashboards/filter_state/commands/create.py @@ -18,10 +18,11 @@ from superset.dashboards.dao import DashboardDAO from superset.extensions import cache_manager +from superset.key_value.utils import random_key from superset.temporary_cache.commands.create import CreateTemporaryCacheCommand from superset.temporary_cache.commands.entry import Entry from superset.temporary_cache.commands.parameters import CommandParameters -from superset.temporary_cache.utils import cache_key, random_key +from superset.temporary_cache.utils import cache_key class CreateFilterStateCommand(CreateTemporaryCacheCommand): diff --git a/superset/dashboards/filter_state/commands/update.py b/superset/dashboards/filter_state/commands/update.py index 6a9cd3931c9c8..d27277f9afb97 100644 --- a/superset/dashboards/filter_state/commands/update.py +++ b/superset/dashboards/filter_state/commands/update.py @@ -20,11 +20,12 @@ from superset.dashboards.dao import DashboardDAO from superset.extensions import cache_manager +from superset.key_value.utils import random_key from superset.temporary_cache.commands.entry import Entry from superset.temporary_cache.commands.exceptions import TemporaryCacheAccessDeniedError from superset.temporary_cache.commands.parameters import CommandParameters from superset.temporary_cache.commands.update import UpdateTemporaryCacheCommand -from superset.temporary_cache.utils import cache_key, random_key +from superset.temporary_cache.utils import cache_key class UpdateFilterStateCommand(UpdateTemporaryCacheCommand): diff --git a/superset/dashboards/permalink/api.py b/superset/dashboards/permalink/api.py index 2236c4de2a1af..978a63cbcf9e8 100644 --- a/superset/dashboards/permalink/api.py +++ b/superset/dashboards/permalink/api.py @@ -16,7 +16,7 @@ # under the License. import logging -from flask import current_app, g, request, Response +from flask import g, request, Response from flask_appbuilder.api import BaseApi, expose, protect, safe from marshmallow import ValidationError @@ -101,11 +101,10 @@ def post(self, pk: str) -> Response: 500: $ref: '#/components/responses/500' """ - key_type = current_app.config["PERMALINK_KEY_TYPE"] try: state = self.add_model_schema.load(request.json) key = CreateDashboardPermalinkCommand( - actor=g.user, dashboard_id=pk, state=state, key_type=key_type, + actor=g.user, dashboard_id=pk, state=state, ).run() http_origin = request.headers.environ.get("HTTP_ORIGIN") url = f"{http_origin}/superset/dashboard/p/{key}/" @@ -158,10 +157,7 @@ def get(self, key: str) -> Response: $ref: '#/components/responses/500' """ try: - key_type = current_app.config["PERMALINK_KEY_TYPE"] - value = GetDashboardPermalinkCommand( - actor=g.user, key=key, key_type=key_type - ).run() + value = GetDashboardPermalinkCommand(actor=g.user, key=key).run() if not value: return self.response_404() return self.response(200, **value) diff --git a/superset/dashboards/permalink/commands/base.py b/superset/dashboards/permalink/commands/base.py index 2c0343810e024..f4dc4f0726110 100644 --- a/superset/dashboards/permalink/commands/base.py +++ b/superset/dashboards/permalink/commands/base.py @@ -17,7 +17,13 @@ from abc import ABC from superset.commands.base import BaseCommand +from superset.key_value.shared_entries import get_permalink_salt +from superset.key_value.types import KeyValueResource, SharedKey class BaseDashboardPermalinkCommand(BaseCommand, ABC): - resource = "dashboard_permalink" + resource = KeyValueResource.DASHBOARD_PERMALINK + + @property + def salt(self) -> str: + return get_permalink_salt(SharedKey.DASHBOARD_PERMALINK_SALT) diff --git a/superset/dashboards/permalink/commands/create.py b/superset/dashboards/permalink/commands/create.py index a97f228dd83d1..8a0f6d5973a3d 100644 --- a/superset/dashboards/permalink/commands/create.py +++ b/superset/dashboards/permalink/commands/create.py @@ -24,23 +24,18 @@ from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError from superset.dashboards.permalink.types import DashboardPermalinkState from superset.key_value.commands.create import CreateKeyValueCommand -from superset.key_value.types import KeyType +from superset.key_value.utils import encode_permalink_key logger = logging.getLogger(__name__) class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand): def __init__( - self, - actor: User, - dashboard_id: str, - state: DashboardPermalinkState, - key_type: KeyType, + self, actor: User, dashboard_id: str, state: DashboardPermalinkState, ): self.actor = actor self.dashboard_id = dashboard_id self.state = state - self.key_type = key_type def run(self) -> str: self.validate() @@ -50,12 +45,10 @@ def run(self) -> str: "dashboardId": self.dashboard_id, "state": self.state, } - return CreateKeyValueCommand( - actor=self.actor, - resource=self.resource, - value=value, - key_type=self.key_type, + key = CreateKeyValueCommand( + actor=self.actor, resource=self.resource, value=value, ).run() + return encode_permalink_key(key=key.id, salt=self.salt) except SQLAlchemyError as ex: logger.exception("Error running create command") raise DashboardPermalinkCreateFailedError() from ex diff --git a/superset/dashboards/permalink/commands/get.py b/superset/dashboards/permalink/commands/get.py index 6cb2749cbacb9..24bf77834a60c 100644 --- a/superset/dashboards/permalink/commands/get.py +++ b/superset/dashboards/permalink/commands/get.py @@ -27,25 +27,21 @@ from superset.dashboards.permalink.types import DashboardPermalinkValue from superset.key_value.commands.get import GetKeyValueCommand from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError -from superset.key_value.types import KeyType +from superset.key_value.utils import decode_permalink_id logger = logging.getLogger(__name__) class GetDashboardPermalinkCommand(BaseDashboardPermalinkCommand): - def __init__( - self, actor: User, key: str, key_type: KeyType, - ): + def __init__(self, actor: User, key: str): self.actor = actor self.key = key - self.key_type = key_type def run(self) -> Optional[DashboardPermalinkValue]: self.validate() try: - command = GetKeyValueCommand( - resource=self.resource, key=self.key, key_type=self.key_type - ) + key = decode_permalink_id(self.key, salt=self.salt) + command = GetKeyValueCommand(resource=self.resource, key=key) value: Optional[DashboardPermalinkValue] = command.run() if value: DashboardDAO.get_by_id_or_slug(value["dashboardId"]) diff --git a/superset/dashboards/permalink/schemas.py b/superset/dashboards/permalink/schemas.py index 91d60b02c23b7..0e373ce85bd0c 100644 --- a/superset/dashboards/permalink/schemas.py +++ b/superset/dashboards/permalink/schemas.py @@ -19,7 +19,7 @@ class DashboardPermalinkPostSchema(Schema): filterState = fields.Dict( - required=True, allow_none=False, description="Native filter state", + required=False, allow_none=True, description="Native filter state", ) urlParams = fields.List( fields.Tuple( diff --git a/superset/dashboards/permalink/types.py b/superset/dashboards/permalink/types.py index 815c5bfe91d47..e93076ba23785 100644 --- a/superset/dashboards/permalink/types.py +++ b/superset/dashboards/permalink/types.py @@ -18,7 +18,7 @@ class DashboardPermalinkState(TypedDict): - filterState: Dict[str, Any] + filterState: Optional[Dict[str, Any]] hash: Optional[str] urlParams: Optional[List[Tuple[str, str]]] diff --git a/superset/explore/form_data/commands/create.py b/superset/explore/form_data/commands/create.py index a325241c6641d..7b1f866c505df 100644 --- a/superset/explore/form_data/commands/create.py +++ b/superset/explore/form_data/commands/create.py @@ -24,8 +24,9 @@ from superset.explore.form_data.commands.state import TemporaryExploreState from superset.explore.utils import check_access from superset.extensions import cache_manager +from superset.key_value.utils import random_key from superset.temporary_cache.commands.exceptions import TemporaryCacheCreateFailedError -from superset.temporary_cache.utils import cache_key, random_key +from superset.temporary_cache.utils import cache_key from superset.utils.schema import validate_json logger = logging.getLogger(__name__) diff --git a/superset/explore/form_data/commands/update.py b/superset/explore/form_data/commands/update.py index 0c986ee102cb0..279722971f5fb 100644 --- a/superset/explore/form_data/commands/update.py +++ b/superset/explore/form_data/commands/update.py @@ -26,11 +26,12 @@ from superset.explore.form_data.commands.state import TemporaryExploreState from superset.explore.utils import check_access from superset.extensions import cache_manager +from superset.key_value.utils import random_key from superset.temporary_cache.commands.exceptions import ( TemporaryCacheAccessDeniedError, TemporaryCacheUpdateFailedError, ) -from superset.temporary_cache.utils import cache_key, random_key +from superset.temporary_cache.utils import cache_key from superset.utils.schema import validate_json logger = logging.getLogger(__name__) diff --git a/superset/explore/permalink/api.py b/superset/explore/permalink/api.py index 025b1a45481c8..9a03b71150e35 100644 --- a/superset/explore/permalink/api.py +++ b/superset/explore/permalink/api.py @@ -16,7 +16,7 @@ # under the License. import logging -from flask import current_app, g, request, Response +from flask import g, request, Response from flask_appbuilder.api import BaseApi, expose, protect, safe from marshmallow import ValidationError @@ -98,12 +98,9 @@ def post(self) -> Response: 500: $ref: '#/components/responses/500' """ - key_type = current_app.config["PERMALINK_KEY_TYPE"] try: state = self.add_model_schema.load(request.json) - key = CreateExplorePermalinkCommand( - actor=g.user, state=state, key_type=key_type, - ).run() + key = CreateExplorePermalinkCommand(actor=g.user, state=state).run() http_origin = request.headers.environ.get("HTTP_ORIGIN") url = f"{http_origin}/superset/explore/p/{key}/" return self.response(201, key=key, url=url) @@ -159,10 +156,7 @@ def get(self, key: str) -> Response: $ref: '#/components/responses/500' """ try: - key_type = current_app.config["PERMALINK_KEY_TYPE"] - value = GetExplorePermalinkCommand( - actor=g.user, key=key, key_type=key_type - ).run() + value = GetExplorePermalinkCommand(actor=g.user, key=key).run() if not value: return self.response_404() return self.response(200, **value) diff --git a/superset/explore/permalink/commands/base.py b/superset/explore/permalink/commands/base.py index 01a96405da026..bef9546e21686 100644 --- a/superset/explore/permalink/commands/base.py +++ b/superset/explore/permalink/commands/base.py @@ -17,7 +17,13 @@ from abc import ABC from superset.commands.base import BaseCommand +from superset.key_value.shared_entries import get_permalink_salt +from superset.key_value.types import KeyValueResource, SharedKey class BaseExplorePermalinkCommand(BaseCommand, ABC): - resource = "explore_permalink" + resource: KeyValueResource = KeyValueResource.EXPLORE_PERMALINK + + @property + def salt(self) -> str: + return get_permalink_salt(SharedKey.EXPLORE_PERMALINK_SALT) diff --git a/superset/explore/permalink/commands/create.py b/superset/explore/permalink/commands/create.py index 936f20063b9bd..38e91fb105b27 100644 --- a/superset/explore/permalink/commands/create.py +++ b/superset/explore/permalink/commands/create.py @@ -24,18 +24,17 @@ from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError from superset.explore.utils import check_access from superset.key_value.commands.create import CreateKeyValueCommand -from superset.key_value.types import KeyType +from superset.key_value.utils import encode_permalink_key logger = logging.getLogger(__name__) class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand): - def __init__(self, actor: User, state: Dict[str, Any], key_type: KeyType): + def __init__(self, actor: User, state: Dict[str, Any]): self.actor = actor self.chart_id: Optional[int] = state["formData"].get("slice_id") self.datasource: str = state["formData"]["datasource"] self.state = state - self.key_type = key_type def run(self) -> str: self.validate() @@ -49,12 +48,10 @@ def run(self) -> str: "state": self.state, } command = CreateKeyValueCommand( - actor=self.actor, - resource=self.resource, - value=value, - key_type=self.key_type, + actor=self.actor, resource=self.resource, value=value, ) - return command.run() + key = command.run() + return encode_permalink_key(key=key.id, salt=self.salt) except SQLAlchemyError as ex: logger.exception("Error running create command") raise ExplorePermalinkCreateFailedError() from ex diff --git a/superset/explore/permalink/commands/get.py b/superset/explore/permalink/commands/get.py index e22ab8332f3dc..15a2d495cd6d7 100644 --- a/superset/explore/permalink/commands/get.py +++ b/superset/explore/permalink/commands/get.py @@ -27,24 +27,22 @@ from superset.explore.utils import check_access from superset.key_value.commands.get import GetKeyValueCommand from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError -from superset.key_value.types import KeyType +from superset.key_value.utils import decode_permalink_id logger = logging.getLogger(__name__) class GetExplorePermalinkCommand(BaseExplorePermalinkCommand): - def __init__( - self, actor: User, key: str, key_type: KeyType, - ): + def __init__(self, actor: User, key: str): self.actor = actor self.key = key - self.key_type = key_type def run(self) -> Optional[ExplorePermalinkValue]: self.validate() try: + key = decode_permalink_id(self.key, salt=self.salt) value: Optional[ExplorePermalinkValue] = GetKeyValueCommand( - resource=self.resource, key=self.key, key_type=self.key_type + resource=self.resource, key=key, ).run() if value: chart_id: Optional[int] = value.get("chartId") diff --git a/superset/extensions/metastore_cache.py b/superset/extensions/metastore_cache.py index 156f7771fbce7..1e5cff7ee3ccf 100644 --- a/superset/extensions/metastore_cache.py +++ b/superset/extensions/metastore_cache.py @@ -16,7 +16,6 @@ # under the License. from datetime import datetime, timedelta -from hashlib import md5 from typing import Any, Dict, List, Optional from uuid import UUID, uuid3 @@ -24,10 +23,10 @@ from flask_caching import BaseCache from superset.key_value.exceptions import KeyValueCreateFailedError -from superset.key_value.types import KeyType +from superset.key_value.types import KeyValueResource +from superset.key_value.utils import get_uuid_namespace -RESOURCE = "superset_metastore_cache" -KEY_TYPE: KeyType = "uuid" +RESOURCE = KeyValueResource.METASTORE_CACHE class SupersetMetastoreCache(BaseCache): @@ -39,15 +38,12 @@ def __init__(self, namespace: UUID, default_timeout: int = 300) -> None: def factory( cls, app: Flask, config: Dict[str, Any], args: List[Any], kwargs: Dict[str, Any] ) -> BaseCache: - # base namespace for generating deterministic UUIDs - md5_obj = md5() seed = config.get("CACHE_KEY_PREFIX", "") - md5_obj.update(seed.encode("utf-8")) - kwargs["namespace"] = UUID(md5_obj.hexdigest()) + kwargs["namespace"] = get_uuid_namespace(seed) return cls(*args, **kwargs) - def get_key(self, key: str) -> str: - return str(uuid3(self.namespace, key)) + def get_key(self, key: str) -> UUID: + return uuid3(self.namespace, key) @staticmethod def _prune() -> None: @@ -70,7 +66,6 @@ def set(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: UpsertKeyValueCommand( resource=RESOURCE, - key_type=KEY_TYPE, key=self.get_key(key), value=value, expires_on=self._get_expiry(timeout), @@ -85,7 +80,6 @@ def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: CreateKeyValueCommand( resource=RESOURCE, value=value, - key_type=KEY_TYPE, key=self.get_key(key), expires_on=self._get_expiry(timeout), ).run() @@ -98,9 +92,7 @@ def get(self, key: str) -> Any: # pylint: disable=import-outside-toplevel from superset.key_value.commands.get import GetKeyValueCommand - return GetKeyValueCommand( - resource=RESOURCE, key_type=KEY_TYPE, key=self.get_key(key), - ).run() + return GetKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run() def has(self, key: str) -> bool: entry = self.get(key) @@ -112,6 +104,4 @@ def delete(self, key: str) -> Any: # pylint: disable=import-outside-toplevel from superset.key_value.commands.delete import DeleteKeyValueCommand - return DeleteKeyValueCommand( - resource=RESOURCE, key_type=KEY_TYPE, key=self.get_key(key), - ).run() + return DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run() diff --git a/superset/key_value/commands/create.py b/superset/key_value/commands/create.py index e3c228adbac56..613fabcdeb1f7 100644 --- a/superset/key_value/commands/create.py +++ b/superset/key_value/commands/create.py @@ -17,7 +17,7 @@ import logging import pickle from datetime import datetime -from typing import Any, Optional +from typing import Any, Optional, Union from uuid import UUID from flask_appbuilder.security.sqla.models import User @@ -27,27 +27,24 @@ from superset.commands.base import BaseCommand from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.models import KeyValueEntry -from superset.key_value.types import KeyType -from superset.key_value.utils import extract_key +from superset.key_value.types import Key, KeyValueResource logger = logging.getLogger(__name__) class CreateKeyValueCommand(BaseCommand): actor: Optional[User] - resource: str + resource: KeyValueResource value: Any - key_type: KeyType - key: Optional[str] + key: Optional[Union[int, UUID]] expires_on: Optional[datetime] def __init__( self, - resource: str, + resource: KeyValueResource, value: Any, - key_type: KeyType = "uuid", actor: Optional[User] = None, - key: Optional[str] = None, + key: Optional[Union[int, UUID]] = None, expires_on: Optional[datetime] = None, ): """ @@ -55,7 +52,6 @@ def __init__( :param resource: the resource (dashboard, chart etc) :param value: the value to persist in the key-value store - :param key_type: the type of the key to return :param actor: the user performing the command :param key: id of entry (autogenerated if undefined) :param expires_on: entry expiration time @@ -64,11 +60,10 @@ def __init__( self.resource = resource self.actor = actor self.value = value - self.key_type = key_type self.key = key self.expires_on = expires_on - def run(self) -> str: + def run(self) -> Key: try: return self.create() except SQLAlchemyError as ex: @@ -79,9 +74,9 @@ def run(self) -> str: def validate(self) -> None: pass - def create(self) -> str: + def create(self) -> Key: entry = KeyValueEntry( - resource=self.resource, + resource=self.resource.value, value=pickle.dumps(self.value), created_on=datetime.now(), created_by_fk=None @@ -91,12 +86,12 @@ def create(self) -> str: ) if self.key is not None: try: - if self.key_type == "uuid": - entry.uuid = UUID(self.key) + if isinstance(self.key, UUID): + entry.uuid = self.key else: - entry.id = int(self.key) + entry.id = self.key except ValueError as ex: raise KeyValueCreateFailedError() from ex db.session.add(entry) db.session.commit() - return extract_key(entry, self.key_type) + return Key(id=entry.id, uuid=entry.uuid) diff --git a/superset/key_value/commands/delete.py b/superset/key_value/commands/delete.py index 06cf4230c1ee2..f8ad291714ae1 100644 --- a/superset/key_value/commands/delete.py +++ b/superset/key_value/commands/delete.py @@ -15,40 +15,35 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Optional +from typing import Union +from uuid import UUID -from flask_appbuilder.security.sqla.models import User from sqlalchemy.exc import SQLAlchemyError from superset import db from superset.commands.base import BaseCommand from superset.key_value.exceptions import KeyValueDeleteFailedError from superset.key_value.models import KeyValueEntry -from superset.key_value.types import KeyType +from superset.key_value.types import KeyValueResource from superset.key_value.utils import get_filter logger = logging.getLogger(__name__) class DeleteKeyValueCommand(BaseCommand): - key: str - key_type: KeyType - resource: str + key: Union[int, UUID] + resource: KeyValueResource - def __init__( - self, resource: str, key: str, key_type: KeyType = "uuid", - ): + def __init__(self, resource: KeyValueResource, key: Union[int, UUID]): """ Delete a key-value pair :param resource: the resource (dashboard, chart etc) :param key: the key to delete - :param key_type: the type of key :return: was the entry deleted or not """ self.resource = resource self.key = key - self.key_type = key_type def run(self) -> bool: try: @@ -62,7 +57,7 @@ def validate(self) -> None: pass def delete(self) -> bool: - filter_ = get_filter(self.resource, self.key, self.key_type) + filter_ = get_filter(self.resource, self.key) entry = ( db.session.query(KeyValueEntry) .filter_by(**filter_) diff --git a/superset/key_value/commands/delete_expired.py b/superset/key_value/commands/delete_expired.py index 09507397e187f..4031d13968302 100644 --- a/superset/key_value/commands/delete_expired.py +++ b/superset/key_value/commands/delete_expired.py @@ -17,20 +17,22 @@ import logging from datetime import datetime +from sqlalchemy import and_ from sqlalchemy.exc import SQLAlchemyError from superset import db from superset.commands.base import BaseCommand from superset.key_value.exceptions import KeyValueDeleteFailedError from superset.key_value.models import KeyValueEntry +from superset.key_value.types import KeyValueResource logger = logging.getLogger(__name__) class DeleteExpiredKeyValueCommand(BaseCommand): - resource: str + resource: KeyValueResource - def __init__(self, resource: str): + def __init__(self, resource: KeyValueResource): """ Delete all expired key-value pairs @@ -50,11 +52,15 @@ def run(self) -> None: def validate(self) -> None: pass - @staticmethod - def delete_expired() -> None: + def delete_expired(self) -> None: ( db.session.query(KeyValueEntry) - .filter(KeyValueEntry.expires_on <= datetime.now()) + .filter( + and_( + KeyValueEntry.resource == self.resource.value, + KeyValueEntry.expires_on <= datetime.now(), + ) + ) .delete() ) db.session.commit() diff --git a/superset/key_value/commands/get.py b/superset/key_value/commands/get.py index b0530b976ba7c..01560949e37ff 100644 --- a/superset/key_value/commands/get.py +++ b/superset/key_value/commands/get.py @@ -18,7 +18,8 @@ import logging import pickle from datetime import datetime -from typing import Any, Optional +from typing import Any, Optional, Union +from uuid import UUID from sqlalchemy.exc import SQLAlchemyError @@ -26,29 +27,26 @@ from superset.commands.base import BaseCommand from superset.key_value.exceptions import KeyValueGetFailedError from superset.key_value.models import KeyValueEntry -from superset.key_value.types import KeyType +from superset.key_value.types import KeyValueResource from superset.key_value.utils import get_filter logger = logging.getLogger(__name__) class GetKeyValueCommand(BaseCommand): - key: str - key_type: KeyType - resource: str + resource: KeyValueResource + key: Union[int, UUID] - def __init__(self, resource: str, key: str, key_type: KeyType = "uuid"): + def __init__(self, resource: KeyValueResource, key: Union[int, UUID]): """ Retrieve a key value entry :param resource: the resource (dashboard, chart etc) :param key: the key to retrieve - :param key_type: the type of the key to retrieve :return: the value associated with the key if present """ self.resource = resource self.key = key - self.key_type = key_type def run(self) -> Any: try: @@ -61,7 +59,7 @@ def validate(self) -> None: pass def get(self) -> Optional[Any]: - filter_ = get_filter(self.resource, self.key, self.key_type) + filter_ = get_filter(self.resource, self.key) entry = ( db.session.query(KeyValueEntry) .filter_by(**filter_) diff --git a/superset/key_value/commands/update.py b/superset/key_value/commands/update.py index b739cfea86041..7333b48c5cc34 100644 --- a/superset/key_value/commands/update.py +++ b/superset/key_value/commands/update.py @@ -18,7 +18,8 @@ import logging import pickle from datetime import datetime -from typing import Any, Optional +from typing import Any, Optional, Union +from uuid import UUID from flask_appbuilder.security.sqla.models import User from sqlalchemy.exc import SQLAlchemyError @@ -27,27 +28,25 @@ from superset.commands.base import BaseCommand from superset.key_value.exceptions import KeyValueUpdateFailedError from superset.key_value.models import KeyValueEntry -from superset.key_value.types import KeyType -from superset.key_value.utils import extract_key, get_filter +from superset.key_value.types import Key, KeyValueResource +from superset.key_value.utils import get_filter logger = logging.getLogger(__name__) class UpdateKeyValueCommand(BaseCommand): actor: Optional[User] - resource: str + resource: KeyValueResource value: Any - key: str - key_type: KeyType + key: Union[int, UUID] expires_on: Optional[datetime] def __init__( self, - resource: str, - key: str, + resource: KeyValueResource, + key: Union[int, UUID], value: Any, actor: Optional[User] = None, - key_type: KeyType = "uuid", expires_on: Optional[datetime] = None, ): """ @@ -57,7 +56,6 @@ def __init__( :param key: the key to update :param value: the value to persist in the key-value store :param actor: the user performing the command - :param key_type: the type of the key to update :param expires_on: entry expiration time :return: the key associated with the updated value """ @@ -65,10 +63,9 @@ def __init__( self.resource = resource self.key = key self.value = value - self.key_type = key_type self.expires_on = expires_on - def run(self) -> Optional[str]: + def run(self) -> Optional[Key]: try: return self.update() except SQLAlchemyError as ex: @@ -79,8 +76,8 @@ def run(self) -> Optional[str]: def validate(self) -> None: pass - def update(self) -> Optional[str]: - filter_ = get_filter(self.resource, self.key, self.key_type) + def update(self) -> Optional[Key]: + filter_ = get_filter(self.resource, self.key) entry: KeyValueEntry = ( db.session.query(KeyValueEntry) .filter_by(**filter_) @@ -96,6 +93,6 @@ def update(self) -> Optional[str]: ) db.session.merge(entry) db.session.commit() - return extract_key(entry, self.key_type) + return Key(id=entry.id, uuid=entry.uuid) return None diff --git a/superset/key_value/commands/upsert.py b/superset/key_value/commands/upsert.py index 4afc4c38e424a..aa495f7cc77c1 100644 --- a/superset/key_value/commands/upsert.py +++ b/superset/key_value/commands/upsert.py @@ -18,7 +18,8 @@ import logging import pickle from datetime import datetime -from typing import Any, Optional +from typing import Any, Optional, Union +from uuid import UUID from flask_appbuilder.security.sqla.models import User from sqlalchemy.exc import SQLAlchemyError @@ -28,27 +29,25 @@ from superset.key_value.commands.create import CreateKeyValueCommand from superset.key_value.exceptions import KeyValueUpdateFailedError from superset.key_value.models import KeyValueEntry -from superset.key_value.types import KeyType -from superset.key_value.utils import extract_key, get_filter +from superset.key_value.types import Key, KeyValueResource +from superset.key_value.utils import get_filter logger = logging.getLogger(__name__) class UpsertKeyValueCommand(BaseCommand): actor: Optional[User] - resource: str + resource: KeyValueResource value: Any - key: str - key_type: KeyType + key: Union[int, UUID] expires_on: Optional[datetime] def __init__( self, - resource: str, - key: str, + resource: KeyValueResource, + key: Union[int, UUID], value: Any, actor: Optional[User] = None, - key_type: KeyType = "uuid", expires_on: Optional[datetime] = None, ): """ @@ -66,10 +65,9 @@ def __init__( self.resource = resource self.key = key self.value = value - self.key_type = key_type self.expires_on = expires_on - def run(self) -> Optional[str]: + def run(self) -> Optional[Key]: try: return self.upsert() except SQLAlchemyError as ex: @@ -80,8 +78,8 @@ def run(self) -> Optional[str]: def validate(self) -> None: pass - def upsert(self) -> Optional[str]: - filter_ = get_filter(self.resource, self.key, self.key_type) + def upsert(self) -> Optional[Key]: + filter_ = get_filter(self.resource, self.key) entry: KeyValueEntry = ( db.session.query(KeyValueEntry) .filter_by(**filter_) @@ -97,12 +95,11 @@ def upsert(self) -> Optional[str]: ) db.session.merge(entry) db.session.commit() - return extract_key(entry, self.key_type) + return Key(entry.id, entry.uuid) else: return CreateKeyValueCommand( resource=self.resource, value=self.value, - key_type=self.key_type, actor=self.actor, key=self.key, expires_on=self.expires_on, diff --git a/superset/key_value/shared_entries.py b/superset/key_value/shared_entries.py new file mode 100644 index 0000000000000..5dda89a7b3163 --- /dev/null +++ b/superset/key_value/shared_entries.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Optional +from uuid import uuid3 + +from superset.key_value.types import KeyValueResource, SharedKey +from superset.key_value.utils import get_uuid_namespace, random_key +from superset.utils.memoized import memoized + +RESOURCE = KeyValueResource.APP +NAMESPACE = get_uuid_namespace("") + + +def get_shared_value(key: SharedKey) -> Optional[Any]: + # pylint: disable=import-outside-toplevel + from superset.key_value.commands.get import GetKeyValueCommand + + uuid_key = uuid3(NAMESPACE, key) + return GetKeyValueCommand(RESOURCE, key=uuid_key).run() + + +def set_shared_value(key: SharedKey, value: Any) -> None: + # pylint: disable=import-outside-toplevel + from superset.key_value.commands.create import CreateKeyValueCommand + + uuid_key = uuid3(NAMESPACE, key) + CreateKeyValueCommand(resource=RESOURCE, value=value, key=uuid_key).run() + + +@memoized +def get_permalink_salt(key: SharedKey) -> str: + salt = get_shared_value(key) + if salt is None: + salt = random_key() + set_shared_value(key, value=salt) + return salt diff --git a/superset/key_value/types.py b/superset/key_value/types.py index d36520ddbfb75..c3064fbef4d42 100644 --- a/superset/key_value/types.py +++ b/superset/key_value/types.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. from dataclasses import dataclass -from typing import Literal, Optional, TypedDict +from enum import Enum +from typing import Optional, TypedDict from uuid import UUID @@ -25,10 +26,19 @@ class Key: uuid: Optional[UUID] -KeyType = Literal["id", "uuid"] - - class KeyValueFilter(TypedDict, total=False): resource: str id: Optional[int] uuid: Optional[UUID] + + +class KeyValueResource(str, Enum): + APP = "app" + DASHBOARD_PERMALINK = "dashboard_permalink" + EXPLORE_PERMALINK = "explore_permalink" + METASTORE_CACHE = "superset_metastore_cache" + + +class SharedKey(str, Enum): + DASHBOARD_PERMALINK_SALT = "dashboard_permalink_salt" + EXPLORE_PERMALINK_SALT = "explore_permalink_salt" diff --git a/superset/key_value/utils.py b/superset/key_value/utils.py index 50aa34918e434..b2e8e729b0466 100644 --- a/superset/key_value/utils.py +++ b/superset/key_value/utils.py @@ -14,44 +14,52 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Literal +from __future__ import annotations + +from hashlib import md5 +from secrets import token_urlsafe +from typing import Union from uuid import UUID -from flask import current_app +import hashids +from flask_babel import gettext as _ from superset.key_value.exceptions import KeyValueParseKeyError -from superset.key_value.models import KeyValueEntry -from superset.key_value.types import Key, KeyType, KeyValueFilter - - -def parse_permalink_key(key: str) -> Key: - key_type: Literal["id", "uuid"] = current_app.config["PERMALINK_KEY_TYPE"] - if key_type == "id": - return Key(id=int(key), uuid=None) - return Key(id=None, uuid=UUID(key)) +from superset.key_value.types import KeyValueFilter, KeyValueResource +HASHIDS_MIN_LENGTH = 11 -def format_permalink_key(key: Key) -> str: - """ - return the string representation of the key - :param key: a key object with either a numerical or uuid key - :return: a formatted string - """ - return str(key.id if key.id is not None else key.uuid) +def random_key() -> str: + return token_urlsafe(48) -def extract_key(entry: KeyValueEntry, key_type: KeyType) -> str: - return str(entry.id if key_type == "id" else entry.uuid) - - -def get_filter(resource: str, key: str, key_type: KeyType) -> KeyValueFilter: +def get_filter(resource: KeyValueResource, key: Union[int, UUID]) -> KeyValueFilter: try: - filter_: KeyValueFilter = {"resource": resource} - if key_type == "uuid": - filter_["uuid"] = UUID(key) + filter_: KeyValueFilter = {"resource": resource.value} + if isinstance(key, UUID): + filter_["uuid"] = key else: - filter_["id"] = int(key) + filter_["id"] = key return filter_ except ValueError as ex: raise KeyValueParseKeyError() from ex + + +def encode_permalink_key(key: int, salt: str) -> str: + obj = hashids.Hashids(salt, min_length=HASHIDS_MIN_LENGTH) + return obj.encode(key) + + +def decode_permalink_id(key: str, salt: str) -> int: + obj = hashids.Hashids(salt, min_length=HASHIDS_MIN_LENGTH) + ids = obj.decode(key) + if len(ids) == 1: + return ids[0] + raise KeyValueParseKeyError(_("Invalid permalink key")) + + +def get_uuid_namespace(seed: str) -> UUID: + md5_obj = md5() + md5_obj.update(seed.encode("utf-8")) + return UUID(md5_obj.hexdigest()) diff --git a/superset/temporary_cache/utils.py b/superset/temporary_cache/utils.py index 2f2f71f957e08..9ba2a8d036077 100644 --- a/superset/temporary_cache/utils.py +++ b/superset/temporary_cache/utils.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from secrets import token_urlsafe from typing import Any SEPARATOR = ";" @@ -22,7 +21,3 @@ def cache_key(*args: Any) -> str: return SEPARATOR.join(str(arg) for arg in args) - - -def random_key() -> str: - return token_urlsafe(48) diff --git a/superset/views/core.py b/superset/views/core.py index c5232e7e2e798..a3356e77aa08d 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -748,8 +748,7 @@ def explore( form_data_key = request.args.get("form_data_key") if key is not None: - key_type = config["PERMALINK_KEY_TYPE"] - command = GetExplorePermalinkCommand(g.user, key, key_type) + command = GetExplorePermalinkCommand(g.user, key) try: permalink_value = command.run() if permalink_value: @@ -2008,9 +2007,8 @@ def dashboard( def dashboard_permalink( # pylint: disable=no-self-use self, key: str, ) -> FlaskResponse: - key_type = config["PERMALINK_KEY_TYPE"] try: - value = GetDashboardPermalinkCommand(g.user, key, key_type).run() + value = GetDashboardPermalinkCommand(g.user, key).run() except DashboardPermalinkGetFailedError as ex: flash(__("Error: %(msg)s", msg=ex.message), "danger") return redirect("/dashboard/list/") diff --git a/tests/integration_tests/dashboards/permalink/api_tests.py b/tests/integration_tests/dashboards/permalink/api_tests.py index bd821165dc337..7a8a2906bcdfb 100644 --- a/tests/integration_tests/dashboards/permalink/api_tests.py +++ b/tests/integration_tests/dashboards/permalink/api_tests.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. import json +from typing import Iterator from unittest.mock import patch +from uuid import uuid3 import pytest from flask_appbuilder.security.sqla.models import User @@ -24,8 +26,9 @@ from superset import db from superset.dashboards.commands.exceptions import DashboardAccessDeniedError from superset.key_value.models import KeyValueEntry +from superset.key_value.types import KeyValueResource +from superset.key_value.utils import decode_permalink_id from superset.models.dashboard import Dashboard -from superset.models.slice import Slice from tests.integration_tests.base_tests import login from tests.integration_tests.fixtures.client import client from tests.integration_tests.fixtures.world_bank_dashboard import ( @@ -35,7 +38,7 @@ from tests.integration_tests.test_app import app STATE = { - "filterState": {"FILTER_1": "foo",}, + "filterState": {"FILTER_1": "foo"}, "hash": "my-anchor", } @@ -48,7 +51,22 @@ def dashboard_id(load_world_bank_dashboard_with_slices) -> int: return dashboard.id -def test_post(client, dashboard_id: int): +@pytest.fixture +def permalink_salt() -> Iterator[str]: + from superset.key_value.shared_entries import get_permalink_salt, get_uuid_namespace + from superset.key_value.types import SharedKey + + key = SharedKey.DASHBOARD_PERMALINK_SALT + salt = get_permalink_salt(key) + yield salt + namespace = get_uuid_namespace(salt) + db.session.query(KeyValueEntry).filter_by( + resource=KeyValueResource.APP, uuid=uuid3(namespace, key), + ) + db.session.commit() + + +def test_post(client, dashboard_id: int, permalink_salt: str) -> None: login(client, "admin") resp = client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE) assert resp.status_code == 201 @@ -56,7 +74,8 @@ def test_post(client, dashboard_id: int): key = data["key"] url = data["url"] assert key in url - db.session.query(KeyValueEntry).filter_by(uuid=key).delete() + id_ = decode_permalink_id(key, permalink_salt) + db.session.query(KeyValueEntry).filter_by(id=id_).delete() db.session.commit() @@ -76,7 +95,7 @@ def test_post_invalid_schema(client, dashboard_id: int): assert resp.status_code == 400 -def test_get(client, dashboard_id: int): +def test_get(client, dashboard_id: int, permalink_salt: str): login(client, "admin") resp = client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE) data = json.loads(resp.data.decode("utf-8")) @@ -86,5 +105,6 @@ def test_get(client, dashboard_id: int): result = json.loads(resp.data.decode("utf-8")) assert result["dashboardId"] == str(dashboard_id) assert result["state"] == STATE - db.session.query(KeyValueEntry).filter_by(uuid=key).delete() + id_ = decode_permalink_id(key, permalink_salt) + db.session.query(KeyValueEntry).filter_by(id=id_).delete() db.session.commit() diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py index 37b0d2455d80c..a992b36daf523 100644 --- a/tests/integration_tests/explore/permalink/api_tests.py +++ b/tests/integration_tests/explore/permalink/api_tests.py @@ -16,14 +16,16 @@ # under the License. import json import pickle -from typing import Any, Dict -from uuid import UUID +from typing import Any, Dict, Iterator +from uuid import uuid3 import pytest from sqlalchemy.orm import Session from superset import db from superset.key_value.models import KeyValueEntry +from superset.key_value.types import KeyValueResource +from superset.key_value.utils import decode_permalink_id, encode_permalink_key from superset.models.slice import Slice from tests.integration_tests.base_tests import login from tests.integration_tests.fixtures.client import client @@ -51,7 +53,22 @@ def form_data(chart) -> Dict[str, Any]: } -def test_post(client, form_data): +@pytest.fixture +def permalink_salt() -> Iterator[str]: + from superset.key_value.shared_entries import get_permalink_salt, get_uuid_namespace + from superset.key_value.types import SharedKey + + key = SharedKey.EXPLORE_PERMALINK_SALT + salt = get_permalink_salt(key) + yield salt + namespace = get_uuid_namespace(salt) + db.session.query(KeyValueEntry).filter_by( + resource=KeyValueResource.APP, uuid=uuid3(namespace, key), + ) + db.session.commit() + + +def test_post(client, form_data: Dict[str, Any], permalink_salt: str): login(client, "admin") resp = client.post(f"api/v1/explore/permalink", json={"formData": form_data}) assert resp.status_code == 201 @@ -59,7 +76,8 @@ def test_post(client, form_data): key = data["key"] url = data["url"] assert key in url - db.session.query(KeyValueEntry).filter_by(uuid=key).delete() + id_ = decode_permalink_id(key, permalink_salt) + db.session.query(KeyValueEntry).filter_by(id=id_).delete() db.session.commit() @@ -69,21 +87,18 @@ def test_post_access_denied(client, form_data): assert resp.status_code == 404 -def test_get_missing_chart(client, chart): +def test_get_missing_chart(client, chart, permalink_salt: str) -> None: from superset.key_value.models import KeyValueEntry - key = 1234 - uuid_key = "e2ea9d19-7988-4862-aa69-c3a1a7628cb9" + chart_id = 1234 entry = KeyValueEntry( - id=int(key), - uuid=UUID("e2ea9d19-7988-4862-aa69-c3a1a7628cb9"), - resource="explore_permalink", + resource=KeyValueResource.EXPLORE_PERMALINK, value=pickle.dumps( { - "chartId": key, + "chartId": chart_id, "datasetId": chart.datasource.id, "formData": { - "slice_id": key, + "slice_id": chart_id, "datasource": f"{chart.datasource.id}__{chart.datasource.type}", }, } @@ -91,20 +106,21 @@ def test_get_missing_chart(client, chart): ) db.session.add(entry) db.session.commit() + key = encode_permalink_key(entry.id, permalink_salt) login(client, "admin") - resp = client.get(f"api/v1/explore/permalink/{uuid_key}") + resp = client.get(f"api/v1/explore/permalink/{key}") assert resp.status_code == 404 db.session.delete(entry) db.session.commit() -def test_post_invalid_schema(client): +def test_post_invalid_schema(client) -> None: login(client, "admin") resp = client.post(f"api/v1/explore/permalink", json={"abc": 123}) assert resp.status_code == 400 -def test_get(client, form_data): +def test_get(client, form_data: Dict[str, Any], permalink_salt: str) -> None: login(client, "admin") resp = client.post(f"api/v1/explore/permalink", json={"formData": form_data}) data = json.loads(resp.data.decode("utf-8")) @@ -113,5 +129,6 @@ def test_get(client, form_data): assert resp.status_code == 200 result = json.loads(resp.data.decode("utf-8")) assert result["state"]["formData"] == form_data - db.session.query(KeyValueEntry).filter_by(uuid=key).delete() + id_ = decode_permalink_id(key, permalink_salt) + db.session.query(KeyValueEntry).filter_by(id=id_).delete() db.session.commit() diff --git a/tests/integration_tests/key_value/commands/create_test.py b/tests/integration_tests/key_value/commands/create_test.py index 22a1b517485c8..2718aa822c3e4 100644 --- a/tests/integration_tests/key_value/commands/create_test.py +++ b/tests/integration_tests/key_value/commands/create_test.py @@ -36,12 +36,8 @@ def test_create_id_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.commands.create import CreateKeyValueCommand from superset.key_value.models import KeyValueEntry - key = CreateKeyValueCommand( - actor=admin, resource=RESOURCE, value=VALUE, key_type="id", - ).run() - entry = ( - db.session.query(KeyValueEntry).filter_by(id=int(key)).autoflush(False).one() - ) + key = CreateKeyValueCommand(actor=admin, resource=RESOURCE, value=VALUE).run() + entry = db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one() assert pickle.loads(entry.value) == VALUE assert entry.created_by_fk == admin.id db.session.delete(entry) @@ -52,11 +48,9 @@ def test_create_uuid_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.commands.create import CreateKeyValueCommand from superset.key_value.models import KeyValueEntry - key = CreateKeyValueCommand( - actor=admin, resource=RESOURCE, value=VALUE, key_type="uuid", - ).run() + key = CreateKeyValueCommand(actor=admin, resource=RESOURCE, value=VALUE).run() entry = ( - db.session.query(KeyValueEntry).filter_by(uuid=UUID(key)).autoflush(False).one() + db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).autoflush(False).one() ) assert pickle.loads(entry.value) == VALUE assert entry.created_by_fk == admin.id diff --git a/tests/integration_tests/key_value/commands/delete_test.py b/tests/integration_tests/key_value/commands/delete_test.py index a98d941175153..67623461246f6 100644 --- a/tests/integration_tests/key_value/commands/delete_test.py +++ b/tests/integration_tests/key_value/commands/delete_test.py @@ -30,8 +30,8 @@ if TYPE_CHECKING: from superset.key_value.models import KeyValueEntry -ID_KEY = "234" -UUID_KEY = "5aae143c-44f1-478e-9153-ae6154df333a" +ID_KEY = 234 +UUID_KEY = UUID("5aae143c-44f1-478e-9153-ae6154df333a") @pytest.fixture @@ -39,10 +39,7 @@ def key_value_entry() -> KeyValueEntry: from superset.key_value.models import KeyValueEntry entry = KeyValueEntry( - id=int(ID_KEY), - uuid=UUID(UUID_KEY), - resource=RESOURCE, - value=pickle.dumps(VALUE), + id=ID_KEY, uuid=UUID_KEY, resource=RESOURCE, value=pickle.dumps(VALUE), ) db.session.add(entry) db.session.commit() @@ -55,10 +52,7 @@ def test_delete_id_entry( from superset.key_value.commands.delete import DeleteKeyValueCommand from superset.key_value.models import KeyValueEntry - assert ( - DeleteKeyValueCommand(resource=RESOURCE, key=ID_KEY, key_type="id",).run() - is True - ) + assert DeleteKeyValueCommand(resource=RESOURCE, key=ID_KEY).run() is True def test_delete_uuid_entry( @@ -67,10 +61,7 @@ def test_delete_uuid_entry( from superset.key_value.commands.delete import DeleteKeyValueCommand from superset.key_value.models import KeyValueEntry - assert ( - DeleteKeyValueCommand(resource=RESOURCE, key=UUID_KEY, key_type="uuid").run() - is True - ) + assert DeleteKeyValueCommand(resource=RESOURCE, key=UUID_KEY).run() is True def test_delete_entry_missing( @@ -79,7 +70,4 @@ def test_delete_entry_missing( from superset.key_value.commands.delete import DeleteKeyValueCommand from superset.key_value.models import KeyValueEntry - assert ( - DeleteKeyValueCommand(resource=RESOURCE, key="456", key_type="id").run() - is False - ) + assert DeleteKeyValueCommand(resource=RESOURCE, key=456).run() is False diff --git a/tests/integration_tests/key_value/commands/fixtures.py b/tests/integration_tests/key_value/commands/fixtures.py index 44e12f7854cb2..de77a6c46badb 100644 --- a/tests/integration_tests/key_value/commands/fixtures.py +++ b/tests/integration_tests/key_value/commands/fixtures.py @@ -26,14 +26,15 @@ from sqlalchemy.orm import Session from superset.extensions import db +from superset.key_value.types import KeyValueResource from tests.integration_tests.test_app import app if TYPE_CHECKING: from superset.key_value.models import KeyValueEntry -ID_KEY = "123" -UUID_KEY = "3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc" -RESOURCE = "my_resource" +ID_KEY = 123 +UUID_KEY = UUID("3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc") +RESOURCE = KeyValueResource.APP VALUE = {"foo": "bar"} @@ -42,10 +43,7 @@ def key_value_entry() -> Generator[KeyValueEntry, None, None]: from superset.key_value.models import KeyValueEntry entry = KeyValueEntry( - id=int(ID_KEY), - uuid=UUID(UUID_KEY), - resource=RESOURCE, - value=pickle.dumps(VALUE), + id=ID_KEY, uuid=UUID_KEY, resource=RESOURCE, value=pickle.dumps(VALUE), ) db.session.add(entry) db.session.commit() diff --git a/tests/integration_tests/key_value/commands/get_test.py b/tests/integration_tests/key_value/commands/get_test.py index 20efa9dfbd4c5..c2c85e987534f 100644 --- a/tests/integration_tests/key_value/commands/get_test.py +++ b/tests/integration_tests/key_value/commands/get_test.py @@ -39,7 +39,7 @@ def test_get_id_entry(app_context: AppContext, key_value_entry: KeyValueEntry) -> None: from superset.key_value.commands.get import GetKeyValueCommand - value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, key_type="id").run() + value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY).run() assert value == VALUE @@ -48,7 +48,7 @@ def test_get_uuid_entry( ) -> None: from superset.key_value.commands.get import GetKeyValueCommand - value = GetKeyValueCommand(resource=RESOURCE, key=UUID_KEY, key_type="uuid").run() + value = GetKeyValueCommand(resource=RESOURCE, key=UUID_KEY).run() assert value == VALUE @@ -57,7 +57,7 @@ def test_get_id_entry_missing( ) -> None: from superset.key_value.commands.get import GetKeyValueCommand - value = GetKeyValueCommand(resource=RESOURCE, key="456", key_type="id").run() + value = GetKeyValueCommand(resource=RESOURCE, key=456).run() assert value is None @@ -74,7 +74,7 @@ def test_get_expired_entry(app_context: AppContext) -> None: ) db.session.add(entry) db.session.commit() - value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, key_type="id").run() + value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY).run() assert value is None db.session.delete(entry) db.session.commit() @@ -94,7 +94,7 @@ def test_get_future_expiring_entry(app_context: AppContext) -> None: ) db.session.add(entry) db.session.commit() - value = GetKeyValueCommand(resource=RESOURCE, key=str(id_), key_type="id").run() + value = GetKeyValueCommand(resource=RESOURCE, key=id_).run() assert value == VALUE db.session.delete(entry) db.session.commit() diff --git a/tests/integration_tests/key_value/commands/update_test.py b/tests/integration_tests/key_value/commands/update_test.py index 1fbc84d59e332..36de8972a0c72 100644 --- a/tests/integration_tests/key_value/commands/update_test.py +++ b/tests/integration_tests/key_value/commands/update_test.py @@ -46,12 +46,10 @@ def test_update_id_entry( from superset.key_value.models import KeyValueEntry key = UpdateKeyValueCommand( - actor=admin, resource=RESOURCE, key=ID_KEY, value=NEW_VALUE, key_type="id", + actor=admin, resource=RESOURCE, key=ID_KEY, value=NEW_VALUE, ).run() - assert key == ID_KEY - entry = ( - db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).autoflush(False).one() - ) + assert key.id == ID_KEY + entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).autoflush(False).one() assert pickle.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id @@ -63,25 +61,20 @@ def test_update_uuid_entry( from superset.key_value.models import KeyValueEntry key = UpdateKeyValueCommand( - actor=admin, resource=RESOURCE, key=UUID_KEY, value=NEW_VALUE, key_type="uuid", + actor=admin, resource=RESOURCE, key=UUID_KEY, value=NEW_VALUE, ).run() - assert key == UUID_KEY + assert key.uuid == UUID_KEY entry = ( - db.session.query(KeyValueEntry) - .filter_by(uuid=UUID(UUID_KEY)) - .autoflush(False) - .one() + db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one() ) assert pickle.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id -def test_update_missing_entry( - app_context: AppContext, admin: User, key_value_entry: KeyValueEntry, -) -> None: +def test_update_missing_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.commands.update import UpdateKeyValueCommand key = UpdateKeyValueCommand( - actor=admin, resource=RESOURCE, key="456", value=NEW_VALUE, key_type="id", + actor=admin, resource=RESOURCE, key=456, value=NEW_VALUE, ).run() assert key is None diff --git a/tests/integration_tests/key_value/commands/upsert_test.py b/tests/integration_tests/key_value/commands/upsert_test.py index 3221147839d1c..8038614ce5aa9 100644 --- a/tests/integration_tests/key_value/commands/upsert_test.py +++ b/tests/integration_tests/key_value/commands/upsert_test.py @@ -46,9 +46,9 @@ def test_upsert_id_entry( from superset.key_value.models import KeyValueEntry key = UpsertKeyValueCommand( - actor=admin, resource=RESOURCE, key=ID_KEY, value=NEW_VALUE, key_type="id", + actor=admin, resource=RESOURCE, key=ID_KEY, value=NEW_VALUE, ).run() - assert key == ID_KEY + assert key.id == ID_KEY entry = ( db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).autoflush(False).one() ) @@ -63,28 +63,23 @@ def test_upsert_uuid_entry( from superset.key_value.models import KeyValueEntry key = UpsertKeyValueCommand( - actor=admin, resource=RESOURCE, key=UUID_KEY, value=NEW_VALUE, key_type="uuid", + actor=admin, resource=RESOURCE, key=UUID_KEY, value=NEW_VALUE, ).run() - assert key == UUID_KEY + assert key.uuid == UUID_KEY entry = ( - db.session.query(KeyValueEntry) - .filter_by(uuid=UUID(UUID_KEY)) - .autoflush(False) - .one() + db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one() ) assert pickle.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id -def test_upsert_missing_entry( - app_context: AppContext, admin: User, key_value_entry: KeyValueEntry, -) -> None: +def test_upsert_missing_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.commands.upsert import UpsertKeyValueCommand from superset.key_value.models import KeyValueEntry key = UpsertKeyValueCommand( - actor=admin, resource=RESOURCE, key="456", value=NEW_VALUE, key_type="id", + actor=admin, resource=RESOURCE, key=456, value=NEW_VALUE, ).run() - assert key == "456" + assert key.id == 456 db.session.query(KeyValueEntry).filter_by(id=456).delete() db.session.commit() diff --git a/tests/unit_tests/key_value/utils_test.py b/tests/unit_tests/key_value/utils_test.py index f5ad0958bc749..5d78f6361c02c 100644 --- a/tests/unit_tests/key_value/utils_test.py +++ b/tests/unit_tests/key_value/utils_test.py @@ -16,102 +16,45 @@ # under the License. from __future__ import annotations -import json -from typing import TYPE_CHECKING -from unittest.mock import patch from uuid import UUID -if TYPE_CHECKING: - from superset.key_value.models import KeyValueEntry - import pytest -from flask.ctx import AppContext - -from superset.key_value.types import Key - -RESOURCE = "my-resource" -UUID_KEY = "3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc" -ID_KEY = "123" - - -@pytest.fixture -def key_value_entry(app_context: AppContext): - from superset.key_value.models import KeyValueEntry - - return KeyValueEntry( - id=int(ID_KEY), uuid=UUID(UUID_KEY), value=json.dumps({"foo": "bar"}), - ) - - -def test_parse_permalink_key_uuid_valid(app_context: AppContext) -> None: - from superset.key_value.utils import parse_permalink_key - - assert parse_permalink_key(UUID_KEY) == Key(id=None, uuid=UUID(UUID_KEY)) - - -def test_parse_permalink_key_id_invalid(app_context: AppContext) -> None: - from superset.key_value.utils import parse_permalink_key - - with pytest.raises(ValueError): - parse_permalink_key(ID_KEY) - - -@patch("superset.key_value.utils.current_app.config", {"PERMALINK_KEY_TYPE": "id"}) -def test_parse_permalink_key_id_valid(app_context: AppContext) -> None: - from superset.key_value.utils import parse_permalink_key - - assert parse_permalink_key(ID_KEY) == Key(id=int(ID_KEY), uuid=None) - - -@patch("superset.key_value.utils.current_app.config", {"PERMALINK_KEY_TYPE": "id"}) -def test_parse_permalink_key_uuid_invalid(app_context: AppContext) -> None: - from superset.key_value.utils import parse_permalink_key - with pytest.raises(ValueError): - parse_permalink_key(UUID_KEY) +from superset.key_value.exceptions import KeyValueParseKeyError +from superset.key_value.types import KeyValueResource +RESOURCE = KeyValueResource.APP +UUID_KEY = UUID("3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc") +ID_KEY = 123 -def test_format_permalink_key_uuid(app_context: AppContext) -> None: - from superset.key_value.utils import format_permalink_key - assert format_permalink_key(Key(id=None, uuid=UUID(UUID_KEY))) == UUID_KEY - - -def test_format_permalink_key_id(app_context: AppContext) -> None: - from superset.key_value.utils import format_permalink_key - - assert format_permalink_key(Key(id=int(ID_KEY), uuid=None)) == ID_KEY - - -def test_extract_key_uuid( - app_context: AppContext, key_value_entry: KeyValueEntry, -) -> None: - from superset.key_value.utils import extract_key - - assert extract_key(key_value_entry, "id") == ID_KEY - - -def test_extract_key_id( - app_context: AppContext, key_value_entry: KeyValueEntry, -) -> None: - from superset.key_value.utils import extract_key - - assert extract_key(key_value_entry, "uuid") == UUID_KEY - - -def test_get_filter_uuid(app_context: AppContext,) -> None: +def test_get_filter_uuid() -> None: from superset.key_value.utils import get_filter - assert get_filter(resource=RESOURCE, key=UUID_KEY, key_type="uuid",) == { + assert get_filter(resource=RESOURCE, key=UUID_KEY) == { "resource": RESOURCE, - "uuid": UUID(UUID_KEY), + "uuid": UUID_KEY, } -def test_get_filter_id(app_context: AppContext,) -> None: +def test_get_filter_id() -> None: from superset.key_value.utils import get_filter - assert get_filter(resource=RESOURCE, key=ID_KEY, key_type="id",) == { + assert get_filter(resource=RESOURCE, key=ID_KEY) == { "resource": RESOURCE, - "id": int(ID_KEY), + "id": ID_KEY, } + + +def test_encode_permalink_id_valid() -> None: + from superset.key_value.utils import encode_permalink_key + + salt = "abc" + assert encode_permalink_key(1, salt) == "AyBn4lm9qG8" + + +def test_decode_permalink_id_invalid() -> None: + from superset.key_value.utils import decode_permalink_id + + with pytest.raises(KeyValueParseKeyError): + decode_permalink_id("foo", "bar")