diff --git a/docs/generated/sql/operators.md b/docs/generated/sql/operators.md index da1b097298d0..370135313419 100644 --- a/docs/generated/sql/operators.md +++ b/docs/generated/sql/operators.md @@ -414,11 +414,13 @@ + + @@ -441,6 +443,8 @@ + + diff --git a/pkg/sql/sem/builtins/pg_builtins.go b/pkg/sql/sem/builtins/pg_builtins.go index c6a9f18ef3d4..44747b1d5814 100644 --- a/pkg/sql/sem/builtins/pg_builtins.go +++ b/pkg/sql/sem/builtins/pg_builtins.go @@ -59,6 +59,7 @@ func makeNotUsableFalseBuiltin() builtinDefinition { // the existence of this map. var typeBuiltinsHaveUnderscore = map[oid.Oid]struct{}{ types.Any.Oid(): {}, + types.AnyNonArray.Oid(): {}, types.AnyArray.Oid(): {}, types.Date.Oid(): {}, types.Time.Oid(): {}, diff --git a/pkg/sql/sem/tree/eval.go b/pkg/sql/sem/tree/eval.go index 6c1ab4bc622e..3945b6027978 100644 --- a/pkg/sql/sem/tree/eval.go +++ b/pkg/sql/sem/tree/eval.go @@ -202,7 +202,7 @@ func (op *BinOp) returnType() ReturnTyper { return op.retType } -func (*BinOp) preferred() bool { +func (op *BinOp) preferred() bool { return false } @@ -1411,6 +1411,22 @@ var BinOps = map[BinaryOperator]binOpOverload{ return NewDString(string(MustBeDString(left) + MustBeDString(right))), nil }, }, + &BinOp{ + LeftType: types.String, + RightType: types.AnyNonArray, + ReturnType: types.String, + Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { + return NewDString(string(MustBeDString(left)) + right.String()), nil + }, + }, + &BinOp{ + LeftType: types.AnyNonArray, + RightType: types.String, + ReturnType: types.String, + Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { + return NewDString(left.String() + string(MustBeDString(right))), nil + }, + }, &BinOp{ LeftType: types.Bytes, RightType: types.Bytes, diff --git a/pkg/sql/sem/tree/overload.go b/pkg/sql/sem/tree/overload.go index 68af575e73ac..136d96bfa608 100644 --- a/pkg/sql/sem/tree/overload.go +++ b/pkg/sql/sem/tree/overload.go @@ -415,7 +415,6 @@ func typeCheckOverloadedExprs( if len(overloads) > math.MaxUint8 { return nil, nil, errors.AssertionFailedf("too many overloads (%d > 255)", len(overloads)) } - var s typeCheckOverloadState s.exprs = exprs s.overloads = overloads @@ -649,6 +648,21 @@ func typeCheckOverloadedExprs( } } + // We apply another heuristic to try remove any attempts where the overload has Any. + // We prefer more strictly typed options if available. + if ok, typedExprs, fns, err := filterAttempt(ctx, &s, func() { + s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, func(o overloadImpl) bool { + for _, typ := range o.params().Types() { + if typ.Family() == types.AnyFamily { + return false + } + } + return true + }) + }); ok { + return typedExprs, fns, err + } + // In a binary expression, in the case of one of the arguments being untyped NULL, // we prefer overloads where we infer the type of the NULL to be the same as the // other argument. This is used to differentiate the behavior of @@ -684,6 +698,13 @@ func typeCheckOverloadedExprs( } s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, func(o overloadImpl) bool { + // If any side is Any, this filtering will not be useful + // for whittling down types. Ignore them. + // This is equivalent to adding the AnyFamily check above. + if o.params().GetAt(0).Family() == types.AnyFamily || + o.params().GetAt(1).Family() == types.AnyFamily { + return false + } return o.params().GetAt(0).Equivalent(leftType) && o.params().GetAt(1).Equivalent(rightType) }) diff --git a/pkg/sql/sem/tree/testdata/eval/concat b/pkg/sql/sem/tree/testdata/eval/concat index 78f1693faebd..051621be3db4 100644 --- a/pkg/sql/sem/tree/testdata/eval/concat +++ b/pkg/sql/sem/tree/testdata/eval/concat @@ -10,6 +10,11 @@ eval ---- 'a3' +eval +b'hello' || b'world' +---- +'\x68656c6c6f776f726c64' + eval b'hello' || 'world' ---- @@ -19,3 +24,25 @@ eval array['foo'] || '{a,b}' ---- ARRAY['foo','a','b'] + +# String || any; any || String + +eval +'b' || (5)::char || (8)::char || 'c' +---- +'b58c' + +eval +3 || 'a' || 3 +---- +3a3 + +eval +3::oid || 'a' || 3::oid +---- +3a3 + +eval +3.33 || 'a' || 3.33 +---- +3.33a3.33 diff --git a/pkg/sql/sqlbase/testutils.go b/pkg/sql/sqlbase/testutils.go index 4ed6dbaed567..c3e6c5d5b625 100644 --- a/pkg/sql/sqlbase/testutils.go +++ b/pkg/sql/sqlbase/testutils.go @@ -480,7 +480,7 @@ var ( func init() { for _, typ := range types.OidToType { switch typ.Oid() { - case oid.T_unknown, oid.T_anyelement: + case oid.T_unknown, oid.T_anyelement, oid.T_anynonarray: // Don't include these. case oid.T_anyarray, oid.T_oidvector, oid.T_int2vector: // Include these. diff --git a/pkg/sql/types/types.go b/pkg/sql/types/types.go index 10acaf7d9fed..b1086c704dd8 100644 --- a/pkg/sql/types/types.go +++ b/pkg/sql/types/types.go @@ -352,6 +352,12 @@ var ( Any = &T{InternalType: InternalType{ Family: AnyFamily, Oid: oid.T_anyelement, Locale: &emptyLocale}} + // AnyNonArray is a special type used only during static analysis as a wildcard + // type that matches non-array types. + // Execution-time values should never have this type. + AnyNonArray = &T{InternalType: InternalType{ + Family: AnyFamily, Oid: oid.T_anynonarray, Locale: &emptyLocale}} + // AnyArray is a special type used only during static analysis as a wildcard // type that matches an array having elements of any (uniform) type (including // nested array types). Execution-time values should never have this type. @@ -884,6 +890,10 @@ func (t *T) TupleLabels() []string { func (t *T) Name() string { switch t.Family() { case AnyFamily: + switch t.Oid() { + case oid.T_anynonarray: + return "anynonarray" + } return "anyelement" case ArrayFamily: switch t.Oid() { @@ -1021,6 +1031,10 @@ func (t *T) SQLStandardNameWithTypmod(haveTypmod bool, typmod int) string { var buf strings.Builder switch t.Family() { case AnyFamily: + switch t.Oid() { + case oid.T_anynonarray: + return "anynonarray" + } return "anyelement" case ArrayFamily: switch t.Oid() { @@ -1273,6 +1287,13 @@ func (t *T) SQLString() string { // types. And a wildcard collation (empty string) matches any other collation. func (t *T) Equivalent(other *T) bool { if t.Family() == AnyFamily || other.Family() == AnyFamily { + // Non array families must not have an array family on the other side. + if t == AnyNonArray && other.Family() == ArrayFamily { + return false + } + if other == AnyNonArray && t.Family() == ArrayFamily { + return false + } return true } if t.Family() != other.Family() { diff --git a/pkg/sql/types/types_test.go b/pkg/sql/types/types_test.go index 8a7e27af4247..3e93c5f51750 100644 --- a/pkg/sql/types/types_test.go +++ b/pkg/sql/types/types_test.go @@ -383,6 +383,7 @@ func TestEquivalent(t *testing.T) { }{ // ARRAY {Int2Vector, IntArray, true}, + {Int2Vector, AnyNonArray, false}, {OidVector, MakeArray(Oid), true}, {MakeArray(Int), MakeArray(Int4), true}, {MakeArray(String), MakeArray(MakeChar(10)), true}, @@ -394,6 +395,7 @@ func TestEquivalent(t *testing.T) { {MakeBit(1), MakeBit(2), true}, {MakeBit(1), MakeVarBit(2), true}, {MakeVarBit(10), Any, true}, + {MakeVarBit(10), AnyNonArray, true}, {VarBit, Bytes, false}, // COLLATEDSTRING @@ -413,6 +415,7 @@ func TestEquivalent(t *testing.T) { {Int2, Int4, true}, {Int4, Int, true}, {Int, Any, true}, + {Int, AnyNonArray, true}, {Int, IntArray, false}, // TUPLE @@ -420,6 +423,7 @@ func TestEquivalent(t *testing.T) { {MakeTuple([]T{*Int, *String}), MakeTuple([]T{*Int4, *VarChar}), true}, {MakeTuple([]T{*Int, *String}), AnyTuple, true}, {AnyTuple, MakeTuple([]T{*Int, *String}), true}, + {AnyNonArray, MakeTuple([]T{*Int, *String}), true}, {MakeTuple([]T{*Int, *String}), MakeLabeledTuple([]T{*Int4, *VarChar}, []string{"label2", "label1"}), true}, {MakeLabeledTuple([]T{*Int, *String}, []string{"label1", "label2"}), @@ -430,6 +434,7 @@ func TestEquivalent(t *testing.T) { {Unknown, &T{InternalType: InternalType{ Family: UnknownFamily, Oid: oid.T_unknown, Locale: &emptyLocale}}, true}, {Any, Unknown, true}, + {AnyNonArray, Unknown, true}, {Unknown, Int, false}, }
||Return
anynonarray || stringstring
bool || bool[]bool[]
bool[] || boolbool[]
bool[] || bool[]bool[]
bytes || bytesbytes
bytes || bytes[]bytes[]
bytes || stringbytes
bytes[] || bytesbytes[]
bytes[] || bytes[]bytes[]
date || date[]date[]
interval[] || interval[]interval[]
jsonb || jsonbjsonb
oid || oidoid
string || anynonarraystring
string || bytesbytes
string || stringstring
string || string[]string[]
string[] || stringstring[]