Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manually build ES pagination query for default sorter #4271

Merged
merged 1 commit into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 250 additions & 30 deletions common/persistence/visibility/store/elasticsearch/visibility_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"strconv"
"strings"
"time"
Expand All @@ -56,17 +57,6 @@ const (
delimiter = "~"
)

// Default sort by uses the sorting order defined in the index template, so no
// additional sorting is needed during query.
var defaultSorter = []elastic.Sorter{
elastic.NewFieldSort(searchattribute.CloseTime).Desc().Missing("_first"),
elastic.NewFieldSort(searchattribute.StartTime).Desc().Missing("_first"),
}

var docSorter = []elastic.Sorter{
elastic.SortByDoc{},
}

type (
visibilityStore struct {
esClient client.Client
Expand All @@ -82,12 +72,45 @@ type (
visibilityPageToken struct {
SearchAfter []interface{}
}

fieldSort struct {
name string
desc bool
missing_first bool
}
)

var _ store.VisibilityStore = (*visibilityStore)(nil)

var (
errUnexpectedJSONFieldType = errors.New("unexpected JSON field type")

// Default sorter uses the sorting order defined in the index template.
// It is indirectly built so buildPaginationQuery can have access to
// the fields names to build the page query from the token.
defaultSorterFields = []fieldSort{
{searchattribute.CloseTime, true, true},
{searchattribute.StartTime, true, true},
}

defaultSorter = func() []elastic.Sorter {
ret := make([]elastic.Sorter, 0, len(defaultSorterFields))
for _, item := range defaultSorterFields {
fs := elastic.NewFieldSort(item.name)
if item.desc {
fs.Desc()
}
if item.missing_first {
fs.Missing("_first")
}
ret = append(ret, fs)
}
return ret
}()

docSorter = []elastic.Sorter{
elastic.SortByDoc{},
}
)

// NewVisibilityStore create a visibility store connecting to ElasticSearch
Expand Down Expand Up @@ -439,15 +462,6 @@ func (s *visibilityStore) ListWorkflowExecutions(
return nil, err
}

token, err := s.deserializePageToken(request.NextPageToken)
if err != nil {
return nil, err
}

if token != nil && len(token.SearchAfter) > 0 {
p.SearchAfter = token.SearchAfter
}

searchResult, err := s.esClient.Search(ctx, p)
if err != nil {
return nil, convertElasticsearchClientError("ListWorkflowExecutions failed", err)
Expand All @@ -465,15 +479,6 @@ func (s *visibilityStore) ScanWorkflowExecutions(
return nil, err
}

token, err := s.deserializePageToken(request.NextPageToken)
if err != nil {
return nil, err
}

if token != nil && len(token.SearchAfter) > 0 {
p.SearchAfter = token.SearchAfter
}

searchResult, err := s.esClient.Search(ctx, p)
if err != nil {
return nil, convertElasticsearchClientError("ScanWorkflowExecutions failed", err)
Expand Down Expand Up @@ -588,7 +593,6 @@ func (s *visibilityStore) buildSearchParametersV2(
request *manager.ListWorkflowExecutionsRequestV2,
getFieldSorter func([]*elastic.FieldSort) ([]elastic.Sorter, error),
) (*client.SearchParameters, error) {

boolQuery, fieldSorts, err := s.convertQuery(
request.Namespace,
request.NamespaceID,
Expand Down Expand Up @@ -619,9 +623,63 @@ func (s *visibilityStore) buildSearchParametersV2(
Sorter: sorter,
}

pageToken, err := s.deserializePageToken(request.NextPageToken)
if err != nil {
return nil, err
}
err = s.processPageToken(params, pageToken)
if err != nil {
return nil, err
}

return params, nil
}

func (s *visibilityStore) processPageToken(
params *client.SearchParameters,
pageToken *visibilityPageToken,
) error {
if pageToken == nil || len(pageToken.SearchAfter) == 0 {
return nil
}
if len(pageToken.SearchAfter) != len(params.Sorter) {
return serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token for given sort fields: expected %d fields, got %d",
len(params.Sorter),
len(pageToken.SearchAfter),
))
}
if !isDefaultSorter(params.Sorter) {
params.SearchAfter = pageToken.SearchAfter
return nil
}

boolQuery, ok := params.Query.(*elastic.BoolQuery)
if !ok {
return serviceerror.NewInternal(fmt.Sprintf(
"Unexpected query type: expected *elastic.BoolQuery, got %T",
params.Query,
))
}

saTypeMap, err := s.searchAttributesProvider.GetSearchAttributes(s.index, false)
if err != nil {
return serviceerror.NewUnavailable(
fmt.Sprintf("Unable to read search attribute types: %v", err),
)
}

// build pagination search query for default sorter
shouldQueries, err := buildPaginationQuery(defaultSorterFields, pageToken.SearchAfter, saTypeMap)
if err != nil {
return err
}

boolQuery.Should(shouldQueries...)
boolQuery.MinimumNumberShouldMatch(1)
return nil
}

func (s *visibilityStore) convertQuery(
namespace namespace.Name,
namespaceID namespace.ID,
Expand Down Expand Up @@ -998,3 +1056,165 @@ func detailedErrorMessage(err error) string {
}
return sb.String()
}

func isDefaultSorter(sorter []elastic.Sorter) bool {
if len(sorter) != len(defaultSorter) {
return false
}
for i := 0; i < len(defaultSorter); i++ {
if &sorter[i] != &defaultSorter[i] {
return false
}
}
return true
}

// buildPaginationQuery builds the Elasticsearch conditions for the next page based on searchAfter.
//
// For example, if sorterFields = [A, B, C] and searchAfter = [lastA, lastB, lastC],
// it will build the following conditions (assuming all values are non-null and orders are desc):
// - k = 0: A < lastA
// - k = 1: A = lastA AND B < lastB
// - k = 2: A = lastA AND B = lastB AND C < lastC
//
//nolint:revive // cyclomatic complexity
func buildPaginationQuery(
sorterFields []fieldSort,
searchAfter []any,
saTypeMap searchattribute.NameTypeMap,
) ([]elastic.Query, error) {
n := len(sorterFields)
if len(sorterFields) != len(searchAfter) {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token for given sort fields: expected %d fields, got %d",
len(sorterFields),
len(searchAfter),
))
}

parsedSearchAfter := make([]any, n)
for i := 0; i < n; i++ {
tp, err := saTypeMap.GetType(sorterFields[i].name)
if err != nil {
return nil, err
}
parsedSearchAfter[i], err = parsePageTokenValue(sorterFields[i].name, searchAfter[i], tp)
if err != nil {
return nil, err
}
}

// Last field of sorter must be a tie breaker, and thus cannot contain null value.
if parsedSearchAfter[len(parsedSearchAfter)-1] == nil {
return nil, serviceerror.NewInternal(fmt.Sprintf(
"Last field of sorter cannot be a nullable field: %q has null values",
sorterFields[len(sorterFields)-1].name,
))
}

shouldQueries := make([]elastic.Query, 0, len(sorterFields))
for k := 0; k < len(sorterFields); k++ {
bq := elastic.NewBoolQuery()
for i := 0; i <= k; i++ {
field := sorterFields[i]
value := parsedSearchAfter[i]
if i == k {
if value == nil {
bq.Filter(elastic.NewExistsQuery(field.name))
} else if field.desc {
bq.Filter(elastic.NewRangeQuery(field.name).Lt(value))
} else {
bq.Filter(elastic.NewRangeQuery(field.name).Gt(value))
}
} else {
if value == nil {
bq.MustNot(elastic.NewExistsQuery(field.name))
} else {
bq.Filter(elastic.NewTermQuery(field.name, value))
}
}
}
shouldQueries = append(shouldQueries, bq)
}
return shouldQueries, nil
}

// parsePageTokenValue parses the page token values to be used in the search query.
// The page token comes from the `sort` field from the previous response from Elasticsearch.
// Depending on the type of the field, the null value is represented differently:
// - integer, bool, and datetime: MaxInt64 (desc) or MinInt64 (asc)
// - double: "Infinity" (desc) or "-Infinity" (asc)
// - keyword: nil
//
// Furthermore, for bool and datetime, they need to be converted to boolean or the RFC3339Nano
// formats respectively.
//
//nolint:revive // cyclomatic complexity
func parsePageTokenValue(
fieldName string, jsonValue any,
tp enumspb.IndexedValueType,
) (any, error) {
switch tp {
case enumspb.INDEXED_VALUE_TYPE_INT,
enumspb.INDEXED_VALUE_TYPE_BOOL,
enumspb.INDEXED_VALUE_TYPE_DATETIME:
jsonNumber, ok := jsonValue.(json.Number)
if !ok {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected interger type, got %q", jsonValue))
}
num, err := jsonNumber.Int64()
if err != nil {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected interger type, got %v", jsonValue))
}
if num == math.MaxInt64 || num == math.MinInt64 {
return nil, nil
}
if tp == enumspb.INDEXED_VALUE_TYPE_BOOL {
return num != 0, nil
}
if tp == enumspb.INDEXED_VALUE_TYPE_DATETIME {
return time.Unix(0, num).UTC().Format(time.RFC3339Nano), nil
}
return num, nil

case enumspb.INDEXED_VALUE_TYPE_DOUBLE:
switch v := jsonValue.(type) {
case json.Number:
num, err := v.Float64()
if err != nil {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected float type, got %v", jsonValue))
}
return num, nil
case string:
// it can be the string representation of infinity
if _, err := strconv.ParseFloat(v, 64); err != nil {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected float type, got %q", jsonValue))
}
return nil, nil
default:
// it should never reach here
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected float type, got %#v", jsonValue))
}

case enumspb.INDEXED_VALUE_TYPE_KEYWORD:
if jsonValue == nil {
return nil, nil
}
if _, ok := jsonValue.(string); !ok {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected string type, got %v", jsonValue))
}
return jsonValue, nil

default:
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid field type in sorter: cannot order by %q",
fieldName,
))
}
}
Loading