Skip to content

Commit

Permalink
executor: support single precision value in json agg (#37389)
Browse files Browse the repository at this point in the history
close #37287
  • Loading branch information
YangKeao authored Aug 29, 2022
1 parent 622c6d6 commit a1af5af
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 111 deletions.
54 changes: 0 additions & 54 deletions executor/aggfuncs/aggfunc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,60 +218,6 @@ func rowMemDeltaGens(srcChk *chunk.Chunk, dataType *types.FieldType) (memDeltas

type multiArgsUpdateMemDeltaGens func(*chunk.Chunk, []*types.FieldType, []*util.ByItems) (memDeltas []int64, err error)

func defaultMultiArgsMemDeltaGens(srcChk *chunk.Chunk, dataTypes []*types.FieldType, byItems []*util.ByItems) (memDeltas []int64, err error) {
memDeltas = make([]int64, 0)
m := make(map[string]bool)
for i := 0; i < srcChk.NumRows(); i++ {
row := srcChk.GetRow(i)
if row.IsNull(0) {
memDeltas = append(memDeltas, int64(0))
continue
}
datum := row.GetDatum(0, dataTypes[0])
if datum.IsNull() {
memDeltas = append(memDeltas, int64(0))
continue
}

memDelta := int64(0)
key, err := datum.ToString()
if err != nil {
return memDeltas, errors.Errorf("fail to get key - %s", key)
}
if _, ok := m[key]; ok {
memDeltas = append(memDeltas, int64(0))
continue
}
m[key] = true
memDelta += int64(len(key))

memDelta += aggfuncs.DefInterfaceSize
switch dataTypes[1].GetType() {
case mysql.TypeLonglong:
memDelta += aggfuncs.DefUint64Size
case mysql.TypeDouble:
memDelta += aggfuncs.DefFloat64Size
case mysql.TypeString:
val := row.GetString(1)
memDelta += int64(len(val))
case mysql.TypeJSON:
val := row.GetJSON(1)
// +1 for the memory usage of the TypeCode of json
memDelta += int64(len(val.Value) + 1)
case mysql.TypeDuration:
memDelta += aggfuncs.DefDurationSize
case mysql.TypeDate:
memDelta += aggfuncs.DefTimeSize
case mysql.TypeNewDecimal:
memDelta += aggfuncs.DefMyDecimalSize
default:
return memDeltas, errors.Errorf("unsupported type - %v", dataTypes[1].GetType())
}
memDeltas = append(memDeltas, memDelta)
}
return memDeltas, nil
}

type aggMemTest struct {
aggTest aggTest
allocMemDelta int64
Expand Down
26 changes: 5 additions & 21 deletions executor/aggfuncs/func_json_arrayagg.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
)
Expand Down Expand Up @@ -55,24 +54,6 @@ func (e *jsonArrayagg) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Parti
return nil
}

// appendBinary does not support some type such as uint8、types.time,so convert is needed here
for idx, val := range p.entries {
switch x := val.(type) {
case *types.MyDecimal:
float64Val, err := x.ToFloat64()
if err != nil {
return errors.Trace(err)
}
p.entries[idx] = float64Val
case []uint8, types.Time, types.Duration:
strVal, err := types.ToString(x)
if err != nil {
return errors.Trace(err)
}
p.entries[idx] = strVal
}
}

chk.AppendJSON(e.ordinal, json.CreateBinary(p.entries))
return nil
}
Expand All @@ -85,10 +66,13 @@ func (e *jsonArrayagg) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup
return 0, errors.Trace(err)
}

realItem := getRealJSONValue(item, e.args[0].GetType())
realItem, err := getRealJSONValue(item, e.args[0].GetType())
if err != nil {
return 0, errors.Trace(err)
}

switch x := realItem.(type) {
case nil, bool, int64, uint64, float64, string, json.BinaryJSON, json.Opaque, *types.MyDecimal, []uint8, types.Time, types.Duration:
case nil, bool, int64, uint64, float64, string, json.BinaryJSON, json.Opaque:
p.entries = append(p.entries, realItem)
memDelta += getValMemDelta(realItem)
default:
Expand Down
28 changes: 17 additions & 11 deletions executor/aggfuncs/func_json_arrayagg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
)

