From bb50e33509b3cb6ccf6256feecd609904d4ae01e Mon Sep 17 00:00:00 2001 From: xiongjiwei Date: Wed, 14 Dec 2022 17:16:52 +0800 Subject: [PATCH] expression: support `member of` function (#39880) ref pingcap/tidb#39866 --- executor/showtest/show_test.go | 2 +- expression/builtin.go | 1 + expression/builtin_json.go | 64 +++++++++++++++++++++++++++++++++ expression/builtin_json_test.go | 41 +++++++++++++++++++++ expression/builtin_json_vec.go | 53 +++++++++++++++++++++++++++ expression/integration_test.go | 17 +++++++++ types/json_binary.go | 7 ++-- types/json_binary_functions.go | 46 ++++++++++++------------ 8 files changed, 204 insertions(+), 27 deletions(-) diff --git a/executor/showtest/show_test.go b/executor/showtest/show_test.go index bcbd7c0016f85..0573de30137f6 100644 --- a/executor/showtest/show_test.go +++ b/executor/showtest/show_test.go @@ -1515,7 +1515,7 @@ func TestShowBuiltin(t *testing.T) { res := tk.MustQuery("show builtins;") require.NotNil(t, res) rows := res.Rows() - const builtinFuncNum = 284 + const builtinFuncNum = 285 require.Equal(t, builtinFuncNum, len(rows)) require.Equal(t, rows[0][0].(string), "abs") require.Equal(t, rows[builtinFuncNum-1][0].(string), "yearweek") diff --git a/expression/builtin.go b/expression/builtin.go index 18e78ba17bdb8..66abac551a3e6 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -873,6 +873,7 @@ var funcs = map[string]functionClass{ ast.JSONMerge: &jsonMergeFunctionClass{baseFunctionClass{ast.JSONMerge, 2, -1}}, ast.JSONObject: &jsonObjectFunctionClass{baseFunctionClass{ast.JSONObject, 0, -1}}, ast.JSONArray: &jsonArrayFunctionClass{baseFunctionClass{ast.JSONArray, 0, -1}}, + ast.JSONMemberOf: &jsonMemberOfFunctionClass{baseFunctionClass{ast.JSONMemberOf, 2, 2}}, ast.JSONContains: &jsonContainsFunctionClass{baseFunctionClass{ast.JSONContains, 2, 3}}, ast.JSONOverlaps: &jsonOverlapsFunctionClass{baseFunctionClass{ast.JSONOverlaps, 2, 2}}, ast.JSONContainsPath: &jsonContainsPathFunctionClass{baseFunctionClass{ast.JSONContainsPath, 3, -1}}, diff --git a/expression/builtin_json.go b/expression/builtin_json.go index eeabef6fe2880..e9f803dcc86df 100644 --- a/expression/builtin_json.go +++ b/expression/builtin_json.go @@ -43,6 +43,7 @@ var ( _ functionClass = &jsonMergeFunctionClass{} _ functionClass = &jsonObjectFunctionClass{} _ functionClass = &jsonArrayFunctionClass{} + _ functionClass = &jsonMemberOfFunctionClass{} _ functionClass = &jsonContainsFunctionClass{} _ functionClass = &jsonOverlapsFunctionClass{} _ functionClass = &jsonContainsPathFunctionClass{} @@ -72,6 +73,7 @@ var ( _ builtinFunc = &builtinJSONReplaceSig{} _ builtinFunc = &builtinJSONRemoveSig{} _ builtinFunc = &builtinJSONMergeSig{} + _ builtinFunc = &builtinJSONMemberOfSig{} _ builtinFunc = &builtinJSONContainsSig{} _ builtinFunc = &builtinJSONOverlapsSig{} _ builtinFunc = &builtinJSONStorageSizeSig{} @@ -742,6 +744,68 @@ func jsonModify(ctx sessionctx.Context, args []Expression, row chunk.Row, mt typ return res, false, nil } +type jsonMemberOfFunctionClass struct { + baseFunctionClass +} + +type builtinJSONMemberOfSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONMemberOfSig) Clone() builtinFunc { + newSig := &builtinJSONMemberOfSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonMemberOfFunctionClass) verifyArgs(args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[1].GetType().EvalType(); evalType != types.ETJson && evalType != types.ETString { + return types.ErrInvalidJSONData.GenWithStackByArgs(2, "member of") + } + return nil +} + +func (c *jsonMemberOfFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := []types.EvalType{types.ETJson, types.ETJson} + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) + if err != nil { + return nil, err + } + DisableParseJSONFlag4Expr(args[0]) + sig := &builtinJSONMemberOfSig{bf} + return sig, nil +} + +func (b *builtinJSONMemberOfSig) evalInt(row chunk.Row) (res int64, isNull bool, err error) { + target, isNull, err := b.args[0].EvalJSON(b.ctx, row) + if isNull || err != nil { + return res, isNull, err + } + obj, isNull, err := b.args[1].EvalJSON(b.ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + if obj.TypeCode != types.JSONTypeCodeArray { + return boolToInt64(types.CompareBinaryJSON(obj, target) == 0), false, nil + } + + elemCount := obj.GetElemCount() + for i := 0; i < elemCount; i++ { + if types.CompareBinaryJSON(obj.ArrayGetElem(i), target) == 0 { + return 1, false, nil + } + } + + return 0, false, nil +} + type jsonContainsFunctionClass struct { baseFunctionClass } diff --git a/expression/builtin_json_test.go b/expression/builtin_json_test.go index 72e3c725594c4..23da60f380652 100644 --- a/expression/builtin_json_test.go +++ b/expression/builtin_json_test.go @@ -386,6 +386,47 @@ func TestJSONRemove(t *testing.T) { } } +func TestJSONMemberOf(t *testing.T) { + ctx := createContext(t) + fc := funcs[ast.JSONMemberOf] + tbl := []struct { + input []interface{} + expected interface{} + err error + }{ + {[]interface{}{`1`, `a:1`}, 1, types.ErrInvalidJSONText}, + + {[]interface{}{1, `[1, 2]`}, 1, nil}, + {[]interface{}{1, `[1]`}, 1, nil}, + {[]interface{}{1, `[0]`}, 0, nil}, + {[]interface{}{1, `[1]`}, 1, nil}, + {[]interface{}{1, `[[1]]`}, 0, nil}, + {[]interface{}{"1", `[1]`}, 0, nil}, + {[]interface{}{"1", `["1"]`}, 1, nil}, + {[]interface{}{`{"a":1}`, `{"a":1}`}, 0, nil}, + {[]interface{}{`{"a":1}`, `[{"a":1}]`}, 0, nil}, + {[]interface{}{`{"a":1}`, `[{"a":1}, 1]`}, 0, nil}, + {[]interface{}{`{"a":1}`, `["{\"a\":1}"]`}, 1, nil}, + {[]interface{}{`{"a":1}`, `["{\"a\":1}", 1]`}, 1, nil}, + } + for _, tt := range tbl { + args := types.MakeDatums(tt.input...) + f, err := fc.getFunction(ctx, datumsToConstants(args)) + require.NoError(t, err, tt.input) + d, err := evalBuiltinFunc(f, chunk.Row{}) + if tt.err == nil { + require.NoError(t, err, tt.input) + if tt.expected == nil { + require.True(t, d.IsNull(), tt.input) + } else { + require.Equal(t, int64(tt.expected.(int)), d.GetInt64(), tt.input) + } + } else { + require.True(t, tt.err.(*terror.Error).Equal(err), tt.input) + } + } +} + func TestJSONContains(t *testing.T) { ctx := createContext(t) fc := funcs[ast.JSONContains] diff --git a/expression/builtin_json_vec.go b/expression/builtin_json_vec.go index 45cca97232d2c..0610a1f6ea3ca 100644 --- a/expression/builtin_json_vec.go +++ b/expression/builtin_json_vec.go @@ -274,6 +274,59 @@ func (b *builtinJSONArraySig) vecEvalJSON(input *chunk.Chunk, result *chunk.Colu return nil } +func (b *builtinJSONMemberOfSig) vectorized() bool { + return true +} + +func (b *builtinJSONMemberOfSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error { + nr := input.NumRows() + + targetCol, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(targetCol) + + if err := b.args[0].VecEvalJSON(b.ctx, input, targetCol); err != nil { + return err + } + + objCol, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(objCol) + + if err := b.args[1].VecEvalJSON(b.ctx, input, objCol); err != nil { + return err + } + + result.ResizeInt64(nr, false) + resI64s := result.Int64s() + + result.MergeNulls(targetCol, objCol) + for i := 0; i < nr; i++ { + if result.IsNull(i) { + continue + } + obj := objCol.GetJSON(i) + target := targetCol.GetJSON(i) + if obj.TypeCode != types.JSONTypeCodeArray { + resI64s[i] = boolToInt64(types.CompareBinaryJSON(obj, target) == 0) + } else { + elemCount := obj.GetElemCount() + for j := 0; j < elemCount; j++ { + if types.CompareBinaryJSON(obj.ArrayGetElem(j), target) == 0 { + resI64s[i] = 1 + break + } + } + } + } + + return nil +} + func (b *builtinJSONContainsSig) vectorized() bool { return true } diff --git a/expression/integration_test.go b/expression/integration_test.go index 797742d528ca6..3efa93eff5d13 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2722,6 +2722,23 @@ func TestFuncJSON(t *testing.T) { tk.MustExec("insert into tx1 values (1, 0.1, 0.2, 0.3, 0.0)") tk.MustQuery("select a+b, c from tx1").Check(testkit.Rows("0.30000000000000004 0.3")) tk.MustQuery("select json_array(a+b) = json_array(c) from tx1").Check(testkit.Rows("0")) + + tk.MustQuery("SELECT '{\"a\":1}' MEMBER OF('{\"a\":1}');").Check(testkit.Rows("0")) + tk.MustQuery("SELECT '{\"a\":1}' MEMBER OF('[{\"a\":1}]');").Check(testkit.Rows("0")) + tk.MustQuery("SELECT 1 MEMBER OF('1');").Check(testkit.Rows("1")) + tk.MustQuery("SELECT '{\"a\":1}' MEMBER OF('{\"a\":1}');").Check(testkit.Rows("0")) + tk.MustQuery("SELECT '[4,5]' MEMBER OF('[[3,4],[4,5]]');").Check(testkit.Rows("0")) + tk.MustQuery("SELECT '[4,5]' MEMBER OF('[[3,4],\"[4,5]\"]');").Check(testkit.Rows("1")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a enum('a', 'b'), b time, c binary(10))") + tk.MustExec("insert into t values ('a', '11:00:00', 'a')") + tk.MustQuery("select a member of ('\"a\"') from t").Check(testkit.Rows(`1`)) + tk.MustQuery("select b member of (json_array(cast('11:00:00' as time))) from t;").Check(testkit.Rows(`1`)) + tk.MustQuery("select b member of ('\"11:00:00\"') from t").Check(testkit.Rows(`0`)) + tk.MustQuery("select c member of ('\"a\"') from t").Check(testkit.Rows(`0`)) + err = tk.QueryToErr("select 'a' member of ('a')") + require.Error(t, err, "ERROR 3140 (22032): Invalid JSON text: The document root must not be followed by other values.") } func TestColumnInfoModified(t *testing.T) { diff --git a/types/json_binary.go b/types/json_binary.go index eb9a818b6a7f5..6fe01d2b4f28e 100644 --- a/types/json_binary.go +++ b/types/json_binary.go @@ -275,7 +275,8 @@ func (bj BinaryJSON) GetElemCount() int { return int(jsonEndian.Uint32(bj.Value)) } -func (bj BinaryJSON) arrayGetElem(idx int) BinaryJSON { +// ArrayGetElem gets the element of the index `idx`. +func (bj BinaryJSON) ArrayGetElem(idx int) BinaryJSON { return bj.valEntryGet(headerSize + idx*valEntrySize) } @@ -355,7 +356,7 @@ func (bj BinaryJSON) marshalArrayTo(buf []byte) ([]byte, error) { buf = append(buf, ", "...) } var err error - buf, err = bj.arrayGetElem(i).marshalTo(buf) + buf, err = bj.ArrayGetElem(i).marshalTo(buf) if err != nil { return nil, errors.Trace(err) } @@ -557,7 +558,7 @@ func (bj BinaryJSON) HashValue(buf []byte) []byte { elemCount := int(jsonEndian.Uint32(bj.Value)) buf = append(buf, bj.Value[0:dataSizeOff]...) for i := 0; i < elemCount; i++ { - buf = bj.arrayGetElem(i).HashValue(buf) + buf = bj.ArrayGetElem(i).HashValue(buf) } case JSONTypeCodeObject: // this hash value is bidirectional, because you can get the key using the json diff --git a/types/json_binary_functions.go b/types/json_binary_functions.go index 84bd86445aa94..7c4d7bf7f97bc 100644 --- a/types/json_binary_functions.go +++ b/types/json_binary_functions.go @@ -294,7 +294,7 @@ func (bj BinaryJSON) extractTo(buf []BinaryJSON, pathExpr JSONPathExpression, du start, end := currentLeg.arraySelection.getIndexRange(bj) if start >= 0 && start <= end { for i := start; i <= end; i++ { - buf = bj.arrayGetElem(i).extractTo(buf, subPathExpr, dup, one) + buf = bj.ArrayGetElem(i).extractTo(buf, subPathExpr, dup, one) } } } else if currentLeg.typ == jsonPathLegKey && bj.TypeCode == JSONTypeCodeObject { @@ -314,7 +314,7 @@ func (bj BinaryJSON) extractTo(buf []BinaryJSON, pathExpr JSONPathExpression, du if bj.TypeCode == JSONTypeCodeArray { elemCount := bj.GetElemCount() for i := 0; i < elemCount && !jsonFinished(buf, one); i++ { - buf = bj.arrayGetElem(i).extractTo(buf, pathExpr, dup, one) + buf = bj.ArrayGetElem(i).extractTo(buf, pathExpr, dup, one) } } else if bj.TypeCode == JSONTypeCodeObject { elemCount := bj.GetElemCount() @@ -459,12 +459,12 @@ func (bj BinaryJSON) ArrayInsert(pathExpr JSONPathExpression, value BinaryJSON) // Insert into the array newArray := make([]BinaryJSON, 0, count+1) for i := 0; i < idx; i++ { - elem := obj.arrayGetElem(i) + elem := obj.ArrayGetElem(i) newArray = append(newArray, elem) } newArray = append(newArray, value) for i := idx; i < count; i++ { - elem := obj.arrayGetElem(i) + elem := obj.ArrayGetElem(i) newArray = append(newArray, elem) } obj = buildBinaryJSONArray(newArray) @@ -556,7 +556,7 @@ func (bm *binaryModifier) doInsert(path JSONPathExpression, newBj BinaryJSON) { elemCount := parentBj.GetElemCount() elems := make([]BinaryJSON, 0, elemCount+1) for i := 0; i < elemCount; i++ { - elems = append(elems, parentBj.arrayGetElem(i)) + elems = append(elems, parentBj.ArrayGetElem(i)) } elems = append(elems, newBj) bm.modifyValue = buildBinaryJSONArray(elems) @@ -622,7 +622,7 @@ func (bm *binaryModifier) doRemove(path JSONPathExpression) { elems := make([]BinaryJSON, 0, elemCount-1) for i := 0; i < elemCount; i++ { if i != idx { - elems = append(elems, parentBj.arrayGetElem(i)) + elems = append(elems, parentBj.ArrayGetElem(i)) } } bm.modifyValue = buildBinaryJSONArray(elems) @@ -809,8 +809,8 @@ func CompareBinaryJSON(left, right BinaryJSON) int { leftCount := left.GetElemCount() rightCount := right.GetElemCount() for i := 0; i < leftCount && i < rightCount; i++ { - elem1 := left.arrayGetElem(i) - elem2 := right.arrayGetElem(i) + elem1 := left.ArrayGetElem(i) + elem2 := right.ArrayGetElem(i) cmp = CompareBinaryJSON(elem1, elem2) if cmp != 0 { return cmp @@ -993,7 +993,7 @@ func mergeBinaryArray(elems []BinaryJSON) BinaryJSON { } else { childCount := elem.GetElemCount() for j := 0; j < childCount; j++ { - buf = append(buf, elem.arrayGetElem(j)) + buf = append(buf, elem.ArrayGetElem(j)) } } } @@ -1088,7 +1088,7 @@ func ContainsBinaryJSON(obj, target BinaryJSON) bool { if target.TypeCode == JSONTypeCodeArray { elemCount := target.GetElemCount() for i := 0; i < elemCount; i++ { - if !ContainsBinaryJSON(obj, target.arrayGetElem(i)) { + if !ContainsBinaryJSON(obj, target.ArrayGetElem(i)) { return false } } @@ -1096,7 +1096,7 @@ func ContainsBinaryJSON(obj, target BinaryJSON) bool { } elemCount := obj.GetElemCount() for i := 0; i < elemCount; i++ { - if ContainsBinaryJSON(obj.arrayGetElem(i), target) { + if ContainsBinaryJSON(obj.ArrayGetElem(i), target) { return true } } @@ -1127,9 +1127,9 @@ func OverlapsBinaryJSON(obj, target BinaryJSON) bool { case JSONTypeCodeArray: if target.TypeCode == JSONTypeCodeArray { for i := 0; i < obj.GetElemCount(); i++ { - o := obj.arrayGetElem(i) + o := obj.ArrayGetElem(i) for j := 0; j < target.GetElemCount(); j++ { - if CompareBinaryJSON(o, target.arrayGetElem(j)) == 0 { + if CompareBinaryJSON(o, target.ArrayGetElem(j)) == 0 { return true } } @@ -1138,7 +1138,7 @@ func OverlapsBinaryJSON(obj, target BinaryJSON) bool { } elemCount := obj.GetElemCount() for i := 0; i < elemCount; i++ { - if CompareBinaryJSON(obj.arrayGetElem(i), target) == 0 { + if CompareBinaryJSON(obj.ArrayGetElem(i), target) == 0 { return true } } @@ -1175,7 +1175,7 @@ func (bj BinaryJSON) GetElemDepth() int { elemCount := bj.GetElemCount() maxDepth := 0 for i := 0; i < elemCount; i++ { - obj := bj.arrayGetElem(i) + obj := bj.ArrayGetElem(i) depth := obj.GetElemDepth() if depth > maxDepth { maxDepth = depth @@ -1246,9 +1246,9 @@ func (bj BinaryJSON) extractToCallback(pathExpr JSONPathExpression, callbackFn e switch selection := currentLeg.arraySelection.(type) { case jsonPathArraySelectionAsterisk: for i := 0; i < elemCount; i++ { - // buf = bj.arrayGetElem(i).extractTo(buf, subPathExpr) + // buf = bj.ArrayGetElem(i).extractTo(buf, subPathExpr) path := fullpath.pushBackOneArraySelectionLeg(jsonPathArraySelectionIndex{jsonPathArrayIndexFromStart(i)}) - stop, err = bj.arrayGetElem(i).extractToCallback(subPathExpr, callbackFn, path) + stop, err = bj.ArrayGetElem(i).extractToCallback(subPathExpr, callbackFn, path) if stop || err != nil { return } @@ -1256,9 +1256,9 @@ func (bj BinaryJSON) extractToCallback(pathExpr JSONPathExpression, callbackFn e case jsonPathArraySelectionIndex: idx := selection.index.getIndexFromStart(bj) if idx < elemCount && idx >= 0 { - // buf = bj.arrayGetElem(currentLeg.arraySelection).extractTo(buf, subPathExpr) + // buf = bj.ArrayGetElem(currentLeg.arraySelection).extractTo(buf, subPathExpr) path := fullpath.pushBackOneArraySelectionLeg(currentLeg.arraySelection) - stop, err = bj.arrayGetElem(idx).extractToCallback(subPathExpr, callbackFn, path) + stop, err = bj.ArrayGetElem(idx).extractToCallback(subPathExpr, callbackFn, path) if stop || err != nil { return } @@ -1272,7 +1272,7 @@ func (bj BinaryJSON) extractToCallback(pathExpr JSONPathExpression, callbackFn e if start <= end && start >= 0 { for i := start; i <= end; i++ { path := fullpath.pushBackOneArraySelectionLeg(jsonPathArraySelectionIndex{jsonPathArrayIndexFromStart(i)}) - stop, err = bj.arrayGetElem(i).extractToCallback(subPathExpr, callbackFn, path) + stop, err = bj.ArrayGetElem(i).extractToCallback(subPathExpr, callbackFn, path) if stop || err != nil { return } @@ -1311,9 +1311,9 @@ func (bj BinaryJSON) extractToCallback(pathExpr JSONPathExpression, callbackFn e if bj.TypeCode == JSONTypeCodeArray { elemCount := bj.GetElemCount() for i := 0; i < elemCount; i++ { - // buf = bj.arrayGetElem(i).extractTo(buf, pathExpr) + // buf = bj.ArrayGetElem(i).extractTo(buf, pathExpr) path := fullpath.pushBackOneArraySelectionLeg(jsonPathArraySelectionIndex{jsonPathArrayIndexFromStart(i)}) - stop, err = bj.arrayGetElem(i).extractToCallback(pathExpr, callbackFn, path) + stop, err = bj.ArrayGetElem(i).extractToCallback(pathExpr, callbackFn, path) if stop || err != nil { return } @@ -1357,7 +1357,7 @@ func (bj BinaryJSON) Walk(walkFn BinaryJSONWalkFunc, pathExprList ...JSONPathExp elemCount := bj.GetElemCount() for i := 0; i < elemCount; i++ { path := fullpath.pushBackOneArraySelectionLeg(jsonPathArraySelectionIndex{jsonPathArrayIndexFromStart(i)}) - stop, err = doWalk(path, bj.arrayGetElem(i)) + stop, err = doWalk(path, bj.ArrayGetElem(i)) if stop || err != nil { return }