From bee95497d165b74bc17322b6c33c9fef1f474247 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Mon, 17 Feb 2025 10:29:18 -0800 Subject: [PATCH] [expression] coalesce shoudl cast return values --- enginetest/queries/queries.go | 20 ++++++++++++++++ sql/expression/function/coalesce.go | 13 +++++++++- sql/expression/function/coalesce_test.go | 30 ++++++++++++------------ 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index eab345375e..f080edd44a 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -842,6 +842,26 @@ var QueryTests = []QueryTest{ Query: "SELECT 1 WHERE ((1 IN (NULL * 1)) IS NULL);", Expected: []sql.Row{{1}}, }, + { + Query: "select coalesce(1, 0.0);", + Expected: []sql.Row{{"1"}}, + }, + { + Query: "select coalesce(1, '0');", + Expected: []sql.Row{{"1"}}, + }, + { + Query: "select coalesce(1, 'x');", + Expected: []sql.Row{{"1"}}, + }, + { + Query: "select coalesce(1, 1);", + Expected: []sql.Row{{1}}, + }, + { + Query: "select coalesce(1, CAST('2017-08-29' AS DATE))", + Expected: []sql.Row{{"1"}}, + }, { Query: "SELECT count(*) from mytable WHERE ((i IN (NULL >= 1)) IS NULL);", Expected: []sql.Row{{3}}, diff --git a/sql/expression/function/coalesce.go b/sql/expression/function/coalesce.go index 11fc9e04ec..5e1eb703cf 100644 --- a/sql/expression/function/coalesce.go +++ b/sql/expression/function/coalesce.go @@ -27,6 +27,7 @@ import ( // Coalesce returns the first non-NULL value in the list, or NULL if there are no non-NULL values. type Coalesce struct { args []sql.Expression + typ sql.Type } var _ sql.FunctionExpression = (*Coalesce)(nil) @@ -38,7 +39,7 @@ func NewCoalesce(args ...sql.Expression) (sql.Expression, error) { return nil, sql.ErrInvalidArgumentNumber.New("COALESCE", "1 or more", 0) } - return &Coalesce{args}, nil + return &Coalesce{args: args}, nil } // FunctionName implements sql.FunctionExpression @@ -54,6 +55,9 @@ func (c *Coalesce) Description() string { // Type implements the sql.Expression interface. // The return type of Type() is the aggregated type of the argument types. func (c *Coalesce) Type() sql.Type { + if c.typ != nil { + return c.typ + } retType := types.Null for i, arg := range c.args { if arg == nil { @@ -120,6 +124,7 @@ func (c *Coalesce) Type() sql.Type { } } + c.typ = retType return retType } @@ -201,6 +206,12 @@ func (c *Coalesce) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { continue } + if !types.IsEnum(c.Type()) && !types.IsSet(c.Type()) { + val, _, err = c.Type().Convert(val) + if err != nil { + return nil, err + } + } return val, nil } diff --git a/sql/expression/function/coalesce_test.go b/sql/expression/function/coalesce_test.go index 887c975a16..df48bfded4 100644 --- a/sql/expression/function/coalesce_test.go +++ b/sql/expression/function/coalesce_test.go @@ -45,7 +45,7 @@ func TestCoalesce(t *testing.T) { expression.NewLiteral(2, types.Int32), expression.NewLiteral(3, types.Int32), }, - expected: 1, + expected: int32(1), typ: types.Int32, nullable: false, }, @@ -56,7 +56,7 @@ func TestCoalesce(t *testing.T) { nil, expression.NewLiteral(3, types.Int32), }, - expected: 3, + expected: int32(3), typ: types.Int32, nullable: false, }, @@ -100,7 +100,7 @@ func TestCoalesce(t *testing.T) { expression.NewLiteral(decimal.NewFromFloat(2.0), types.MustCreateDecimalType(10, 0)), expression.NewLiteral("3", types.LongText), }, - expected: 1, + expected: "1", typ: types.LongText, nullable: false, }, @@ -110,7 +110,7 @@ func TestCoalesce(t *testing.T) { expression.NewLiteral(1, types.Int32), expression.NewLiteral(2, types.Uint32), }, - expected: 1, + expected: decimal.New(1, 0), typ: types.MustCreateDecimalType(20, 0), nullable: false, }, @@ -120,7 +120,7 @@ func TestCoalesce(t *testing.T) { expression.NewLiteral(1, types.Int32), expression.NewLiteral(2, types.Uint32), }, - expected: 1, + expected: decimal.New(1, 0), typ: types.MustCreateDecimalType(20, 0), nullable: false, }, @@ -130,7 +130,7 @@ func TestCoalesce(t *testing.T) { expression.NewLiteral(1, types.MustCreateDecimalType(10, 0)), expression.NewLiteral(2, types.Float64), }, - expected: 1, + expected: float64(1), typ: types.Float64, nullable: false, }, @@ -139,7 +139,7 @@ func TestCoalesce(t *testing.T) { input: []sql.Expression{ expression.NewLiteral(2, types.Float64), }, - expected: 2, + expected: float64(2), typ: types.Float64, nullable: false, }, @@ -148,7 +148,7 @@ func TestCoalesce(t *testing.T) { input: []sql.Expression{ expression.NewLiteral(1, types.Float64), }, - expected: 1, + expected: float64(1), typ: types.Float64, nullable: false, }, @@ -158,7 +158,7 @@ func TestCoalesce(t *testing.T) { expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)), expression.NewLiteral(2, types.NewSystemIntType("int2", 0, 10, false)), }, - expected: 1, + expected: int64(1), typ: types.Int64, nullable: false, }, @@ -168,7 +168,7 @@ func TestCoalesce(t *testing.T) { expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)), expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)), }, - expected: 1, + expected: decimal.New(1, 0), typ: types.MustCreateDecimalType(20, 0), nullable: false, }, @@ -178,7 +178,7 @@ func TestCoalesce(t *testing.T) { expression.NewLiteral(1, types.NewSystemUintType("int1", 0, 10)), expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)), }, - expected: 1, + expected: uint64(1), typ: types.Uint64, nullable: false, }, @@ -188,7 +188,7 @@ func TestCoalesce(t *testing.T) { expression.NewLiteral(1.0, types.NewSystemDoubleType("dbl1", 0.0, 10.0)), expression.NewLiteral(2.0, types.NewSystemDoubleType("dbl2", 0.0, 10.0)), }, - expected: 1.0, + expected: float64(1), typ: types.Float64, nullable: false, }, @@ -249,19 +249,19 @@ func TestComposeCoalasce(t *testing.T) { require.Equal(t, types.Int32, c2.Type()) v, err = c2.Eval(ctx, nil) require.NoError(t, err) - require.Equal(t, 1, v) + require.Equal(t, int32(1), v) c3, err := NewCoalesce(nil, c1, c2) require.NoError(t, err) require.Equal(t, types.Int32, c3.Type()) v, err = c3.Eval(ctx, nil) require.NoError(t, err) - require.Equal(t, 1, v) + require.Equal(t, int32(1), v) c4, err := NewCoalesce(expression.NewLiteral(nil, types.Null), c1, c2) require.NoError(t, err) require.Equal(t, types.Int32, c4.Type()) v, err = c4.Eval(ctx, nil) require.NoError(t, err) - require.Equal(t, 1, v) + require.Equal(t, int32(1), v) }