Skip to content

Commit

Permalink
[expression] coalesce shoudl cast return values (#2853)
Browse files Browse the repository at this point in the history
  • Loading branch information
max-hoffman authored Feb 17, 2025
1 parent debff7a commit 710ac92
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 16 deletions.
20 changes: 20 additions & 0 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,26 @@ var QueryTests = []QueryTest{
Query: "SELECT 1 WHERE ((1 IN (NULL * 1)) IS NULL);",
Expected: []sql.Row{{1}},
},
{
Query: "select coalesce(1, 0.0);",
Expected: []sql.Row{{"1"}},
},
{
Query: "select coalesce(1, '0');",
Expected: []sql.Row{{"1"}},
},
{
Query: "select coalesce(1, 'x');",
Expected: []sql.Row{{"1"}},
},
{
Query: "select coalesce(1, 1);",
Expected: []sql.Row{{1}},
},
{
Query: "select coalesce(1, CAST('2017-08-29' AS DATE))",
Expected: []sql.Row{{"1"}},
},
{
Query: "SELECT count(*) from mytable WHERE ((i IN (NULL >= 1)) IS NULL);",
Expected: []sql.Row{{3}},
Expand Down
13 changes: 12 additions & 1 deletion sql/expression/function/coalesce.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
// Coalesce returns the first non-NULL value in the list, or NULL if there are no non-NULL values.
type Coalesce struct {
args []sql.Expression
typ sql.Type
}

var _ sql.FunctionExpression = (*Coalesce)(nil)
Expand All @@ -38,7 +39,7 @@ func NewCoalesce(args ...sql.Expression) (sql.Expression, error) {
return nil, sql.ErrInvalidArgumentNumber.New("COALESCE", "1 or more", 0)
}

return &Coalesce{args}, nil
return &Coalesce{args: args}, nil
}

// FunctionName implements sql.FunctionExpression
Expand All @@ -54,6 +55,9 @@ func (c *Coalesce) Description() string {
// Type implements the sql.Expression interface.
// The return type of Type() is the aggregated type of the argument types.
func (c *Coalesce) Type() sql.Type {
if c.typ != nil {
return c.typ
}
retType := types.Null
for i, arg := range c.args {
if arg == nil {
Expand Down Expand Up @@ -120,6 +124,7 @@ func (c *Coalesce) Type() sql.Type {
}
}

c.typ = retType
return retType
}

Expand Down Expand Up @@ -201,6 +206,12 @@ func (c *Coalesce) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
continue
}

if !types.IsEnum(c.Type()) && !types.IsSet(c.Type()) {
val, _, err = c.Type().Convert(val)
if err != nil {
return nil, err
}
}
return val, nil
}

Expand Down
30 changes: 15 additions & 15 deletions sql/expression/function/coalesce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestCoalesce(t *testing.T) {
expression.NewLiteral(2, types.Int32),
expression.NewLiteral(3, types.Int32),
},
expected: 1,
expected: int32(1),
typ: types.Int32,
nullable: false,
},
Expand All @@ -56,7 +56,7 @@ func TestCoalesce(t *testing.T) {
nil,
expression.NewLiteral(3, types.Int32),
},
expected: 3,
expected: int32(3),
typ: types.Int32,
nullable: false,
},
Expand Down Expand Up @@ -100,7 +100,7 @@ func TestCoalesce(t *testing.T) {
expression.NewLiteral(decimal.NewFromFloat(2.0), types.MustCreateDecimalType(10, 0)),
expression.NewLiteral("3", types.LongText),
},
expected: 1,
expected: "1",
typ: types.LongText,
nullable: false,
},
Expand All @@ -110,7 +110,7 @@ func TestCoalesce(t *testing.T) {
expression.NewLiteral(1, types.Int32),
expression.NewLiteral(2, types.Uint32),
},
expected: 1,
expected: decimal.New(1, 0),
typ: types.MustCreateDecimalType(20, 0),
nullable: false,
},
Expand All @@ -120,7 +120,7 @@ func TestCoalesce(t *testing.T) {
expression.NewLiteral(1, types.Int32),
expression.NewLiteral(2, types.Uint32),
},
expected: 1,
expected: decimal.New(1, 0),
typ: types.MustCreateDecimalType(20, 0),
nullable: false,
},
Expand All @@ -130,7 +130,7 @@ func TestCoalesce(t *testing.T) {
expression.NewLiteral(1, types.MustCreateDecimalType(10, 0)),
expression.NewLiteral(2, types.Float64),
},
expected: 1,
expected: float64(1),
typ: types.Float64,
nullable: false,
},
Expand All @@ -139,7 +139,7 @@ func TestCoalesce(t *testing.T) {
input: []sql.Expression{
expression.NewLiteral(2, types.Float64),
},
expected: 2,
expected: float64(2),
typ: types.Float64,
nullable: false,
},
Expand All @@ -148,7 +148,7 @@ func TestCoalesce(t *testing.T) {
input: []sql.Expression{
expression.NewLiteral(1, types.Float64),
},
expected: 1,
expected: float64(1),
typ: types.Float64,
nullable: false,
},
Expand All @@ -158,7 +158,7 @@ func TestCoalesce(t *testing.T) {
expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)),
expression.NewLiteral(2, types.NewSystemIntType("int2", 0, 10, false)),
},
expected: 1,
expected: int64(1),
typ: types.Int64,
nullable: false,
},
Expand All @@ -168,7 +168,7 @@ func TestCoalesce(t *testing.T) {
expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)),
expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)),
},
expected: 1,
expected: decimal.New(1, 0),
typ: types.MustCreateDecimalType(20, 0),
nullable: false,
},
Expand All @@ -178,7 +178,7 @@ func TestCoalesce(t *testing.T) {
expression.NewLiteral(1, types.NewSystemUintType("int1", 0, 10)),
expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)),
},
expected: 1,
expected: uint64(1),
typ: types.Uint64,
nullable: false,
},
Expand All @@ -188,7 +188,7 @@ func TestCoalesce(t *testing.T) {
expression.NewLiteral(1.0, types.NewSystemDoubleType("dbl1", 0.0, 10.0)),
expression.NewLiteral(2.0, types.NewSystemDoubleType("dbl2", 0.0, 10.0)),
},
expected: 1.0,
expected: float64(1),
typ: types.Float64,
nullable: false,
},
Expand Down Expand Up @@ -249,19 +249,19 @@ func TestComposeCoalasce(t *testing.T) {
require.Equal(t, types.Int32, c2.Type())
v, err = c2.Eval(ctx, nil)
require.NoError(t, err)
require.Equal(t, 1, v)
require.Equal(t, int32(1), v)

c3, err := NewCoalesce(nil, c1, c2)
require.NoError(t, err)
require.Equal(t, types.Int32, c3.Type())
v, err = c3.Eval(ctx, nil)
require.NoError(t, err)
require.Equal(t, 1, v)
require.Equal(t, int32(1), v)

c4, err := NewCoalesce(expression.NewLiteral(nil, types.Null), c1, c2)
require.NoError(t, err)
require.Equal(t, types.Int32, c4.Type())
v, err = c4.Eval(ctx, nil)
require.NoError(t, err)
require.Equal(t, 1, v)
require.Equal(t, int32(1), v)
}

0 comments on commit 710ac92

Please sign in to comment.