Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exec: overflow handling for vectorized arithmetic #38967

Merged
merged 2 commits into from
Jul 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,8 @@ EXECGEN_TARGETS = \
pkg/sql/exec/vec_comparators.eg.go \
pkg/sql/exec/vecbuiltins/rank.eg.go \
pkg/sql/exec/vecbuiltins/row_number.eg.go \
pkg/sql/exec/zerocolumns.eg.go
pkg/sql/exec/zerocolumns.eg.go \
pkg/sql/exec/overloads_test_utils.eg.go

execgen-exclusions = $(addprefix -not -path ,$(EXECGEN_TARGETS))

Expand Down
1 change: 1 addition & 0 deletions pkg/sql/exec/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ vec_comparators.eg.go
vecbuiltins/rank.eg.go
vecbuiltins/row_number.eg.go
zerocolumns.eg.go
overloads_test_utils.eg.go
118 changes: 117 additions & 1 deletion pkg/sql/exec/execgen/cmd/execgen/overloads.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"errors"
"fmt"
"regexp"
"strings"
"text/template"

"github.com/cockroachdb/cockroach/pkg/sql/exec/types"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
Expand Down Expand Up @@ -103,7 +105,7 @@ var hashOverloads []*overload
// Assign produces a Go source string that assigns the "target" variable to the
// result of applying the overload to the two inputs, l and r.
//
// For example, an overload that implemented the int64 plus operation, when fed
// For example, an overload that implemented the float64 plus operation, when fed
// the inputs "x", "a", "b", would produce the string "x = a + b".
func (o overload) Assign(target, l, r string) string {
if o.AssignFunc != nil {
Expand Down Expand Up @@ -349,6 +351,120 @@ func (c intCustomizer) getHashAssignFunc() assignFunc {
}
}

func (c intCustomizer) getBinOpAssignFunc() assignFunc {
return func(op overload, target, l, r string) string {
args := map[string]string{"Target": target, "Left": l, "Right": r}
buf := strings.Builder{}
var t *template.Template

switch op.BinOp {

case tree.Plus:
t = template.Must(template.New("").Parse(`
{
result := {{.Left}} + {{.Right}}
if (result < {{.Left}}) != ({{.Right}} < 0) {
panic(tree.ErrIntOutOfRange)
}
{{.Target}} = result
}
`))

case tree.Minus:
t = template.Must(template.New("").Parse(`
{
result := {{.Left}} - {{.Right}}
if (result < {{.Left}}) != ({{.Right}} > 0) {
panic(tree.ErrIntOutOfRange)
}
{{.Target}} = result
}
`))

case tree.Mult:
// If the inputs are small enough, then we don't have to do any further
// checks. For the sake of legibility, upperBound and lowerBound are both
// not set to their maximal/minimal values. An even more advanced check
// (for positive values) might involve adding together the highest bit
// positions of the inputs, and checking if the sum is less than the
// integer width.
var upperBound, lowerBound string
switch c.width {
case 8:
upperBound = "10"
lowerBound = "-10"
case 16:
upperBound = "math.MaxInt8"
lowerBound = "math.MinInt8"
case 32:
upperBound = "math.MaxInt16"
lowerBound = "math.MinInt16"
case 64:
upperBound = "math.MaxInt32"
lowerBound = "math.MinInt32"
default:
panic(fmt.Sprintf("unhandled integer width %d", c.width))
}

args["UpperBound"] = upperBound
args["LowerBound"] = lowerBound
t = template.Must(template.New("").Parse(`
{
result := {{.Left}} * {{.Right}}
if {{.Left}} > {{.UpperBound}} || {{.Left}} < {{.LowerBound}} || {{.Right}} > {{.UpperBound}} || {{.Right}} < {{.LowerBound}} {
if {{.Left}} != 0 && {{.Right}} != 0 {
sameSign := ({{.Left}} < 0) == ({{.Right}} < 0)
if (result < 0) == sameSign {
panic(tree.ErrIntOutOfRange)
} else if result/{{.Right}} != {{.Left}} {
panic(tree.ErrIntOutOfRange)
}
}
}
{{.Target}} = result
}
`))

case tree.Div:
var minInt string
switch c.width {
case 8:
minInt = "math.MinInt8"
case 16:
minInt = "math.MinInt16"
case 32:
minInt = "math.MinInt32"
case 64:
minInt = "math.MinInt64"
default:
panic(fmt.Sprintf("unhandled integer width %d", c.width))
}

args["MinInt"] = minInt
t = template.Must(template.New("").Parse(`
{
if {{.Right}} == 0 {
panic(tree.ErrDivByZero)
}
result := {{.Left}} / {{.Right}}
if {{.Left}} == {{.MinInt}} && {{.Right}} == -1 {
panic(tree.ErrIntOutOfRange)
}
{{.Target}} = result
}
`))

default:
panic(fmt.Sprintf("unhandled binary operator %s", op.BinOp.String()))
}

if err := t.Execute(&buf, args); err != nil {
panic(err)
}
return buf.String()
}
}

func registerTypeCustomizers() {
typeCustomizers = make(map[types.T]typeCustomizer)
registerTypeCustomizer(types.Bool, boolCustomizer{})
Expand Down
65 changes: 65 additions & 0 deletions pkg/sql/exec/execgen/cmd/execgen/overloads_test_utils_gen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2019 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package main

import (
"io"
"text/template"

"github.com/cockroachdb/cockroach/pkg/sql/exec/types"
)

const overloadsTestUtilsTemplate = `
package exec

import (
"math"

"github.com/cockroachdb/apd"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
)

{{define "opName"}}perform{{.Name}}{{.LTyp}}{{end}}

{{/* The outer range is a types.T, and the inner is the overloads associated
with that type. */}}
{{range .}}
{{range .}}

func {{template "opName" .}}(a, b {{.LTyp.GoTypeName}}) {{.RetTyp.GoTypeName}} {
{{(.Assign "a" "a" "b")}}
return a
}

{{end}}
{{end}}
`

// genOverloadsTestUtils creates a file that has a function for each overload
// defined in overloads.go. This is so that we can more easily test each
// overload.
func genOverloadsTestUtils(wr io.Writer) error {
tmpl, err := template.New("overloads_test_utils").Parse(overloadsTestUtilsTemplate)
if err != nil {
return err
}

typToOverloads := make(map[types.T][]*overload)
for _, overload := range binaryOpOverloads {
typ := overload.LTyp
typToOverloads[typ] = append(typToOverloads[typ], overload)
}
return tmpl.Execute(wr, typToOverloads)
}

func init() {
registerGenerator(genOverloadsTestUtils, "overloads_test_utils.eg.go")
}
1 change: 1 addition & 0 deletions pkg/sql/exec/execgen/cmd/execgen/projection_ops_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package exec
import (
"bytes"
"context"
"math"

"github.com/cockroachdb/apd"
"github.com/cockroachdb/cockroach/pkg/sql/exec/coldata"
Expand Down
121 changes: 121 additions & 0 deletions pkg/sql/exec/overloads_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright 2019 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package exec

import (
"math"
"testing"

"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/stretchr/testify/assert"
)

func TestIntegerAddition(t *testing.T) {
// The addition overload is the same for all integer widths, so we only test
// one of them.
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performPlusInt16(1, math.MaxInt16) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performPlusInt16(-1, math.MinInt16) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performPlusInt16(math.MaxInt16, 1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performPlusInt16(math.MinInt16, -1) })

assert.Equal(t, int16(math.MaxInt16), performPlusInt16(1, math.MaxInt16-1))
assert.Equal(t, int16(math.MinInt16), performPlusInt16(-1, math.MinInt16+1))
assert.Equal(t, int16(math.MaxInt16-1), performPlusInt16(-1, math.MaxInt16))
assert.Equal(t, int16(math.MinInt16+1), performPlusInt16(1, math.MinInt16))

assert.Equal(t, int16(22), performPlusInt16(10, 12))
assert.Equal(t, int16(-22), performPlusInt16(-10, -12))
assert.Equal(t, int16(2), performPlusInt16(-10, 12))
assert.Equal(t, int16(-2), performPlusInt16(10, -12))
}

func TestIntegerSubtraction(t *testing.T) {
// The subtraction overload is the same for all integer widths, so we only
// test one of them.
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMinusInt16(1, -math.MaxInt16) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMinusInt16(-2, math.MaxInt16) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMinusInt16(math.MaxInt16, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMinusInt16(math.MinInt16, 1) })

assert.Equal(t, int16(math.MaxInt16), performMinusInt16(1, -math.MaxInt16+1))
assert.Equal(t, int16(math.MinInt16), performMinusInt16(-1, math.MaxInt16))
assert.Equal(t, int16(math.MaxInt16-1), performMinusInt16(-1, -math.MaxInt16))
assert.Equal(t, int16(math.MinInt16+1), performMinusInt16(0, math.MaxInt16))

assert.Equal(t, int16(-2), performMinusInt16(10, 12))
assert.Equal(t, int16(2), performMinusInt16(-10, -12))
assert.Equal(t, int16(-22), performMinusInt16(-10, 12))
assert.Equal(t, int16(22), performMinusInt16(10, -12))
}

func TestIntegerDivision(t *testing.T) {
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performDivInt8(math.MinInt8, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performDivInt16(math.MinInt16, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performDivInt32(math.MinInt32, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performDivInt64(math.MinInt64, -1) })

assert.PanicsWithValue(t, tree.ErrDivByZero, func() { performDivInt8(10, 0) })
assert.PanicsWithValue(t, tree.ErrDivByZero, func() { performDivInt16(10, 0) })
assert.PanicsWithValue(t, tree.ErrDivByZero, func() { performDivInt32(10, 0) })
assert.PanicsWithValue(t, tree.ErrDivByZero, func() { performDivInt64(10, 0) })

assert.Equal(t, int8(-math.MaxInt8), performDivInt8(math.MaxInt8, -1))
assert.Equal(t, int16(-math.MaxInt16), performDivInt16(math.MaxInt16, -1))
assert.Equal(t, int32(-math.MaxInt32), performDivInt32(math.MaxInt32, -1))
assert.Equal(t, int64(-math.MaxInt64), performDivInt64(math.MaxInt64, -1))

assert.Equal(t, int16(0), performDivInt16(10, 12))
assert.Equal(t, int16(0), performDivInt16(-10, -12))
assert.Equal(t, int16(-1), performDivInt16(-12, 10))
assert.Equal(t, int16(-1), performDivInt16(12, -10))
}

func TestIntegerMultiplication(t *testing.T) {
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt8(math.MaxInt8-1, 100) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt8(math.MaxInt8-1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt8(math.MinInt8+1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt8(math.MinInt8+1, 100) })

assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt16(math.MaxInt16-1, 100) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt16(math.MaxInt16-1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt16(math.MinInt16+1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt16(math.MinInt16+1, 100) })

assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt32(math.MaxInt32-1, 100) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt32(math.MaxInt32-1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt32(math.MinInt32+1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt32(math.MinInt32+1, 100) })

assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt64(math.MaxInt64-1, 100) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt64(math.MaxInt64-1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt64(math.MinInt64+1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt64(math.MinInt64+1, 100) })

assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt8(math.MinInt8, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt16(math.MinInt16, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt32(math.MinInt32, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt64(math.MinInt64, -1) })

assert.Equal(t, int8(-math.MaxInt8), performMultInt8(math.MaxInt8, -1))
assert.Equal(t, int16(-math.MaxInt16), performMultInt16(math.MaxInt16, -1))
assert.Equal(t, int32(-math.MaxInt32), performMultInt32(math.MaxInt32, -1))
assert.Equal(t, int64(-math.MaxInt64), performMultInt64(math.MaxInt64, -1))

assert.Equal(t, int8(0), performMultInt8(math.MinInt8, 0))
assert.Equal(t, int16(0), performMultInt16(math.MinInt16, 0))
assert.Equal(t, int32(0), performMultInt32(math.MinInt32, 0))
assert.Equal(t, int64(0), performMultInt64(math.MinInt64, 0))

assert.Equal(t, int8(120), performMultInt8(10, 12))
assert.Equal(t, int16(120), performMultInt16(-10, -12))
assert.Equal(t, int32(-120), performMultInt32(-12, 10))
assert.Equal(t, int64(-120), performMultInt64(12, -10))
}
Loading