diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 049503bbbf505..a0801792aa484 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -110,7 +110,7 @@ For large features or major changes to codebase, please create **Superset Improv ### Fix Bugs Look through the GitHub issues. Issues tagged with `#bug` are -open to whoever wants to implement it. +open to whoever wants to implement them. ### Implement Features diff --git a/superset/__init__.py b/superset/__init__.py index 0feaca91fd987..01d98f55fd621 100644 --- a/superset/__init__.py +++ b/superset/__init__.py @@ -35,7 +35,7 @@ from superset.connectors.connector_registry import ConnectorRegistry from superset.security import SupersetSecurityManager from superset.utils.core import pessimistic_connection_handling, setup_cache -from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value +from superset.utils.log import get_event_logger_from_cfg_value wtforms_json.init() @@ -132,19 +132,19 @@ def get_manifest(): app.config["LOGGING_CONFIGURATOR"].configure_logging(app.config, app.debug) -if app.config.get("ENABLE_CORS"): +if app.config["ENABLE_CORS"]: from flask_cors import CORS - CORS(app, **app.config.get("CORS_OPTIONS")) + CORS(app, **app.config["CORS_OPTIONS"]) -if app.config.get("ENABLE_PROXY_FIX"): +if app.config["ENABLE_PROXY_FIX"]: from werkzeug.middleware.proxy_fix import ProxyFix app.wsgi_app = ProxyFix( # type: ignore - app.wsgi_app, **app.config.get("PROXY_FIX_CONFIG") + app.wsgi_app, **app.config["PROXY_FIX_CONFIG"] ) -if app.config.get("ENABLE_CHUNK_ENCODING"): +if app.config["ENABLE_CHUNK_ENCODING"]: class ChunkedEncodingFix(object): def __init__(self, app): @@ -175,7 +175,7 @@ def index(self): return redirect("/superset/welcome") -custom_sm = app.config.get("CUSTOM_SECURITY_MANAGER") or SupersetSecurityManager +custom_sm = app.config["CUSTOM_SECURITY_MANAGER"] or SupersetSecurityManager if not issubclass(custom_sm, SupersetSecurityManager): raise Exception( """Your CUSTOM_SECURITY_MANAGER must now extend SupersetSecurityManager, @@ -195,21 +195,19 @@ def index(self): security_manager = appbuilder.sm -results_backend = app.config.get("RESULTS_BACKEND") -results_backend_use_msgpack = app.config.get("RESULTS_BACKEND_USE_MSGPACK") +results_backend = app.config["RESULTS_BACKEND"] +results_backend_use_msgpack = app.config["RESULTS_BACKEND_USE_MSGPACK"] # Merge user defined feature flags with default feature flags -_feature_flags = app.config.get("DEFAULT_FEATURE_FLAGS") or {} -_feature_flags.update(app.config.get("FEATURE_FLAGS") or {}) +_feature_flags = app.config["DEFAULT_FEATURE_FLAGS"] +_feature_flags.update(app.config["FEATURE_FLAGS"]) # Event Logger -event_logger = get_event_logger_from_cfg_value( - app.config.get("EVENT_LOGGER", DBEventLogger()) -) +event_logger = get_event_logger_from_cfg_value(app.config["EVENT_LOGGER"]) def get_feature_flags(): - GET_FEATURE_FLAGS_FUNC = app.config.get("GET_FEATURE_FLAGS_FUNC") + GET_FEATURE_FLAGS_FUNC = app.config["GET_FEATURE_FLAGS_FUNC"] if GET_FEATURE_FLAGS_FUNC: return GET_FEATURE_FLAGS_FUNC(deepcopy(_feature_flags)) return _feature_flags @@ -232,7 +230,7 @@ def is_feature_enabled(feature): # Hook that provides administrators a handle on the Flask APP # after initialization -flask_app_mutator = app.config.get("FLASK_APP_MUTATOR") +flask_app_mutator = app.config["FLASK_APP_MUTATOR"] if flask_app_mutator: flask_app_mutator(app) diff --git a/superset/cli.py b/superset/cli.py index a32449eeb8061..8e695bbf292c6 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -62,7 +62,7 @@ def version(verbose): Fore.YELLOW + "Superset " + Fore.CYAN - + "{version}".format(version=config.get("VERSION_STRING")) + + "{version}".format(version=config["VERSION_STRING"]) ) print(Fore.BLUE + "-=" * 15) if verbose: @@ -372,10 +372,8 @@ def worker(workers): ) if workers: celery_app.conf.update(CELERYD_CONCURRENCY=workers) - elif config.get("SUPERSET_CELERY_WORKERS"): - celery_app.conf.update( - CELERYD_CONCURRENCY=config.get("SUPERSET_CELERY_WORKERS") - ) + elif config["SUPERSET_CELERY_WORKERS"]: + celery_app.conf.update(CELERYD_CONCURRENCY=config["SUPERSET_CELERY_WORKERS"]) worker = celery_app.Worker(optimization="fair") worker.start() @@ -428,7 +426,7 @@ def load_test_users_run(): Syncs permissions for those users/roles """ - if config.get("TESTING"): + if config["TESTING"]: sm = security_manager diff --git a/superset/common/query_object.py b/superset/common/query_object.py index b2822e37daca1..21649d1d0a32c 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -66,8 +66,8 @@ def __init__( extras: Optional[Dict] = None, columns: Optional[List[str]] = None, orderby: Optional[List[List]] = None, - relative_start: str = app.config.get("DEFAULT_RELATIVE_START_TIME", "today"), - relative_end: str = app.config.get("DEFAULT_RELATIVE_END_TIME", "today"), + relative_start: str = app.config["DEFAULT_RELATIVE_START_TIME"], + relative_end: str = app.config["DEFAULT_RELATIVE_END_TIME"], ): self.granularity = granularity self.from_dttm, self.to_dttm = utils.get_since_until( diff --git a/superset/config.py b/superset/config.py index b2632855106e9..838c3c6e9600c 100644 --- a/superset/config.py +++ b/superset/config.py @@ -35,10 +35,14 @@ from flask_appbuilder.security.manager import AUTH_DB from superset.stats_logger import DummyStatsLogger +from superset.utils.log import DBEventLogger from superset.utils.logging_configurator import DefaultLoggingConfigurator # Realtime stats logger, a StatsD implementation exists STATS_LOGGER = DummyStatsLogger() +EVENT_LOGGER = DBEventLogger() + +SUPERSET_LOG_VIEW = True BASE_DIR = os.path.abspath(os.path.dirname(__file__)) if "SUPERSET_HOME" in os.environ: @@ -109,6 +113,7 @@ def _try_json_readfile(filepath): # def lookup_password(url): # return 'secret' # SQLALCHEMY_CUSTOM_PASSWORD_STORE = lookup_password +SQLALCHEMY_CUSTOM_PASSWORD_STORE = None # The limit of queries fetched for query search QUERY_SEARCH_LIMIT = 1000 @@ -232,6 +237,9 @@ def _try_json_readfile(filepath): "PRESTO_EXPAND_DATA": False, } +# This is merely a default. +FEATURE_FLAGS: Dict[str, bool] = {} + # A function that receives a dict of all feature flags # (DEFAULT_FEATURE_FLAGS merged with FEATURE_FLAGS) # can alter it, and returns a similar dict. Note the dict of feature @@ -371,6 +379,7 @@ def _try_json_readfile(filepath): # security_manager=None, # ): # pass +QUERY_LOGGER = None # Set this API key to enable Mapbox visualizations MAPBOX_API_KEY = os.environ.get("MAPBOX_API_KEY", "") @@ -444,6 +453,7 @@ class CeleryConfig(object): # override anything set within the app DEFAULT_HTTP_HEADERS: Dict[str, Any] = {} OVERRIDE_HTTP_HEADERS: Dict[str, Any] = {} +HTTP_HEADERS: Dict[str, Any] = {} # The db id here results in selecting this one as a default in SQL Lab DEFAULT_DB_ID = None @@ -522,13 +532,18 @@ class CeleryConfig(object): SMTP_MAIL_FROM = "superset@superset.com" if not CACHE_DEFAULT_TIMEOUT: - CACHE_DEFAULT_TIMEOUT = CACHE_CONFIG.get("CACHE_DEFAULT_TIMEOUT") # type: ignore + CACHE_DEFAULT_TIMEOUT = CACHE_CONFIG["CACHE_DEFAULT_TIMEOUT"] + + +ENABLE_CHUNK_ENCODING = False # Whether to bump the logging level to ERROR on the flask_appbuilder package # Set to False if/when debugging FAB related issues like # permission management SILENCE_FAB = True +FAB_ADD_SECURITY_VIEWS = True + # The link to a page containing common errors and their resolutions # It will be appended at the bottom of sql_lab errors. TROUBLESHOOTING_LINK = "" diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 5504eb370c748..0a40bcb17c968 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -544,7 +544,7 @@ def mutate_query_from_config(self, sql: str) -> str: """Apply config's SQL_QUERY_MUTATOR Typically adds comments to the query with context""" - SQL_QUERY_MUTATOR = config.get("SQL_QUERY_MUTATOR") + SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"] if SQL_QUERY_MUTATOR: username = utils.get_username() sql = SQL_QUERY_MUTATOR(sql, username, security_manager, self.database) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 97e214f42f820..4071538229eaa 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -183,7 +183,7 @@ def get_time_grains(cls) -> Tuple[TimeGrain, ...]: ret_list = [] time_grain_functions = cls.get_time_grain_functions() time_grains = builtin_time_grains.copy() - time_grains.update(config.get("TIME_GRAIN_ADDONS", {})) + time_grains.update(config["TIME_GRAIN_ADDONS"]) for duration, func in time_grain_functions.items(): if duration in time_grains: name = time_grains[duration] @@ -200,9 +200,9 @@ def get_time_grain_functions(cls) -> Dict[Optional[str], str]: """ # TODO: use @memoize decorator or similar to avoid recomputation on every call time_grain_functions = cls._time_grain_functions.copy() - grain_addon_functions = config.get("TIME_GRAIN_ADDON_FUNCTIONS", {}) + grain_addon_functions = config["TIME_GRAIN_ADDON_FUNCTIONS"] time_grain_functions.update(grain_addon_functions.get(cls.engine, {})) - blacklist: List[str] = config.get("TIME_GRAIN_BLACKLIST", []) + blacklist: List[str] = config["TIME_GRAIN_BLACKLIST"] for key in blacklist: time_grain_functions.pop(key) return time_grain_functions diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index dbcb337fe9bd6..cf1b0d0ab770c 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -122,16 +122,16 @@ def convert_to_hive_type(col_type): table_name = form.name.data schema_name = form.schema.data - if config.get("UPLOADED_CSV_HIVE_NAMESPACE"): + if config["UPLOADED_CSV_HIVE_NAMESPACE"]: if "." in table_name or schema_name: raise Exception( "You can't specify a namespace. " "All tables will be uploaded to the `{}` namespace".format( - config.get("HIVE_NAMESPACE") + config["HIVE_NAMESPACE"] ) ) full_table_name = "{}.{}".format( - config.get("UPLOADED_CSV_HIVE_NAMESPACE"), table_name + config["UPLOADED_CSV_HIVE_NAMESPACE"], table_name ) else: if "." in table_name and schema_name: diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 66e798234d674..44215440e993d 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -455,7 +455,7 @@ def estimate_statement_cost( # pylint: disable=too-many-locals parsed_query = ParsedQuery(statement) sql = parsed_query.stripped() - sql_query_mutator = config.get("SQL_QUERY_MUTATOR") + sql_query_mutator = config["SQL_QUERY_MUTATOR"] if sql_query_mutator: sql = sql_query_mutator(sql, user_name, security_manager, database) diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index bad084d8a5e50..ec79dd70d2041 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -122,7 +122,7 @@ def load_birth_names(only_metadata=False, force=False): "optionName": "metric_11", } ], - "row_limit": config.get("ROW_LIMIT"), + "row_limit": config["ROW_LIMIT"], "since": "100 years ago", "until": "now", "viz_type": "table", diff --git a/superset/examples/multiformat_time_series.py b/superset/examples/multiformat_time_series.py index 4a03f967d3e4c..84ac93e26b880 100644 --- a/superset/examples/multiformat_time_series.py +++ b/superset/examples/multiformat_time_series.py @@ -95,7 +95,7 @@ def load_multiformat_time_series(only_metadata=False, force=False): slice_data = { "metrics": ["count"], "granularity_sqla": col.column_name, - "row_limit": config.get("ROW_LIMIT"), + "row_limit": config["ROW_LIMIT"], "since": "2015", "until": "2016", "where": "", diff --git a/superset/examples/random_time_series.py b/superset/examples/random_time_series.py index 2e8f51f7bf6a5..eb1e82721ea16 100644 --- a/superset/examples/random_time_series.py +++ b/superset/examples/random_time_series.py @@ -58,7 +58,7 @@ def load_random_time_series_data(only_metadata=False, force=False): slice_data = { "granularity_sqla": "day", - "row_limit": config.get("ROW_LIMIT"), + "row_limit": config["ROW_LIMIT"], "since": "1 year ago", "until": "now", "metric": "count", diff --git a/superset/examples/unicode_test_data.py b/superset/examples/unicode_test_data.py index 1c88456c80d16..14e691852a5b0 100644 --- a/superset/examples/unicode_test_data.py +++ b/superset/examples/unicode_test_data.py @@ -87,7 +87,7 @@ def load_unicode_test_data(only_metadata=False, force=False): "expressionType": "SIMPLE", "label": "Value", }, - "row_limit": config.get("ROW_LIMIT"), + "row_limit": config["ROW_LIMIT"], "since": "100 years ago", "until": "now", "where": "", diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index 699eb70450b8c..34cb394b11729 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -104,7 +104,7 @@ def load_world_bank_health_n_pop(only_metadata=False, force=False): "groupby": [], "metric": "sum__SP_POP_TOTL", "metrics": ["sum__SP_POP_TOTL"], - "row_limit": config.get("ROW_LIMIT"), + "row_limit": config["ROW_LIMIT"], "since": "2014-01-01", "until": "2014-01-02", "time_range": "2014-01-01 : 2014-01-02", diff --git a/superset/jinja_context.py b/superset/jinja_context.py index c48218a54de5a..f522de3440c88 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -39,7 +39,7 @@ "timedelta": timedelta, "uuid": uuid, } -BASE_CONTEXT.update(config.get("JINJA_CONTEXT_ADDONS", {})) +BASE_CONTEXT.update(config["JINJA_CONTEXT_ADDONS"]) def url_param(param: str, default: Optional[str] = None) -> Optional[Any]: diff --git a/superset/migrations/env.py b/superset/migrations/env.py index 7e647d9063e57..2e81bc06beb3c 100755 --- a/superset/migrations/env.py +++ b/superset/migrations/env.py @@ -33,9 +33,7 @@ logger = logging.getLogger("alembic.env") -config.set_main_option( - "sqlalchemy.url", current_app.config.get("SQLALCHEMY_DATABASE_URI") -) +config.set_main_option("sqlalchemy.url", current_app.config["SQLALCHEMY_DATABASE_URI"]) target_metadata = Base.metadata # pylint: disable=no-member # other values from the config, defined by the needs of env.py, diff --git a/superset/models/core.py b/superset/models/core.py index c9b63da449199..a177d26c077a2 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -63,9 +63,9 @@ from superset.viz import viz_types config = app.config -custom_password_store = config.get("SQLALCHEMY_CUSTOM_PASSWORD_STORE") -stats_logger = config.get("STATS_LOGGER") -log_query = config.get("QUERY_LOGGER") +custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] +stats_logger = config["STATS_LOGGER"] +log_query = config["QUERY_LOGGER"] metadata = Model.metadata # pylint: disable=no-member PASSWORD_MASK = "X" * 10 @@ -81,7 +81,7 @@ def set_related_perm(mapper, connection, target): def copy_dashboard(mapper, connection, target): - dashboard_id = config.get("DASHBOARD_TEMPLATE_ID") + dashboard_id = config["DASHBOARD_TEMPLATE_ID"] if dashboard_id is None: return @@ -725,7 +725,7 @@ class Database(Model, AuditMixinNullable, ImportMixin): # short unique name, used in permissions database_name = Column(String(250), unique=True, nullable=False) sqlalchemy_uri = Column(String(1024)) - password = Column(EncryptedType(String(1024), config.get("SECRET_KEY"))) + password = Column(EncryptedType(String(1024), config["SECRET_KEY"])) cache_timeout = Column(Integer) select_as_create_table_as = Column(Boolean, default=False) expose_in_sqllab = Column(Boolean, default=True) @@ -906,7 +906,7 @@ def get_sqla_engine(self, schema=None, nullpool=True, user_name=None, source=Non params.update(self.get_encrypted_extra()) - DB_CONNECTION_MUTATOR = config.get("DB_CONNECTION_MUTATOR") + DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"] if DB_CONNECTION_MUTATOR: url, params = DB_CONNECTION_MUTATOR( url, params, effective_username, security_manager, source @@ -1262,7 +1262,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable): datasource_id = Column(Integer) datasource_type = Column(String(200)) - ROLES_BLACKLIST = set(config.get("ROBOT_PERMISSION_ROLES", [])) + ROLES_BLACKLIST = set(config["ROBOT_PERMISSION_ROLES"]) @property def cls_model(self): diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 9050f06d95bbd..8cbc9346b2830 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -50,10 +50,10 @@ from superset.utils.decorators import stats_timing config = app.config -stats_logger = config.get("STATS_LOGGER") -SQLLAB_TIMEOUT = config.get("SQLLAB_ASYNC_TIME_LIMIT_SEC", 600) +stats_logger = config["STATS_LOGGER"] +SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"] SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60 -log_query = config.get("QUERY_LOGGER") +log_query = config["QUERY_LOGGER"] class SqlLabException(Exception): @@ -114,7 +114,7 @@ def session_scope(nullpool): """Provide a transactional scope around a series of operations.""" if nullpool: engine = sqlalchemy.create_engine( - app.config.get("SQLALCHEMY_DATABASE_URI"), poolclass=NullPool + app.config["SQLALCHEMY_DATABASE_URI"], poolclass=NullPool ) session_class = sessionmaker() session_class.configure(bind=engine) @@ -177,7 +177,7 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor): db_engine_spec = database.db_engine_spec parsed_query = ParsedQuery(sql_statement) sql = parsed_query.stripped() - SQL_MAX_ROWS = app.config.get("SQL_MAX_ROW") + SQL_MAX_ROWS = app.config["SQL_MAX_ROW"] if not parsed_query.is_readonly() and not database.allow_dml: raise SqlLabSecurityException( @@ -205,7 +205,7 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor): sql = database.apply_limit_to_sql(sql, query.limit) # Hook to allow environment-specific mutation (usually comments) to the SQL - SQL_QUERY_MUTATOR = config.get("SQL_QUERY_MUTATOR") + SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"] if SQL_QUERY_MUTATOR: sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database) @@ -400,7 +400,7 @@ def execute_sql_statements( ) cache_timeout = database.cache_timeout if cache_timeout is None: - cache_timeout = config.get("CACHE_DEFAULT_TIMEOUT", 0) + cache_timeout = config["CACHE_DEFAULT_TIMEOUT"] compressed = zlib_compress(serialized_payload) logging.debug( diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index f684db652cc81..90d22f0e6b3bf 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -52,7 +52,7 @@ def validate_statement( # Hook to allow environment-specific mutation (usually comments) to the SQL # pylint: disable=invalid-name - SQL_QUERY_MUTATOR = config.get("SQL_QUERY_MUTATOR") + SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"] if SQL_QUERY_MUTATOR: sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database) diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py index 6bda5b7ba2b6c..64d9b4dc57aae 100644 --- a/superset/tasks/schedules.py +++ b/superset/tasks/schedules.py @@ -60,7 +60,7 @@ def _get_recipients(schedule): - bcc = config.get("EMAIL_REPORT_BCC_ADDRESS", None) + bcc = config["EMAIL_REPORT_BCC_ADDRESS"] if schedule.deliver_as_group: to = schedule.recipients @@ -81,7 +81,7 @@ def _deliver_email(schedule, subject, email): images=email.images, bcc=bcc, mime_subtype="related", - dryrun=config.get("SCHEDULED_EMAIL_DEBUG_MODE"), + dryrun=config["SCHEDULED_EMAIL_DEBUG_MODE"], ) @@ -97,7 +97,7 @@ def _generate_mail_content(schedule, screenshot, name, url): elif schedule.delivery_type == EmailDeliveryType.inline: # Get the domain from the 'From' address .. # and make a message id without the < > in the ends - domain = parseaddr(config.get("SMTP_MAIL_FROM"))[1].split("@")[1] + domain = parseaddr(config["SMTP_MAIL_FROM"])[1].split("@")[1] msgid = make_msgid(domain)[1:-1] images = {msgid: screenshot} @@ -118,7 +118,7 @@ def _generate_mail_content(schedule, screenshot, name, url): def _get_auth_cookies(): # Login with the user specified to get the reports with app.test_request_context(): - user = security_manager.find_user(config.get("EMAIL_REPORTS_USER")) + user = security_manager.find_user(config["EMAIL_REPORTS_USER"]) login_user(user) # A mock response object to get the cookie information from @@ -139,16 +139,16 @@ def _get_auth_cookies(): def _get_url_path(view, **kwargs): with app.test_request_context(): return urllib.parse.urljoin( - str(config.get("WEBDRIVER_BASEURL")), url_for(view, **kwargs) + str(config["WEBDRIVER_BASEURL"]), url_for(view, **kwargs) ) def create_webdriver(): # Create a webdriver for use in fetching reports - if config.get("EMAIL_REPORTS_WEBDRIVER") == "firefox": + if config["EMAIL_REPORTS_WEBDRIVER"] == "firefox": driver_class = firefox.webdriver.WebDriver options = firefox.options.Options() - elif config.get("EMAIL_REPORTS_WEBDRIVER") == "chrome": + elif config["EMAIL_REPORTS_WEBDRIVER"] == "chrome": driver_class = chrome.webdriver.WebDriver options = chrome.options.Options() @@ -156,7 +156,7 @@ def create_webdriver(): # Prepare args for the webdriver init kwargs = dict(options=options) - kwargs.update(config.get("WEBDRIVER_CONFIGURATION")) + kwargs.update(config["WEBDRIVER_CONFIGURATION"]) # Initialize the driver driver = driver_class(**kwargs) @@ -208,7 +208,7 @@ def deliver_dashboard(schedule): # Create a driver, fetch the page, wait for the page to render driver = create_webdriver() - window = config.get("WEBDRIVER_WINDOW")["dashboard"] + window = config["WEBDRIVER_WINDOW"]["dashboard"] driver.set_window_size(*window) driver.get(dashboard_url) time.sleep(PAGE_RENDER_WAIT) @@ -236,7 +236,7 @@ def deliver_dashboard(schedule): subject = __( "%(prefix)s %(title)s", - prefix=config.get("EMAIL_REPORTS_SUBJECT_PREFIX"), + prefix=config["EMAIL_REPORTS_SUBJECT_PREFIX"], title=dashboard.dashboard_title, ) @@ -296,7 +296,7 @@ def _get_slice_visualization(schedule): # Create a driver, fetch the page, wait for the page to render driver = create_webdriver() - window = config.get("WEBDRIVER_WINDOW")["slice"] + window = config["WEBDRIVER_WINDOW"]["slice"] driver.set_window_size(*window) slice_url = _get_url_path("Superset.slice", slice_id=slc.id) @@ -339,7 +339,7 @@ def deliver_slice(schedule): subject = __( "%(prefix)s %(title)s", - prefix=config.get("EMAIL_REPORTS_SUBJECT_PREFIX"), + prefix=config["EMAIL_REPORTS_SUBJECT_PREFIX"], title=schedule.slice.slice_name, ) @@ -413,11 +413,11 @@ def schedule_window(report_type, start_at, stop_at, resolution): def schedule_hourly(): """ Celery beat job meant to be invoked hourly """ - if not config.get("ENABLE_SCHEDULED_EMAIL_REPORTS"): + if not config["ENABLE_SCHEDULED_EMAIL_REPORTS"]: logging.info("Scheduled email reports not enabled in config") return - resolution = config.get("EMAIL_REPORTS_CRON_RESOLUTION", 0) * 60 + resolution = config["EMAIL_REPORTS_CRON_RESOLUTION"] * 60 # Get the top of the hour start_at = datetime.now(tzlocal()).replace(microsecond=0, second=0, minute=0) diff --git a/superset/templates/appbuilder/navbar.html b/superset/templates/appbuilder/navbar.html index 72eefdfc5e13f..312596657a731 100644 --- a/superset/templates/appbuilder/navbar.html +++ b/superset/templates/appbuilder/navbar.html @@ -18,9 +18,9 @@ #} {% set menu = appbuilder.menu %} {% set languages = appbuilder.languages %} -{% set WARNING_MSG = appbuilder.app.config.get('WARNING_MSG') %} -{% set app_icon_width = appbuilder.app.config.get('APP_ICON_WIDTH', 126) %} -{% set logo_target_path = appbuilder.app.config.get('LOGO_TARGET_PATH') or '/profile/{}/'.format(current_user.username) %} +{% set WARNING_MSG = appbuilder.app.config['WARNING_MSG'] %} +{% set app_icon_width = appbuilder.app.config['APP_ICON_WIDTH'] %} +{% set logo_target_path = appbuilder.app.config['LOGO_TARGET_PATH'] or '/profile/{}/'.format(current_user.username) %} {% set root_path = logo_target_path if not logo_target_path.startswith('/') else '/superset' + logo_target_path if current_user.username is defined else '#' %} <div class="navbar navbar-static-top {{menu.extra_classes}}" role="navigation"> diff --git a/superset/templates/appbuilder/navbar_right.html b/superset/templates/appbuilder/navbar_right.html index 73c545e3974ad..41e526dad9f99 100644 --- a/superset/templates/appbuilder/navbar_right.html +++ b/superset/templates/appbuilder/navbar_right.html @@ -17,8 +17,8 @@ under the License. #} -{% set bug_report_url = appbuilder.app.config.get('BUG_REPORT_URL') %} -{% set documentation_url = appbuilder.app.config.get('DOCUMENTATION_URL') %} +{% set bug_report_url = appbuilder.app.config['BUG_REPORT_URL'] %} +{% set documentation_url = appbuilder.app.config['DOCUMENTATION_URL'] %} {% set locale = session['locale'] %} {% if not locale %} {% set locale = 'en' %} diff --git a/superset/utils/core.py b/superset/utils/core.py index a27ff71a0fc4a..2e79c893a7005 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -668,7 +668,7 @@ def notify_user_about_perm_udate(granter, user, role, datasource, tpl_name, conf msg, config, bcc=granter.email, - dryrun=not config.get("EMAIL_NOTIFICATIONS"), + dryrun=not config["EMAIL_NOTIFICATIONS"], ) @@ -690,7 +690,7 @@ def send_email_smtp( send_email_smtp( 'test@example.com', 'foo', '<b>Foo</b> bar',['/dev/null'], dryrun=True) """ - smtp_mail_from = config.get("SMTP_MAIL_FROM") + smtp_mail_from = config["SMTP_MAIL_FROM"] to = get_email_address_list(to) msg = MIMEMultipart(mime_subtype) @@ -746,12 +746,12 @@ def send_email_smtp( def send_MIME_email(e_from, e_to, mime_msg, config, dryrun=False): - SMTP_HOST = config.get("SMTP_HOST") - SMTP_PORT = config.get("SMTP_PORT") - SMTP_USER = config.get("SMTP_USER") - SMTP_PASSWORD = config.get("SMTP_PASSWORD") - SMTP_STARTTLS = config.get("SMTP_STARTTLS") - SMTP_SSL = config.get("SMTP_SSL") + SMTP_HOST = config["SMTP_HOST"] + SMTP_PORT = config["SMTP_PORT"] + SMTP_USER = config["SMTP_USER"] + SMTP_PASSWORD = config["SMTP_PASSWORD"] + SMTP_STARTTLS = config["SMTP_STARTTLS"] + SMTP_SSL = config["SMTP_SSL"] if not dryrun: s = ( @@ -794,7 +794,7 @@ def setup_cache(app: Flask, cache_config) -> Optional[Cache]: """Setup the flask-cache on a flask app""" if cache_config: if isinstance(cache_config, dict): - if cache_config.get("CACHE_TYPE") != "null": + if cache_config["CACHE_TYPE"] != "null": return Cache(app, config=cache_config) else: # Accepts a custom cache initialization function, @@ -839,7 +839,7 @@ def get_celery_app(config): if _celery_app: return _celery_app _celery_app = celery.Celery() - _celery_app.config_from_object(config.get("CELERY_CONFIG")) + _celery_app.config_from_object(config["CELERY_CONFIG"]) _celery_app.set_default() return _celery_app @@ -1210,7 +1210,7 @@ class DatasourceName(NamedTuple): def get_stacktrace(): - if current_app.config.get("SHOW_STACKTRACE"): + if current_app.config["SHOW_STACKTRACE"]: return traceback.format_exc() diff --git a/superset/utils/log.py b/superset/utils/log.py index 85d41398c9287..ca8bd98bc089e 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -84,7 +84,7 @@ def wrapper(*args, **kwargs): @property def stats_logger(self): - return current_app.config.get("STATS_LOGGER") + return current_app.config["STATS_LOGGER"] def get_event_logger_from_cfg_value(cfg_value: object) -> AbstractEventLogger: diff --git a/superset/utils/logging_configurator.py b/superset/utils/logging_configurator.py index 37d6717fdd973..a145bf9107a21 100644 --- a/superset/utils/logging_configurator.py +++ b/superset/utils/logging_configurator.py @@ -35,7 +35,7 @@ class DefaultLoggingConfigurator(LoggingConfigurator): def configure_logging( self, app_config: flask.config.Config, debug_mode: bool ) -> None: - if app_config.get("SILENCE_FAB"): + if app_config["SILENCE_FAB"]: logging.getLogger("flask_appbuilder").setLevel(logging.ERROR) # configure superset app logger @@ -54,13 +54,13 @@ def configure_logging( logging.basicConfig(format=app_config["LOG_FORMAT"]) logging.getLogger().setLevel(app_config["LOG_LEVEL"]) - if app_config.get("ENABLE_TIME_ROTATE"): + if app_config["ENABLE_TIME_ROTATE"]: logging.getLogger().setLevel(app_config["TIME_ROTATE_LOG_LEVEL"]) handler = TimedRotatingFileHandler( # type: ignore - app_config.get("FILENAME"), - when=app_config.get("ROLLOVER"), - interval=app_config.get("INTERVAL"), - backupCount=app_config.get("BACKUP_COUNT"), + app_config["FILENAME"], + when=app_config["ROLLOVER"], + interval=app_config["INTERVAL"], + backupCount=app_config["BACKUP_COUNT"], ) logging.getLogger().addHandler(handler) diff --git a/superset/views/core.py b/superset/views/core.py index 69a4ff1e58a2e..8fe91e17607fe 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -110,11 +110,9 @@ ) config = app.config -CACHE_DEFAULT_TIMEOUT = config.get("CACHE_DEFAULT_TIMEOUT", 0) -SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = config.get( - "SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT", 10 -) -stats_logger = config.get("STATS_LOGGER") +CACHE_DEFAULT_TIMEOUT = config["CACHE_DEFAULT_TIMEOUT"] +SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"] +stats_logger = config["STATS_LOGGER"] DAR = models.DatasourceAccessRequest QueryStatus = utils.QueryStatus @@ -127,7 +125,7 @@ USER_MISSING_ERR = __("The user seems to have been deleted") FORM_DATA_KEY_BLACKLIST: List[str] = [] -if not config.get("ENABLE_JAVASCRIPT_CONTROLS"): +if not config["ENABLE_JAVASCRIPT_CONTROLS"]: FORM_DATA_KEY_BLACKLIST = ["js_tooltip", "js_onclick_href", "js_data_mutator"] @@ -296,7 +294,7 @@ def apply(self, query, func): return query -if config.get("ENABLE_ACCESS_REQUEST"): +if config["ENABLE_ACCESS_REQUEST"]: class AccessRequestsModelView(SupersetModelView, DeleteMixin): datamodel = SQLAInterface(DAR) @@ -1159,7 +1157,7 @@ def explore(self, datasource_type=None, datasource_id=None): flash(DATASOURCE_MISSING_ERR, "danger") return redirect(error_redirect) - if config.get("ENABLE_ACCESS_REQUEST") and ( + if config["ENABLE_ACCESS_REQUEST"] and ( not security_manager.datasource_access(datasource) ): flash( @@ -1273,9 +1271,7 @@ def filter(self, datasource_type, datasource_id, column): return json_error_response(DATASOURCE_MISSING_ERR) security_manager.assert_datasource_permission(datasource) payload = json.dumps( - datasource.values_for_column( - column, config.get("FILTER_SELECT_ROW_LIMIT", 10000) - ), + datasource.values_for_column(column, config["FILTER_SELECT_ROW_LIMIT"]), default=utils.json_int_dttm_ser, ) return json_success(payload) @@ -1491,7 +1487,7 @@ def get_datasource_label(ds_name: utils.DatasourceName) -> str: tables = [tn for tn in tables if tn.schema in valid_schemas] views = [vn for vn in views if vn.schema in valid_schemas] - max_items = config.get("MAX_TABLE_NAMES") or len(tables) + max_items = config["MAX_TABLE_NAMES"] or len(tables) total_items = len(tables) + len(views) max_tables = len(tables) max_views = len(views) @@ -2101,7 +2097,7 @@ def dashboard(self, dashboard_id): if datasource: datasources.add(datasource) - if config.get("ENABLE_ACCESS_REQUEST"): + if config["ENABLE_ACCESS_REQUEST"]: for datasource in datasources: if datasource and not security_manager.datasource_access(datasource): flash( @@ -2529,7 +2525,7 @@ def validate_sql_json(self): ) try: - timeout = config.get("SQLLAB_VALIDATION_TIMEOUT") + timeout = config["SQLLAB_VALIDATION_TIMEOUT"] timeout_msg = f"The query exceeded the {timeout} seconds timeout." with utils.timeout(seconds=timeout, error_message=timeout_msg): errors = validator.validate(sql, schema, mydb) @@ -2604,7 +2600,7 @@ def _sql_json_sync( :return: String JSON response """ try: - timeout = config.get("SQLLAB_TIMEOUT") + timeout = config["SQLLAB_TIMEOUT"] timeout_msg = f"The query exceeded the {timeout} seconds timeout." with utils.timeout(seconds=timeout, error_message=timeout_msg): # pylint: disable=no-value-for-parameter @@ -2647,7 +2643,7 @@ def sql_json(self): " specified. Defaulting to empty dict" ) template_params = {} - limit = request.json.get("queryLimit") or app.config.get("SQL_MAX_ROW") + limit = request.json.get("queryLimit") or app.config["SQL_MAX_ROW"] async_flag: bool = request.json.get("runAsync") if limit < 0: logging.warning( @@ -2762,13 +2758,13 @@ def csv(self, client_id): columns = [c["name"] for c in obj["columns"]] df = pd.DataFrame.from_records(obj["data"], columns=columns) logging.info("Using pandas to convert to CSV") - csv = df.to_csv(index=False, **config.get("CSV_EXPORT")) + csv = df.to_csv(index=False, **config["CSV_EXPORT"]) else: logging.info("Running a query to turn into CSV") sql = query.select_sql or query.executed_sql df = query.database.get_df(sql, query.schema) # TODO(bkyryliuk): add compression=gzip for big files. - csv = df.to_csv(index=False, **config.get("CSV_EXPORT")) + csv = df.to_csv(index=False, **config["CSV_EXPORT"]) response = Response(csv, mimetype="text/csv") response.headers[ "Content-Disposition" @@ -2878,7 +2874,7 @@ def search_queries(self) -> Response: if to_time: query = query.filter(Query.start_time < int(to_time)) - query_limit = config.get("QUERY_SEARCH_LIMIT", 1000) + query_limit = config["QUERY_SEARCH_LIMIT"] sql_queries = query.order_by(Query.start_time.asc()).limit(query_limit).all() dict_queries = [q.to_dict() for q in sql_queries] @@ -2952,7 +2948,7 @@ def profile(self, username): def sqllab(self): """SQL Editor""" d = { - "defaultDbId": config.get("SQLLAB_DEFAULT_DBID"), + "defaultDbId": config["SQLLAB_DEFAULT_DBID"], "common": self.common_bootstrap_payload(), } return self.render_template( @@ -3084,7 +3080,7 @@ def apply_http_headers(response: Response): # HTTP_HEADERS is deprecated, this provides backwards compatibility response.headers.extend( - {**config["OVERRIDE_HTTP_HEADERS"], **config.get("HTTP_HEADERS", {})} + {**config["OVERRIDE_HTTP_HEADERS"], **config["HTTP_HEADERS"]} ) for k, v in config["DEFAULT_HTTP_HEADERS"].items(): diff --git a/superset/views/database/views.py b/superset/views/database/views.py index bb8700c26ecbc..c7a97964bd616 100644 --- a/superset/views/database/views.py +++ b/superset/views/database/views.py @@ -37,7 +37,7 @@ from .forms import CsvToDatabaseForm config = app.config -stats_logger = config.get("STATS_LOGGER") +stats_logger = config["STATS_LOGGER"] def sqlalchemy_uri_form_validator(form: DynamicForm, field: StringField) -> None: diff --git a/superset/views/log/api.py b/superset/views/log/api.py index 0ebbd5d660fb8..b9b1ae43a7855 100644 --- a/superset/views/log/api.py +++ b/superset/views/log/api.py @@ -41,7 +41,7 @@ class LogRestApi(LogMixin, ModelRestApi): if ( - not app.config.get("FAB_ADD_SECURITY_VIEWS") is False - or app.config.get("SUPERSET_LOG_VIEW") is False + not app.config["FAB_ADD_SECURITY_VIEWS"] is False + or app.config["SUPERSET_LOG_VIEW"] is False ): appbuilder.add_api(LogRestApi) diff --git a/superset/views/log/views.py b/superset/views/log/views.py index 7cb8709928527..b84360b7dc08f 100644 --- a/superset/views/log/views.py +++ b/superset/views/log/views.py @@ -30,8 +30,8 @@ class LogModelView(LogMixin, SupersetModelView): if ( - not app.config.get("FAB_ADD_SECURITY_VIEWS") is False - or app.config.get("SUPERSET_LOG_VIEW") is False + not app.config["FAB_ADD_SECURITY_VIEWS"] is False + or app.config["SUPERSET_LOG_VIEW"] is False ): appbuilder.add_view( LogModelView, diff --git a/superset/views/schedules.py b/superset/views/schedules.py index 7e602d949775d..2a7e0e16d8590 100644 --- a/superset/views/schedules.py +++ b/superset/views/schedules.py @@ -291,5 +291,5 @@ def _register_schedule_menus(): ) -if app.config.get("ENABLE_SCHEDULED_EMAIL_REPORTS"): +if app.config["ENABLE_SCHEDULED_EMAIL_REPORTS"]: _register_schedule_menus() diff --git a/superset/views/utils.py b/superset/views/utils.py index 231373537a15f..db3bc42dc55f4 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -30,7 +30,7 @@ from superset.utils.core import QueryStatus FORM_DATA_KEY_BLACKLIST: List[str] = [] -if not app.config.get("ENABLE_JAVASCRIPT_CONTROLS"): +if not app.config["ENABLE_JAVASCRIPT_CONTROLS"]: FORM_DATA_KEY_BLACKLIST = ["js_tooltip", "js_onclick_href", "js_data_mutator"] @@ -189,7 +189,9 @@ def apply_display_max_row_limit( :param sql_results: The results of a sql query from sql_lab.get_sql_results :returns: The mutated sql_results structure """ - display_limit = rows or app.config.get("DISPLAY_MAX_ROW") + + display_limit = rows or app.config["DISPLAY_MAX_ROW"] + if ( display_limit and sql_results["status"] == QueryStatus.SUCCESS diff --git a/superset/viz.py b/superset/viz.py index 0d11a9d07d90b..85fbf32452d90 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -57,9 +57,9 @@ ) config = app.config -stats_logger = config.get("STATS_LOGGER") -relative_start = config.get("DEFAULT_RELATIVE_START_TIME", "today") -relative_end = config.get("DEFAULT_RELATIVE_END_TIME", "today") +stats_logger = config["STATS_LOGGER"] +relative_start = config["DEFAULT_RELATIVE_START_TIME"] +relative_end = config["DEFAULT_RELATIVE_END_TIME"] METRIC_KEYS = [ "metric", @@ -277,7 +277,7 @@ def query_obj(self): granularity = form_data.get("granularity") or form_data.get("granularity_sqla") limit = int(form_data.get("limit") or 0) timeseries_limit_metric = form_data.get("timeseries_limit_metric") - row_limit = int(form_data.get("row_limit") or config.get("ROW_LIMIT")) + row_limit = int(form_data.get("row_limit") or config["ROW_LIMIT"]) # default order direction order_desc = form_data.get("order_desc", True) @@ -336,7 +336,7 @@ def cache_timeout(self): and self.datasource.database.cache_timeout ) is not None: return self.datasource.database.cache_timeout - return config.get("CACHE_DEFAULT_TIMEOUT") + return config["CACHE_DEFAULT_TIMEOUT"] def get_json(self): return json.dumps( @@ -491,7 +491,7 @@ def data(self): def get_csv(self): df = self.get_df() include_index = not isinstance(df.index, pd.RangeIndex) - return df.to_csv(index=include_index, **config.get("CSV_EXPORT")) + return df.to_csv(index=include_index, **config["CSV_EXPORT"]) def get_data(self, df): return df.to_dict(orient="records") @@ -1473,9 +1473,7 @@ class HistogramViz(BaseViz): def query_obj(self): """Returns the query object for this visualization""" d = super().query_obj() - d["row_limit"] = self.form_data.get( - "row_limit", int(config.get("VIZ_ROW_LIMIT")) - ) + d["row_limit"] = self.form_data.get("row_limit", int(config["VIZ_ROW_LIMIT"])) numeric_columns = self.form_data.get("all_columns_x") if numeric_columns is None: raise Exception(_("Must have at least one numeric column specified")) @@ -2063,7 +2061,7 @@ def get_data(self, df): return { "geoJSON": geo_json, "hasCustomMetric": has_custom_metric, - "mapboxApiKey": config.get("MAPBOX_API_KEY"), + "mapboxApiKey": config["MAPBOX_API_KEY"], "mapStyle": fd.get("mapbox_style"), "aggregatorName": fd.get("pandas_aggfunc"), "clusteringRadius": fd.get("clustering_radius"), @@ -2098,7 +2096,7 @@ def get_data(self, df): slice_ids = fd.get("deck_slices") slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all() return { - "mapboxApiKey": config.get("MAPBOX_API_KEY"), + "mapboxApiKey": config["MAPBOX_API_KEY"], "slices": [slc.data for slc in slices], } @@ -2249,7 +2247,7 @@ def get_data(self, df): return { "features": features, - "mapboxApiKey": config.get("MAPBOX_API_KEY"), + "mapboxApiKey": config["MAPBOX_API_KEY"], "metricLabels": self.metric_labels, } @@ -2495,7 +2493,7 @@ def get_properties(self, d): def get_data(self, df): d = super().get_data(df) - return {"features": d["features"], "mapboxApiKey": config.get("MAPBOX_API_KEY")} + return {"features": d["features"], "mapboxApiKey": config["MAPBOX_API_KEY"]} class EventFlowViz(BaseViz): diff --git a/tests/access_tests.py b/tests/access_tests.py index 221934cdee196..a27000ac29dd9 100644 --- a/tests/access_tests.py +++ b/tests/access_tests.py @@ -212,7 +212,7 @@ def test_clean_requests_after_role_extend(self): # Check if access request for gamma at energy_usage was deleted # gamma2 and gamma request table_role on energy usage - if app.config.get("ENABLE_ACCESS_REQUEST"): + if app.config["ENABLE_ACCESS_REQUEST"]: access_request1 = create_access_request( session, "table", "random_time_series", TEST_ROLE_1, "gamma2" ) @@ -354,7 +354,7 @@ def test_clean_requests_after_schema_grant(self): @mock.patch("superset.utils.core.send_MIME_email") def test_approve(self, mock_send_mime): - if app.config.get("ENABLE_ACCESS_REQUEST"): + if app.config["ENABLE_ACCESS_REQUEST"]: session = db.session TEST_ROLE_NAME = "table_role" security_manager.add_role(TEST_ROLE_NAME) @@ -481,7 +481,7 @@ def test_approve(self, mock_send_mime): session.commit() def test_request_access(self): - if app.config.get("ENABLE_ACCESS_REQUEST"): + if app.config["ENABLE_ACCESS_REQUEST"]: session = db.session self.logout() self.login(username="gamma") diff --git a/tests/base_tests.py b/tests/base_tests.py index 7fb1c99c26f38..9e342d818395b 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -30,7 +30,7 @@ from superset.models.core import Database from superset.utils.core import get_example_database -BASE_DIR = app.config.get("BASE_DIR") +BASE_DIR = app.config["BASE_DIR"] class SupersetTestCase(unittest.TestCase): diff --git a/tests/celery_tests.py b/tests/celery_tests.py index 742e61904bf03..c16087bbf9829 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -32,12 +32,12 @@ from .base_tests import SupersetTestCase -BASE_DIR = app.config.get("BASE_DIR") +BASE_DIR = app.config["BASE_DIR"] CELERY_SLEEP_TIME = 5 class CeleryConfig(object): - BROKER_URL = app.config.get("CELERY_RESULT_BACKEND") + BROKER_URL = app.config["CELERY_CONFIG"].BROKER_URL CELERY_IMPORTS = ("superset.sql_lab",) CELERY_ANNOTATIONS = {"sql_lab.add": {"rate_limit": "10/s"}} CONCURRENCY = 1 diff --git a/tests/email_tests.py b/tests/email_tests.py index cc3388d5691c2..ba4c2638e84f6 100644 --- a/tests/email_tests.py +++ b/tests/email_tests.py @@ -47,11 +47,11 @@ def test_send_smtp(self, mock_send_mime): assert mock_send_mime.called call_args = mock_send_mime.call_args[0] logging.debug(call_args) - assert call_args[0] == app.config.get("SMTP_MAIL_FROM") + assert call_args[0] == app.config["SMTP_MAIL_FROM"] assert call_args[1] == ["to"] msg = call_args[2] assert msg["Subject"] == "subject" - assert msg["From"] == app.config.get("SMTP_MAIL_FROM") + assert msg["From"] == app.config["SMTP_MAIL_FROM"] assert len(msg.get_payload()) == 2 mimeapp = MIMEApplication("attachment") assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload() @@ -64,11 +64,11 @@ def test_send_smtp_data(self, mock_send_mime): assert mock_send_mime.called call_args = mock_send_mime.call_args[0] logging.debug(call_args) - assert call_args[0] == app.config.get("SMTP_MAIL_FROM") + assert call_args[0] == app.config["SMTP_MAIL_FROM"] assert call_args[1] == ["to"] msg = call_args[2] assert msg["Subject"] == "subject" - assert msg["From"] == app.config.get("SMTP_MAIL_FROM") + assert msg["From"] == app.config["SMTP_MAIL_FROM"] assert len(msg.get_payload()) == 2 mimeapp = MIMEApplication("data") assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload() @@ -82,11 +82,11 @@ def test_send_smtp_inline_images(self, mock_send_mime): assert mock_send_mime.called call_args = mock_send_mime.call_args[0] logging.debug(call_args) - assert call_args[0] == app.config.get("SMTP_MAIL_FROM") + assert call_args[0] == app.config["SMTP_MAIL_FROM"] assert call_args[1] == ["to"] msg = call_args[2] assert msg["Subject"] == "subject" - assert msg["From"] == app.config.get("SMTP_MAIL_FROM") + assert msg["From"] == app.config["SMTP_MAIL_FROM"] assert len(msg.get_payload()) == 2 mimeapp = MIMEImage(image) assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload() @@ -107,11 +107,11 @@ def test_send_bcc_smtp(self, mock_send_mime): ) assert mock_send_mime.called call_args = mock_send_mime.call_args[0] - assert call_args[0] == app.config.get("SMTP_MAIL_FROM") + assert call_args[0] == app.config["SMTP_MAIL_FROM"] assert call_args[1] == ["to", "cc", "bcc"] msg = call_args[2] assert msg["Subject"] == "subject" - assert msg["From"] == app.config.get("SMTP_MAIL_FROM") + assert msg["From"] == app.config["SMTP_MAIL_FROM"] assert len(msg.get_payload()) == 2 mimeapp = MIMEApplication("attachment") assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload() @@ -123,12 +123,10 @@ def test_send_mime(self, mock_smtp, mock_smtp_ssl): mock_smtp_ssl.return_value = mock.Mock() msg = MIMEMultipart() utils.send_MIME_email("from", "to", msg, app.config, dryrun=False) - mock_smtp.assert_called_with( - app.config.get("SMTP_HOST"), app.config.get("SMTP_PORT") - ) + mock_smtp.assert_called_with(app.config["SMTP_HOST"], app.config["SMTP_PORT"]) assert mock_smtp.return_value.starttls.called mock_smtp.return_value.login.assert_called_with( - app.config.get("SMTP_USER"), app.config.get("SMTP_PASSWORD") + app.config["SMTP_USER"], app.config["SMTP_PASSWORD"] ) mock_smtp.return_value.sendmail.assert_called_with( "from", "to", msg.as_string() @@ -144,7 +142,7 @@ def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl): utils.send_MIME_email("from", "to", MIMEMultipart(), app.config, dryrun=False) assert not mock_smtp.called mock_smtp_ssl.assert_called_with( - app.config.get("SMTP_HOST"), app.config.get("SMTP_PORT") + app.config["SMTP_HOST"], app.config["SMTP_PORT"] ) @mock.patch("smtplib.SMTP_SSL") @@ -156,9 +154,7 @@ def test_send_mime_noauth(self, mock_smtp, mock_smtp_ssl): mock_smtp_ssl.return_value = mock.Mock() utils.send_MIME_email("from", "to", MIMEMultipart(), app.config, dryrun=False) assert not mock_smtp_ssl.called - mock_smtp.assert_called_with( - app.config.get("SMTP_HOST"), app.config.get("SMTP_PORT") - ) + mock_smtp.assert_called_with(app.config["SMTP_HOST"], app.config["SMTP_PORT"]) assert not mock_smtp.login.called @mock.patch("smtplib.SMTP_SSL") diff --git a/tests/security_tests.py b/tests/security_tests.py index 4192843ae5cb7..4478389902fc9 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -94,7 +94,7 @@ def assert_can_alpha(self, perm_set): self.assertIn(("muldelete", "DruidDatasourceModelView"), perm_set) def assert_cannot_alpha(self, perm_set): - if app.config.get("ENABLE_ACCESS_REQUEST"): + if app.config["ENABLE_ACCESS_REQUEST"]: self.assert_cannot_write("AccessRequestsModelView", perm_set) self.assert_can_all("AccessRequestsModelView", perm_set) self.assert_cannot_write("Queries", perm_set) @@ -133,7 +133,7 @@ def test_is_admin_only(self): security_manager.find_permission_view_menu("can_delete", "DatabaseView") ) ) - if app.config.get("ENABLE_ACCESS_REQUEST"): + if app.config["ENABLE_ACCESS_REQUEST"]: self.assertTrue( security_manager._is_admin_only( security_manager.find_permission_view_menu(