diff --git a/pkg/col/coldataext/datum_vec.go b/pkg/col/coldataext/datum_vec.go index a46646a5e5a7..c6d5eac49265 100644 --- a/pkg/col/coldataext/datum_vec.go +++ b/pkg/col/coldataext/datum_vec.go @@ -70,6 +70,12 @@ func (d *Datum) CompareDatum(dVec, other interface{}) int { return d.Datum.Compare(dVec.(*datumVec).evalCtx, maybeUnwrapDatum(other)) } +// Cast returns the result of casting d to the type toType. dVec is the +// datumVec that stores d and is used to supply the eval context. +func (d *Datum) Cast(dVec interface{}, toType *types.T) (tree.Datum, error) { + return tree.PerformCast(dVec.(*datumVec).evalCtx, d.Datum, toType) +} + // Hash returns the hash of the datum as a byte slice. func (d *Datum) Hash(da *sqlbase.DatumAlloc) []byte { ed := sqlbase.EncDatum{Datum: maybeUnwrapDatum(d)} diff --git a/pkg/sql/colexec/cast_tmpl.go b/pkg/sql/colexec/cast_tmpl.go index ff08c5a8ae50..37f97533bf7c 100644 --- a/pkg/sql/colexec/cast_tmpl.go +++ b/pkg/sql/colexec/cast_tmpl.go @@ -56,7 +56,7 @@ const _RIGHT_CANONICAL_TYPE_FAMILY = types.UnknownFamily // _RIGHT_TYPE_WIDTH is the template variable. const _RIGHT_TYPE_WIDTH = 0 -func _CAST(to, from interface{}) { +func _CAST(to, from, fromCol interface{}) { colexecerror.InternalError("") } @@ -104,7 +104,7 @@ func cast(inputVec, outputVec coldata.Vec, n int, sel []int) { } else { v := _L_UNSAFEGET(inputCol, i) var r _R_GO_TYPE - _CAST(r, v) + _CAST(r, v, inputCol) _R_SET(outputCol, i, r) } } @@ -116,7 +116,7 @@ func cast(inputVec, outputVec coldata.Vec, n int, sel []int) { } else { v := _L_UNSAFEGET(inputCol, i) var r _R_GO_TYPE - _CAST(r, v) + _CAST(r, v, inputCol) _R_SET(outputCol, i, r) } } @@ -127,7 +127,7 @@ func cast(inputVec, outputVec coldata.Vec, n int, sel []int) { for _, i := range sel { v := _L_UNSAFEGET(inputCol, i) var r _R_GO_TYPE - _CAST(r, v) + _CAST(r, v, inputCol) _R_SET(outputCol, i, r) } } else { @@ -135,7 +135,7 @@ func cast(inputVec, outputVec coldata.Vec, n int, sel []int) { for execgen.RANGE(i, inputCol, 0, n) { v := _L_UNSAFEGET(inputCol, i) var r _R_GO_TYPE - _CAST(r, v) + _CAST(r, v, inputCol) _R_SET(outputCol, i, r) } } diff --git a/pkg/sql/colexec/execgen/cmd/execgen/cast_gen.go b/pkg/sql/colexec/execgen/cmd/execgen/cast_gen.go index 4dc444480ed2..fa6a80ca83cc 100644 --- a/pkg/sql/colexec/execgen/cmd/execgen/cast_gen.go +++ b/pkg/sql/colexec/execgen/cmd/execgen/cast_gen.go @@ -30,8 +30,8 @@ func genCastOperators(inputFileContents string, wr io.Writer) error { ) s := r.Replace(inputFileContents) - castRe := makeFunctionRegex("_CAST", 2) - s = castRe.ReplaceAllString(s, makeTemplateFunctionCall("Right.Cast", 2)) + castRe := makeFunctionRegex("_CAST", 3) + s = castRe.ReplaceAllString(s, makeTemplateFunctionCall("Right.Cast", 3)) s = strings.ReplaceAll(s, "_L_SLICE", "execgen.SLICE") s = strings.ReplaceAll(s, "_L_UNSAFEGET", "execgen.UNSAFEGET") diff --git a/pkg/sql/colexec/execgen/cmd/execgen/overloads_base.go b/pkg/sql/colexec/execgen/cmd/execgen/overloads_base.go index 5abbbaa0154c..f7d8cb4cfa4e 100644 --- a/pkg/sql/colexec/execgen/cmd/execgen/overloads_base.go +++ b/pkg/sql/colexec/execgen/cmd/execgen/overloads_base.go @@ -344,7 +344,7 @@ type twoArgsResolvedOverloadRightWidthInfo struct { type assignFunc func(op *lastArgWidthOverload, targetElem, leftElem, rightElem, targetCol, leftCol, rightCol string) string type compareFunc func(targetElem, leftElem, rightElem, leftCol, rightCol string) string -type castFunc func(to, from string) string +type castFunc func(to, from, fromCol string) string // Assign produces a Go source string that assigns the "targetElem" variable to // the result of applying the overload to the two inputs, "leftElem" and @@ -392,9 +392,9 @@ func (o *lastArgWidthOverload) Compare( leftElem, rightElem, targetElem, leftElem, rightElem, targetElem, targetElem) } -func (o *lastArgWidthOverload) Cast(to, from string) string { +func (o *lastArgWidthOverload) Cast(to, from, fromCol string) string { if o.CastFunc != nil { - if ret := o.CastFunc(to, from); ret != "" { + if ret := o.CastFunc(to, from, fromCol); ret != "" { return ret } } diff --git a/pkg/sql/colexec/execgen/cmd/execgen/overloads_cast.go b/pkg/sql/colexec/execgen/cmd/execgen/overloads_cast.go index f0cf1ce3b85e..6c720361d2e8 100644 --- a/pkg/sql/colexec/execgen/cmd/execgen/overloads_cast.go +++ b/pkg/sql/colexec/execgen/cmd/execgen/overloads_cast.go @@ -13,6 +13,7 @@ package main import ( "fmt" + "github.com/cockroachdb/cockroach/pkg/col/typeconv" "github.com/cockroachdb/cockroach/pkg/sql/colexecbase/colexecerror" "github.com/cockroachdb/cockroach/pkg/sql/types" ) @@ -56,15 +57,15 @@ func populateCastOverloads() { }, castTypeCustomizers) } -func intToDecimal(to, from string) string { +func intToDecimal(to, from, _ string) string { convStr := ` %[1]s = *apd.New(int64(%[2]s), 0) ` return fmt.Sprintf(convStr, to, from) } -func intToFloat() func(string, string) string { - return func(to, from string) string { +func intToFloat() func(string, string, string) string { + return func(to, from, _ string) string { convStr := ` %[1]s = float64(%[2]s) ` @@ -72,29 +73,29 @@ func intToFloat() func(string, string) string { } } -func intToInt16(to, from string) string { +func intToInt16(to, from, _ string) string { convStr := ` %[1]s = int16(%[2]s) ` return fmt.Sprintf(convStr, to, from) } -func intToInt32(to, from string) string { +func intToInt32(to, from, _ string) string { convStr := ` %[1]s = int32(%[2]s) ` return fmt.Sprintf(convStr, to, from) } -func intToInt64(to, from string) string { +func intToInt64(to, from, _ string) string { convStr := ` %[1]s = int64(%[2]s) ` return fmt.Sprintf(convStr, to, from) } -func floatToInt(intWidth, floatWidth int32) func(string, string) string { - return func(to, from string) string { +func floatToInt(intWidth, floatWidth int32) func(string, string, string) string { + return func(to, from, _ string) string { convStr := ` if math.IsNaN(float64(%[2]s)) || %[2]s <= float%[4]d(math.MinInt%[3]d) || %[2]s >= float%[4]d(math.MaxInt%[3]d) { colexecerror.ExpectedError(tree.ErrIntOutOfRange) @@ -108,14 +109,14 @@ func floatToInt(intWidth, floatWidth int32) func(string, string) string { } } -func numToBool(to, from string) string { +func numToBool(to, from, _ string) string { convStr := ` %[1]s = %[2]s != 0 ` return fmt.Sprintf(convStr, to, from) } -func floatToDecimal(to, from string) string { +func floatToDecimal(to, from, _ string) string { convStr := ` { var tmpDec apd.Decimal @@ -129,6 +130,19 @@ func floatToDecimal(to, from string) string { return fmt.Sprintf(convStr, to, from) } +func datumToBool(to, from, fromCol string) string { + convStr := ` + { + _castedDatum, err := %[2]s.(*coldataext.Datum).Cast(%[3]s, types.Bool) + if err != nil { + colexecerror.ExpectedError(err) + } + %[1]s = _castedDatum == tree.DBoolTrue + } + ` + return fmt.Sprintf(convStr, to, from, fromCol) +} + // castTypeCustomizer is a type customizer that changes how the templater // produces cast operator output for a particular type. type castTypeCustomizer interface { @@ -200,6 +214,9 @@ func registerCastTypeCustomizers() { registerCastTypeCustomizer(typePair{types.FloatFamily, anyWidth, toFamily, toWidth}, floatCastCustomizer{toFamily: toFamily, toWidth: toWidth}) } } + + // Casts from datum-backed types. + registerCastTypeCustomizer(typePair{typeconv.DatumVecCanonicalTypeFamily, anyWidth, types.BoolFamily, anyWidth}, datumCastCustomizer{toFamily: types.BoolFamily}) } // boolCastCustomizer specifies casts from booleans. @@ -220,8 +237,14 @@ type intCastCustomizer struct { toWidth int32 } +// datumCastCustomizer specifies casts from types that are backed by tree.Datum +// to other types. +type datumCastCustomizer struct { + toFamily types.Family +} + func (boolCastCustomizer) getCastFunc() castFunc { - return func(to, from string) string { + return func(to, from, _ string) string { convStr := ` %[1]s = 0 if %[2]s { @@ -233,7 +256,7 @@ func (boolCastCustomizer) getCastFunc() castFunc { } func (decimalCastCustomizer) getCastFunc() castFunc { - return func(to, from string) string { + return func(to, from, _ string) string { return fmt.Sprintf("%[1]s = %[2]s.Sign() != 0", to, from) } } @@ -274,3 +297,13 @@ func (c intCastCustomizer) getCastFunc() castFunc { // This code is unreachable, but the compiler cannot infer that. return nil } + +func (c datumCastCustomizer) getCastFunc() castFunc { + switch c.toFamily { + case types.BoolFamily: + return datumToBool + } + colexecerror.InternalError(fmt.Sprintf("unexpectedly didn't find a cast from datum-backed type to %s", c.toFamily)) + // This code is unreachable, but the compiler cannot infer that. + return nil +}