diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index 62be52745d..6adaa7d452 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -819,7 +819,10 @@ func widenJSONValues(val interface{}) sql.JSONWrapper { js = types.MustJSON(str) } - doc := js.ToInterface() + doc, err := js.ToInterface() + if err != nil { + panic(err) + } if _, ok := js.(sql.Statistic); ok { // avoid comparing time values in statistics diff --git a/sql/expression/function/aggregation/json_agg.go b/sql/expression/function/aggregation/json_agg.go index 7b35cf5fe6..952a5b3b33 100644 --- a/sql/expression/function/aggregation/json_agg.go +++ b/sql/expression/function/aggregation/json_agg.go @@ -159,7 +159,10 @@ func (j *jsonObjectBuffer) Update(ctx *sql.Context, row sql.Row) error { // unwrap JSON values if js, ok := val.(sql.JSONWrapper); ok { - val = js.ToInterface() + val, err = js.ToInterface() + if err != nil { + return err + } } // Update the map. diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index 07729f68b4..ec631f3026 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -647,7 +647,10 @@ func (j *jsonArrayBuffer) Update(ctx *sql.Context, row sql.Row) error { // unwrap JSON values if js, ok := v.(sql.JSONWrapper); ok { - v = js.ToInterface() + v, err = js.ToInterface() + if err != nil { + return err + } } j.vals = append(j.vals, v) diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index bd25373008..ad6783f0e6 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -1050,7 +1050,10 @@ func (a *WindowedJSONArrayAgg) aggregateVals(ctx *sql.Context, interval sql.Wind // unwrap JSON values if js, ok := v.(sql.JSONWrapper); ok { - v = js.ToInterface() + v, err = js.ToInterface() + if err != nil { + return nil, err + } } vals = append(vals, v) @@ -1134,7 +1137,10 @@ func (a *WindowedJSONObjectAgg) aggregateVals(ctx *sql.Context, interval sql.Win // unwrap JSON values if js, ok := val.(sql.JSONWrapper); ok { - val = js.ToInterface() + val, err = js.ToInterface() + if err != nil { + return nil, err + } } // Update the map. diff --git a/sql/expression/function/json/json_array.go b/sql/expression/function/json/json_array.go index da2742ca12..ab0fb49bb9 100644 --- a/sql/expression/function/json/json_array.go +++ b/sql/expression/function/json/json_array.go @@ -113,7 +113,10 @@ func (j *JSONArray) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { switch v := val.(type) { case sql.JSONWrapper: - val = v.ToInterface() + val, err = v.ToInterface() + if err != nil { + return nil, err + } case []byte: val = string(v) } diff --git a/sql/expression/function/json/json_common.go b/sql/expression/function/json/json_common.go index ea28801fd1..222416ab73 100644 --- a/sql/expression/function/json/json_common.go +++ b/sql/expression/function/json/json_common.go @@ -71,7 +71,11 @@ func getJSONDocumentFromRow(ctx *sql.Context, row sql.Row, json sql.Expression) doc, ok := converted.(types.JSONDocument) if !ok { // This should never happen, but just in case. - doc = types.JSONDocument{Val: js.(sql.JSONWrapper).ToInterface()} + val, err := js.(sql.JSONWrapper).ToInterface() + if err != nil { + return nil, err + } + doc = types.JSONDocument{Val: val} } return &doc, nil diff --git a/sql/expression/function/json/json_contains.go b/sql/expression/function/json/json_contains.go index 9b15dfe1d5..64852c7b80 100644 --- a/sql/expression/function/json/json_contains.go +++ b/sql/expression/function/json/json_contains.go @@ -162,7 +162,15 @@ func (j *JSONContains) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) } // Now determine whether the candidate value exists in the target - return types.ContainsJSON(target.ToInterface(), candidate.ToInterface()) + targetVal, err := target.ToInterface() + if err != nil { + return nil, err + } + candidateVal, err := candidate.ToInterface() + if err != nil { + return nil, err + } + return types.ContainsJSON(targetVal, candidateVal) } func (j *JSONContains) Children() []sql.Expression { diff --git a/sql/expression/function/json/json_object.go b/sql/expression/function/json/json_object.go index 2b031db83c..73c93c6c40 100644 --- a/sql/expression/function/json/json_object.go +++ b/sql/expression/function/json/json_object.go @@ -109,7 +109,10 @@ func (j JSONObject) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { key = val.(string) } else { if json, ok := val.(sql.JSONWrapper); ok { - val = json.ToInterface() + val, err = json.ToInterface() + if err != nil { + return nil, err + } } obj[key] = val } diff --git a/sql/expression/function/json/json_value.go b/sql/expression/function/json/json_value.go index 89516bb3d7..ddb7b9e102 100644 --- a/sql/expression/function/json/json_value.go +++ b/sql/expression/function/json/json_value.go @@ -182,7 +182,7 @@ func GetJSONFromWrapperOrCoercibleString(js interface{}) (jsonData interface{}, } return jsonData, nil case sql.JSONWrapper: - return jsType.ToInterface(), nil + return jsType.ToInterface() default: return nil, InvalidJsonArgument.New() } diff --git a/sql/fulltext/fulltext.go b/sql/fulltext/fulltext.go index dd1031cd79..e67dd464e3 100644 --- a/sql/fulltext/fulltext.go +++ b/sql/fulltext/fulltext.go @@ -175,6 +175,14 @@ func writeHashedValue(h hash.Hash, val interface{}) (valIsNull bool, err error) if _, err := h.Write([]byte(str)); err != nil { return false, err } + case *types.LazyJSONDocument: + str, err := types.StringifyJSON(val) + if err != nil { + return false, err + } + if _, err := h.Write([]byte(str)); err != nil { + return false, err + } case nil: return true, nil default: diff --git a/sql/plan/histogram.go b/sql/plan/histogram.go index da8a6722d6..fa0c76177c 100644 --- a/sql/plan/histogram.go +++ b/sql/plan/histogram.go @@ -1,10 +1,11 @@ package plan import ( - "encoding/json" "fmt" "strings" + "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/go-mysql-server/sql" ) @@ -58,8 +59,7 @@ func (u *UpdateHistogram) Resolved() bool { } func (u *UpdateHistogram) String() string { - statMap := u.stats.ToInterface() - statBytes, _ := json.Marshal(statMap) + statBytes, _ := types.MarshallJson(u.stats) return fmt.Sprintf("update histogram %s.(%s) using %s", u.table, strings.Join(u.cols, ","), statBytes) } diff --git a/sql/statistics.go b/sql/statistics.go index fdd123ee94..826aa7ac9a 100644 --- a/sql/statistics.go +++ b/sql/statistics.go @@ -142,7 +142,7 @@ func (h Histogram) IsEmpty() bool { return len(h) == 0 } -func (h Histogram) ToInterface() interface{} { +func (h Histogram) ToInterface() (interface{}, error) { ret := make([]interface{}, len(h)) for i, b := range h { var upperBound Row @@ -167,7 +167,7 @@ func (h Histogram) ToInterface() interface{} { "upper_bound": upperBound, } } - return ret + return ret, nil } func (h Histogram) DebugString() string { @@ -216,5 +216,5 @@ type HistogramBucket interface { // by minimizing the need to unmarshall a JSONWrapper into a JSONDocument. type JSONWrapper interface { // ToInterface converts a JSONWrapper to an interface{} of simple types - ToInterface() interface{} + ToInterface() (interface{}, error) } diff --git a/sql/stats/statistic.go b/sql/stats/statistic.go index c077d86254..e9a1ab8adb 100644 --- a/sql/stats/statistic.go +++ b/sql/stats/statistic.go @@ -209,12 +209,17 @@ func (s *Statistic) IndexClass() sql.IndexClass { return sql.IndexClass(s.IdxClass) } -func (s *Statistic) ToInterface() interface{} { +func (s *Statistic) ToInterface() (interface{}, error) { typs := make([]string, len(s.Typs)) for i, t := range s.Typs { typs[i] = t.String() } + buckets, err := s.Histogram().ToInterface() + if err != nil { + return nil, err + } + return map[string]interface{}{ "statistic": map[string]interface{}{ "row_count": s.RowCount(), @@ -225,9 +230,9 @@ func (s *Statistic) ToInterface() interface{} { "qualifier": s.Qualifier().String(), "columns": s.Columns(), "types:": typs, - "buckets": s.Histogram().ToInterface(), + "buckets": buckets, }, - } + }, nil } func ParseTypeStrings(typs []string) ([]sql.Type, error) { diff --git a/sql/types/json.go b/sql/types/json.go index 4112f918af..215dc99b77 100644 --- a/sql/types/json.go +++ b/sql/types/json.go @@ -15,7 +15,6 @@ package types import ( - "bytes" "encoding/json" "reflect" @@ -131,30 +130,30 @@ func (t JsonType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va return sqltypes.NULL, nil } - // Convert to jsonType - jsVal, _, err := t.Convert(v) - if err != nil { - return sqltypes.NULL, err - } - js := jsVal.(sql.JSONWrapper) - var val []byte - switch j := js.(type) { - case JSONStringer: - str, err := j.JSONString() + + // If we read the JSON from a table, pass through the bytes to avoid a deserialization and reserialization round-trip. + // This is kind of a hack, and it means that reading JSON from tables no longer matches MySQL byte-for-byte. + // But its worth it to avoid the round-trip, which can be very slow. + if j, ok := v.(*LazyJSONDocument); ok { + str, err := MarshallJson(j) if err != nil { return sqltypes.NULL, err } - val = AppendAndSliceString(dest, str) - default: - jsonBytes, err := json.Marshal(js.ToInterface()) + val = AppendAndSliceBytes(dest, str) + } else { + // Convert to jsonType + jsVal, _, err := t.Convert(v) if err != nil { return sqltypes.NULL, err } + js := jsVal.(sql.JSONWrapper) - jsonBytes = bytes.ReplaceAll(jsonBytes, []byte(",\""), []byte(", \"")) - jsonBytes = bytes.ReplaceAll(jsonBytes, []byte("\":"), []byte("\": ")) - val = AppendAndSliceBytes(dest, jsonBytes) + str, err := StringifyJSON(js) + if err != nil { + return sqltypes.NULL, err + } + val = AppendAndSliceString(dest, str) } return sqltypes.MakeTrusted(sqltypes.TypeJSON, val), nil diff --git a/sql/types/json_test.go b/sql/types/json_test.go index 9541eead0c..360292189c 100644 --- a/sql/types/json_test.go +++ b/sql/types/json_test.go @@ -212,6 +212,29 @@ func TestValuer(t *testing.T) { require.Equal(t, `{"a": "one"}`, res) } +func TestLazyJsonDocument(t *testing.T) { + testCases := []struct { + s string + json interface{} + }{ + {`"1"`, "1"}, + {`{"a": [1.0, null]}`, map[string]any{"a": []any{1.0, nil}}}, + } + for _, testCase := range testCases { + t.Run(testCase.s, func(t *testing.T) { + doc := NewLazyJSONDocument([]byte(testCase.s)) + val, err := doc.ToInterface() + require.NoError(t, err) + require.Equal(t, testCase.json, val) + }) + } + t.Run("lazy docs only error when deserialized", func(t *testing.T) { + doc := NewLazyJSONDocument([]byte("not valid json")) + _, err := doc.ToInterface() + require.Error(t, err) + }) +} + type JsonRoundtripTest struct { desc string input string diff --git a/sql/types/json_value.go b/sql/types/json_value.go index 8df3a6127f..14d4950cf6 100644 --- a/sql/types/json_value.go +++ b/sql/types/json_value.go @@ -16,12 +16,14 @@ package types import ( "database/sql/driver" + "encoding/json" "fmt" "io" "regexp" "sort" "strconv" "strings" + "sync" "github.com/dolthub/jsonpath" "github.com/shopspring/decimal" @@ -30,10 +32,40 @@ import ( "github.com/dolthub/go-mysql-server/sql" ) +// JSONStringer can be converted to a string representation that is compatible with MySQL's JSON output, including spaces. type JSONStringer interface { JSONString() (string, error) } +// StringifyJSON generates a string representation of a sql.JSONWrapper that is compatible with MySQL's JSON output, including spaces. +func StringifyJSON(jsonWrapper sql.JSONWrapper) (string, error) { + if stringer, ok := jsonWrapper.(JSONStringer); ok { + return stringer.JSONString() + } + val, err := jsonWrapper.ToInterface() + if err != nil { + return "", err + } + return marshalToMySqlString(val) +} + +// JSONBytes are values which can be represented as JSON. +type JSONBytes interface { + GetBytes() ([]byte, error) +} + +// JSONBytes returns or generates a byte array for the JSON representation of the underlying sql.JSONWrapper +func MarshallJson(jsonWrapper sql.JSONWrapper) ([]byte, error) { + if bytes, ok := jsonWrapper.(JSONBytes); ok { + return bytes.GetBytes() + } + val, err := jsonWrapper.ToInterface() + if err != nil { + return []byte{}, err + } + return json.Marshal(val) +} + type JsonObject = map[string]interface{} type JsonArray = []interface{} @@ -60,13 +92,18 @@ type JSONDocument struct { } var _ sql.JSONWrapper = JSONDocument{} +var _ MutableJSON = JSONDocument{} -func (doc JSONDocument) ToInterface() interface{} { - return doc.Val +func (doc JSONDocument) ToInterface() (interface{}, error) { + return doc.Val, nil } func (doc JSONDocument) Compare(other sql.JSONWrapper) (int, error) { - return CompareJSON(doc.Val, other.ToInterface()) + otherVal, err := other.ToInterface() + if err != nil { + return 0, err + } + return CompareJSON(doc.Val, otherVal) } func (doc JSONDocument) JSONString() (string, error) { @@ -82,19 +119,52 @@ func (doc JSONDocument) String() string { return result } -var _ sql.JSONWrapper = JSONDocument{} -var _ MutableJSON = JSONDocument{} - // Contains returns nil in case of a nil value for either the doc.Val or candidate. Otherwise // it returns a bool func (doc JSONDocument) Contains(candidate sql.JSONWrapper) (val interface{}, err error) { - return ContainsJSON(doc.Val, candidate.ToInterface()) + candidateVal, err := candidate.ToInterface() + if err != nil { + return nil, err + } + return ContainsJSON(doc.Val, candidateVal) } func (doc JSONDocument) Extract(path string) (sql.JSONWrapper, error) { return LookupJSONValue(doc, path) } +// LazyJSONDocument is an implementation of sql.JSONWrapper that wraps a JSON string and defers deserializing +// it unless needed. This is more efficient for queries that interact with JSON values but don't care about their structure. +type LazyJSONDocument struct { + Bytes []byte + interfaceFunc func() (interface{}, error) +} + +var _ sql.JSONWrapper = &LazyJSONDocument{} +var _ JSONBytes = &LazyJSONDocument{} + +func NewLazyJSONDocument(bytes []byte) sql.JSONWrapper { + return &LazyJSONDocument{ + Bytes: bytes, + interfaceFunc: sync.OnceValues(func() (interface{}, error) { + var val interface{} + err := json.Unmarshal(bytes, &val) + if err != nil { + return nil, err + } + return val, nil + }), + } +} + +func (j *LazyJSONDocument) ToInterface() (interface{}, error) { + return j.interfaceFunc() +} + +func (j *LazyJSONDocument) GetBytes() ([]byte, error) { + return j.Bytes, nil +} + func LookupJSONValue(j sql.JSONWrapper, path string) (sql.JSONWrapper, error) { if path == "$" { // Special case the identity operation to handle a nil value for doc.Val @@ -113,7 +183,10 @@ func LookupJSONValue(j sql.JSONWrapper, path string) (sql.JSONWrapper, error) { // Lookup(obj) throws an error if obj is nil. We want lookups on a json null // to always result in sql NULL, except in the case of the identity lookup // $. - r := j.ToInterface() + r, err := j.ToInterface() + if err != nil { + return nil, err + } if r == nil { return nil, nil } @@ -147,9 +220,13 @@ func (doc JSONDocument) Value() (driver.Value, error) { } func ConcatenateJSONValues(ctx *sql.Context, vals ...sql.JSONWrapper) (sql.JSONWrapper, error) { + var err error arr := make(JsonArray, len(vals)) for i, v := range vals { - arr[i] = v.ToInterface() + arr[i], err = v.ToInterface() + if err != nil { + return nil, err + } } return JSONDocument{Val: arr}, nil } @@ -349,6 +426,7 @@ func containsJSONNumber(a float64, b interface{}) (bool, error) { // // https://dev.mysql.com/doc/refman/8.0/en/json.html#json-comparison func CompareJSON(a, b interface{}) (int, error) { + var err error if hasNulls, res := CompareNulls(b, a); hasNulls { return res, nil } @@ -389,9 +467,16 @@ func CompareJSON(a, b interface{}) (int, error) { return compareJSONNumber(af, b) case sql.JSONWrapper: if jw, ok := b.(sql.JSONWrapper); ok { - b = jw.ToInterface() + b, err = jw.ToInterface() + if err != nil { + return 0, err + } + } + aVal, err := a.ToInterface() + if err != nil { + return 0, err } - return CompareJSON(a.ToInterface(), b) + return CompareJSON(aVal, b) default: return 0, sql.ErrInvalidType.New(a) } @@ -655,7 +740,10 @@ func (doc JSONDocument) unwrapAndExecute(path string, val sql.JSONWrapper, mode var err error var unmarshalled interface{} if val != nil { - unmarshalled = val.ToInterface() + unmarshalled, err = val.ToInterface() + if err != nil { + return nil, false, err + } } else if mode != REMOVE { return nil, false, fmt.Errorf("Invariant violation. value may not be nil") } diff --git a/sql/types/number.go b/sql/types/number.go index 73a5e51e67..f2cc14ef9a 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -233,6 +233,7 @@ func (t NumberTypeImpl_) Compare(a interface{}, b interface{}) (int, error) { // Convert implements Type interface. func (t NumberTypeImpl_) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) { + var err error if v == nil { return nil, sql.InRange, nil } @@ -242,7 +243,10 @@ func (t NumberTypeImpl_) Convert(v interface{}) (interface{}, sql.ConvertInRange } if jv, ok := v.(sql.JSONWrapper); ok { - v = jv.ToInterface() + v, err = jv.ToInterface() + if err != nil { + return nil, sql.OutOfRange, err + } } switch t.baseType { diff --git a/sql/types/strings.go b/sql/types/strings.go index 0b5b2abd77..db1451c70a 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -15,7 +15,6 @@ package types import ( - "encoding/json" "fmt" "reflect" "strconv" @@ -372,20 +371,17 @@ func ConvertToString(v interface{}, t sql.StringType) (string, error) { return "", nil } val = s.Decimal.String() - - case JSONStringer: - var err error - val, err = s.JSONString() + case JSONDocument: + jsonString, err := StringifyJSON(s) if err != nil { return "", err } - val, err = strings.Unquote(val) + val, err = strings.Unquote(jsonString) if err != nil { return "", err } case sql.JSONWrapper: - jsonInterface := s.ToInterface() - jsonBytes, err := json.Marshal(jsonInterface) + jsonBytes, err := MarshallJson(s) if err != nil { return "", err }