Skip to content

Commit

Permalink
Manually build ES pagination query for default sorter
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigozhou committed May 4, 2023
1 parent d4cd05e commit ebb5dbf
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 36 deletions.
233 changes: 203 additions & 30 deletions common/persistence/visibility/store/elasticsearch/visibility_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"strconv"
"strings"
"time"
Expand All @@ -56,17 +57,6 @@ const (
delimiter = "~"
)

// Default sort by uses the sorting order defined in the index template, so no
// additional sorting is needed during query.
var defaultSorter = []elastic.Sorter{
elastic.NewFieldSort(searchattribute.CloseTime).Desc().Missing("_first"),
elastic.NewFieldSort(searchattribute.StartTime).Desc().Missing("_first"),
}

var docSorter = []elastic.Sorter{
elastic.SortByDoc{},
}

type (
visibilityStore struct {
esClient client.Client
Expand All @@ -82,12 +72,45 @@ type (
visibilityPageToken struct {
SearchAfter []interface{}
}

fieldSort struct {
name string
desc bool
missing_first bool
}
)

var _ store.VisibilityStore = (*visibilityStore)(nil)

var (
errUnexpectedJSONFieldType = errors.New("unexpected JSON field type")

// Default sorter uses the sorting order defined in the index template.
// It is indirectly built so buildPageQuery can have access to the fields names
// to build the page query from the token.
defaultSorterFields = []fieldSort{
{searchattribute.CloseTime, true, true},
{searchattribute.StartTime, true, true},
}

defaultSorter = func() []elastic.Sorter {
ret := make([]elastic.Sorter, 0, len(defaultSorterFields))
for _, item := range defaultSorterFields {
fs := elastic.NewFieldSort(item.name)
if item.desc {
fs.Desc()
}
if item.missing_first {
fs.Missing("_first")
}
ret = append(ret, fs)
}
return ret
}()

docSorter = []elastic.Sorter{
elastic.SortByDoc{},
}
)

// NewVisibilityStore create a visibility store connecting to ElasticSearch
Expand Down Expand Up @@ -439,15 +462,6 @@ func (s *visibilityStore) ListWorkflowExecutions(
return nil, err
}

token, err := s.deserializePageToken(request.NextPageToken)
if err != nil {
return nil, err
}

if token != nil && len(token.SearchAfter) > 0 {
p.SearchAfter = token.SearchAfter
}

searchResult, err := s.esClient.Search(ctx, p)
if err != nil {
return nil, convertElasticsearchClientError("ListWorkflowExecutions failed", err)
Expand All @@ -465,15 +479,6 @@ func (s *visibilityStore) ScanWorkflowExecutions(
return nil, err
}

token, err := s.deserializePageToken(request.NextPageToken)
if err != nil {
return nil, err
}

if token != nil && len(token.SearchAfter) > 0 {
p.SearchAfter = token.SearchAfter
}

searchResult, err := s.esClient.Search(ctx, p)
if err != nil {
return nil, convertElasticsearchClientError("ScanWorkflowExecutions failed", err)
Expand Down Expand Up @@ -588,7 +593,6 @@ func (s *visibilityStore) buildSearchParametersV2(
request *manager.ListWorkflowExecutionsRequestV2,
getFieldSorter func([]*elastic.FieldSort) ([]elastic.Sorter, error),
) (*client.SearchParameters, error) {

boolQuery, fieldSorts, err := s.convertQuery(
request.Namespace,
request.NamespaceID,
Expand Down Expand Up @@ -619,9 +623,99 @@ func (s *visibilityStore) buildSearchParametersV2(
Sorter: sorter,
}

pageToken, err := s.deserializePageToken(request.NextPageToken)
if err != nil {
return nil, err
}
err = s.buildPageQuery(params, pageToken)
if err != nil {
return nil, err
}

return params, nil
}

//nolint:revive // cyclomatic complexity
func (s *visibilityStore) buildPageQuery(
params *client.SearchParameters,
pageToken *visibilityPageToken,
) error {
if pageToken == nil || len(pageToken.SearchAfter) == 0 {
return nil
}
if len(pageToken.SearchAfter) != len(params.Sorter) {
return serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token for given sort fields: expected %d fields, got %d",
len(params.Sorter),
len(pageToken.SearchAfter),
))
}
if !isDefaultSorter(params.Sorter) {
params.SearchAfter = pageToken.SearchAfter
return nil
}

boolQuery, ok := params.Query.(*elastic.BoolQuery)
if !ok {
return serviceerror.NewInternal(fmt.Sprintf(
"Unexpected query type: expected *elastic.BoolQuery, got %T",
params.Query,
))
}

saTypeMap, err := s.searchAttributesProvider.GetSearchAttributes(s.index, false)
if err != nil {
return serviceerror.NewUnavailable(
fmt.Sprintf("Unable to read search attribute types: %v", err),
)
}

// build pagination search query for default sorter
shouldQueries := make([]elastic.Query, 0, len(defaultSorter))
searchAfter := pageToken.SearchAfter
for k := 0; k < len(defaultSorterFields); k++ {
bq := elastic.NewBoolQuery()
for i := 0; i <= k; i++ {
tp, err := saTypeMap.GetType(defaultSorterFields[i].name)
if err != nil {
return err
}

lastValue, err := parsePageTokenValue(defaultSorterFields[i].name, searchAfter[i], tp)
if err != nil {
return err
}
if i == len(defaultSorterFields)-1 && lastValue == nil {
return serviceerror.NewInternal(fmt.Sprintf(
"Last field of sorter cannot be a nullable field: %q has null values",
defaultSorterFields[i].name,
))
}

if i == k {
if lastValue == nil {
bq.Filter(elastic.NewExistsQuery(defaultSorterFields[i].name))
} else if defaultSorterFields[i].desc {
bq.Filter(elastic.NewRangeQuery(defaultSorterFields[i].name).Lt(lastValue))
} else {
bq.Filter(elastic.NewRangeQuery(defaultSorterFields[i].name).Gt(lastValue))
}
} else {
if lastValue == nil {
bq.MustNot(elastic.NewExistsQuery(defaultSorterFields[i].name))
} else {
bq.Filter(elastic.NewTermQuery(defaultSorterFields[i].name, lastValue))
}
}
}
shouldQueries = append(shouldQueries, bq)
}

boolQuery.Should(shouldQueries...)
boolQuery.MinimumShouldMatch("1")
return nil
}

func (s *visibilityStore) convertQuery(
namespace namespace.Name,
namespaceID namespace.ID,
Expand Down Expand Up @@ -998,3 +1092,82 @@ func detailedErrorMessage(err error) string {
}
return sb.String()
}

func isDefaultSorter(sorter []elastic.Sorter) bool {
if len(sorter) != len(defaultSorter) {
return false
}
for i := 0; i < len(defaultSorter); i++ {
if &sorter[i] != &defaultSorter[i] {
return false
}
}
return true
}

//nolint:revive // cyclomatic complexity
func parsePageTokenValue(fieldName string, jsonValue any, tp enumspb.IndexedValueType) (any, error) {
switch tp {
case enumspb.INDEXED_VALUE_TYPE_INT,
enumspb.INDEXED_VALUE_TYPE_BOOL,
enumspb.INDEXED_VALUE_TYPE_DATETIME:
jsonNumber, ok := jsonValue.(json.Number)
if !ok {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected interger type, got %v", jsonNumber))
}
num, err := jsonNumber.Int64()
if err != nil {
return false, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected interger type, got %v", jsonValue))
}
if num == math.MaxInt64 || num == math.MinInt64 {
return nil, nil
}
if tp == enumspb.INDEXED_VALUE_TYPE_BOOL {
return num != 0, nil
}
if tp == enumspb.INDEXED_VALUE_TYPE_DATETIME {
return time.Unix(0, num).UTC().Format(time.RFC3339Nano), nil
}
return num, nil

case enumspb.INDEXED_VALUE_TYPE_DOUBLE:
switch v := jsonValue.(type) {
case json.Number:
num, err := v.Float64()
if err != nil {
return false, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected float type, got %v", jsonValue))
}
return num, nil
case string:
// it can be the string representation of infinity
if _, err := strconv.ParseFloat(v, 64); err != nil {
return false, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected float type, got %v", jsonValue))
}
return nil, nil
default:
// it should never reach here
return false, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected float type, got %v", jsonValue))
}

case enumspb.INDEXED_VALUE_TYPE_KEYWORD:
if jsonValue == nil {
return nil, nil
}
if _, ok := jsonValue.(string); !ok {
return false, serviceerror.NewInvalidArgument(fmt.Sprintf(
"Invalid page token: expected string type, got %v", jsonValue))
}
return jsonValue, nil

default:
return nil, serviceerror.NewInternal(fmt.Sprintf(
"Invalid field type in sorter: cannot order by %q",
fieldName,
))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ func (s *ESVisibilitySuite) TestScanWorkflowExecutions() {
Hits: []*elastic.SearchHit{
{
Source: source,
Sort: []interface{}{json.Number("123"), "runId"},
Sort: []interface{}{json.Number("123")},
},
},
},
Expand All @@ -1156,15 +1156,15 @@ func (s *ESVisibilitySuite) TestScanWorkflowExecutions() {
request.Query = `ExecutionStatus = "Terminated"`
s.mockESClient.EXPECT().Search(gomock.Any(), gomock.Any()).Return(searchResult, nil)

token := &visibilityPageToken{SearchAfter: []interface{}{json.Number("1528358645123456789"), "qwe"}}
token := &visibilityPageToken{SearchAfter: []interface{}{json.Number("1528358645123456789")}}
tokenBytes, err := s.visibilityStore.serializePageToken(token)
s.NoError(err)
request.NextPageToken = tokenBytes
result, err := s.visibilityStore.ScanWorkflowExecutions(context.Background(), request)
s.NoError(err)
responseToken, err := s.visibilityStore.deserializePageToken(result.NextPageToken)
s.NoError(err)
s.Equal([]interface{}{json.Number("123"), "runId"}, responseToken.SearchAfter)
s.Equal([]interface{}{json.Number("123")}, responseToken.SearchAfter)

// test last page
searchResult = &elastic.SearchResult{
Expand Down Expand Up @@ -1209,7 +1209,7 @@ func (s *ESVisibilitySuite) TestScanWorkflowExecutions_OldPageToken() {
Hits: []*elastic.SearchHit{
{
Source: source,
Sort: []interface{}{json.Number("123"), "runId"},
Sort: []interface{}{json.Number("123")},
},
},
},
Expand All @@ -1224,7 +1224,7 @@ func (s *ESVisibilitySuite) TestScanWorkflowExecutions_OldPageToken() {
ScrollID string
PointInTimeID string
}{
SearchAfter: []interface{}{json.Number("1528358645123456789"), "qwe"},
SearchAfter: []interface{}{json.Number("1528358645123456789")},
ScrollID: "random-scroll",
PointInTimeID: "random-pit",
}
Expand All @@ -1236,7 +1236,7 @@ func (s *ESVisibilitySuite) TestScanWorkflowExecutions_OldPageToken() {
s.NoError(err)
responseToken, err := s.visibilityStore.deserializePageToken(result.NextPageToken)
s.NoError(err)
s.Equal([]interface{}{json.Number("123"), "runId"}, responseToken.SearchAfter)
s.Equal([]interface{}{json.Number("123")}, responseToken.SearchAfter)
}

func (s *ESVisibilitySuite) TestCountWorkflowExecutions() {
Expand Down

0 comments on commit ebb5dbf

Please sign in to comment.