func TestMergePartialResult4JsonArrayagg(t *testing.T) {
typeList := []byte{mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeString, mysql.TypeJSON}
typeList := []byte{mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeFloat, mysql.TypeString, mysql.TypeJSON}

tests := make([]aggTest, 0, len(typeList))
numRows := 5
Expand All @@ -36,18 +36,19 @@ func TestMergePartialResult4JsonArrayagg(t *testing.T) {
entries2 := make([]interface{}, 0)
entries3 := make([]interface{}, 0)

genFunc := getDataGenFunc(types.NewFieldType(argType))
argFieldType := types.NewFieldType(argType)
genFunc := getDataGenFunc(argFieldType)

for m := 0; m < numRows; m++ {
arg := genFunc(m)
entries1 = append(entries1, arg.GetValue())
entries1 = append(entries1, getJSONValue(arg, argFieldType))
}
// to adapt the `genSrcChk` Chunk format
entries1 = append(entries1, nil)

for m := 2; m < numRows; m++ {
arg := genFunc(m)
entries2 = append(entries2, arg.GetValue())
entries2 = append(entries2, getJSONValue(arg, argFieldType))
}
// to adapt the `genSrcChk` Chunk format
entries2 = append(entries2, nil)
Expand All @@ -64,19 +65,20 @@ func TestMergePartialResult4JsonArrayagg(t *testing.T) {
}

func TestJsonArrayagg(t *testing.T) {
typeList := []byte{mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeString, mysql.TypeJSON}
typeList := []byte{mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeFloat, mysql.TypeString, mysql.TypeJSON}

tests := make([]aggTest, 0, len(typeList))
numRows := 5

for _, argType := range typeList {
entries := make([]interface{}, 0)

genFunc := getDataGenFunc(types.NewFieldType(argType))
argFieldType := types.NewFieldType(argType)
genFunc := getDataGenFunc(argFieldType)

for m := 0; m < numRows; m++ {
arg := genFunc(m)
entries = append(entries, arg.GetValue())
entries = append(entries, getJSONValue(arg, argFieldType))
}
// to adapt the `genSrcChk` Chunk format
entries = append(entries, nil)
Expand All @@ -103,6 +105,8 @@ func jsonArrayaggMemDeltaGens(srcChk *chunk.Chunk, dataType *types.FieldType) (m
switch dataType.GetType() {
case mysql.TypeLonglong:
memDelta += aggfuncs.DefUint64Size
case mysql.TypeFloat:
memDelta += aggfuncs.DefFloat64Size
case mysql.TypeDouble:
memDelta += aggfuncs.DefFloat64Size
case mysql.TypeString:
Expand All @@ -113,11 +117,13 @@ func jsonArrayaggMemDeltaGens(srcChk *chunk.Chunk, dataType *types.FieldType) (m
// +1 for the memory usage of the TypeCode of json
memDelta += int64(len(val.Value) + 1)
case mysql.TypeDuration:
memDelta += aggfuncs.DefDurationSize
val := row.GetDuration(0, dataType.GetDecimal())
memDelta += int64(len(val.String()))
case mysql.TypeDate:
memDelta += aggfuncs.DefTimeSize
val := row.GetTime(0)
memDelta += int64(len(val.String()))
case mysql.TypeNewDecimal:
memDelta += aggfuncs.DefMyDecimalSize
memDelta += aggfuncs.DefFloat64Size
default:
return memDeltas, errors.Errorf("unsupported type - %v", dataType.GetType())
}
Expand All @@ -127,7 +133,7 @@ func jsonArrayaggMemDeltaGens(srcChk *chunk.Chunk, dataType *types.FieldType) (m
}

func TestMemJsonArrayagg(t *testing.T) {
typeList := []byte{mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeString, mysql.TypeJSON}
typeList := []byte{mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeString, mysql.TypeJSON, mysql.TypeDuration, mysql.TypeNewDecimal, mysql.TypeDate}

tests := make([]aggMemTest, 0, len(typeList))
numRows := 5
Expand Down
48 changes: 26 additions & 22 deletions executor/aggfuncs/func_json_objectagg.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,6 @@ func (e *jsonObjectAgg) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Part
return nil
}

// appendBinary does not support some type such as uint8、types.time,so convert is needed here
for key, val := range p.entries {
switch x := val.(type) {
case *types.MyDecimal:
float64Val, err := x.ToFloat64()
if err != nil {
return errors.Trace(err)
}
p.entries[key] = float64Val
case []uint8, types.Time, types.Duration:
strVal, err := types.ToString(x)
if err != nil {
return errors.Trace(err)
}
p.entries[key] = strVal
}
}

chk.AppendJSON(e.ordinal, json.CreateBinary(p.entries))
return nil
}
Expand All @@ -102,9 +84,13 @@ func (e *jsonObjectAgg) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup
return 0, errors.Trace(err)
}

realVal := getRealJSONValue(value, e.args[1].GetType())
realVal, err := getRealJSONValue(value, e.args[1].GetType())
if err != nil {
return 0, errors.Trace(err)
}

switch x := realVal.(type) {
case nil, bool, int64, uint64, float64, string, json.BinaryJSON, json.Opaque, *types.MyDecimal, types.Time, types.Duration:
case nil, bool, int64, uint64, float64, string, json.BinaryJSON, json.Opaque:
if _, ok := p.entries[key]; !ok {
memDelta += int64(len(key)) + getValMemDelta(realVal)
if len(p.entries)+1 > (1<<p.bInMap)*hack.LoadFactorNum/hack.LoadFactorDen {
Expand All @@ -121,7 +107,7 @@ func (e *jsonObjectAgg) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup
return memDelta, nil
}

func getRealJSONValue(value types.Datum, ft *types.FieldType) interface{} {
func getRealJSONValue(value types.Datum, ft *types.FieldType) (interface{}, error) {
realVal := value.Clone().GetValue()
switch value.Kind() {
case types.KindBinaryLiteral, types.KindMysqlBit, types.KindBytes:
Expand All @@ -146,7 +132,25 @@ func getRealJSONValue(value types.Datum, ft *types.FieldType) interface{} {
}
}

return realVal
// appendBinary does not support some type such as uint8、types.time,so convert is needed here
switch x := realVal.(type) {
case float32:
realVal = float64(x)
case *types.MyDecimal:
float64Val, err := x.ToFloat64()
if err != nil {
return nil, errors.Trace(err)
}
realVal = float64Val
case []uint8, types.Time, types.Duration:
strVal, err := types.ToString(x)
if err != nil {
return nil, errors.Trace(err)
}
realVal = strVal
}

return realVal, nil
}

func getValMemDelta(val interface{}) (memDelta int64) {
Expand Down
Loading

0 comments on commit a1af5af

Please sign in to comment.