Skip to content

Commit

Permalink
feat: add global max row limit (apache#16683)
Browse files Browse the repository at this point in the history
* feat: add global max limit

* fix lint and tests

* leave SAMPLES_ROW_LIMIT unchanged

* fix sample rowcount test

* replace max global limit with existing sql max row limit

* fix test

* make max_limit optional in util

* improve comments

(cherry picked from commit 4e3d4f6)
  • Loading branch information
villebro authored and Steven Uray committed Sep 17, 2021
1 parent 3014da1 commit 1a58188
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 43 deletions.
4 changes: 0 additions & 4 deletions superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import copy
import math
from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING

from flask_babel import _
Expand Down Expand Up @@ -131,15 +130,12 @@ def _get_samples(
query_context: "QueryContext", query_obj: "QueryObject", force_cached: bool = False
) -> Dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
row_limit = query_obj.row_limit or math.inf
query_obj = copy.copy(query_obj)
query_obj.is_timeseries = False
query_obj.orderby = []
query_obj.groupby = []
query_obj.metrics = []
query_obj.post_processing = []
query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
query_obj.row_offset = 0
query_obj.columns = [o.column_name for o in datasource.columns]
return _get_full(query_context, query_obj, force_cached)

Expand Down
2 changes: 1 addition & 1 deletion superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ def __init__(
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
self.queries = [QueryObject(**query_obj) for query_obj in queries]
self.force = force
self.custom_cache_timeout = custom_cache_timeout
self.result_type = result_type or ChartDataResultType.FULL
self.result_format = result_format or ChartDataResultFormat.JSON
self.queries = [QueryObject(self, **query_obj) for query_obj in queries]
self.cache_values = {
"datasource": datasource,
"queries": queries,
Expand Down
17 changes: 14 additions & 3 deletions superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, NamedTuple, Optional
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING

from flask_babel import gettext as _
from pandas import DataFrame
Expand All @@ -28,6 +28,7 @@
from superset.typing import Metric, OrderBy
from superset.utils import pandas_postprocessing
from superset.utils.core import (
apply_max_row_limit,
ChartDataResultType,
DatasourceDict,
DTTM_ALIAS,
Expand All @@ -41,6 +42,10 @@
from superset.utils.hashing import md5_sha_from_dict
from superset.views.utils import get_time_range_endpoints

if TYPE_CHECKING:
from superset.common.query_context import QueryContext # pragma: no cover


config = app.config
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,6 +105,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes

def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
query_context: "QueryContext",
datasource: Optional[DatasourceDict] = None,
result_type: Optional[ChartDataResultType] = None,
annotation_layers: Optional[List[Dict[str, Any]]] = None,
Expand Down Expand Up @@ -138,7 +144,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
self.result_type = result_type
self.result_type = result_type or query_context.result_type
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
self.annotation_layers = [
layer
Expand Down Expand Up @@ -180,7 +186,12 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
for x in metrics
]

self.row_limit = config["ROW_LIMIT"] if row_limit is None else row_limit
default_row_limit = (
config["SAMPLES_ROW_LIMIT"]
if self.result_type == ChartDataResultType.SAMPLES
else config["ROW_LIMIT"]
)
self.row_limit = apply_max_row_limit(row_limit or default_row_limit)
self.row_offset = row_offset or 0
self.filter = filters or []
self.timeseries_limit = timeseries_limit
Expand Down
8 changes: 3 additions & 5 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
# default viz used in chart explorer
DEFAULT_VIZ_TYPE = "table"

# default row limit when requesting chart data
ROW_LIMIT = 50000
VIZ_ROW_LIMIT = 10000
# max rows retreieved when requesting samples from datasource in explore view
# default row limit when requesting samples from datasource in explore view
SAMPLES_ROW_LIMIT = 1000
# max rows retrieved by filter select auto complete
FILTER_SELECT_ROW_LIMIT = 10000
Expand Down Expand Up @@ -665,9 +665,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
# Set this API key to enable Mapbox visualizations
MAPBOX_API_KEY = os.environ.get("MAPBOX_API_KEY", "")

# Maximum number of rows returned from a database
# in async mode, no more than SQL_MAX_ROW will be returned and stored
# in the results backend. This also becomes the limit when exporting CSVs
# Maximum number of rows returned for any analytical database query
SQL_MAX_ROW = 100000

# Maximum number of rows displayed in SQL Lab UI
Expand Down
22 changes: 22 additions & 0 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,3 +1761,25 @@ def parse_boolean_string(bool_str: Optional[str]) -> bool:
return bool(strtobool(bool_str.lower()))
except ValueError:
return False


def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
"""
Override row limit if max global limit is defined
:param limit: requested row limit
:param max_limit: Maximum allowed row limit
:return: Capped row limit
>>> apply_max_row_limit(100000, 10)
10
>>> apply_max_row_limit(10, 100000)
10
>>> apply_max_row_limit(0, 10000)
10000
"""
if max_limit is None:
max_limit = current_app.config["SQL_MAX_ROW"]
if limit != 0:
return min(max_limit, limit)
return max_limit
5 changes: 3 additions & 2 deletions superset/utils/sqllab_execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@

from flask import g

from superset import app, is_feature_enabled
from superset import is_feature_enabled
from superset.models.sql_lab import Query
from superset.sql_parse import CtasMethod
from superset.utils import core as utils
from superset.utils.core import apply_max_row_limit
from superset.utils.dates import now_as_float
from superset.views.utils import get_cta_schema_name

Expand Down Expand Up @@ -97,7 +98,7 @@ def _get_template_params(query_params: Dict[str, Any]) -> Dict[str, Any]:

@staticmethod
def _get_limit_param(query_params: Dict[str, Any]) -> int:
limit: int = query_params.get("queryLimit") or app.config["SQL_MAX_ROW"]
limit = apply_max_row_limit(query_params.get("queryLimit") or 0)
if limit < 0:
logger.warning(
"Invalid limit of %i specified. Defaulting to max limit.", limit
Expand Down
5 changes: 3 additions & 2 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
from superset.utils import core as utils, csv
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.cache import etag_cache
from superset.utils.core import ReservedUrlParameters
from superset.utils.core import apply_max_row_limit, ReservedUrlParameters
from superset.utils.dates import now_as_float
from superset.utils.decorators import check_dashboard_access
from superset.utils.sqllab_execution_context import SqlJsonExecutionContext
Expand Down Expand Up @@ -898,8 +898,9 @@ def filter( # pylint: disable=no-self-use
return json_error_response(DATASOURCE_MISSING_ERR)

datasource.raise_for_access()
row_limit = apply_max_row_limit(config["FILTER_SELECT_ROW_LIMIT"])
payload = json.dumps(
datasource.values_for_column(column, config["FILTER_SELECT_ROW_LIMIT"]),
datasource.values_for_column(column, row_limit),
default=utils.json_int_dttm_ser,
ignore_nan=True,
)
Expand Down
45 changes: 26 additions & 19 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Superset can render.
"""
import copy
import inspect
import dataclasses
import logging
import math
import re
Expand Down Expand Up @@ -70,6 +70,7 @@
from superset.utils import core as utils, csv
from superset.utils.cache import set_and_log_cache
from superset.utils.core import (
apply_max_row_limit,
DTTM_ALIAS,
ExtraFiltersReasonType,
JS_MAX_INTEGER,
Expand All @@ -81,9 +82,6 @@
from superset.utils.dates import datetime_to_epoch
from superset.utils.hashing import md5_sha_from_str

import dataclasses # isort:skip


if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource

Expand All @@ -110,7 +108,7 @@
FILTER_VALUES_REGEX = re.compile(r"filter_values\(['\"](\w+)['\"]\,")


class BaseViz:
class BaseViz: # pylint: disable=too-many-public-methods

"""All visualizations derive this base class"""

Expand Down Expand Up @@ -332,6 +330,7 @@ def query_obj(self) -> QueryObjectDict:
limit = int(form_data.get("limit") or 0)
timeseries_limit_metric = form_data.get("timeseries_limit_metric")
row_limit = int(form_data.get("row_limit") or config["ROW_LIMIT"])
row_limit = apply_max_row_limit(row_limit)

# default order direction
order_desc = form_data.get("order_desc", True)
Expand Down Expand Up @@ -556,7 +555,7 @@ def get_df_payload(
)
self.errors.append(error)
self.status = utils.QueryStatus.FAILED
except Exception as ex:
except Exception as ex: # pylint: disable=broad-except
logger.exception(ex)

error = dataclasses.asdict(
Expand Down Expand Up @@ -625,7 +624,7 @@ def get_csv(self) -> Optional[str]:
include_index = not isinstance(df.index, pd.RangeIndex)
return csv.df_to_escaped_csv(df, index=include_index, **config["CSV_EXPORT"])

def get_data(self, df: pd.DataFrame) -> VizData:
def get_data(self, df: pd.DataFrame) -> VizData: # pylint: disable=no-self-use
return df.to_dict(orient="records")

@property
Expand Down Expand Up @@ -1242,7 +1241,7 @@ def query_obj(self) -> QueryObjectDict:
d["orderby"] = [(sort_by, is_asc)]
return d

def to_series(
def to_series( # pylint: disable=too-many-branches
self, df: pd.DataFrame, classed: str = "", title_suffix: str = ""
) -> List[Dict[str, Any]]:
cols = []
Expand Down Expand Up @@ -1446,6 +1445,7 @@ def query_obj(self) -> QueryObjectDict:
return {}

def get_data(self, df: pd.DataFrame) -> VizData:
# pylint: disable=import-outside-toplevel,too-many-locals
multiline_fd = self.form_data
# Late import to avoid circular import issues
from superset.charts.dao import ChartDAO
Expand Down Expand Up @@ -1669,19 +1669,20 @@ class HistogramViz(BaseViz):

def query_obj(self) -> QueryObjectDict:
"""Returns the query object for this visualization"""
d = super().query_obj()
d["row_limit"] = self.form_data.get("row_limit", int(config["VIZ_ROW_LIMIT"]))
query_obj = super().query_obj()
numeric_columns = self.form_data.get("all_columns_x")
if numeric_columns is None:
raise QueryObjectValidationError(
_("Must have at least one numeric column specified")
)
self.columns = numeric_columns
d["columns"] = numeric_columns + self.groupby
self.columns = ( # pylint: disable=attribute-defined-outside-init
numeric_columns
)
query_obj["columns"] = numeric_columns + self.groupby
# override groupby entry to avoid aggregation
d["groupby"] = None
d["metrics"] = None
return d
query_obj["groupby"] = None
query_obj["metrics"] = None
return query_obj

def labelify(self, keys: Union[List[str], str], column: str) -> str:
if isinstance(keys, str):
Expand Down Expand Up @@ -1751,7 +1752,7 @@ def query_obj(self) -> QueryObjectDict:

return d

def get_data(self, df: pd.DataFrame) -> VizData:
def get_data(self, df: pd.DataFrame) -> VizData: # pylint: disable=too-many-locals
if df.empty:
return None

Expand Down Expand Up @@ -2061,6 +2062,7 @@ def query_obj(self) -> QueryObjectDict:
return {}

def run_extra_queries(self) -> None:
# pylint: disable=import-outside-toplevel
from superset.common.query_context import QueryContext

qry = super().query_obj()
Expand Down Expand Up @@ -2373,6 +2375,7 @@ def query_obj(self) -> QueryObjectDict:
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
# Late imports to avoid circular import issues
# pylint: disable=import-outside-toplevel
from superset import db
from superset.models.slice import Slice

Expand All @@ -2393,6 +2396,7 @@ class BaseDeckGLViz(BaseViz):
spatial_control_keys: List[str] = []

def get_metrics(self) -> List[str]:
# pylint: disable=attribute-defined-outside-init
self.metric = self.form_data.get("size")
return [self.metric] if self.metric else []

Expand Down Expand Up @@ -2557,15 +2561,18 @@ class DeckScatterViz(BaseDeckGLViz):
is_timeseries = True

def query_obj(self) -> QueryObjectDict:
fd = self.form_data
self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity"))
self.point_radius_fixed = fd.get("point_radius_fixed") or {
# pylint: disable=attribute-defined-outside-init
self.is_timeseries = bool(
self.form_data.get("time_grain_sqla") or self.form_data.get("granularity")
)
self.point_radius_fixed = self.form_data.get("point_radius_fixed") or {
"type": "fix",
"value": 500,
}
return super().query_obj()

def get_metrics(self) -> List[str]:
# pylint: disable=attribute-defined-outside-init
self.metric = None
if self.point_radius_fixed.get("type") == "metric":
self.metric = self.point_radius_fixed["value"]
Expand Down
Loading

0 comments on commit 1a58188

Please sign in to comment.