Skip to content

Commit

Permalink
colexec: add casts from datum-backed types to bools
Browse files Browse the repository at this point in the history
While investigating unrelated test failures, I added this cast, so we
might as well merge it.

Release note: None
  • Loading branch information
yuzefovich committed Jun 10, 2020
1 parent 3b116e9 commit d67888b
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 22 deletions.
6 changes: 6 additions & 0 deletions pkg/col/coldataext/datum_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ func (d *Datum) CompareDatum(dVec, other interface{}) int {
return d.Datum.Compare(dVec.(*datumVec).evalCtx, maybeUnwrapDatum(other))
}

// Cast returns the result of casting d to the type toType. dVec is the
// datumVec that stores d and is used to supply the eval context.
func (d *Datum) Cast(dVec interface{}, toType *types.T) (tree.Datum, error) {
return tree.PerformCast(dVec.(*datumVec).evalCtx, d.Datum, toType)
}

// Hash returns the hash of the datum as a byte slice.
func (d *Datum) Hash(da *sqlbase.DatumAlloc) []byte {
ed := sqlbase.EncDatum{Datum: maybeUnwrapDatum(d)}
Expand Down
10 changes: 5 additions & 5 deletions pkg/sql/colexec/cast_tmpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ const _RIGHT_CANONICAL_TYPE_FAMILY = types.UnknownFamily
// _RIGHT_TYPE_WIDTH is the template variable.
const _RIGHT_TYPE_WIDTH = 0

func _CAST(to, from interface{}) {
func _CAST(to, from, fromCol interface{}) {
colexecerror.InternalError("")
}

Expand Down Expand Up @@ -104,7 +104,7 @@ func cast(inputVec, outputVec coldata.Vec, n int, sel []int) {
} else {
v := _L_UNSAFEGET(inputCol, i)
var r _R_GO_TYPE
_CAST(r, v)
_CAST(r, v, inputCol)
_R_SET(outputCol, i, r)
}
}
Expand All @@ -116,7 +116,7 @@ func cast(inputVec, outputVec coldata.Vec, n int, sel []int) {
} else {
v := _L_UNSAFEGET(inputCol, i)
var r _R_GO_TYPE
_CAST(r, v)
_CAST(r, v, inputCol)
_R_SET(outputCol, i, r)
}
}
Expand All @@ -127,15 +127,15 @@ func cast(inputVec, outputVec coldata.Vec, n int, sel []int) {
for _, i := range sel {
v := _L_UNSAFEGET(inputCol, i)
var r _R_GO_TYPE
_CAST(r, v)
_CAST(r, v, inputCol)
_R_SET(outputCol, i, r)
}
} else {
inputCol = _L_SLICE(inputCol, 0, n)
for execgen.RANGE(i, inputCol, 0, n) {
v := _L_UNSAFEGET(inputCol, i)
var r _R_GO_TYPE
_CAST(r, v)
_CAST(r, v, inputCol)
_R_SET(outputCol, i, r)
}
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/colexec/execgen/cmd/execgen/cast_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ func genCastOperators(inputFileContents string, wr io.Writer) error {
)
s := r.Replace(inputFileContents)

castRe := makeFunctionRegex("_CAST", 2)
s = castRe.ReplaceAllString(s, makeTemplateFunctionCall("Right.Cast", 2))
castRe := makeFunctionRegex("_CAST", 3)
s = castRe.ReplaceAllString(s, makeTemplateFunctionCall("Right.Cast", 3))

s = strings.ReplaceAll(s, "_L_SLICE", "execgen.SLICE")
s = strings.ReplaceAll(s, "_L_UNSAFEGET", "execgen.UNSAFEGET")
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/colexec/execgen/cmd/execgen/overloads_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ type twoArgsResolvedOverloadRightWidthInfo struct {

type assignFunc func(op *lastArgWidthOverload, targetElem, leftElem, rightElem, targetCol, leftCol, rightCol string) string
type compareFunc func(targetElem, leftElem, rightElem, leftCol, rightCol string) string
type castFunc func(to, from string) string
type castFunc func(to, from, fromCol string) string

// Assign produces a Go source string that assigns the "targetElem" variable to
// the result of applying the overload to the two inputs, "leftElem" and
Expand Down Expand Up @@ -392,9 +392,9 @@ func (o *lastArgWidthOverload) Compare(
leftElem, rightElem, targetElem, leftElem, rightElem, targetElem, targetElem)
}

func (o *lastArgWidthOverload) Cast(to, from string) string {
func (o *lastArgWidthOverload) Cast(to, from, fromCol string) string {
if o.CastFunc != nil {
if ret := o.CastFunc(to, from); ret != "" {
if ret := o.CastFunc(to, from, fromCol); ret != "" {
return ret
}
}
Expand Down
57 changes: 45 additions & 12 deletions pkg/sql/colexec/execgen/cmd/execgen/overloads_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package main
import (
"fmt"

"github.com/cockroachdb/cockroach/pkg/col/typeconv"
"github.com/cockroachdb/cockroach/pkg/sql/colexecbase/colexecerror"
"github.com/cockroachdb/cockroach/pkg/sql/types"
)
Expand Down Expand Up @@ -56,45 +57,45 @@ func populateCastOverloads() {
}, castTypeCustomizers)
}

func intToDecimal(to, from string) string {
func intToDecimal(to, from, _ string) string {
convStr := `
%[1]s = *apd.New(int64(%[2]s), 0)
`
return fmt.Sprintf(convStr, to, from)
}

func intToFloat() func(string, string) string {
return func(to, from string) string {
func intToFloat() func(string, string, string) string {
return func(to, from, _ string) string {
convStr := `
%[1]s = float64(%[2]s)
`
return fmt.Sprintf(convStr, to, from)
}
}

func intToInt16(to, from string) string {
func intToInt16(to, from, _ string) string {
convStr := `
%[1]s = int16(%[2]s)
`
return fmt.Sprintf(convStr, to, from)
}

func intToInt32(to, from string) string {
func intToInt32(to, from, _ string) string {
convStr := `
%[1]s = int32(%[2]s)
`
return fmt.Sprintf(convStr, to, from)
}

func intToInt64(to, from string) string {
func intToInt64(to, from, _ string) string {
convStr := `
%[1]s = int64(%[2]s)
`
return fmt.Sprintf(convStr, to, from)
}

func floatToInt(intWidth, floatWidth int32) func(string, string) string {
return func(to, from string) string {
func floatToInt(intWidth, floatWidth int32) func(string, string, string) string {
return func(to, from, _ string) string {
convStr := `
if math.IsNaN(float64(%[2]s)) || %[2]s <= float%[4]d(math.MinInt%[3]d) || %[2]s >= float%[4]d(math.MaxInt%[3]d) {
colexecerror.ExpectedError(tree.ErrIntOutOfRange)
Expand All @@ -108,14 +109,14 @@ func floatToInt(intWidth, floatWidth int32) func(string, string) string {
}
}

func numToBool(to, from string) string {
func numToBool(to, from, _ string) string {
convStr := `
%[1]s = %[2]s != 0
`
return fmt.Sprintf(convStr, to, from)
}

func floatToDecimal(to, from string) string {
func floatToDecimal(to, from, _ string) string {
convStr := `
{
var tmpDec apd.Decimal
Expand All @@ -129,6 +130,19 @@ func floatToDecimal(to, from string) string {
return fmt.Sprintf(convStr, to, from)
}

func datumToBool(to, from, fromCol string) string {
convStr := `
{
_castedDatum, err := %[2]s.(*coldataext.Datum).Cast(%[3]s, types.Bool)
if err != nil {
colexecerror.ExpectedError(err)
}
%[1]s = _castedDatum == tree.DBoolTrue
}
`
return fmt.Sprintf(convStr, to, from, fromCol)
}

// castTypeCustomizer is a type customizer that changes how the templater
// produces cast operator output for a particular type.
type castTypeCustomizer interface {
Expand Down Expand Up @@ -200,6 +214,9 @@ func registerCastTypeCustomizers() {
registerCastTypeCustomizer(typePair{types.FloatFamily, anyWidth, toFamily, toWidth}, floatCastCustomizer{toFamily: toFamily, toWidth: toWidth})
}
}

// Casts from datum-backed types.
registerCastTypeCustomizer(typePair{typeconv.DatumVecCanonicalTypeFamily, anyWidth, types.BoolFamily, anyWidth}, datumCastCustomizer{toFamily: types.BoolFamily})
}

// boolCastCustomizer specifies casts from booleans.
Expand All @@ -220,8 +237,14 @@ type intCastCustomizer struct {
toWidth int32
}

// datumCastCustomizer specifies casts from types that are backed by tree.Datum
// to other types.
type datumCastCustomizer struct {
toFamily types.Family
}

func (boolCastCustomizer) getCastFunc() castFunc {
return func(to, from string) string {
return func(to, from, _ string) string {
convStr := `
%[1]s = 0
if %[2]s {
Expand All @@ -233,7 +256,7 @@ func (boolCastCustomizer) getCastFunc() castFunc {
}

func (decimalCastCustomizer) getCastFunc() castFunc {
return func(to, from string) string {
return func(to, from, _ string) string {
return fmt.Sprintf("%[1]s = %[2]s.Sign() != 0", to, from)
}
}
Expand Down Expand Up @@ -274,3 +297,13 @@ func (c intCastCustomizer) getCastFunc() castFunc {
// This code is unreachable, but the compiler cannot infer that.
return nil
}

func (c datumCastCustomizer) getCastFunc() castFunc {
switch c.toFamily {
case types.BoolFamily:
return datumToBool
}
colexecerror.InternalError(fmt.Sprintf("unexpectedly didn't find a cast from datum-backed type to %s", c.toFamily))
// This code is unreachable, but the compiler cannot infer that.
return nil
}

0 comments on commit d67888b

Please sign in to comment.