diff --git a/.pylintrc b/.pylintrc index 64ecf9b9cb71b..8814957194bac 100644 --- a/.pylintrc +++ b/.pylintrc @@ -120,7 +120,7 @@ evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / stateme [BASIC] # Good variable names which should always be accepted, separated by a comma -good-names=_,df,ex,f,i,id,j,k,l,o,pk,Run,ts,v,x +good-names=_,df,ex,f,i,id,j,k,l,o,pk,Run,ts,v,x,y # Bad variable names which should always be refused, separated by a comma bad-names=fd,foo,bar,baz,toto,tutu,tata diff --git a/RELEASING/changelog.py b/RELEASING/changelog.py index e9ff2de041a23..0cf600280b799 100644 --- a/RELEASING/changelog.py +++ b/RELEASING/changelog.py @@ -13,9 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# pylint: disable=no-value-for-parameter - import csv as lib_csv import os import re diff --git a/superset/commands/importers/v1/examples.py b/superset/commands/importers/v1/examples.py index b7431358dd9f3..9d2cbe4306a9e 100644 --- a/superset/commands/importers/v1/examples.py +++ b/superset/commands/importers/v1/examples.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=protected-access - from typing import Any, Dict, List, Set, Tuple from marshmallow import Schema @@ -78,6 +76,7 @@ def run(self) -> None: @classmethod def _get_uuids(cls) -> Set[str]: + # pylint: disable=protected-access return ( ImportDatabasesCommand._get_uuids() | ImportDatasetsCommand._get_uuids() @@ -85,9 +84,8 @@ def _get_uuids(cls) -> Set[str]: | ImportDashboardsCommand._get_uuids() ) - # pylint: disable=too-many-locals, arguments-differ @staticmethod - def _import( + def _import( # pylint: disable=too-many-locals, arguments-differ session: Session, configs: Dict[str, Any], overwrite: bool = False, diff --git a/superset/config.py b/superset/config.py index 36c6e04a6b018..590ac7a3c5d40 100644 --- a/superset/config.py +++ b/superset/config.py @@ -38,6 +38,7 @@ from flask import Blueprint from flask_appbuilder.security.manager import AUTH_DB from pandas.io.parsers import STR_NA_VALUES +from werkzeug.local import LocalProxy from superset.jinja_context import BaseTemplateProcessor from superset.stats_logger import DummyStatsLogger @@ -178,9 +179,9 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: # Note: the default impl leverages SqlAlchemyUtils' EncryptedType, which defaults # to AES-128 under the covers using the app's SECRET_KEY as key material. # -# pylint: disable=C0103 -SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER = SQLAlchemyUtilsAdapter - +SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER = ( # pylint: disable=invalid-name + SQLAlchemyUtilsAdapter +) # The limit of queries fetched for query search QUERY_SEARCH_LIMIT = 1000 @@ -839,7 +840,7 @@ class CeleryConfig: # pylint: disable=too-few-public-methods CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/" # Function that creates upload directory dynamically based on the # database used, user and schema provided. -def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( +def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name database: "Database", user: "models.User", # pylint: disable=unused-argument schema: Optional[str], @@ -984,7 +985,14 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # def SQL_QUERY_MUTATOR(sql, user_name, security_manager, database): # dttm = datetime.now().isoformat() # return f"-- [SQL LAB] {username} {dttm}\n{sql}" -SQL_QUERY_MUTATOR = None +def SQL_QUERY_MUTATOR( # pylint: disable=invalid-name,unused-argument + sql: str, + user_name: Optional[str], + security_manager: LocalProxy, + database: "Database", +) -> str: + return sql + # Enable / disable scheduled email reports # diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py index eca14d481cf4e..03a3a42ec08cc 100644 --- a/superset/connectors/druid/views.py +++ b/superset/connectors/druid/views.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=too-many-ancestors import json import logging from datetime import datetime @@ -62,7 +61,9 @@ def ensure_enabled(self) -> None: raise NotFound() -class DruidColumnInlineView(CompactCRUDMixin, EnsureEnabledMixin, SupersetModelView): +class DruidColumnInlineView( # pylint: disable=too-many-ancestors + CompactCRUDMixin, EnsureEnabledMixin, SupersetModelView, +): datamodel = SQLAInterface(models.DruidColumn) include_route_methods = RouteMethod.RELATED_VIEW_SET @@ -149,7 +150,9 @@ def post_add(self, item: "DruidColumnInlineView") -> None: self.post_update(item) -class DruidMetricInlineView(CompactCRUDMixin, EnsureEnabledMixin, SupersetModelView): +class DruidMetricInlineView( # pylint: disable=too-many-ancestors + CompactCRUDMixin, EnsureEnabledMixin, SupersetModelView, +): datamodel = SQLAInterface(models.DruidMetric) include_route_methods = RouteMethod.RELATED_VIEW_SET @@ -202,7 +205,7 @@ class DruidMetricInlineView(CompactCRUDMixin, EnsureEnabledMixin, SupersetModelV edit_form_extra_fields = add_form_extra_fields -class DruidClusterModelView( +class DruidClusterModelView( # pylint: disable=too-many-ancestors EnsureEnabledMixin, SupersetModelView, DeleteMixin, YamlExportMixin, ): datamodel = SQLAInterface(models.DruidCluster) @@ -266,7 +269,7 @@ def _delete(self, pk: int) -> None: DeleteMixin._delete(self, pk) -class DruidDatasourceModelView( +class DruidDatasourceModelView( # pylint: disable=too-many-ancestors EnsureEnabledMixin, DatasourceModelView, DeleteMixin, YamlExportMixin, ): datamodel = SQLAInterface(models.DruidDatasource) diff --git a/superset/dashboards/filters.py b/superset/dashboards/filters.py index 12cdc7fe40b29..db20ec7004ee8 100644 --- a/superset/dashboards/filters.py +++ b/superset/dashboards/filters.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# pylint: disable=too-few-public-methods from typing import Any, Optional from flask_appbuilder.security.sqla.models import Role @@ -31,7 +29,7 @@ from superset.views.base_api import BaseFavoriteFilter -class DashboardTitleOrSlugFilter(BaseFilter): +class DashboardTitleOrSlugFilter(BaseFilter): # pylint: disable=too-few-public-methods name = _("Title or Slug") arg_name = "title_or_slug" @@ -47,7 +45,9 @@ def apply(self, query: Query, value: Any) -> Query: ) -class DashboardFavoriteFilter(BaseFavoriteFilter): +class DashboardFavoriteFilter( # pylint: disable=too-few-public-methods + BaseFavoriteFilter +): """ Custom filter for the GET list that filters all dashboards that a user has favored """ @@ -57,7 +57,7 @@ class DashboardFavoriteFilter(BaseFavoriteFilter): model = Dashboard -class DashboardAccessFilter(BaseFilter): +class DashboardAccessFilter(BaseFilter): # pylint: disable=too-few-public-methods """ List dashboards with the following criteria: 1. Those which the user owns @@ -140,7 +140,7 @@ def apply(self, query: Query, value: Any) -> Query: return query -class FilterRelatedRoles(BaseFilter): +class FilterRelatedRoles(BaseFilter): # pylint: disable=too-few-public-methods """ A filter to allow searching for related roles of a resource. diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index fce4d68bc3064..f10ded906cf91 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=too-many-lines,unused-argument +# pylint: disable=too-many-lines import json import logging import re @@ -328,7 +328,9 @@ def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception: return new_exception(str(exception)) @classmethod - def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: + def get_allow_cost_estimate( # pylint: disable=unused-argument + cls, extra: Dict[str, Any], + ) -> bool: return False @classmethod @@ -581,8 +583,8 @@ def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any] return indexes @classmethod - def extra_table_metadata( - cls, database: "Database", table_name: str, schema_name: str + def extra_table_metadata( # pylint: disable=unused-argument + cls, database: "Database", table_name: str, schema_name: str, ) -> Dict[str, Any]: """ Returns engine-specific table metadata @@ -683,7 +685,9 @@ def df_to_sql( df.to_sql(con=engine, **to_sql_kwargs) @classmethod - def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]: + def convert_dttm( # pylint: disable=unused-argument + cls, target_type: str, dttm: datetime, + ) -> Optional[str]: """ Convert Python datetime object to a SQL expression @@ -815,8 +819,8 @@ def get_schema_names(cls, inspector: Inspector) -> List[str]: return sorted(inspector.get_schema_names()) @classmethod - def get_table_names( - cls, database: "Database", inspector: Inspector, schema: Optional[str] + def get_table_names( # pylint: disable=unused-argument + cls, database: "Database", inspector: Inspector, schema: Optional[str], ) -> List[str]: """ Get all tables from schema @@ -831,8 +835,8 @@ def get_table_names( return sorted(tables) @classmethod - def get_view_names( - cls, database: "Database", inspector: Inspector, schema: Optional[str] + def get_view_names( # pylint: disable=unused-argument + cls, database: "Database", inspector: Inspector, schema: Optional[str], ) -> List[str]: """ Get all views from schema @@ -885,7 +889,7 @@ def get_columns( return inspector.get_columns(table_name, schema) @classmethod - def where_latest_partition( # pylint: disable=too-many-arguments + def where_latest_partition( # pylint: disable=too-many-arguments,unused-argument cls, table_name: str, schema: Optional[str], @@ -1072,7 +1076,9 @@ def update_impersonation_config( """ @classmethod - def execute(cls, cursor: Any, query: str, **kwargs: Any) -> None: + def execute( # pylint: disable=unused-argument + cls, cursor: Any, query: str, **kwargs: Any, + ) -> None: """ Execute a SQL query @@ -1201,7 +1207,9 @@ def column_datatype_to_string( return sqla_column_type.compile(dialect=dialect).upper() @classmethod - def get_function_names(cls, database: "Database") -> List[str]: + def get_function_names( # pylint: disable=unused-argument + cls, database: "Database", + ) -> List[str]: """ Get a list of function names that are able to be called on the database. Used for SQL Lab autocomplete. @@ -1224,7 +1232,9 @@ def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]: return data @staticmethod - def mutate_db_for_connection_test(database: "Database") -> None: + def mutate_db_for_connection_test( # pylint: disable=unused-argument + database: "Database", + ) -> None: """ Some databases require passing additional parameters for validating database connections. This method makes it possible to mutate the database instance prior @@ -1271,7 +1281,7 @@ def is_select_query(cls, parsed_query: ParsedQuery) -> bool: @classmethod @memoized - def get_column_spec( + def get_column_spec( # pylint: disable=unused-argument cls, native_type: Optional[str], source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, @@ -1320,7 +1330,9 @@ def has_implicit_cancel(cls) -> bool: return False @classmethod - def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: + def get_cancel_query_id( # pylint: disable=unused-argument + cls, cursor: Any, query: Query, + ) -> Optional[str]: """ Select identifiers from the database engine that uniquely identifies the queries to cancel. The identifier is typically a session id, process id @@ -1334,7 +1346,9 @@ def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: return None @classmethod - def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: + def cancel_query( # pylint: disable=unused-argument + cls, cursor: Any, query: Query, cancel_query_id: str, + ) -> bool: """ Cancel query in the underlying database. @@ -1407,7 +1421,7 @@ class BasicParametersMixin: encryption_parameters: Dict[str, str] = {} @classmethod - def build_sqlalchemy_uri( + def build_sqlalchemy_uri( # pylint: disable=unused-argument cls, parameters: BasicParametersType, encryted_extra: Optional[Dict[str, str]] = None, @@ -1432,7 +1446,7 @@ def build_sqlalchemy_uri( ) @classmethod - def get_parameters_from_uri( + def get_parameters_from_uri( # pylint: disable=unused-argument cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None ) -> BasicParametersType: url = make_url(uri) diff --git a/superset/db_engines/hive.py b/superset/db_engines/hive.py index 37fe50f5e21e6..0c4094fe3ef2c 100644 --- a/superset/db_engines/hive.py +++ b/superset/db_engines/hive.py @@ -20,9 +20,8 @@ from pyhive.hive import Cursor from TCLIService.ttypes import TFetchOrientation -# pylint: disable=protected-access # TODO: contribute back to pyhive. -def fetch_logs( +def fetch_logs( # pylint: disable=protected-access self: "Cursor", _max_rows: int = 1024, orientation: Optional["TFetchOrientation"] = None, diff --git a/superset/examples/deck.py b/superset/examples/deck.py index a3d137bb06d71..a0e3246176141 100644 --- a/superset/examples/deck.py +++ b/superset/examples/deck.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=too-many-statements import json from superset import db @@ -172,7 +171,7 @@ }""" -def load_deck_dash() -> None: +def load_deck_dash() -> None: # pylint: disable=too-many-statements print("Loading deck.gl dashboard") slices = [] table = get_table_connector_registry() diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index d94ac76aa777a..d19c1d131fff9 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -58,8 +58,7 @@ logger = logging.getLogger(__name__) -# pylint: disable=R0904 -class SupersetAppInitializer: +class SupersetAppInitializer: # pylint: disable=too-many-public-methods def __init__(self, app: SupersetApp) -> None: super().__init__() diff --git a/superset/models/core.py b/superset/models/core.py index 129923b05ccf6..77e6c356e1338 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=line-too-long,unused-argument """A collection of ORM sqlalchemy models for Superset""" import enum import json @@ -244,7 +243,7 @@ def parameters(self) -> Dict[str, Any]: uri = make_url(self.sqlalchemy_uri_decrypted) encrypted_extra = self.get_encrypted_extra() try: - parameters = self.db_engine_spec.get_parameters_from_uri(uri, encrypted_extra=encrypted_extra) # type: ignore + parameters = self.db_engine_spec.get_parameters_from_uri(uri, encrypted_extra=encrypted_extra) # type: ignore # pylint: disable=line-too-long,useless-suppression except Exception: # pylint: disable=broad-except parameters = {} @@ -479,7 +478,7 @@ def inspector(self) -> Inspector: key=lambda self, *args, **kwargs: f"db:{self.id}:schema:None:table_list", cache=cache_manager.data_cache, ) - def get_all_table_names_in_database( + def get_all_table_names_in_database( # pylint: disable=unused-argument self, cache: bool = False, cache_timeout: Optional[bool] = None, @@ -494,7 +493,7 @@ def get_all_table_names_in_database( key=lambda self, *args, **kwargs: f"db:{self.id}:schema:None:view_list", cache=cache_manager.data_cache, ) - def get_all_view_names_in_database( + def get_all_view_names_in_database( # pylint: disable=unused-argument self, cache: bool = False, cache_timeout: Optional[bool] = None, @@ -506,10 +505,10 @@ def get_all_view_names_in_database( return self.db_engine_spec.get_all_datasource_names(self, "view") @cache_util.memoized_func( - key=lambda self, schema, *args, **kwargs: f"db:{self.id}:schema:{schema}:table_list", + key=lambda self, schema, *args, **kwargs: f"db:{self.id}:schema:{schema}:table_list", # pylint: disable=line-too-long,useless-suppression cache=cache_manager.data_cache, ) - def get_all_table_names_in_schema( + def get_all_table_names_in_schema( # pylint: disable=unused-argument self, schema: str, cache: bool = False, @@ -539,10 +538,10 @@ def get_all_table_names_in_schema( return [] @cache_util.memoized_func( - key=lambda self, schema, *args, **kwargs: f"db:{self.id}:schema:{schema}:view_list", + key=lambda self, schema, *args, **kwargs: f"db:{self.id}:schema:{schema}:view_list", # pylint: disable=line-too-long,useless-suppression cache=cache_manager.data_cache, ) - def get_all_view_names_in_schema( + def get_all_view_names_in_schema( # pylint: disable=unused-argument self, schema: str, cache: bool = False, @@ -573,7 +572,7 @@ def get_all_view_names_in_schema( key=lambda self, *args, **kwargs: f"db:{self.id}:schema_list", cache=cache_manager.data_cache, ) - def get_all_schema_names( + def get_all_schema_names( # pylint: disable=unused-argument self, cache: bool = False, cache_timeout: Optional[int] = None, diff --git a/superset/models/slice.py b/superset/models/slice.py index 9093cfa43acc7..6bf05ffc87fdd 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -53,7 +53,7 @@ logger = logging.getLogger(__name__) -class Slice( # pylint: disable=too-many-public-methods, too-many-instance-attributes +class Slice( # pylint: disable=too-many-public-methods Model, AuditMixinNullable, ImportExportMixin ): """A slice is essentially a report or a view on data""" @@ -74,7 +74,7 @@ class Slice( # pylint: disable=too-many-public-methods, too-many-instance-attri # the last time a user has saved the chart, changed_on is referencing # when the database row was last written last_saved_at = Column(DateTime, nullable=True) - last_saved_by_fk = Column(Integer, ForeignKey("ab_user.id"), nullable=True,) + last_saved_by_fk = Column(Integer, ForeignKey("ab_user.id"), nullable=True) last_saved_by = relationship( security_manager.user_model, foreign_keys=[last_saved_by_fk] ) diff --git a/superset/security/manager.py b/superset/security/manager.py index 4340a4f659bb0..ff25a2c1dcdcf 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=too-few-public-methods,too-many-lines +# pylint: disable=too-many-lines """A set of constants and methods to manage permissions and security""" import logging import re @@ -78,7 +78,7 @@ logger = logging.getLogger(__name__) -class SupersetSecurityListWidget(ListWidget): +class SupersetSecurityListWidget(ListWidget): # pylint: disable=too-few-public-methods """ Redeclaring to avoid circular imports """ @@ -86,7 +86,7 @@ class SupersetSecurityListWidget(ListWidget): template = "superset/fab_overrides/list.html" -class SupersetRoleListWidget(ListWidget): +class SupersetRoleListWidget(ListWidget): # pylint: disable=too-few-public-methods """ Role model view from FAB already uses a custom list widget override So we override the override diff --git a/superset/sql_lab.py b/superset/sql_lab.py index deb7a09725937..ec0b5d04a5616 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -30,7 +30,6 @@ from celery.exceptions import SoftTimeLimitExceeded from flask_babel import gettext as __ from sqlalchemy.orm import Session -from werkzeug.local import LocalProxy from superset import app, results_backend, results_backend_use_msgpack, security_manager from superset.dataframe import df_to_records @@ -38,7 +37,6 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetErrorException, SupersetErrorsException from superset.extensions import celery_app -from superset.models.core import Database from superset.models.sql_lab import LimitingFactor, Query from superset.result_set import SupersetResultSet from superset.sql_parse import CtasMethod, ParsedQuery @@ -52,25 +50,13 @@ from superset.utils.dates import now_as_float from superset.utils.decorators import stats_timing - -# pylint: disable=unused-argument, redefined-outer-name -def dummy_sql_query_mutator( - sql: str, - user_name: Optional[str], - security_manager: LocalProxy, - database: Database, -) -> str: - """A no-op version of SQL_QUERY_MUTATOR""" - return sql - - config = app.config stats_logger = config["STATS_LOGGER"] SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"] SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60 SQL_MAX_ROW = config["SQL_MAX_ROW"] SQLLAB_CTAS_NO_LIMIT = config["SQLLAB_CTAS_NO_LIMIT"] -SQL_QUERY_MUTATOR = config.get("SQL_QUERY_MUTATOR") or dummy_sql_query_mutator +SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"] log_query = config["QUERY_LOGGER"] logger = logging.getLogger(__name__) cancel_query_key = "cancel_query" @@ -192,8 +178,7 @@ def get_sql_results( # pylint: disable=too-many-arguments return handle_query_error(ex, query, session) -# pylint: disable=too-many-arguments, too-many-locals, too-many-statements -def execute_sql_statement( +def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements sql_statement: str, query: Query, user_name: Optional[str], diff --git a/superset/sql_validators/base.py b/superset/sql_validators/base.py index c477568b6394e..de29a96e8e3f8 100644 --- a/superset/sql_validators/base.py +++ b/superset/sql_validators/base.py @@ -14,15 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# pylint: disable=too-few-public-methods - from typing import Any, Dict, List, Optional from superset.models.core import Database -class SQLValidationAnnotation: +class SQLValidationAnnotation: # pylint: disable=too-few-public-methods """Represents a single annotation (error/warning) in an SQL querytext""" def __init__( @@ -47,7 +44,7 @@ def to_dict(self) -> Dict[str, Any]: } -class BaseSQLValidator: +class BaseSQLValidator: # pylint: disable=too-few-public-methods """BaseSQLValidator defines the interface for checking that a given sql query is valid for a given database engine.""" diff --git a/superset/sqllab/command.py b/superset/sqllab/command.py index 1513719534a7d..c984de8e2da70 100644 --- a/superset/sqllab/command.py +++ b/superset/sqllab/command.py @@ -124,7 +124,7 @@ def is_query_handled(cls, query: Optional[Query]) -> bool: QueryStatus.TIMED_OUT, ] - def _run_sql_json_exec_from_scratch(self,) -> SqlJsonExecutionStatus: + def _run_sql_json_exec_from_scratch(self) -> SqlJsonExecutionStatus: self.execution_context.set_database(self._get_the_query_db()) query = self.execution_context.create_query() try: @@ -183,7 +183,7 @@ def _validate_access(self, query: Query) -> None: self.session.commit() raise SupersetErrorException(ex.error, status=403) from ex - def _render_query(self,) -> str: + def _render_query(self) -> str: def validate( rendered_query: str, template_processor: BaseTemplateProcessor ) -> None: @@ -232,7 +232,7 @@ def _set_query_limit_if_required(self, rendered_query: str,) -> None: if self._is_required_to_set_limit(): self._set_query_limit(rendered_query) - def _is_required_to_set_limit(self,) -> bool: + def _is_required_to_set_limit(self) -> bool: return not ( config.get("SQLLAB_CTAS_NO_LIMIT") and self.execution_context.select_as_cta ) @@ -382,10 +382,8 @@ def _is_store_results(cls, query: Query) -> bool: is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE") and not query.select_as_cta ) - def _create_payload_from_execution_context( - # pylint: disable=invalid-name, no-self-use - self, - status: SqlJsonExecutionStatus, + def _create_payload_from_execution_context( # pylint: disable=invalid-name + self, status: SqlJsonExecutionStatus, ) -> str: if status == SqlJsonExecutionStatus.HAS_RESULTS: diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index 1529f10631f8e..18094323ec1ec 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -68,18 +68,17 @@ def load_chart_data_into_cache( async_query_manager.update_job( job_metadata, async_query_manager.STATUS_DONE, result_url=result_url, ) - except SoftTimeLimitExceeded as exc: - logger.warning("A timeout occurred while loading chart data, error: %s", exc) - raise exc - except Exception as exc: + except SoftTimeLimitExceeded as ex: + logger.warning("A timeout occurred while loading chart data, error: %s", ex) + raise ex + except Exception as ex: # TODO: QueryContext should support SIP-40 style errors - # pylint: disable=no-member - error = exc.message if hasattr(exc, "message") else str(exc) # type: ignore + error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member errors = [{"message": error}] async_query_manager.update_job( job_metadata, async_query_manager.STATUS_ERROR, errors=errors ) - raise exc + raise ex @celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout) @@ -127,16 +126,14 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals except SoftTimeLimitExceeded as ex: logger.warning("A timeout occurred while loading explore json, error: %s", ex) raise ex - except Exception as exc: - # pylint: disable=no-member - if isinstance(exc, SupersetVizException): - # pylint: disable=no-member - errors = exc.errors + except Exception as ex: + if isinstance(ex, SupersetVizException): + errors = ex.errors # pylint: disable=no-member else: - error = exc.message if hasattr(exc, "message") else str(exc) # type: ignore + error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member errors = [error] async_query_manager.update_job( job_metadata, async_query_manager.STATUS_ERROR, errors=errors ) - raise exc + raise ex diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index 32b421a47055a..ee73df5fde14e 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=too-few-public-methods - import json import logging from typing import Any, Dict, List, Optional, Union @@ -85,7 +83,7 @@ def get_url(chart: Slice, extra_filters: Optional[Dict[str, Any]] = None) -> str return f"{baseurl}{chart.get_explore_url(overrides=extra_filters)}" -class Strategy: +class Strategy: # pylint: disable=too-few-public-methods """ A cache warm up strategy. @@ -115,7 +113,7 @@ def get_urls(self) -> List[str]: raise NotImplementedError("Subclasses must implement get_urls!") -class DummyStrategy(Strategy): +class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods """ Warm up all charts. @@ -140,7 +138,7 @@ def get_urls(self) -> List[str]: return [get_url(chart) for chart in charts] -class TopNDashboardsStrategy(Strategy): +class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-methods """ Warm up charts in the top-n dashboards. @@ -187,7 +185,7 @@ def get_urls(self) -> List[str]: return urls -class DashboardTagsStrategy(Strategy): +class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods """ Warm up charts in dashboards with custom tags. diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py index be5be222481e7..05506d077a973 100644 --- a/superset/tasks/schedules.py +++ b/superset/tasks/schedules.py @@ -26,6 +26,7 @@ from collections import namedtuple from datetime import datetime, timedelta from email.utils import make_msgid, parseaddr +from enum import Enum from typing import ( Any, Callable, @@ -72,8 +73,6 @@ from superset.utils.screenshots import ChartScreenshot, WebDriverProxy from superset.utils.urls import get_url_path -# pylint: disable=too-few-public-methods - if TYPE_CHECKING: from flask_appbuilder.security.sqla.models import User from werkzeug.datastructures import TypeConversionDict @@ -571,7 +570,7 @@ def schedule_alert_query( raise RuntimeError("Unknown report type") -class AlertState: +class AlertState(str, Enum): ERROR = "error" TRIGGER = "trigger" PASS = "pass" diff --git a/superset/utils/date_parser.py b/superset/utils/date_parser.py index 2bbfc185acbc8..802c185d9ec64 100644 --- a/superset/utils/date_parser.py +++ b/superset/utils/date_parser.py @@ -138,8 +138,7 @@ def parse_past_timedelta( ) -# pylint: disable=too-many-arguments, too-many-locals, too-many-branches -def get_since_until( +def get_since_until( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches time_range: Optional[str] = None, since: Optional[str] = None, until: Optional[str] = None, diff --git a/superset/utils/logging_configurator.py b/superset/utils/logging_configurator.py index da45553f6cde9..5753b3941ccf5 100644 --- a/superset/utils/logging_configurator.py +++ b/superset/utils/logging_configurator.py @@ -24,8 +24,7 @@ logger = logging.getLogger(__name__) -# pylint: disable=too-few-public-methods -class LoggingConfigurator(abc.ABC): +class LoggingConfigurator(abc.ABC): # pylint: disable=too-few-public-methods @abc.abstractmethod def configure_logging( self, app_config: flask.config.Config, debug_mode: bool @@ -33,7 +32,9 @@ def configure_logging( pass -class DefaultLoggingConfigurator(LoggingConfigurator): +class DefaultLoggingConfigurator( # pylint: disable=too-few-public-methods + LoggingConfigurator +): def configure_logging( self, app_config: flask.config.Config, debug_mode: bool ) -> None: diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py index 6de9eca60bea1..52838f0478881 100644 --- a/superset/utils/mock_data.py +++ b/superset/utils/mock_data.py @@ -67,8 +67,9 @@ days_range = (MAXIMUM_DATE - MINIMUM_DATE).days -# pylint: disable=too-many-return-statements, too-many-branches -def get_type_generator(sqltype: sqlalchemy.sql.sqltypes) -> Callable[[], Any]: +def get_type_generator( # pylint: disable=too-many-return-statements,too-many-branches + sqltype: sqlalchemy.sql.sqltypes, +) -> Callable[[], Any]: if isinstance(sqltype, sqlalchemy.dialects.mysql.types.TINYINT): return lambda: random.choice([0, 1]) diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index ad6c6afedade1..a141c56b6892d 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -470,9 +470,8 @@ def diff( return _append_columns(df, df_diff, columns) -# pylint: disable=too-many-arguments @validate_column_args("source_columns", "compare_columns") -def compare( +def compare( # pylint: disable=too-many-arguments df: DataFrame, source_columns: List[str], compare_columns: List[str], diff --git a/superset/views/base.py b/superset/views/base.py index 9d27861e5f2c7..26985a2e175af 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -376,8 +376,9 @@ def common_bootstrap_payload() -> Dict[str, Any]: return bootstrap_data -# pylint: disable=invalid-name -def get_error_level_from_status_code(status: int) -> ErrorLevel: +def get_error_level_from_status_code( # pylint: disable=invalid-name + status: int, +) -> ErrorLevel: if status < 400: return ErrorLevel.INFO if status < 500: diff --git a/superset/views/core.py b/superset/views/core.py index 962cc81509bef..4ed7e89cdbd30 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=comparison-with-callable,line-too-long,too-many-lines +# pylint: disable=too-many-lines from __future__ import annotations import logging @@ -326,7 +326,7 @@ def clean_fulfilled_requests(session: Session) -> None: requests = ( session.query(DAR) - .filter( + .filter( # pylint: disable=comparison-with-callable DAR.datasource_id == datasource_id, DAR.datasource_type == datasource_type, DAR.created_by_fk == requested_by.id, @@ -1537,7 +1537,9 @@ def created_dashboards( # pylint: disable=no-self-use Dash = Dashboard qry = ( db.session.query(Dash) - .filter(or_(Dash.created_by_fk == user_id, Dash.changed_by_fk == user_id)) + .filter( # pylint: disable=comparison-with-callable + or_(Dash.created_by_fk == user_id, Dash.changed_by_fk == user_id) + ) .order_by(Dash.changed_on.desc()) ) payload = [ @@ -1581,7 +1583,7 @@ def user_slices( # pylint: disable=no-self-use ), isouter=True, ) - .filter( + .filter( # pylint: disable=comparison-with-callable or_( Slice.id.in_(owner_ids_query), Slice.created_by_fk == user_id, @@ -1617,7 +1619,9 @@ def created_slices( # pylint: disable=no-self-use user_id = g.user.get_id() qry = ( db.session.query(Slice) - .filter(or_(Slice.created_by_fk == user_id, Slice.changed_by_fk == user_id)) + .filter( # pylint: disable=comparison-with-callable + or_(Slice.created_by_fk == user_id, Slice.changed_by_fk == user_id) + ) .order_by(Slice.changed_on.desc()) ) payload = [ @@ -1859,7 +1863,8 @@ def dashboard( """ Server side rendering for a dashboard :param dashboard_id_or_slug: identifier for dashboard. used in the decorators - :param add_extra_log_payload: added by `log_this_with_manual_updates`, set a default value to appease pylint + :param add_extra_log_payload: added by `log_this_with_manual_updates`, set a + default value to appease pylint :param dashboard: added by `check_dashboard_access` """ if not dashboard: @@ -2422,10 +2427,8 @@ def sql_json(self) -> FlaskResponse: command_result: CommandResult = command.run() return self._create_response_from_execution_context(command_result) - def _create_response_from_execution_context( - # pylint: disable=invalid-name, no-self-use - self, - command_result: CommandResult, + def _create_response_from_execution_context( # pylint: disable=invalid-name, no-self-use + self, command_result: CommandResult, ) -> FlaskResponse: status_code = 200 diff --git a/superset/views/filters.py b/superset/views/filters.py index 3594d2107507e..3a503e66b614a 100644 --- a/superset/views/filters.py +++ b/superset/views/filters.py @@ -23,10 +23,9 @@ from superset import security_manager -# pylint: disable=too-few-public-methods +class FilterRelatedOwners(BaseFilter): # pylint: disable=too-few-public-methods -class FilterRelatedOwners(BaseFilter): """ A filter to allow searching for related owners of a resource. diff --git a/superset/viz.py b/superset/viz.py index 82e74064e6c31..cff9eda8d85aa 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -14,14 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=C,R,W,too-many-lines,useless-suppression +# pylint: disable=too-many-lines """This module contains the 'Viz' objects These objects represent the backend of all the visualizations that Superset can render. """ import copy -import inspect +import dataclasses import logging import math import re @@ -53,7 +53,7 @@ from geopy.point import Point from pandas.tseries.frequencies import to_offset -from superset import app, db, is_feature_enabled +from superset import app, is_feature_enabled from superset.constants import NULL_STRING from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( @@ -64,7 +64,6 @@ SupersetSecurityException, ) from superset.extensions import cache_manager, security_manager -from superset.models.cache import CacheKey from superset.models.helpers import QueryResult from superset.typing import Metric, QueryObjectDict, VizData, VizPayload from superset.utils import core as utils, csv @@ -81,9 +80,6 @@ from superset.utils.dates import datetime_to_epoch from superset.utils.hashing import md5_sha_from_str -import dataclasses # isort:skip - - if TYPE_CHECKING: from superset.connectors.base.models import BaseDatasource @@ -104,13 +100,13 @@ "size", ] -# This regex is to get user defined filter column name, which is the first param in the filter_values function. -# see the definition of filter_values template: +# This regex is to get user defined filter column name, which is the first param in the +# filter_values function. See the definition of filter_values template: # https://github.com/apache/superset/blob/24ad6063d736c1f38ad6f962e586b9b1a21946af/superset/jinja_context.py#L63 FILTER_VALUES_REGEX = re.compile(r"filter_values\(['\"](\w+)['\"]\,") -class BaseViz: +class BaseViz: # pylint: disable=too-many-public-methods """All visualizations derive this base class""" @@ -165,9 +161,8 @@ def process_metrics(self) -> None: # metrics in Viz is order sensitive, so metric_dict should be # OrderedDict self.metric_dict = OrderedDict() - fd = self.form_data for mkey in METRIC_KEYS: - val = fd.get(mkey) + val = self.form_data.get(mkey) if val: if not isinstance(val, list): val = [val] @@ -183,13 +178,13 @@ def process_metrics(self) -> None: def handle_js_int_overflow( data: Dict[str, List[Dict[str, Any]]] ) -> Dict[str, List[Dict[str, Any]]]: - for d in data.get("records", {}): - for k, v in list(d.items()): + for record in data.get("records", {}): + for k, v in list(record.items()): if isinstance(v, int): # if an int is too big for Java Script to handle # convert it to a string if abs(v) > JS_MAX_INTEGER: - d[k] = str(v) + record[k] = str(v) return data def run_extra_queries(self) -> None: @@ -213,13 +208,11 @@ def run_extra_queries(self) -> None: when moving from caching the visualization's data itself, to caching the underlying query(ies). """ - pass def apply_rolling(self, df: pd.DataFrame) -> pd.DataFrame: - fd = self.form_data - rolling_type = fd.get("rolling_type") - rolling_periods = int(fd.get("rolling_periods") or 0) - min_periods = int(fd.get("min_periods") or 0) + rolling_type = self.form_data.get("rolling_type") + rolling_periods = int(self.form_data.get("rolling_periods") or 0) + min_periods = int(self.form_data.get("min_periods") or 0) if rolling_type in ("mean", "std", "sum") and rolling_periods: kwargs = dict(window=rolling_periods, min_periods=min_periods) @@ -311,15 +304,13 @@ def process_query_filters(self) -> None: merge_extra_filters(self.form_data) utils.split_adhoc_filters_into_base_filters(self.form_data) - def query_obj(self) -> QueryObjectDict: + def query_obj(self) -> QueryObjectDict: # pylint: disable=too-many-locals """Building a query object""" - form_data = self.form_data - self.process_query_filters() gb = self.groupby metrics = self.all_metrics or [] - columns = form_data.get("columns") or [] + columns = self.form_data.get("columns") or [] # merge list and dedup while preserving order groupby = list(OrderedDict.fromkeys(gb + columns)) @@ -328,26 +319,28 @@ def query_obj(self) -> QueryObjectDict: groupby.remove(DTTM_ALIAS) is_timeseries = True - granularity = form_data.get("granularity") or form_data.get("granularity_sqla") - limit = int(form_data.get("limit") or 0) - timeseries_limit_metric = form_data.get("timeseries_limit_metric") - row_limit = int(form_data.get("row_limit") or config["ROW_LIMIT"]) + granularity = self.form_data.get("granularity") or self.form_data.get( + "granularity_sqla" + ) + limit = int(self.form_data.get("limit") or 0) + timeseries_limit_metric = self.form_data.get("timeseries_limit_metric") + row_limit = int(self.form_data.get("row_limit") or config["ROW_LIMIT"]) # default order direction - order_desc = form_data.get("order_desc", True) + order_desc = self.form_data.get("order_desc", True) try: since, until = get_since_until( relative_start=relative_start, relative_end=relative_end, - time_range=form_data.get("time_range"), - since=form_data.get("since"), - until=form_data.get("until"), + time_range=self.form_data.get("time_range"), + since=self.form_data.get("since"), + until=self.form_data.get("until"), ) except ValueError as ex: - raise QueryObjectValidationError(str(ex)) + raise QueryObjectValidationError(str(ex)) from ex - time_shift = form_data.get("time_shift", "") + time_shift = self.form_data.get("time_shift", "") self.time_shift = parse_past_timedelta(time_shift) from_dttm = None if since is None else (since - self.time_shift) to_dttm = None if until is None else (until - self.time_shift) @@ -362,12 +355,12 @@ def query_obj(self) -> QueryObjectDict: # extras are used to query elements specific to a datasource type # for instance the extra where clause that applies only to Tables extras = { - "druid_time_origin": form_data.get("druid_time_origin", ""), - "having": form_data.get("having", ""), - "having_druid": form_data.get("having_filters", []), - "time_grain_sqla": form_data.get("time_grain_sqla"), - "time_range_endpoints": form_data.get("time_range_endpoints"), - "where": form_data.get("where", ""), + "druid_time_origin": self.form_data.get("druid_time_origin", ""), + "having": self.form_data.get("having", ""), + "having_druid": self.form_data.get("having_filters", []), + "time_grain_sqla": self.form_data.get("time_grain_sqla"), + "time_range_endpoints": self.form_data.get("time_range_endpoints"), + "where": self.form_data.get("where", ""), } return { @@ -495,7 +488,7 @@ def get_df_payload( query_obj = self.query_obj() cache_key = self.cache_key(query_obj, **kwargs) if query_obj else None cache_value = None - logger.info("Cache key: {}".format(cache_key)) + logger.info("Cache key: %s", cache_key) is_loaded = False stacktrace = None df = None @@ -509,10 +502,11 @@ def get_df_payload( self.status = utils.QueryStatus.SUCCESS is_loaded = True stats_logger.incr("loaded_from_cache") - except Exception as ex: + except Exception as ex: # pylint: disable=broad-except logger.exception(ex) logger.error( - "Error reading cache: " + utils.error_msg_from_exception(ex), + "Error reading cache: %s", + utils.error_msg_from_exception(ex), exc_info=True, ) logger.info("Serving from cache") @@ -520,7 +514,8 @@ def get_df_payload( if query_obj and not is_loaded: if self.force_cached: logger.warning( - f"force_cached (viz.py): value not found for cache key {cache_key}" + "force_cached (viz.py): value not found for cache key %s", + cache_key, ) raise CacheLoadError(_("Cached value not found")) try: @@ -556,7 +551,7 @@ def get_df_payload( ) self.errors.append(error) self.status = utils.QueryStatus.FAILED - except Exception as ex: + except Exception as ex: # pylint: disable=broad-except logger.exception(ex) error = dataclasses.asdict( @@ -594,12 +589,17 @@ def get_df_payload( "rowcount": len(df.index) if df is not None else 0, } - def json_dumps(self, obj: Any, sort_keys: bool = False) -> str: + @staticmethod + def json_dumps(query_obj: Any, sort_keys: bool = False) -> str: return json.dumps( - obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys + query_obj, + default=utils.json_int_dttm_ser, + ignore_nan=True, + sort_keys=sort_keys, ) - def has_error(self, payload: VizPayload) -> bool: + @staticmethod + def has_error(payload: VizPayload) -> bool: return ( payload.get("status") == utils.QueryStatus.FAILED or payload.get("error") is not None @@ -625,7 +625,7 @@ def get_csv(self) -> Optional[str]: include_index = not isinstance(df.index, pd.RangeIndex) return csv.df_to_escaped_csv(df, index=include_index, **config["CSV_EXPORT"]) - def get_data(self, df: pd.DataFrame) -> VizData: + def get_data(self, df: pd.DataFrame) -> VizData: # pylint: disable=no-self-use return df.to_dict(orient="records") @property @@ -662,11 +662,14 @@ def process_metrics(self) -> None: """ # Verify form data first: if not specifying query mode, then cannot have both # GROUP BY and RAW COLUMNS. - fd = self.form_data if ( - not fd.get("query_mode") - and fd.get("all_columns") - and (fd.get("groupby") or fd.get("metrics") or fd.get("percent_metrics")) + not self.form_data.get("query_mode") + and self.form_data.get("all_columns") + and ( + self.form_data.get("groupby") + or self.form_data.get("metrics") + or self.form_data.get("percent_metrics") + ) ): raise QueryObjectValidationError( _( @@ -678,10 +681,12 @@ def process_metrics(self) -> None: super().process_metrics() - self.query_mode: QueryMode = QueryMode.get(fd.get("query_mode")) or ( + self.query_mode: QueryMode = QueryMode.get( + self.form_data.get("query_mode") + ) or ( # infer query mode from the presence of other fields QueryMode.RAW - if len(fd.get("all_columns") or []) > 0 + if len(self.form_data.get("all_columns") or []) > 0 else QueryMode.AGGREGATE ) @@ -689,53 +694,63 @@ def process_metrics(self) -> None: percent_columns: List[str] = [] # percent columns that needs extra computation if self.query_mode == QueryMode.RAW: - columns = utils.get_metric_names(fd.get("all_columns") or []) + columns = utils.get_metric_names(self.form_data.get("all_columns") or []) else: - columns = utils.get_metric_names(self.groupby + (fd.get("metrics") or [])) - percent_columns = utils.get_metric_names(fd.get("percent_metrics") or []) + columns = utils.get_metric_names( + self.groupby + (self.form_data.get("metrics") or []) + ) + percent_columns = utils.get_metric_names( + self.form_data.get("percent_metrics") or [] + ) self.columns = columns self.percent_columns = percent_columns self.is_timeseries = self.should_be_timeseries() def should_be_timeseries(self) -> bool: - fd = self.form_data # TODO handle datasource-type-specific code in datasource - conditions_met = (fd.get("granularity") and fd.get("granularity") != "all") or ( - fd.get("granularity_sqla") and fd.get("time_grain_sqla") + conditions_met = ( + self.form_data.get("granularity") + and self.form_data.get("granularity") != "all" + ) or ( + self.form_data.get("granularity_sqla") + and self.form_data.get("time_grain_sqla") ) - if fd.get("include_time") and not conditions_met: + if self.form_data.get("include_time") and not conditions_met: raise QueryObjectValidationError( _("Pick a granularity in the Time section or " "uncheck 'Include Time'") ) - return bool(fd.get("include_time")) + return bool(self.form_data.get("include_time")) def query_obj(self) -> QueryObjectDict: - d = super().query_obj() - fd = self.form_data + query_obj = super().query_obj() if self.query_mode == QueryMode.RAW: - d["columns"] = fd.get("all_columns") - order_by_cols = fd.get("order_by_cols") or [] - d["orderby"] = [json.loads(t) for t in order_by_cols] + query_obj["columns"] = self.form_data.get("all_columns") + order_by_cols = self.form_data.get("order_by_cols") or [] + query_obj["orderby"] = [json.loads(t) for t in order_by_cols] # must disable groupby and metrics in raw mode - d["groupby"] = [] - d["metrics"] = [] + query_obj["groupby"] = [] + query_obj["metrics"] = [] # raw mode does not support timeseries queries - d["timeseries_limit_metric"] = None - d["timeseries_limit"] = None - d["is_timeseries"] = None + query_obj["timeseries_limit_metric"] = None + query_obj["timeseries_limit"] = None + query_obj["is_timeseries"] = None else: - sort_by = fd.get("timeseries_limit_metric") + sort_by = self.form_data.get("timeseries_limit_metric") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): - d["metrics"].append(sort_by) - d["orderby"] = [(sort_by, not fd.get("order_desc", True))] - elif d["metrics"]: + if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): + query_obj["metrics"].append(sort_by) + query_obj["orderby"] = [ + (sort_by, not self.form_data.get("order_desc", True)) + ] + elif query_obj["metrics"]: # Legacy behavior of sorting by first metric by default - first_metric = d["metrics"][0] - d["orderby"] = [(first_metric, not fd.get("order_desc", True))] - return d + first_metric = query_obj["metrics"][0] + query_obj["orderby"] = [ + (first_metric, not self.form_data.get("order_desc", True)) + ] + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: """ @@ -768,9 +783,13 @@ def get_data(self, df: pd.DataFrame) -> VizData: dict(records=df.to_dict(orient="records"), columns=list(df.columns)) ) - def json_dumps(self, obj: Any, sort_keys: bool = False) -> str: + @staticmethod + def json_dumps(query_obj: Any, sort_keys: bool = False) -> str: return json.dumps( - obj, default=utils.json_iso_dttm_ser, sort_keys=sort_keys, ignore_nan=True + query_obj, + default=utils.json_iso_dttm_ser, + sort_keys=sort_keys, + ignore_nan=True, ) @@ -784,35 +803,33 @@ class TimeTableViz(BaseViz): is_timeseries = True def query_obj(self) -> QueryObjectDict: - d = super().query_obj() - fd = self.form_data + query_obj = super().query_obj() - if not fd.get("metrics"): + if not self.form_data.get("metrics"): raise QueryObjectValidationError(_("Pick at least one metric")) - if fd.get("groupby") and len(fd["metrics"]) > 1: + if self.form_data.get("groupby") and len(self.form_data["metrics"]) > 1: raise QueryObjectValidationError( _("When using 'Group By' you are limited to use a single metric") ) - return d + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None - fd = self.form_data columns = None values: Union[List[str], str] = self.metric_labels - if fd.get("groupby"): + if self.form_data.get("groupby"): values = self.metric_labels[0] - columns = fd.get("groupby") + columns = self.form_data.get("groupby") pt = df.pivot_table(index=DTTM_ALIAS, columns=columns, values=values) pt.index = pt.index.map(str) pt = pt.sort_index() return dict( records=pt.to_dict(orient="index"), columns=list(pt.columns), - is_group_by=True if fd.get("groupby") else False, + is_group_by=bool(self.form_data.get("groupby")), ) @@ -827,7 +844,7 @@ class PivotTableViz(BaseViz): enforce_numerical_metrics = False def query_obj(self) -> QueryObjectDict: - d = super().query_obj() + query_obj = super().query_obj() groupby = self.form_data.get("groupby") columns = self.form_data.get("columns") metrics = self.form_data.get("metrics") @@ -856,11 +873,13 @@ def query_obj(self) -> QueryObjectDict: sort_by = self.form_data.get("timeseries_limit_metric") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): - d["metrics"].append(sort_by) + if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): + query_obj["metrics"].append(sort_by) if self.form_data.get("order_desc"): - d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))] - return d + query_obj["orderby"] = [ + (sort_by, not self.form_data.get("order_desc", True)) + ] + return query_obj @staticmethod def get_aggfunc( @@ -887,7 +906,7 @@ def _format_datetime(value: Union[pd.Timestamp, datetime, date, str]) -> str: tstamp: Optional[pd.Timestamp] = None if isinstance(value, pd.Timestamp): tstamp = value - if isinstance(value, datetime) or isinstance(value, date): + if isinstance(value, (date, datetime)): tstamp = pd.Timestamp(value) if isinstance(value, str): try: @@ -959,15 +978,17 @@ class TreemapViz(BaseViz): is_timeseries = False def query_obj(self) -> QueryObjectDict: - d = super().query_obj() + query_obj = super().query_obj() sort_by = self.form_data.get("timeseries_limit_metric") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): - d["metrics"].append(sort_by) + if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): + query_obj["metrics"].append(sort_by) if self.form_data.get("order_desc"): - d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))] - return d + query_obj["orderby"] = [ + (sort_by, not self.form_data.get("order_desc", True)) + ] + return query_obj def _nest(self, metric: str, df: pd.DataFrame) -> List[Dict[str, Any]]: nlevels = df.index.nlevels @@ -1001,7 +1022,7 @@ class CalHeatmapViz(BaseViz): credits = "cal-heatmap" is_timeseries = True - def get_data(self, df: pd.DataFrame) -> VizData: + def get_data(self, df: pd.DataFrame) -> VizData: # pylint: disable=too-many-locals if df.empty: return None @@ -1010,11 +1031,11 @@ def get_data(self, df: pd.DataFrame) -> VizData: records = df.to_dict("records") for metric in self.metric_labels: values = {} - for obj in records: - v = obj[DTTM_ALIAS] + for query_obj in records: + v = query_obj[DTTM_ALIAS] if hasattr(v, "value"): v = v.value - values[str(v / 10 ** 9)] = obj.get(metric) + values[str(v / 10 ** 9)] = query_obj.get(metric) data[metric] = values try: @@ -1026,7 +1047,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: until=form_data.get("until"), ) except ValueError as ex: - raise QueryObjectValidationError(str(ex)) + raise QueryObjectValidationError(str(ex)) from ex if not start or not end: raise QueryObjectValidationError( "Please provide both time bounds (Since and Until)" @@ -1055,9 +1076,8 @@ def get_data(self, df: pd.DataFrame) -> VizData: } def query_obj(self) -> QueryObjectDict: - d = super().query_obj() - fd = self.form_data - d["metrics"] = fd.get("metrics") + query_obj = super().query_obj() + query_obj["metrics"] = self.form_data.get("metrics") mapping = { "min": "PT1M", "hour": "PT1H", @@ -1066,12 +1086,12 @@ def query_obj(self) -> QueryObjectDict: "month": "P1M", "year": "P1Y", } - time_grain = mapping[fd.get("subdomain_granularity", "min")] + time_grain = mapping[self.form_data.get("subdomain_granularity", "min")] if self.datasource.type == "druid": - d["granularity"] = time_grain + query_obj["granularity"] = time_grain else: - d["extras"]["time_grain_sqla"] = time_grain - return d + query_obj["extras"]["time_grain_sqla"] = time_grain + return query_obj class NVD3Viz(BaseViz): @@ -1093,28 +1113,28 @@ class BubbleViz(NVD3Viz): is_timeseries = False def query_obj(self) -> QueryObjectDict: - form_data = self.form_data - d = super().query_obj() - d["groupby"] = [form_data.get("entity")] - if form_data.get("series"): - d["groupby"].append(form_data.get("series")) + query_obj = super().query_obj() + query_obj["groupby"] = [self.form_data.get("entity")] + if self.form_data.get("series"): + query_obj["groupby"].append(self.form_data.get("series")) # dedup groupby if it happens to be the same - d["groupby"] = list(dict.fromkeys(d["groupby"])) + query_obj["groupby"] = list(dict.fromkeys(query_obj["groupby"])) - self.x_metric = form_data["x"] - self.y_metric = form_data["y"] - self.z_metric = form_data["size"] - self.entity = form_data.get("entity") - self.series = form_data.get("series") or self.entity - d["row_limit"] = form_data.get("limit") + # pylint: disable=attribute-defined-outside-init + self.x_metric = self.form_data["x"] + self.y_metric = self.form_data["y"] + self.z_metric = self.form_data["size"] + self.entity = self.form_data.get("entity") + self.series = self.form_data.get("series") or self.entity + query_obj["row_limit"] = self.form_data.get("limit") - d["metrics"] = [self.z_metric, self.x_metric, self.y_metric] + query_obj["metrics"] = [self.z_metric, self.x_metric, self.y_metric] if len(set(self.metric_labels)) < 3: raise QueryObjectValidationError(_("Please use 3 different metric labels")) - if not all(d["metrics"] + [self.entity]): + if not all(query_obj["metrics"] + [self.entity]): raise QueryObjectValidationError(_("Pick a metric for x, y and size")) - return d + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: @@ -1145,13 +1165,15 @@ class BulletViz(NVD3Viz): def query_obj(self) -> QueryObjectDict: form_data = self.form_data - d = super().query_obj() - self.metric = form_data["metric"] + query_obj = super().query_obj() + self.metric = form_data[ # pylint: disable=attribute-defined-outside-init + "metric" + ] - d["metrics"] = [self.metric] + query_obj["metrics"] = [self.metric] if not self.metric: raise QueryObjectValidationError(_("Pick a metric to display")) - return d + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: @@ -1173,13 +1195,13 @@ class BigNumberViz(BaseViz): is_timeseries = True def query_obj(self) -> QueryObjectDict: - d = super().query_obj() + query_obj = super().query_obj() metric = self.form_data.get("metric") if not metric: raise QueryObjectValidationError(_("Pick a metric!")) - d["metrics"] = [self.form_data.get("metric")] + query_obj["metrics"] = [self.form_data.get("metric")] self.form_data["metric"] = metric - return d + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: @@ -1207,16 +1229,16 @@ class BigNumberTotalViz(BaseViz): is_timeseries = False def query_obj(self) -> QueryObjectDict: - d = super().query_obj() + query_obj = super().query_obj() metric = self.form_data.get("metric") if not metric: raise QueryObjectValidationError(_("Pick a metric!")) - d["metrics"] = [self.form_data.get("metric")] + query_obj["metrics"] = [self.form_data.get("metric")] self.form_data["metric"] = metric # Limiting rows is not required as only one cell is returned - d["row_limit"] = None - return d + query_obj["row_limit"] = None + return query_obj class NVD3TimeSeriesViz(NVD3Viz): @@ -1230,19 +1252,19 @@ class NVD3TimeSeriesViz(NVD3Viz): pivot_fill_value: Optional[int] = None def query_obj(self) -> QueryObjectDict: - d = super().query_obj() + query_obj = super().query_obj() sort_by = self.form_data.get( "timeseries_limit_metric" - ) or utils.get_first_metric_name(d.get("metrics") or []) + ) or utils.get_first_metric_name(query_obj.get("metrics") or []) is_asc = not self.form_data.get("order_desc") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): - d["metrics"].append(sort_by) - d["orderby"] = [(sort_by, is_asc)] - return d + if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): + query_obj["metrics"].append(sort_by) + query_obj["orderby"] = [(sort_by, is_asc)] + return query_obj - def to_series( + def to_series( # pylint: disable=too-many-branches self, df: pd.DataFrame, classed: str = "", title_suffix: str = "" ) -> List[Dict[str, Any]]: cols = [] @@ -1287,25 +1309,24 @@ def to_series( non_nan_cnt = 0 for ds in df.index: if ds in ys: - d = {"x": ds, "y": ys[ds]} + data = {"x": ds, "y": ys[ds]} if not np.isnan(ys[ds]): non_nan_cnt += 1 else: - d = {} - values.append(d) + data = {} + values.append(data) if non_nan_cnt == 0: continue - d = {"key": series_title, "values": values} + data = {"key": series_title, "values": values} if classed: - d["classed"] = classed - chart_data.append(d) + data["classed"] = classed + chart_data.append(data) return chart_data def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData: - fd = self.form_data - if fd.get("granularity") == "all": + if self.form_data.get("granularity") == "all": raise QueryObjectValidationError( _("Pick a time granularity for your time series") ) @@ -1316,7 +1337,7 @@ def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData: if aggregate: df = df.pivot_table( index=DTTM_ALIAS, - columns=fd.get("groupby"), + columns=self.form_data.get("groupby"), values=self.metric_labels, fill_value=0, aggfunc=sum, @@ -1324,13 +1345,13 @@ def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData: else: df = df.pivot_table( index=DTTM_ALIAS, - columns=fd.get("groupby"), + columns=self.form_data.get("groupby"), values=self.metric_labels, fill_value=self.pivot_fill_value, ) - rule = fd.get("resample_rule") - method = fd.get("resample_method") + rule = self.form_data.get("resample_rule") + method = self.form_data.get("resample_method") if rule and method: df = getattr(df.resample(rule), method)() @@ -1341,16 +1362,14 @@ def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData: df = df[dfs.index] df = self.apply_rolling(df) - if fd.get("contribution"): + if self.form_data.get("contribution"): dft = df.T df = (dft / dft.sum()).T return df def run_extra_queries(self) -> None: - fd = self.form_data - - time_compare = fd.get("time_compare") or [] + time_compare = self.form_data.get("time_compare") or [] # backwards compatibility if not isinstance(time_compare, list): time_compare = [time_compare] @@ -1360,7 +1379,7 @@ def run_extra_queries(self) -> None: try: delta = parse_past_timedelta(option) except ValueError as ex: - raise QueryObjectValidationError(str(ex)) + raise QueryObjectValidationError(str(ex)) from ex query_object["inner_from_dttm"] = query_object["from_dttm"] query_object["inner_to_dttm"] = query_object["to_dttm"] @@ -1384,8 +1403,7 @@ def run_extra_queries(self) -> None: self._extra_chart_data.append((label, df2)) def get_data(self, df: pd.DataFrame) -> VizData: - fd = self.form_data - comparison_type = fd.get("comparison_type") or "values" + comparison_type = self.form_data.get("comparison_type") or "values" df = self.process_data(df) if comparison_type == "values": # Filter out series with all NaN @@ -1446,6 +1464,7 @@ def query_obj(self) -> QueryObjectDict: return {} def get_data(self, df: pd.DataFrame) -> VizData: + # pylint: disable=import-outside-toplevel,too-many-locals multiline_fd = self.form_data # Late import to avoid circular import issues from superset.charts.dao import ChartDAO @@ -1520,10 +1539,9 @@ class NVD3DualLineViz(NVD3Viz): is_timeseries = True def query_obj(self) -> QueryObjectDict: - d = super().query_obj() + query_obj = super().query_obj() m1 = self.form_data.get("metric") m2 = self.form_data.get("metric_2") - d["metrics"] = [m1, m2] if not m1: raise QueryObjectValidationError(_("Pick a metric for left axis!")) if not m2: @@ -1532,7 +1550,8 @@ def query_obj(self) -> QueryObjectDict: raise QueryObjectValidationError( _("Please choose different metrics" " on left and right axis") ) - return d + query_obj["metrics"] = [m1, m2] + return query_obj def to_series(self, df: pd.DataFrame, classed: str = "") -> List[Dict[str, Any]]: cols = [] @@ -1547,37 +1566,36 @@ def to_series(self, df: pd.DataFrame, classed: str = "") -> List[Dict[str, Any]] series = df.to_dict("series") chart_data = [] metrics = [self.form_data["metric"], self.form_data["metric_2"]] - for i, m in enumerate(metrics): - m = utils.get_metric_name(m) - ys = series[m] - if df[m].dtype.kind not in "biufc": + for i, metric in enumerate(metrics): + metric_name = utils.get_metric_name(metric) + ys = series[metric_name] + if df[metric_name].dtype.kind not in "biufc": continue - series_title = m - d = { - "key": series_title, - "classed": classed, - "values": [ - {"x": ds, "y": ys[ds] if ds in ys else None} for ds in df.index - ], - "yAxis": i + 1, - "type": "line", - } - chart_data.append(d) + series_title = metric_name + chart_data.append( + { + "key": series_title, + "classed": classed, + "values": [ + {"x": ds, "y": ys[ds] if ds in ys else None} for ds in df.index + ], + "yAxis": i + 1, + "type": "line", + } + ) return chart_data def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None - fd = self.form_data - if self.form_data.get("granularity") == "all": raise QueryObjectValidationError( _("Pick a time granularity for your time series") ) - metric = utils.get_metric_name(fd["metric"]) - metric_2 = utils.get_metric_name(fd["metric_2"]) + metric = utils.get_metric_name(self.form_data["metric"]) + metric_2 = utils.get_metric_name(self.form_data["metric_2"]) df = df.pivot_table(index=DTTM_ALIAS, values=[metric, metric_2]) chart_data = self.to_series(df) @@ -1602,17 +1620,16 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz): verbose_name = _("Time Series - Period Pivot") def query_obj(self) -> QueryObjectDict: - d = super().query_obj() - d["metrics"] = [self.form_data.get("metric")] - return d + query_obj = super().query_obj() + query_obj["metrics"] = [self.form_data.get("metric")] + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None - fd = self.form_data df = self.process_data(df) - freq = to_offset(fd.get("freq")) + freq = to_offset(self.form_data.get("freq")) try: freq = type(freq)(freq.n, normalize=True, **freq.kwds) except ValueError: @@ -1632,7 +1649,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: df = df.pivot_table( index=DTTM_ALIAS, columns="series", - values=utils.get_metric_name(fd["metric"]), + values=utils.get_metric_name(self.form_data["metric"]), ) chart_data = self.to_series(df) for serie in chart_data: @@ -1669,19 +1686,23 @@ class HistogramViz(BaseViz): def query_obj(self) -> QueryObjectDict: """Returns the query object for this visualization""" - d = super().query_obj() - d["row_limit"] = self.form_data.get("row_limit", int(config["VIZ_ROW_LIMIT"])) + query_obj = super().query_obj() + query_obj["row_limit"] = self.form_data.get( + "row_limit", int(config["VIZ_ROW_LIMIT"]) + ) numeric_columns = self.form_data.get("all_columns_x") if numeric_columns is None: raise QueryObjectValidationError( _("Must have at least one numeric column specified") ) - self.columns = numeric_columns - d["columns"] = numeric_columns + self.groupby + self.columns = ( # pylint: disable=attribute-defined-outside-init + numeric_columns + ) + query_obj["columns"] = numeric_columns + self.groupby # override groupby entry to avoid aggregation - d["groupby"] = None - d["metrics"] = None - return d + query_obj["groupby"] = None + query_obj["metrics"] = None + return query_obj def labelify(self, keys: Union[List[str], str], column: str) -> str: if isinstance(keys, str): @@ -1725,39 +1746,41 @@ class DistributionBarViz(BaseViz): is_timeseries = False def query_obj(self) -> QueryObjectDict: - d = super().query_obj() - fd = self.form_data - if len(d["groupby"]) < len(fd.get("groupby") or []) + len( - fd.get("columns") or [] + query_obj = super().query_obj() + if len(query_obj["groupby"]) < len(self.form_data.get("groupby") or []) + len( + self.form_data.get("columns") or [] ): raise QueryObjectValidationError( _("Can't have overlap between Series and Breakdowns") ) - if not fd.get("metrics"): + if not self.form_data.get("metrics"): raise QueryObjectValidationError(_("Pick at least one metric")) - if not fd.get("groupby"): + if not self.form_data.get("groupby"): raise QueryObjectValidationError(_("Pick at least one field for [Series]")) - sort_by = fd.get("timeseries_limit_metric") + sort_by = self.form_data.get("timeseries_limit_metric") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): - d["metrics"].append(sort_by) - d["orderby"] = [(sort_by, not fd.get("order_desc", True))] - elif d["metrics"]: + if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): + query_obj["metrics"].append(sort_by) + query_obj["orderby"] = [ + (sort_by, not self.form_data.get("order_desc", True)) + ] + elif query_obj["metrics"]: # Legacy behavior of sorting by first metric by default - first_metric = d["metrics"][0] - d["orderby"] = [(first_metric, not fd.get("order_desc", True))] + first_metric = query_obj["metrics"][0] + query_obj["orderby"] = [ + (first_metric, not self.form_data.get("order_desc", True)) + ] - return d + return query_obj - def get_data(self, df: pd.DataFrame) -> VizData: + def get_data(self, df: pd.DataFrame) -> VizData: # pylint: disable=too-many-locals if df.empty: return None - fd = self.form_data metrics = self.metric_labels - columns = fd.get("columns") or [] + columns = self.form_data.get("columns") or [] # pandas will throw away nulls when grouping/pivoting, # so we substitute NULL_STRING for any nulls in the necessary columns @@ -1772,7 +1795,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: is_asc = not self.form_data.get("order_desc") row.sort_values(ascending=is_asc, inplace=True) pt = df.pivot_table(index=self.groupby, columns=columns, values=metrics) - if fd.get("contribution"): + if self.form_data.get("contribution"): pt = pt.T pt = (pt / pt.sum()).T pt = pt.reindex(row.index) @@ -1796,8 +1819,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: else: x = str(x) values.append({"x": x, "y": v}) - d = {"key": series_title, "values": values} - chart_data.append(d) + chart_data.append({"key": series_title, "values": values}) return chart_data @@ -1816,13 +1838,13 @@ class SunburstViz(BaseViz): def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None - fd = copy.deepcopy(self.form_data) - cols = fd.get("groupby") or [] + form_data = copy.deepcopy(self.form_data) + cols = form_data.get("groupby") or [] cols.extend(["m1", "m2"]) - metric = utils.get_metric_name(fd["metric"]) + metric = utils.get_metric_name(form_data["metric"]) secondary_metric = ( - utils.get_metric_name(fd["secondary_metric"]) - if "secondary_metric" in fd + utils.get_metric_name(form_data["secondary_metric"]) + if "secondary_metric" in form_data else None ) if metric == secondary_metric or secondary_metric is None: @@ -1838,15 +1860,14 @@ def get_data(self, df: pd.DataFrame) -> VizData: return df.to_numpy().tolist() def query_obj(self) -> QueryObjectDict: - qry = super().query_obj() - fd = self.form_data - qry["metrics"] = [fd["metric"]] - secondary_metric = fd.get("secondary_metric") - if secondary_metric and secondary_metric != fd["metric"]: - qry["metrics"].append(secondary_metric) + query_obj = super().query_obj() + query_obj["metrics"] = [self.form_data["metric"]] + secondary_metric = self.form_data.get("secondary_metric") + if secondary_metric and secondary_metric != self.form_data["metric"]: + query_obj["metrics"].append(secondary_metric) if self.form_data.get("sort_by_metric", False): - qry["orderby"] = [(qry["metrics"][0], False)] - return qry + query_obj["orderby"] = [(query_obj["metrics"][0], False)] + return query_obj class SankeyViz(BaseViz): @@ -1859,15 +1880,15 @@ class SankeyViz(BaseViz): credits = 'd3-sankey on npm' def query_obj(self) -> QueryObjectDict: - qry = super().query_obj() - if len(qry["groupby"]) != 2: + query_obj = super().query_obj() + if len(query_obj["groupby"]) != 2: raise QueryObjectValidationError( _("Pick exactly 2 columns as [Source / Target]") ) - qry["metrics"] = [self.form_data["metric"]] + query_obj["metrics"] = [self.form_data["metric"]] if self.form_data.get("sort_by_metric", False): - qry["orderby"] = [(qry["metrics"][0], False)] - return qry + query_obj["orderby"] = [(query_obj["metrics"][0], False)] + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: @@ -1885,20 +1906,20 @@ def get_data(self, df: pd.DataFrame) -> VizData: for row in recs: hierarchy[row["source"]].add(row["target"]) - def find_cycle(g: Dict[str, Set[str]]) -> Optional[Tuple[str, str]]: + def find_cycle(graph: Dict[str, Set[str]]) -> Optional[Tuple[str, str]]: """Whether there's a cycle in a directed graph""" path = set() def visit(vertex: str) -> Optional[Tuple[str, str]]: path.add(vertex) - for neighbour in g.get(vertex, ()): + for neighbour in graph.get(vertex, ()): if neighbour in path or visit(neighbour): return (vertex, neighbour) path.remove(vertex) return None - for v in g: - cycle = visit(v) + for vertex in graph: + cycle = visit(vertex) if cycle: return cycle return None @@ -1924,13 +1945,15 @@ class ChordViz(BaseViz): is_timeseries = False def query_obj(self) -> QueryObjectDict: - qry = super().query_obj() - fd = self.form_data - qry["groupby"] = [fd.get("groupby"), fd.get("columns")] - qry["metrics"] = [fd.get("metric")] + query_obj = super().query_obj() + query_obj["groupby"] = [ + self.form_data.get("groupby"), + self.form_data.get("columns"), + ] + query_obj["metrics"] = [self.form_data.get("metric")] if self.form_data.get("sort_by_metric", False): - qry["orderby"] = [(qry["metrics"][0], False)] - return qry + query_obj["orderby"] = [(query_obj["metrics"][0], False)] + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: @@ -1945,8 +1968,10 @@ def get_data(self, df: pd.DataFrame) -> VizData: matrix[(source, target)] = 0 for source, target, value in df.to_records(index=False): matrix[(source, target)] = value - m = [[matrix[(n1, n2)] for n1 in nodes] for n2 in nodes] - return {"nodes": list(nodes), "matrix": m} + return { + "nodes": list(nodes), + "matrix": [[matrix[(n1, n2)] for n1 in nodes] for n2 in nodes], + } class CountryMapViz(BaseViz): @@ -1959,7 +1984,7 @@ class CountryMapViz(BaseViz): credits = "From bl.ocks.org By john-guerra" def query_obj(self) -> QueryObjectDict: - qry = super().query_obj() + query_obj = super().query_obj() metric = self.form_data.get("metric") entity = self.form_data.get("entity") if not self.form_data.get("select_country"): @@ -1968,22 +1993,20 @@ def query_obj(self) -> QueryObjectDict: raise QueryObjectValidationError("Must specify a metric") if not entity: raise QueryObjectValidationError("Must provide ISO codes") - qry["metrics"] = [metric] - qry["groupby"] = [entity] - return qry + query_obj["metrics"] = [metric] + query_obj["groupby"] = [entity] + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None - fd = self.form_data - cols = [fd.get("entity")] + cols = [self.form_data.get("entity")] metric = self.metric_labels[0] cols += [metric] ndf = df[cols] df = ndf df.columns = ["country_id", "metric"] - d = df.to_dict(orient="records") - return d + return df.to_dict(orient="records") class WorldMapViz(BaseViz): @@ -1996,24 +2019,24 @@ class WorldMapViz(BaseViz): credits = 'datamaps on npm' def query_obj(self) -> QueryObjectDict: - qry = super().query_obj() - qry["groupby"] = [self.form_data["entity"]] + query_obj = super().query_obj() + query_obj["groupby"] = [self.form_data["entity"]] if self.form_data.get("sort_by_metric", False): - qry["orderby"] = [(qry["metrics"][0], False)] - return qry + query_obj["orderby"] = [(query_obj["metrics"][0], False)] + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None + # pylint: disable=import-outside-toplevel from superset.examples import countries - fd = self.form_data - cols = [fd.get("entity")] - metric = utils.get_metric_name(fd["metric"]) + cols = [self.form_data.get("entity")] + metric = utils.get_metric_name(self.form_data["metric"]) secondary_metric = ( - utils.get_metric_name(fd["secondary_metric"]) - if "secondary_metric" in fd + utils.get_metric_name(self.form_data["secondary_metric"]) + if "secondary_metric" in self.form_data else None ) columns = ["country", "m1", "m2"] @@ -2030,12 +2053,14 @@ def get_data(self, df: pd.DataFrame) -> VizData: ndf = df[cols] df = ndf df.columns = columns - d = df.to_dict(orient="records") - for row in d: + data = df.to_dict(orient="records") + for row in data: country = None if isinstance(row["country"], str): - if "country_fieldtype" in fd: - country = countries.get(fd["country_fieldtype"], row["country"]) + if "country_fieldtype" in self.form_data: + country = countries.get( + self.form_data["country_fieldtype"], row["country"] + ) if country: row["country"] = country["cca3"] row["latitude"] = country["lat"] @@ -2043,7 +2068,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: row["name"] = country["name"] else: row["country"] = "XXX" - return d + return data class FilterBoxViz(BaseViz): @@ -2061,34 +2086,35 @@ def query_obj(self) -> QueryObjectDict: return {} def run_extra_queries(self) -> None: + # pylint: disable=import-outside-toplevel from superset.common.query_context import QueryContext - qry = super().query_obj() + query_obj = super().query_obj() filters = self.form_data.get("filter_configs") or [] - qry["row_limit"] = self.filter_row_limit - self.dataframes = {} + query_obj["row_limit"] = self.filter_row_limit + self.dataframes = {} # pylint: disable=attribute-defined-outside-init for flt in filters: col = flt.get("column") if not col: raise QueryObjectValidationError( _("Invalid filter configuration, please select a column") ) - qry["groupby"] = [col] + query_obj["groupby"] = [col] metric = flt.get("metric") - qry["metrics"] = [metric] if metric else [] + query_obj["metrics"] = [metric] if metric else [] asc = flt.get("asc") if metric and asc is not None: - qry["orderby"] = [(metric, asc)] + query_obj["orderby"] = [(metric, asc)] QueryContext( datasource={"id": self.datasource.id, "type": self.datasource.type}, - queries=[qry], + queries=[query_obj], ).raise_for_access() - df = self.get_df_payload(query_obj=qry).get("df") + df = self.get_df_payload(query_obj=query_obj).get("df") self.dataframes[col] = df def get_data(self, df: pd.DataFrame) -> VizData: filters = self.form_data.get("filter_configs") or [] - d = {} + data = {} for flt in filters: col = flt.get("column") metric = flt.get("metric") @@ -2098,19 +2124,19 @@ def get_data(self, df: pd.DataFrame) -> VizData: df = df.sort_values( utils.get_metric_name(metric), ascending=flt.get("asc") ) - d[col] = [ + data[col] = [ {"id": row[0], "text": row[0], "metric": row[1]} for row in df.itertuples(index=False) ] else: df = df.sort_values(col, ascending=flt.get("asc")) - d[col] = [ + data[col] = [ {"id": row[0], "text": row[0]} for row in df.itertuples(index=False) ] else: - d[col] = [] - return d + data[col] = [] + return data class ParallelCoordinatesViz(BaseViz): @@ -2130,17 +2156,18 @@ class ParallelCoordinatesViz(BaseViz): is_timeseries = False def query_obj(self) -> QueryObjectDict: - d = super().query_obj() - fd = self.form_data - d["groupby"] = [fd.get("series")] + query_obj = super().query_obj() + query_obj["groupby"] = [self.form_data.get("series")] sort_by = self.form_data.get("timeseries_limit_metric") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): - d["metrics"].append(sort_by) + if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): + query_obj["metrics"].append(sort_by) if self.form_data.get("order_desc"): - d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))] - return d + query_obj["orderby"] = [ + (sort_by, not self.form_data.get("order_desc", True)) + ] + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: return df.to_dict(orient="records") @@ -2159,30 +2186,31 @@ class HeatmapViz(BaseViz): ) def query_obj(self) -> QueryObjectDict: - d = super().query_obj() - fd = self.form_data - d["metrics"] = [fd.get("metric")] - d["groupby"] = [fd.get("all_columns_x"), fd.get("all_columns_y")] + query_obj = super().query_obj() + query_obj["metrics"] = [self.form_data.get("metric")] + query_obj["groupby"] = [ + self.form_data.get("all_columns_x"), + self.form_data.get("all_columns_y"), + ] if self.form_data.get("sort_by_metric", False): - d["orderby"] = [(d["metrics"][0], False)] + query_obj["orderby"] = [(query_obj["metrics"][0], False)] - return d + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None - fd = self.form_data - x = fd.get("all_columns_x") - y = fd.get("all_columns_y") + x = self.form_data.get("all_columns_x") + y = self.form_data.get("all_columns_y") v = self.metric_labels[0] if x == y: df.columns = ["x", "y", "v"] else: df = df[[x, y, v]] df.columns = ["x", "y", "v"] - norm = fd.get("normalize_across") + norm = self.form_data.get("normalize_across") overall = False max_ = df.v.max() min_ = df.v.min() @@ -2228,16 +2256,21 @@ class MapboxViz(BaseViz): credits = "Mapbox GL JS" def query_obj(self) -> QueryObjectDict: - d = super().query_obj() - fd = self.form_data - label_col = fd.get("mapbox_label") + query_obj = super().query_obj() + label_col = self.form_data.get("mapbox_label") - if not fd.get("groupby"): - if fd.get("all_columns_x") is None or fd.get("all_columns_y") is None: + if not self.form_data.get("groupby"): + if ( + self.form_data.get("all_columns_x") is None + or self.form_data.get("all_columns_y") is None + ): raise QueryObjectValidationError( _("[Longitude] and [Latitude] must be set") ) - d["columns"] = [fd.get("all_columns_x"), fd.get("all_columns_y")] + query_obj["columns"] = [ + self.form_data.get("all_columns_x"), + self.form_data.get("all_columns_y"), + ] if label_col and len(label_col) >= 1: if label_col[0] == "count": @@ -2247,39 +2280,39 @@ def query_obj(self) -> QueryObjectDict: + "[Label]" ) ) - d["columns"].append(label_col[0]) + query_obj["columns"].append(label_col[0]) - if fd.get("point_radius") != "Auto": - d["columns"].append(fd.get("point_radius")) + if self.form_data.get("point_radius") != "Auto": + query_obj["columns"].append(self.form_data.get("point_radius")) # Ensure this value is sorted so that it does not # cause the cache key generation (which hashes the # query object) to generate different keys for values # that should be considered the same. - d["columns"] = sorted(set(d["columns"])) + query_obj["columns"] = sorted(set(query_obj["columns"])) else: # Ensuring columns chosen are all in group by if ( label_col and len(label_col) >= 1 and label_col[0] != "count" - and label_col[0] not in fd["groupby"] + and label_col[0] not in self.form_data["groupby"] ): raise QueryObjectValidationError( _("Choice of [Label] must be present in [Group By]") ) if ( - fd.get("point_radius") != "Auto" - and fd.get("point_radius") not in fd["groupby"] + self.form_data.get("point_radius") != "Auto" + and self.form_data.get("point_radius") not in self.form_data["groupby"] ): raise QueryObjectValidationError( _("Choice of [Point Radius] must be present in [Group By]") ) if ( - fd.get("all_columns_x") not in fd["groupby"] - or fd.get("all_columns_y") not in fd["groupby"] + self.form_data.get("all_columns_x") not in self.form_data["groupby"] + or self.form_data.get("all_columns_y") not in self.form_data["groupby"] ): raise QueryObjectValidationError( _( @@ -2287,32 +2320,31 @@ def query_obj(self) -> QueryObjectDict: + "[Group By]" ) ) - return d + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None - fd = self.form_data - label_col = fd.get("mapbox_label") + label_col = self.form_data.get("mapbox_label") has_custom_metric = label_col is not None and len(label_col) > 0 metric_col = [None] * len(df.index) if has_custom_metric: - if label_col[0] == fd.get("all_columns_x"): # type: ignore - metric_col = df[fd.get("all_columns_x")] - elif label_col[0] == fd.get("all_columns_y"): # type: ignore - metric_col = df[fd.get("all_columns_y")] + if label_col[0] == self.form_data.get("all_columns_x"): # type: ignore + metric_col = df[self.form_data.get("all_columns_x")] + elif label_col[0] == self.form_data.get("all_columns_y"): # type: ignore + metric_col = df[self.form_data.get("all_columns_y")] else: metric_col = df[label_col[0]] # type: ignore point_radius_col = ( [None] * len(df.index) - if fd.get("point_radius") == "Auto" - else df[fd.get("point_radius")] + if self.form_data.get("point_radius") == "Auto" + else df[self.form_data.get("point_radius")] ) # limiting geo precision as long decimal values trigger issues # around json-bignumber in Mapbox - GEO_PRECISION = 10 + geo_precision = 10 # using geoJSON formatting geo_json = { "type": "FeatureCollection", @@ -2323,21 +2355,24 @@ def get_data(self, df: pd.DataFrame) -> VizData: "geometry": { "type": "Point", "coordinates": [ - round(lon, GEO_PRECISION), - round(lat, GEO_PRECISION), + round(lon, geo_precision), + round(lat, geo_precision), ], }, } for lon, lat, metric, point_radius in zip( - df[fd.get("all_columns_x")], - df[fd.get("all_columns_y")], + df[self.form_data.get("all_columns_x")], + df[self.form_data.get("all_columns_y")], metric_col, point_radius_col, ) ], } - x_series, y_series = df[fd.get("all_columns_x")], df[fd.get("all_columns_y")] + x_series, y_series = ( + df[self.form_data.get("all_columns_x")], + df[self.form_data.get("all_columns_y")], + ) south_west = [x_series.min(), y_series.min()] north_east = [x_series.max(), y_series.max()] @@ -2345,15 +2380,15 @@ def get_data(self, df: pd.DataFrame) -> VizData: "geoJSON": geo_json, "hasCustomMetric": has_custom_metric, "mapboxApiKey": config["MAPBOX_API_KEY"], - "mapStyle": fd.get("mapbox_style"), - "aggregatorName": fd.get("pandas_aggfunc"), - "clusteringRadius": fd.get("clustering_radius"), - "pointRadiusUnit": fd.get("point_radius_unit"), - "globalOpacity": fd.get("global_opacity"), + "mapStyle": self.form_data.get("mapbox_style"), + "aggregatorName": self.form_data.get("pandas_aggfunc"), + "clusteringRadius": self.form_data.get("clustering_radius"), + "pointRadiusUnit": self.form_data.get("point_radius_unit"), + "globalOpacity": self.form_data.get("global_opacity"), "bounds": [south_west, north_east], - "renderWhileDragging": fd.get("render_while_dragging"), - "tooltip": fd.get("rich_tooltip"), - "color": fd.get("mapbox_color"), + "renderWhileDragging": self.form_data.get("render_while_dragging"), + "tooltip": self.form_data.get("rich_tooltip"), + "color": self.form_data.get("mapbox_color"), } @@ -2371,12 +2406,12 @@ def query_obj(self) -> QueryObjectDict: return {} def get_data(self, df: pd.DataFrame) -> VizData: - fd = self.form_data # Late imports to avoid circular import issues + # pylint: disable=import-outside-toplevel from superset import db from superset.models.slice import Slice - slice_ids = fd.get("deck_slices") + slice_ids = self.form_data.get("deck_slices") slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all() return { "mapboxApiKey": config["MAPBOX_API_KEY"], @@ -2393,6 +2428,7 @@ class BaseDeckGLViz(BaseViz): spatial_control_keys: List[str] = [] def get_metrics(self) -> List[str]: + # pylint: disable=attribute-defined-outside-init self.metric = self.form_data.get("size") return [self.metric] if self.metric else [] @@ -2406,21 +2442,25 @@ def get_spatial_columns(self, key: str) -> List[str]: if spatial.get("type") == "latlong": return [spatial.get("lonCol"), spatial.get("latCol")] - elif spatial.get("type") == "delimited": + + if spatial.get("type") == "delimited": return [spatial.get("lonlatCol")] - elif spatial.get("type") == "geohash": + + if spatial.get("type") == "geohash": return [spatial.get("geohashCol")] return [] @staticmethod - def parse_coordinates(s: Any) -> Optional[Tuple[float, float]]: - if not s: + def parse_coordinates(latlog: Any) -> Optional[Tuple[float, float]]: + if not latlog: return None try: - p = Point(s) - return (p.latitude, p.longitude) - except Exception: - raise SpatialException(_("Invalid spatial point encountered: %s" % s)) + point = Point(latlog) + return (point.latitude, point.longitude) + except Exception as ex: + raise SpatialException( + _("Invalid spatial point encountered: %s" % latlog) + ) from ex @staticmethod def reverse_geohash_decode(geohash_code: str) -> Tuple[str, str]: @@ -2464,16 +2504,15 @@ def process_spatial_data_obj(self, key: str, df: pd.DataFrame) -> pd.DataFrame: return df def add_null_filters(self) -> None: - fd = self.form_data spatial_columns = set() for key in self.spatial_control_keys: for column in self.get_spatial_columns(key): spatial_columns.add(column) - if fd.get("adhoc_filters") is None: - fd["adhoc_filters"] = [] + if self.form_data.get("adhoc_filters") is None: + self.form_data["adhoc_filters"] = [] - line_column = fd.get("line_column") + line_column = self.form_data.get("line_column") if line_column: spatial_columns.add(line_column) @@ -2481,45 +2520,45 @@ def add_null_filters(self) -> None: filter_ = simple_filter_to_adhoc( {"col": column, "op": "IS NOT NULL", "val": ""} ) - fd["adhoc_filters"].append(filter_) + self.form_data["adhoc_filters"].append(filter_) def query_obj(self) -> QueryObjectDict: - fd = self.form_data - # add NULL filters - if fd.get("filter_nulls", True): + if self.form_data.get("filter_nulls", True): self.add_null_filters() - d = super().query_obj() - gb: List[str] = [] + query_obj = super().query_obj() + group_by: List[str] = [] for key in self.spatial_control_keys: - self.process_spatial_query_obj(key, gb) + self.process_spatial_query_obj(key, group_by) - if fd.get("dimension"): - gb += [fd["dimension"]] + if self.form_data.get("dimension"): + group_by += [self.form_data["dimension"]] - if fd.get("js_columns"): - gb += fd.get("js_columns") or [] + if self.form_data.get("js_columns"): + group_by += self.form_data.get("js_columns") or [] metrics = self.get_metrics() # Ensure this value is sorted so that it does not # cause the cache key generation (which hashes the # query object) to generate different keys for values # that should be considered the same. - gb = sorted(set(gb)) + group_by = sorted(set(group_by)) if metrics: - d["groupby"] = gb - d["metrics"] = metrics - d["columns"] = [] - first_metric = d["metrics"][0] - d["orderby"] = [(first_metric, not fd.get("order_desc", True))] + query_obj["groupby"] = group_by + query_obj["metrics"] = metrics + query_obj["columns"] = [] + first_metric = query_obj["metrics"][0] + query_obj["orderby"] = [ + (first_metric, not self.form_data.get("order_desc", True)) + ] else: - d["columns"] = gb - return d + query_obj["columns"] = group_by + return query_obj - def get_js_columns(self, d: Dict[str, Any]) -> Dict[str, Any]: + def get_js_columns(self, data: Dict[str, Any]) -> Dict[str, Any]: cols = self.form_data.get("js_columns") or [] - return {col: d.get(col) for col in cols} + return {col: data.get(col) for col in cols} def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: @@ -2530,9 +2569,9 @@ def get_data(self, df: pd.DataFrame) -> VizData: df = self.process_spatial_data_obj(key, df) features = [] - for d in df.to_dict(orient="records"): - feature = self.get_properties(d) - extra_props = self.get_js_columns(d) + for data in df.to_dict(orient="records"): + feature = self.get_properties(data) + extra_props = self.get_js_columns(data) if extra_props: feature["extraProps"] = extra_props features.append(feature) @@ -2543,7 +2582,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: "metricLabels": self.metric_labels, } - def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: raise NotImplementedError() @@ -2557,38 +2596,41 @@ class DeckScatterViz(BaseDeckGLViz): is_timeseries = True def query_obj(self) -> QueryObjectDict: - fd = self.form_data - self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity")) - self.point_radius_fixed = fd.get("point_radius_fixed") or { + # pylint: disable=attribute-defined-outside-init + self.is_timeseries = bool( + self.form_data.get("time_grain_sqla") or self.form_data.get("granularity") + ) + self.point_radius_fixed = self.form_data.get("point_radius_fixed") or { "type": "fix", "value": 500, } return super().query_obj() def get_metrics(self) -> List[str]: + # pylint: disable=attribute-defined-outside-init self.metric = None if self.point_radius_fixed.get("type") == "metric": self.metric = self.point_radius_fixed["value"] return [self.metric] return [] - def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: return { - "metric": d.get(self.metric_label) if self.metric_label else None, + "metric": data.get(self.metric_label) if self.metric_label else None, "radius": self.fixed_value if self.fixed_value - else d.get(self.metric_label) + else data.get(self.metric_label) if self.metric_label else None, - "cat_color": d.get(self.dim) if self.dim else None, - "position": d.get("spatial"), - DTTM_ALIAS: d.get(DTTM_ALIAS), + "cat_color": data.get(self.dim) if self.dim else None, + "position": data.get("spatial"), + DTTM_ALIAS: data.get(DTTM_ALIAS), } def get_data(self, df: pd.DataFrame) -> VizData: - fd = self.form_data + # pylint: disable=attribute-defined-outside-init self.metric_label = utils.get_metric_name(self.metric) if self.metric else None - self.point_radius_fixed = fd.get("point_radius_fixed") + self.point_radius_fixed = self.form_data.get("point_radius_fixed") self.fixed_value = None self.dim = self.form_data.get("dimension") if self.point_radius_fixed and self.point_radius_fixed.get("type") != "metric": @@ -2606,19 +2648,22 @@ class DeckScreengrid(BaseDeckGLViz): is_timeseries = True def query_obj(self) -> QueryObjectDict: - fd = self.form_data - self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity")) + self.is_timeseries = bool( + self.form_data.get("time_grain_sqla") or self.form_data.get("granularity") + ) return super().query_obj() - def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: return { - "position": d.get("spatial"), - "weight": (d.get(self.metric_label) if self.metric_label else None) or 1, - "__timestamp": d.get(DTTM_ALIAS) or d.get("__time"), + "position": data.get("spatial"), + "weight": (data.get(self.metric_label) if self.metric_label else None) or 1, + "__timestamp": data.get(DTTM_ALIAS) or data.get("__time"), } def get_data(self, df: pd.DataFrame) -> VizData: - self.metric_label = utils.get_metric_name(self.metric) if self.metric else None + self.metric_label = ( # pylint: disable=attribute-defined-outside-init + utils.get_metric_name(self.metric) if self.metric else None + ) return super().get_data(df) @@ -2630,25 +2675,27 @@ class DeckGrid(BaseDeckGLViz): verbose_name = _("Deck.gl - 3D Grid") spatial_control_keys = ["spatial"] - def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: return { - "position": d.get("spatial"), - "weight": (d.get(self.metric_label) if self.metric_label else None) or 1, + "position": data.get("spatial"), + "weight": (data.get(self.metric_label) if self.metric_label else None) or 1, } def get_data(self, df: pd.DataFrame) -> VizData: - self.metric_label = utils.get_metric_name(self.metric) if self.metric else None + self.metric_label = ( # pylint: disable=attribute-defined-outside-init + utils.get_metric_name(self.metric) if self.metric else None + ) return super().get_data(df) def geohash_to_json(geohash_code: str) -> List[List[float]]: - p = geohash.bbox(geohash_code) + bbox = geohash.bbox(geohash_code) return [ - [p.get("w"), p.get("n")], - [p.get("e"), p.get("n")], - [p.get("e"), p.get("s")], - [p.get("w"), p.get("s")], - [p.get("w"), p.get("n")], + [bbox.get("w"), bbox.get("n")], + [bbox.get("e"), bbox.get("n")], + [bbox.get("e"), bbox.get("s")], + [bbox.get("w"), bbox.get("s")], + [bbox.get("w"), bbox.get("n")], ] @@ -2667,35 +2714,38 @@ class DeckPathViz(BaseDeckGLViz): } def query_obj(self) -> QueryObjectDict: - fd = self.form_data - self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity")) - d = super().query_obj() - self.metric = fd.get("metric") - line_col = fd.get("line_column") - if d["metrics"]: + # pylint: disable=attribute-defined-outside-init + self.is_timeseries = bool( + self.form_data.get("time_grain_sqla") or self.form_data.get("granularity") + ) + query_obj = super().query_obj() + self.metric = self.form_data.get("metric") + line_col = self.form_data.get("line_column") + if query_obj["metrics"]: self.has_metrics = True - d["groupby"].append(line_col) + query_obj["groupby"].append(line_col) else: self.has_metrics = False - d["columns"].append(line_col) - return d + query_obj["columns"].append(line_col) + return query_obj - def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: - fd = self.form_data - line_type = fd["line_type"] + def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + line_type = self.form_data["line_type"] deser = self.deser_map[line_type] - line_column = fd["line_column"] - path = deser(d[line_column]) - if fd.get("reverse_long_lat"): + line_column = self.form_data["line_column"] + path = deser(data[line_column]) + if self.form_data.get("reverse_long_lat"): path = [(o[1], o[0]) for o in path] - d[self.deck_viz_key] = path + data[self.deck_viz_key] = path if line_type != "geohash": - del d[line_column] - d["__timestamp"] = d.get(DTTM_ALIAS) or d.get("__time") - return d + del data[line_column] + data["__timestamp"] = data.get(DTTM_ALIAS) or data.get("__time") + return data def get_data(self, df: pd.DataFrame) -> VizData: - self.metric_label = utils.get_metric_name(self.metric) if self.metric else None + self.metric_label = ( # pylint: disable=attribute-defined-outside-init + utils.get_metric_name(self.metric) if self.metric else None + ) return super().get_data(df) @@ -2708,8 +2758,11 @@ class DeckPolygon(DeckPathViz): verbose_name = _("Deck.gl - Polygon") def query_obj(self) -> QueryObjectDict: - fd = self.form_data - self.elevation = fd.get("point_radius_fixed") or {"type": "fix", "value": 500} + # pylint: disable=attribute-defined-outside-init + self.elevation = self.form_data.get("point_radius_fixed") or { + "type": "fix", + "value": 500, + } return super().query_obj() def get_metrics(self) -> List[str]: @@ -2718,15 +2771,16 @@ def get_metrics(self) -> List[str]: metrics.append(self.elevation.get("value")) return [metric for metric in metrics if metric] - def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: - super().get_properties(d) - fd = self.form_data - elevation = fd["point_radius_fixed"]["value"] - type_ = fd["point_radius_fixed"]["type"] - d["elevation"] = ( - d.get(utils.get_metric_name(elevation)) if type_ == "metric" else elevation + def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + super().get_properties(data) + elevation = self.form_data["point_radius_fixed"]["value"] + type_ = self.form_data["point_radius_fixed"]["type"] + data["elevation"] = ( + data.get(utils.get_metric_name(elevation)) + if type_ == "metric" + else elevation ) - return d + return data class DeckHex(BaseDeckGLViz): @@ -2737,15 +2791,17 @@ class DeckHex(BaseDeckGLViz): verbose_name = _("Deck.gl - 3D HEX") spatial_control_keys = ["spatial"] - def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: return { - "position": d.get("spatial"), - "weight": (d.get(self.metric_label) if self.metric_label else None) or 1, + "position": data.get("spatial"), + "weight": (data.get(self.metric_label) if self.metric_label else None) or 1, } def get_data(self, df: pd.DataFrame) -> VizData: - self.metric_label = utils.get_metric_name(self.metric) if self.metric else None - return super(DeckHex, self).get_data(df) + self.metric_label = ( # pylint: disable=attribute-defined-outside-init + utils.get_metric_name(self.metric) if self.metric else None + ) + return super().get_data(df) class DeckGeoJson(BaseDeckGLViz): @@ -2756,14 +2812,14 @@ class DeckGeoJson(BaseDeckGLViz): verbose_name = _("Deck.gl - GeoJSON") def query_obj(self) -> QueryObjectDict: - d = super().query_obj() - d["columns"] += [self.form_data.get("geojson")] - d["metrics"] = [] - d["groupby"] = [] - return d - - def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: - geojson = d[self.form_data["geojson"]] + query_obj = super().query_obj() + query_obj["columns"] += [self.form_data.get("geojson")] + query_obj["metrics"] = [] + query_obj["groupby"] = [] + return query_obj + + def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + geojson = data[self.form_data["geojson"]] return json.loads(geojson) @@ -2777,27 +2833,26 @@ class DeckArc(BaseDeckGLViz): is_timeseries = True def query_obj(self) -> QueryObjectDict: - fd = self.form_data - self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity")) + self.is_timeseries = bool( + self.form_data.get("time_grain_sqla") or self.form_data.get("granularity") + ) return super().query_obj() - def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: dim = self.form_data.get("dimension") return { - "sourcePosition": d.get("start_spatial"), - "targetPosition": d.get("end_spatial"), - "cat_color": d.get(dim) if dim else None, - DTTM_ALIAS: d.get(DTTM_ALIAS), + "sourcePosition": data.get("start_spatial"), + "targetPosition": data.get("end_spatial"), + "cat_color": data.get(dim) if dim else None, + DTTM_ALIAS: data.get(DTTM_ALIAS), } def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None - d = super().get_data(df) - return { - "features": d["features"], # type: ignore + "features": super().get_data(df)["features"], # type: ignore "mapboxApiKey": config["MAPBOX_API_KEY"], } @@ -2820,7 +2875,7 @@ def query_obj(self) -> QueryObjectDict: meta_keys = [ col for col in form_data["all_columns"] or [] - if col != event_key and col != entity_key + if col not in (event_key, entity_key) ] query["columns"] = [event_key, entity_key] + meta_keys @@ -2844,16 +2899,17 @@ class PairedTTestViz(BaseViz): is_timeseries = True def query_obj(self) -> QueryObjectDict: - d = super().query_obj() - metrics = self.form_data.get("metrics") + query_obj = super().query_obj() sort_by = self.form_data.get("timeseries_limit_metric") if sort_by: sort_by_label = utils.get_metric_name(sort_by) - if sort_by_label not in utils.get_metric_names(d["metrics"]): - d["metrics"].append(sort_by) + if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): + query_obj["metrics"].append(sort_by) if self.form_data.get("order_desc"): - d["orderby"] = [(sort_by, not self.form_data.get("order_desc", True))] - return d + query_obj["orderby"] = [ + (sort_by, not self.form_data.get("order_desc", True)) + ] + return query_obj def get_data(self, df: pd.DataFrame) -> VizData: """ @@ -2871,8 +2927,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None - fd = self.form_data - groups = fd.get("groupby") + groups = self.form_data.get("groupby") metrics = self.metric_labels df = df.pivot_table(index=DTTM_ALIAS, columns=groups, values=metrics) cols = [] @@ -2887,19 +2942,24 @@ def get_data(self, df: pd.DataFrame) -> VizData: df.columns = cols data: Dict[str, List[Dict[str, Any]]] = {} series = df.to_dict("series") - for nameSet in df.columns: + for name_set in df.columns: # If no groups are defined, nameSet will be the metric name - hasGroup = not isinstance(nameSet, str) - Y = series[nameSet] - d = { - "group": nameSet[1:] if hasGroup else "All", - "values": [{"x": t, "y": Y[t] if t in Y else None} for t in df.index], + has_group = not isinstance(name_set, str) + data_ = { + "group": name_set[1:] if has_group else "All", + "values": [ + { + "x": t, + "y": series[name_set][t] if t in series[name_set] else None, + } + for t in df.index + ], } - key = nameSet[0] if hasGroup else nameSet + key = name_set[0] if has_group else name_set if key in data: - data[key].append(d) + data[key].append(data_) else: - data[key] = [d] + data[key] = [data_] return data @@ -2950,8 +3010,9 @@ def query_obj(self) -> QueryObjectDict: query_obj["is_timeseries"] = time_op != "not_time" return query_obj + @staticmethod def levels_for( - self, time_op: str, groups: List[str], df: pd.DataFrame + time_op: str, groups: List[str], df: pd.DataFrame ) -> Dict[int, pd.Series]: """ Compute the partition at each `level` from the dataframe. @@ -2966,8 +3027,9 @@ def levels_for( ) return levels + @staticmethod def levels_for_diff( - self, time_op: str, groups: List[str], df: pd.DataFrame + time_op: str, groups: List[str], df: pd.DataFrame ) -> Dict[int, pd.DataFrame]: # Obtain a unique list of the time grains times = list(set(df[DTTM_ALIAS])) @@ -3094,10 +3156,9 @@ def nest_procs( def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None - fd = self.form_data - groups = fd.get("groupby", []) - time_op = fd.get("time_series_option", "not_time") - if not len(groups): + groups = self.form_data.get("groupby", []) + time_op = self.form_data.get("time_series_option", "not_time") + if not groups: raise ValueError("Please choose at least one groupby") if time_op == "not_time": levels = self.levels_for("agg_sum", groups, df) diff --git a/tests/integration_tests/charts/commands_tests.py b/tests/integration_tests/charts/commands_tests.py index 7e718f8ebd07e..a54e50ad81f0c 100644 --- a/tests/integration_tests/charts/commands_tests.py +++ b/tests/integration_tests/charts/commands_tests.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-self-use, invalid-name - import json from datetime import datetime from unittest.mock import patch @@ -313,7 +311,7 @@ class TestChartsUpdateCommand(SupersetTestCase): @patch("superset.security.manager.g") @pytest.mark.usefixtures("load_energy_table_with_slice") def test_update_v1_response(self, mock_sm_g, mock_g): - """"Test that a chart command updates properties""" + """Test that a chart command updates properties""" pk = db.session.query(Slice).all()[0].id actor = security_manager.find_user(username="admin") mock_g.user = mock_sm_g.user = actor @@ -334,7 +332,7 @@ def test_update_v1_response(self, mock_sm_g, mock_g): @patch("superset.security.manager.g") @pytest.mark.usefixtures("load_energy_table_with_slice") def test_query_context_update_command(self, mock_sm_g, mock_g): - """" + """ " Test that a user can generate the chart query context payloadwithout affecting owners """ diff --git a/tests/integration_tests/commands_test.py b/tests/integration_tests/commands_test.py index bd4206d6550be..1adf5bd646288 100644 --- a/tests/integration_tests/commands_test.py +++ b/tests/integration_tests/commands_test.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-self-use - from superset.commands.exceptions import CommandInvalidError from superset.commands.importers.v1.utils import is_valid_config from tests.integration_tests.base_tests import SupersetTestCase diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index 039ee20dd74ae..743c4af65be02 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -# pylint: disable=too-many-public-methods, no-self-use, invalid-name, too-many-arguments """Unit tests for Superset""" import json from io import BytesIO diff --git a/tests/integration_tests/dashboards/commands_tests.py b/tests/integration_tests/dashboards/commands_tests.py index 7e466b8fcdfd4..596dc404035b3 100644 --- a/tests/integration_tests/dashboards/commands_tests.py +++ b/tests/integration_tests/dashboards/commands_tests.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-self-use, invalid-name - import itertools import json from unittest.mock import MagicMock, patch diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 7c3406bf11ecf..bf87120c1b108 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -# pylint: disable=invalid-name, no-self-use, too-many-public-methods, too-many-arguments """Unit tests for Superset""" import dataclasses import json diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 7d762b10fb554..449dd4c59fafb 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-self-use, invalid-name from unittest import mock, skip from unittest.mock import patch diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 385025e3835cc..db6180b14e53c 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=too-many-public-methods, invalid-name """Unit tests for Superset""" import json import unittest diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py index 62e8414f79f1d..834d91f516276 100644 --- a/tests/integration_tests/datasets/commands_tests.py +++ b/tests/integration_tests/datasets/commands_tests.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-self-use, invalid-name, line-too-long - from operator import itemgetter from typing import Any, List from unittest.mock import patch diff --git a/tests/integration_tests/fixtures/importexport.py b/tests/integration_tests/fixtures/importexport.py index 78f643c587af1..e4a2ec29d8611 100644 --- a/tests/integration_tests/fixtures/importexport.py +++ b/tests/integration_tests/fixtures/importexport.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=line-too-long - from typing import Any, Dict, List # example V0 import/export format diff --git a/tests/integration_tests/fixtures/pyodbcRow.py b/tests/integration_tests/fixtures/pyodbcRow.py index e621286f1ba76..237b31524214b 100644 --- a/tests/integration_tests/fixtures/pyodbcRow.py +++ b/tests/integration_tests/fixtures/pyodbcRow.py @@ -14,10 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name - - -class Row(object): +class Row: def __init__(self, values): self.values = values diff --git a/tests/integration_tests/security/api_tests.py b/tests/integration_tests/security/api_tests.py index 8cae667cd0b56..6a81efaa677ca 100644 --- a/tests/integration_tests/security/api_tests.py +++ b/tests/integration_tests/security/api_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -# pylint: disable=too-many-public-methods, no-self-use, invalid-name, too-many-arguments """Unit tests for Superset""" import json diff --git a/tests/integration_tests/sql_validator_tests.py b/tests/integration_tests/sql_validator_tests.py index a569fff345cb5..e26b8416b7fa4 100644 --- a/tests/integration_tests/sql_validator_tests.py +++ b/tests/integration_tests/sql_validator_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -# pylint: disable=invalid-name, no-self-use """Unit tests for Sql Lab""" import unittest from unittest.mock import MagicMock, patch diff --git a/tests/integration_tests/utils/core_tests.py b/tests/integration_tests/utils/core_tests.py index 3ba39032e4265..29b94d6d37eef 100644 --- a/tests/integration_tests/utils/core_tests.py +++ b/tests/integration_tests/utils/core_tests.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-self-use import pytest from superset.utils.core import form_data_to_adhoc, simple_filter_to_adhoc diff --git a/tests/integration_tests/utils/csv_tests.py b/tests/integration_tests/utils/csv_tests.py index e992fcb5931ca..bf6110c639aa4 100644 --- a/tests/integration_tests/utils/csv_tests.py +++ b/tests/integration_tests/utils/csv_tests.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-self-use import io import pandas as pd diff --git a/tests/integration_tests/utils/hashing_tests.py b/tests/integration_tests/utils/hashing_tests.py index 8931ff15c4ab6..406d383d7cfdd 100644 --- a/tests/integration_tests/utils/hashing_tests.py +++ b/tests/integration_tests/utils/hashing_tests.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-self-use import datetime import math from typing import Any diff --git a/tests/integration_tests/utils/public_interfaces_test.py b/tests/integration_tests/utils/public_interfaces_test.py index 03bd0dee42031..7b5d6712464df 100644 --- a/tests/integration_tests/utils/public_interfaces_test.py +++ b/tests/integration_tests/utils/public_interfaces_test.py @@ -14,19 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-self-use +from typing import Any, Callable, Dict + import pytest -from superset.sql_lab import dummy_sql_query_mutator from superset.utils.public_interfaces import compute_hash, get_warning_message -from tests.integration_tests.base_tests import SupersetTestCase # These are public interfaces exposed by Superset. Make sure # to only change the interfaces and update the hashes in new # major versions of Superset. -hashes = { - dummy_sql_query_mutator: "Kv%NM3b;7BcpoD2wbPkW", -} +hashes: Dict[Callable[..., Any], str] = {} @pytest.mark.parametrize("interface,expected_hash", list(hashes.items())) diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index 0caa7afc8e446..e0a7e2fc23cba 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-argument, invalid-name from flask.ctx import AppContext from pytest_mock import MockFixture diff --git a/tests/unit_tests/tasks/test_cron_util.py b/tests/unit_tests/tasks/test_cron_util.py index 8a4a511351c8c..c905df2797d35 100644 --- a/tests/unit_tests/tasks/test_cron_util.py +++ b/tests/unit_tests/tasks/test_cron_util.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-argument, invalid-name - from datetime import datetime from typing import List