Skip to content

Commit

Permalink
Merge #38967
Browse files Browse the repository at this point in the history
38967: exec: overflow handling for vectorized arithmetic r=rafiss a=rafiss

The overflow checks are done as part of the code generation in
overloads.go. The checks are done inline, rather than calling the
functions in the arith package for performance reasons.

The checks are only done for integer math. float math is already
well-defined since overflow will result in +Inf and -Inf as necessary.

The operations that these checks are relevant for are the SUM_INT
aggregator and projection. In the future, AVG will also benefit from
these overflow checks.

This changes the error message produced by overflows in the
non-vectorized SUM_INT aggregator so that the messages are consistent.
This should be fine in terms of postgres-compatibility since SUM_INT is
unique to CRDB and eventually we will get rid of it anyway.

resolves #38775

Release note: None

Co-authored-by: Rafi Shamim <[email protected]>
  • Loading branch information
craig[bot] and rafiss committed Jul 23, 2019
2 parents ec200a6 + 1dc97ee commit 4233a87
Show file tree
Hide file tree
Showing 11 changed files with 395 additions and 33 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,8 @@ EXECGEN_TARGETS = \
pkg/sql/exec/vec_comparators.eg.go \
pkg/sql/exec/vecbuiltins/rank.eg.go \
pkg/sql/exec/vecbuiltins/row_number.eg.go \
pkg/sql/exec/zerocolumns.eg.go
pkg/sql/exec/zerocolumns.eg.go \
pkg/sql/exec/overloads_test_utils.eg.go

execgen-exclusions = $(addprefix -not -path ,$(EXECGEN_TARGETS))

Expand Down
1 change: 1 addition & 0 deletions pkg/sql/exec/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ vec_comparators.eg.go
vecbuiltins/rank.eg.go
vecbuiltins/row_number.eg.go
zerocolumns.eg.go
overloads_test_utils.eg.go
118 changes: 117 additions & 1 deletion pkg/sql/exec/execgen/cmd/execgen/overloads.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"errors"
"fmt"
"regexp"
"strings"
"text/template"

"github.com/cockroachdb/cockroach/pkg/sql/exec/types"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
Expand Down Expand Up @@ -103,7 +105,7 @@ var hashOverloads []*overload
// Assign produces a Go source string that assigns the "target" variable to the
// result of applying the overload to the two inputs, l and r.
//
// For example, an overload that implemented the int64 plus operation, when fed
// For example, an overload that implemented the float64 plus operation, when fed
// the inputs "x", "a", "b", would produce the string "x = a + b".
func (o overload) Assign(target, l, r string) string {
if o.AssignFunc != nil {
Expand Down Expand Up @@ -349,6 +351,120 @@ func (c intCustomizer) getHashAssignFunc() assignFunc {
}
}

func (c intCustomizer) getBinOpAssignFunc() assignFunc {
return func(op overload, target, l, r string) string {
args := map[string]string{"Target": target, "Left": l, "Right": r}
buf := strings.Builder{}
var t *template.Template

switch op.BinOp {

case tree.Plus:
t = template.Must(template.New("").Parse(`
{
result := {{.Left}} + {{.Right}}
if (result < {{.Left}}) != ({{.Right}} < 0) {
panic(tree.ErrIntOutOfRange)
}
{{.Target}} = result
}
`))

case tree.Minus:
t = template.Must(template.New("").Parse(`
{
result := {{.Left}} - {{.Right}}
if (result < {{.Left}}) != ({{.Right}} > 0) {
panic(tree.ErrIntOutOfRange)
}
{{.Target}} = result
}
`))

case tree.Mult:
// If the inputs are small enough, then we don't have to do any further
// checks. For the sake of legibility, upperBound and lowerBound are both
// not set to their maximal/minimal values. An even more advanced check
// (for positive values) might involve adding together the highest bit
// positions of the inputs, and checking if the sum is less than the
// integer width.
var upperBound, lowerBound string
switch c.width {
case 8:
upperBound = "10"
lowerBound = "-10"
case 16:
upperBound = "math.MaxInt8"
lowerBound = "math.MinInt8"
case 32:
upperBound = "math.MaxInt16"
lowerBound = "math.MinInt16"
case 64:
upperBound = "math.MaxInt32"
lowerBound = "math.MinInt32"
default:
panic(fmt.Sprintf("unhandled integer width %d", c.width))
}

args["UpperBound"] = upperBound
args["LowerBound"] = lowerBound
t = template.Must(template.New("").Parse(`
{
result := {{.Left}} * {{.Right}}
if {{.Left}} > {{.UpperBound}} || {{.Left}} < {{.LowerBound}} || {{.Right}} > {{.UpperBound}} || {{.Right}} < {{.LowerBound}} {
if {{.Left}} != 0 && {{.Right}} != 0 {
sameSign := ({{.Left}} < 0) == ({{.Right}} < 0)
if (result < 0) == sameSign {
panic(tree.ErrIntOutOfRange)
} else if result/{{.Right}} != {{.Left}} {
panic(tree.ErrIntOutOfRange)
}
}
}
{{.Target}} = result
}
`))

case tree.Div:
var minInt string
switch c.width {
case 8:
minInt = "math.MinInt8"
case 16:
minInt = "math.MinInt16"
case 32:
minInt = "math.MinInt32"
case 64:
minInt = "math.MinInt64"
default:
panic(fmt.Sprintf("unhandled integer width %d", c.width))
}

args["MinInt"] = minInt
t = template.Must(template.New("").Parse(`
{
if {{.Right}} == 0 {
panic(tree.ErrDivByZero)
}
result := {{.Left}} / {{.Right}}
if {{.Left}} == {{.MinInt}} && {{.Right}} == -1 {
panic(tree.ErrIntOutOfRange)
}
{{.Target}} = result
}
`))

default:
panic(fmt.Sprintf("unhandled binary operator %s", op.BinOp.String()))
}

if err := t.Execute(&buf, args); err != nil {
panic(err)
}
return buf.String()
}
}

func registerTypeCustomizers() {
typeCustomizers = make(map[types.T]typeCustomizer)
registerTypeCustomizer(types.Bool, boolCustomizer{})
Expand Down
65 changes: 65 additions & 0 deletions pkg/sql/exec/execgen/cmd/execgen/overloads_test_utils_gen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2019 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package main

import (
"io"
"text/template"

"github.com/cockroachdb/cockroach/pkg/sql/exec/types"
)

const overloadsTestUtilsTemplate = `
package exec
import (
"math"
"github.com/cockroachdb/apd"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
)
{{define "opName"}}perform{{.Name}}{{.LTyp}}{{end}}
{{/* The outer range is a types.T, and the inner is the overloads associated
with that type. */}}
{{range .}}
{{range .}}
func {{template "opName" .}}(a, b {{.LTyp.GoTypeName}}) {{.RetTyp.GoTypeName}} {
{{(.Assign "a" "a" "b")}}
return a
}
{{end}}
{{end}}
`

// genOverloadsTestUtils creates a file that has a function for each overload
// defined in overloads.go. This is so that we can more easily test each
// overload.
func genOverloadsTestUtils(wr io.Writer) error {
tmpl, err := template.New("overloads_test_utils").Parse(overloadsTestUtilsTemplate)
if err != nil {
return err
}

typToOverloads := make(map[types.T][]*overload)
for _, overload := range binaryOpOverloads {
typ := overload.LTyp
typToOverloads[typ] = append(typToOverloads[typ], overload)
}
return tmpl.Execute(wr, typToOverloads)
}

func init() {
registerGenerator(genOverloadsTestUtils, "overloads_test_utils.eg.go")
}
1 change: 1 addition & 0 deletions pkg/sql/exec/execgen/cmd/execgen/projection_ops_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package exec
import (
"bytes"
"context"
"math"
"github.com/cockroachdb/apd"
"github.com/cockroachdb/cockroach/pkg/sql/exec/coldata"
Expand Down
121 changes: 121 additions & 0 deletions pkg/sql/exec/overloads_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright 2019 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package exec

import (
"math"
"testing"

"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/stretchr/testify/assert"
)

func TestIntegerAddition(t *testing.T) {
// The addition overload is the same for all integer widths, so we only test
// one of them.
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performPlusInt16(1, math.MaxInt16) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performPlusInt16(-1, math.MinInt16) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performPlusInt16(math.MaxInt16, 1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performPlusInt16(math.MinInt16, -1) })

assert.Equal(t, int16(math.MaxInt16), performPlusInt16(1, math.MaxInt16-1))
assert.Equal(t, int16(math.MinInt16), performPlusInt16(-1, math.MinInt16+1))
assert.Equal(t, int16(math.MaxInt16-1), performPlusInt16(-1, math.MaxInt16))
assert.Equal(t, int16(math.MinInt16+1), performPlusInt16(1, math.MinInt16))

assert.Equal(t, int16(22), performPlusInt16(10, 12))
assert.Equal(t, int16(-22), performPlusInt16(-10, -12))
assert.Equal(t, int16(2), performPlusInt16(-10, 12))
assert.Equal(t, int16(-2), performPlusInt16(10, -12))
}

func TestIntegerSubtraction(t *testing.T) {
// The subtraction overload is the same for all integer widths, so we only
// test one of them.
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMinusInt16(1, -math.MaxInt16) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMinusInt16(-2, math.MaxInt16) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMinusInt16(math.MaxInt16, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMinusInt16(math.MinInt16, 1) })

assert.Equal(t, int16(math.MaxInt16), performMinusInt16(1, -math.MaxInt16+1))
assert.Equal(t, int16(math.MinInt16), performMinusInt16(-1, math.MaxInt16))
assert.Equal(t, int16(math.MaxInt16-1), performMinusInt16(-1, -math.MaxInt16))
assert.Equal(t, int16(math.MinInt16+1), performMinusInt16(0, math.MaxInt16))

assert.Equal(t, int16(-2), performMinusInt16(10, 12))
assert.Equal(t, int16(2), performMinusInt16(-10, -12))
assert.Equal(t, int16(-22), performMinusInt16(-10, 12))
assert.Equal(t, int16(22), performMinusInt16(10, -12))
}

func TestIntegerDivision(t *testing.T) {
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performDivInt8(math.MinInt8, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performDivInt16(math.MinInt16, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performDivInt32(math.MinInt32, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performDivInt64(math.MinInt64, -1) })

assert.PanicsWithValue(t, tree.ErrDivByZero, func() { performDivInt8(10, 0) })
assert.PanicsWithValue(t, tree.ErrDivByZero, func() { performDivInt16(10, 0) })
assert.PanicsWithValue(t, tree.ErrDivByZero, func() { performDivInt32(10, 0) })
assert.PanicsWithValue(t, tree.ErrDivByZero, func() { performDivInt64(10, 0) })

assert.Equal(t, int8(-math.MaxInt8), performDivInt8(math.MaxInt8, -1))
assert.Equal(t, int16(-math.MaxInt16), performDivInt16(math.MaxInt16, -1))
assert.Equal(t, int32(-math.MaxInt32), performDivInt32(math.MaxInt32, -1))
assert.Equal(t, int64(-math.MaxInt64), performDivInt64(math.MaxInt64, -1))

assert.Equal(t, int16(0), performDivInt16(10, 12))
assert.Equal(t, int16(0), performDivInt16(-10, -12))
assert.Equal(t, int16(-1), performDivInt16(-12, 10))
assert.Equal(t, int16(-1), performDivInt16(12, -10))
}

func TestIntegerMultiplication(t *testing.T) {
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt8(math.MaxInt8-1, 100) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt8(math.MaxInt8-1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt8(math.MinInt8+1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt8(math.MinInt8+1, 100) })

assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt16(math.MaxInt16-1, 100) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt16(math.MaxInt16-1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt16(math.MinInt16+1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt16(math.MinInt16+1, 100) })

assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt32(math.MaxInt32-1, 100) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt32(math.MaxInt32-1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt32(math.MinInt32+1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt32(math.MinInt32+1, 100) })

assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt64(math.MaxInt64-1, 100) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt64(math.MaxInt64-1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt64(math.MinInt64+1, 3) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt64(math.MinInt64+1, 100) })

assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt8(math.MinInt8, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt16(math.MinInt16, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt32(math.MinInt32, -1) })
assert.PanicsWithValue(t, tree.ErrIntOutOfRange, func() { performMultInt64(math.MinInt64, -1) })

assert.Equal(t, int8(-math.MaxInt8), performMultInt8(math.MaxInt8, -1))
assert.Equal(t, int16(-math.MaxInt16), performMultInt16(math.MaxInt16, -1))
assert.Equal(t, int32(-math.MaxInt32), performMultInt32(math.MaxInt32, -1))
assert.Equal(t, int64(-math.MaxInt64), performMultInt64(math.MaxInt64, -1))

assert.Equal(t, int8(0), performMultInt8(math.MinInt8, 0))
assert.Equal(t, int16(0), performMultInt16(math.MinInt16, 0))
assert.Equal(t, int32(0), performMultInt32(math.MinInt32, 0))
assert.Equal(t, int64(0), performMultInt64(math.MinInt64, 0))

assert.Equal(t, int8(120), performMultInt8(10, 12))
assert.Equal(t, int16(120), performMultInt16(-10, -12))
assert.Equal(t, int32(-120), performMultInt32(-12, 10))
assert.Equal(t, int64(-120), performMultInt64(12, -10))
}
Loading

0 comments on commit 4233a87

Please sign in to comment.