diff --git a/pkg/sql/distsqlrun/aggregator.go b/pkg/sql/distsqlrun/aggregator.go index 49ef3c0bfedb..8b5d9d0ed1f1 100644 --- a/pkg/sql/distsqlrun/aggregator.go +++ b/pkg/sql/distsqlrun/aggregator.go @@ -418,10 +418,10 @@ func (ag *orderedAggregator) close() { // columns, and false otherwise. func (ag *aggregatorBase) matchLastOrdGroupCols(row sqlbase.EncDatumRow) (bool, error) { for _, colIdx := range ag.orderedGroupCols { - res, err := ag.lastOrdGroupCols[colIdx].Compare( + cmp, err := ag.lastOrdGroupCols[colIdx].Distinct( &ag.inputTypes[colIdx], &ag.datumAlloc, &ag.flowCtx.EvalCtx, &row[colIdx], ) - if res != 0 || err != nil { + if cmp || err != nil { return false, err } } diff --git a/pkg/sql/distsqlrun/disk_row_container_test.go b/pkg/sql/distsqlrun/disk_row_container_test.go index 3969c2b1af7b..6c6c12fc8d5f 100644 --- a/pkg/sql/distsqlrun/disk_row_container_test.go +++ b/pkg/sql/distsqlrun/disk_row_container_test.go @@ -38,27 +38,24 @@ import ( // l < r, 0 if l == r, and 1 if l > r. If an error is returned the int returned // is invalid. Note that the comparison is only performed on the ordering // columns. -func compareRows( +func distinctRows( lTypes []sqlbase.ColumnType, l, r sqlbase.EncDatumRow, e *tree.EvalContext, d *sqlbase.DatumAlloc, ordering sqlbase.ColumnOrdering, -) (int, error) { +) (bool, error) { for _, orderInfo := range ordering { col := orderInfo.ColIdx - cmp, err := l[col].Compare(&lTypes[col], d, e, &r[orderInfo.ColIdx]) + cmp, err := l[col].Distinct(&lTypes[col], d, e, &r[orderInfo.ColIdx]) if err != nil { - return 0, err + return false, err } - if cmp != 0 { - if orderInfo.Direction == encoding.Descending { - cmp = -cmp - } - return cmp, nil + if cmp { + return true, nil } } - return 0, nil + return false, nil } func TestDiskRowContainer(t *testing.T) { @@ -160,9 +157,9 @@ func TestDiskRowContainer(t *testing.T) { // Check equality of the row we wrote and the row we read. for i := range row { - if cmp, err := readRow[i].Compare(&types[i], &d.datumAlloc, &evalCtx, &row[i]); err != nil { + if cmp, err := readRow[i].Distinct(&types[i], &d.datumAlloc, &evalCtx, &row[i]); err != nil { t.Fatal(err) - } else if cmp != 0 { + } else if cmp { t.Fatalf("encoded %s but decoded %s", row.String(types), readRow.String(types)) } } @@ -222,13 +219,13 @@ func TestDiskRowContainer(t *testing.T) { } // Check sorted order. - if cmp, err := compareRows( + if cmp, err := distinctRows( types, sortedRows.EncRow(numKeysRead), row, &evalCtx, &d.datumAlloc, ordering, ); err != nil { t.Fatal(err) - } else if cmp != 0 { + } else if cmp { t.Fatalf( - "expected %s to be equal to %s", + "expected %s to be not distinct from %s", row.String(types), sortedRows.EncRow(numKeysRead).String(types), ) @@ -322,9 +319,9 @@ func TestDiskRowContainerFinalIterator(t *testing.T) { // checkEqual checks that the given row is equal to otherRow. checkEqual := func(row sqlbase.EncDatumRow, otherRow sqlbase.EncDatumRow) error { for j, c := range row { - if cmp, err := c.Compare(&intType, alloc, &evalCtx, &otherRow[j]); err != nil { + if cmp, err := c.Distinct(&intType, alloc, &evalCtx, &otherRow[j]); err != nil { return err - } else if cmp != 0 { + } else if cmp { return fmt.Errorf( "unexpected row %v, expected %v", row.String(oneIntCol), diff --git a/pkg/sql/distsqlrun/distinct.go b/pkg/sql/distsqlrun/distinct.go index 82751de2bc75..94032ece2f26 100644 --- a/pkg/sql/distsqlrun/distinct.go +++ b/pkg/sql/distsqlrun/distinct.go @@ -26,7 +26,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/mon" "github.com/cockroachdb/cockroach/pkg/util/stringarena" "github.com/cockroachdb/cockroach/pkg/util/tracing" - "github.com/opentracing/opentracing-go" + opentracing "github.com/opentracing/opentracing-go" "github.com/pkg/errors" ) @@ -163,10 +163,10 @@ func (d *Distinct) matchLastGroupKey(row sqlbase.EncDatumRow) (bool, error) { return false, nil } for _, colIdx := range d.orderedCols { - res, err := d.lastGroupKey[colIdx].Compare( + cmp, err := d.lastGroupKey[colIdx].Distinct( &d.types[colIdx], &d.datumAlloc, d.evalCtx, &row[colIdx], ) - if res != 0 || err != nil { + if cmp || err != nil { return false, err } } diff --git a/pkg/sql/distsqlrun/input_sync.go b/pkg/sql/distsqlrun/input_sync.go index cb7bf2f96e3e..c134ccc75e52 100644 --- a/pkg/sql/distsqlrun/input_sync.go +++ b/pkg/sql/distsqlrun/input_sync.go @@ -108,7 +108,7 @@ func (s *orderedSynchronizer) Len() int { func (s *orderedSynchronizer) Less(i, j int) bool { si := &s.sources[s.heap[i]] sj := &s.sources[s.heap[j]] - cmp, err := si.row.Compare(s.types, &s.alloc, s.ordering, s.evalCtx, sj.row) + cmp, err := si.row.TotalOrderCompare(s.types, &s.alloc, s.ordering, s.evalCtx, sj.row) if err != nil { s.err = err return false @@ -241,7 +241,7 @@ func (s *orderedSynchronizer) advanceRoot() error { } else { heap.Fix(s, 0) // TODO(radu): this check may be costly, we could disable it in production - if cmp, err := oldRow.Compare(s.types, &s.alloc, s.ordering, s.evalCtx, src.row); err != nil { + if cmp, err := oldRow.TotalOrderCompare(s.types, &s.alloc, s.ordering, s.evalCtx, src.row); err != nil { return err } else if cmp > 0 { return errors.Errorf( diff --git a/pkg/sql/distsqlrun/routers_test.go b/pkg/sql/distsqlrun/routers_test.go index 7d09253eddd6..296424ad1acd 100644 --- a/pkg/sql/distsqlrun/routers_test.go +++ b/pkg/sql/distsqlrun/routers_test.go @@ -39,7 +39,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/mon" "github.com/cockroachdb/cockroach/pkg/util/randutil" "github.com/cockroachdb/cockroach/pkg/util/tracing" - "github.com/opentracing/opentracing-go" + opentracing "github.com/opentracing/opentracing-go" ) // setupRouter creates and starts a router. Returns the router and a WaitGroup @@ -172,11 +172,11 @@ func TestRouters(t *testing.T) { for _, row2 := range r2 { equal := true for _, c := range tc.spec.HashColumns { - cmp, err := row[c].Compare(&types[c], alloc, evalCtx, &row2[c]) + cmp, err := row[c].Distinct(&types[c], alloc, evalCtx, &row2[c]) if err != nil { t.Fatal(err) } - if cmp != 0 { + if cmp { equal = false break } @@ -207,11 +207,11 @@ func TestRouters(t *testing.T) { equal := true for j, c := range row { - cmp, err := c.Compare(&types[j], alloc, evalCtx, &row2[j]) + cmp, err := c.Distinct(&types[j], alloc, evalCtx, &row2[j]) if err != nil { t.Fatal(err) } - if cmp != 0 { + if cmp { equal = false break } @@ -801,9 +801,9 @@ func TestRouterDiskSpill(t *testing.T) { } // Verify correct order (should be the order in which we added rows). for j, c := range row { - if cmp, err := c.Compare(&intType, alloc, &flowCtx.EvalCtx, &rows[i][j]); err != nil { + if cmp, err := c.Distinct(&intType, alloc, &flowCtx.EvalCtx, &rows[i][j]); err != nil { t.Fatal(err) - } else if cmp != 0 { + } else if cmp { t.Fatalf( "order violated on row %d, expected %v got %v", i, diff --git a/pkg/sql/distsqlrun/row_container.go b/pkg/sql/distsqlrun/row_container.go index b6f0e7dfd073..8d079702af25 100644 --- a/pkg/sql/distsqlrun/row_container.go +++ b/pkg/sql/distsqlrun/row_container.go @@ -140,11 +140,8 @@ func (mc *memRowContainer) initWithMon( // Less is part of heap.Interface and is only meant to be used internally. func (mc *memRowContainer) Less(i, j int) bool { - cmp := sqlbase.CompareDatums(mc.ordering, mc.evalCtx, mc.At(i), mc.At(j)) - if mc.invertSorting { - cmp = -cmp - } - return cmp < 0 + ra, rb := mc.At(i), mc.At(j) + return sqlbase.LessDatums(mc.ordering, mc.invertSorting, mc.evalCtx, ra, rb) } // EncRow returns the idx-th row as an EncDatumRow. The slice itself is reused @@ -189,11 +186,11 @@ func (mc *memRowContainer) Pop() interface{} { panic("unimplemented") } // smaller. Assumes InitTopK was called. func (mc *memRowContainer) MaybeReplaceMax(ctx context.Context, row sqlbase.EncDatumRow) error { max := mc.At(0) - cmp, err := row.CompareToDatums(mc.types, &mc.datumAlloc, mc.ordering, mc.evalCtx, max) + cmp, err := row.LessThanDatums(mc.types, &mc.datumAlloc, mc.ordering, mc.evalCtx, max) if err != nil { return err } - if cmp < 0 { + if cmp { // row is smaller than the max; replace. for i := range row { if err := row[i].EnsureDecoded(&mc.types[i], &mc.datumAlloc); err != nil { diff --git a/pkg/sql/distsqlrun/row_container_test.go b/pkg/sql/distsqlrun/row_container_test.go index e8998da51f18..c4e3ed18639b 100644 --- a/pkg/sql/distsqlrun/row_container_test.go +++ b/pkg/sql/distsqlrun/row_container_test.go @@ -52,11 +52,11 @@ func verifyRows( if err != nil { return err } - if cmp, err := compareRows( + if cmp, err := distinctRows( oneIntCol, row, expectedRows[0], evalCtx, &sqlbase.DatumAlloc{}, ordering, ); err != nil { return err - } else if cmp != 0 { + } else if cmp { return fmt.Errorf("unexpected row %v, expected %v", row, expectedRows[0]) } expectedRows = expectedRows[1:] diff --git a/pkg/sql/distsqlrun/sorter.go b/pkg/sql/distsqlrun/sorter.go index e03917c0ec19..af143e59cfb7 100644 --- a/pkg/sql/distsqlrun/sorter.go +++ b/pkg/sql/distsqlrun/sorter.go @@ -23,7 +23,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/humanizeutil" "github.com/cockroachdb/cockroach/pkg/util/mon" "github.com/cockroachdb/cockroach/pkg/util/tracing" - "github.com/opentracing/opentracing-go" + opentracing "github.com/opentracing/opentracing-go" ) // sorter sorts the input rows according to the specified ordering. @@ -501,8 +501,8 @@ func (s *sortChunksProcessor) chunkCompleted( types := s.input.OutputTypes() for _, ord := range s.ordering[:s.matchLen] { col := ord.ColIdx - cmp, err := nextChunkRow[col].Compare(&types[col], &s.alloc, s.evalCtx, &prefix[col]) - if cmp != 0 || err != nil { + cmp, err := nextChunkRow[col].Distinct(&types[col], &s.alloc, s.evalCtx, &prefix[col]) + if cmp || err != nil { return true, err } } diff --git a/pkg/sql/distsqlrun/stream_group_accumulator.go b/pkg/sql/distsqlrun/stream_group_accumulator.go index d344a8dc4217..2dd32b664f1f 100644 --- a/pkg/sql/distsqlrun/stream_group_accumulator.go +++ b/pkg/sql/distsqlrun/stream_group_accumulator.go @@ -101,7 +101,7 @@ func (s *streamGroupAccumulator) nextGroup( continue } - cmp, err := s.curGroup[0].Compare(s.types, &s.datumAlloc, s.ordering, evalCtx, row) + cmp, err := s.curGroup[0].TotalOrderCompare(s.types, &s.datumAlloc, s.ordering, evalCtx, row) if err != nil { return nil, &ProducerMetadata{Err: err} } diff --git a/pkg/sql/distsqlrun/stream_merger.go b/pkg/sql/distsqlrun/stream_merger.go index ee4306224aeb..00b105240b54 100644 --- a/pkg/sql/distsqlrun/stream_merger.go +++ b/pkg/sql/distsqlrun/stream_merger.go @@ -143,7 +143,7 @@ func CompareEncDatumRowForMerge( } continue } - cmp, err := lhs[lIdx].Compare(&lhsTypes[lIdx], da, evalCtx, &rhs[rIdx]) + cmp, err := lhs[lIdx].TotalOrderCompare(&lhsTypes[lIdx], da, evalCtx, &rhs[rIdx]) if err != nil { return 0, err } diff --git a/pkg/sql/distsqlrun/values_test.go b/pkg/sql/distsqlrun/values_test.go index 95ef6263b214..062c6643aae7 100644 --- a/pkg/sql/distsqlrun/values_test.go +++ b/pkg/sql/distsqlrun/values_test.go @@ -105,11 +105,11 @@ func TestValuesProcessor(t *testing.T) { t.Fatalf("row %d incorrect length %d, expected %d", i, len(res[i]), numCols) } for j, val := range res[i] { - cmp, err := val.Compare(&colTypes[j], &a, evalCtx, &inRows[i][j]) + cmp, err := val.Distinct(&colTypes[j], &a, evalCtx, &inRows[i][j]) if err != nil { t.Fatal(err) } - if cmp != 0 { + if cmp { t.Errorf( "row %d, column %d: received %s, expected %s", i, j, val.String(&colTypes[j]), inRows[i][j].String(&colTypes[j]), diff --git a/pkg/sql/distsqlrun/zigzagjoiner.go b/pkg/sql/distsqlrun/zigzagjoiner.go index 2af6b778ee3d..d274def9effa 100644 --- a/pkg/sql/distsqlrun/zigzagjoiner.go +++ b/pkg/sql/distsqlrun/zigzagjoiner.go @@ -245,6 +245,8 @@ type zigzagJoiner struct { // TODO(andrei): get rid of this field and move the actions it gates into the // Start() method. started bool + + datumAlloc sqlbase.DatumAlloc } // Batch size is a parameter which determines how many rows should be fetched @@ -604,12 +606,9 @@ func (z *zigzagJoiner) matchBase(curRow sqlbase.EncDatumRow, side int) (bool, er } // Compare the equality columns of the baseRow to that of the curRow. - da := &sqlbase.DatumAlloc{} - cmp, err := prevEqDatums.Compare(eqColTypes, da, ordering, &z.flowCtx.EvalCtx, curEqDatums) - if err != nil { - return false, err - } - return cmp == 0, nil + cmp, err := prevEqDatums.Distinct( + eqColTypes, &z.datumAlloc, ordering, &z.flowCtx.EvalCtx, curEqDatums) + return !cmp, err } // emitFromContainers returns the next row that is to be emitted from those @@ -753,8 +752,8 @@ func (z *zigzagJoiner) nextRow( if err != nil { return nil, z.producerMeta(err) } - da := &sqlbase.DatumAlloc{} - cmp, err := prevEqCols.Compare(eqColTypes, da, ordering, &z.flowCtx.EvalCtx, currentEqCols) + cmp, err := prevEqCols.TotalOrderCompare( + eqColTypes, &z.datumAlloc, ordering, &z.flowCtx.EvalCtx, currentEqCols) if err != nil { return nil, z.producerMeta(err) } diff --git a/pkg/sql/group.go b/pkg/sql/group.go index a4f6bc46044d..dc920d78a648 100644 --- a/pkg/sql/group.go +++ b/pkg/sql/group.go @@ -331,7 +331,7 @@ type groupRun struct { // grouping columns, and false otherwise. func (n *groupNode) matchLastGroupKey(ctx *tree.EvalContext, row tree.Datums) bool { for _, i := range n.orderedGroupCols { - if n.run.lastOrderedGroupKey[i].Compare(ctx, row[i]) != 0 { + if tree.Distinct(ctx, n.run.lastOrderedGroupKey[i], row[i]) { return false } } diff --git a/pkg/sql/opt/constraint/constraint.go b/pkg/sql/opt/constraint/constraint.go index 9bb19e172fb5..39e093f16ca7 100644 --- a/pkg/sql/opt/constraint/constraint.go +++ b/pkg/sql/opt/constraint/constraint.go @@ -436,12 +436,12 @@ func (c *Constraint) ExactPrefix(evalCtx *tree.EvalContext) int { return col } startVal := sp.start.Value(col) - if startVal.Compare(evalCtx, sp.end.Value(col)) != 0 { + if tree.Distinct(evalCtx, startVal, sp.end.Value(col)) { return col } if i == 0 { val = startVal - } else if startVal.Compare(evalCtx, val) != 0 { + } else if tree.Distinct(evalCtx, startVal, val) { return col } } @@ -465,7 +465,7 @@ func (c *Constraint) Prefix(evalCtx *tree.EvalContext) int { start := sp.StartKey() end := sp.EndKey() if start.Length() <= prefix || end.Length() <= prefix || - start.Value(prefix).Compare(evalCtx, end.Value(prefix)) != 0 { + tree.Distinct(evalCtx, start.Value(prefix), end.Value(prefix)) { return prefix } } diff --git a/pkg/sql/opt/constraint/key.go b/pkg/sql/opt/constraint/key.go index 00ec3fd1a00a..e0d9ceda4cc8 100644 --- a/pkg/sql/opt/constraint/key.go +++ b/pkg/sql/opt/constraint/key.go @@ -291,7 +291,7 @@ func (c *KeyContext) Compare(colIdx int, a, b tree.Datum) int { if a == b { return 0 } - cmp := a.Compare(c.EvalCtx, b) + cmp := tree.TotalOrderCompare(c.EvalCtx, a, b) if c.Columns.Get(colIdx).Descending() { cmp = -cmp } diff --git a/pkg/sql/opt/idxconstraint/index_constraints.go b/pkg/sql/opt/idxconstraint/index_constraints.go index b37a074f0214..52a3ebfd05c1 100644 --- a/pkg/sql/opt/idxconstraint/index_constraints.go +++ b/pkg/sql/opt/idxconstraint/index_constraints.go @@ -883,9 +883,9 @@ func (c *indexConstraintCtx) getMaxSimplifyPrefix(idxConstraint *constraint.Cons for i := 0; i < idxConstraint.Spans.Count(); i++ { sp := idxConstraint.Spans.Get(i) j := 0 - // Find the longest prefix of equal values. + // Find the longest prefix of non-distinct values. for ; j < sp.StartKey().Length() && j < sp.EndKey().Length(); j++ { - if sp.StartKey().Value(j).Compare(c.evalCtx, sp.EndKey().Value(j)) != 0 { + if tree.Distinct(c.evalCtx, sp.StartKey().Value(j), sp.EndKey().Value(j)) { break } } diff --git a/pkg/sql/opt/memo/statistics_builder.go b/pkg/sql/opt/memo/statistics_builder.go index ef85c661725c..0209d3d7c638 100644 --- a/pkg/sql/opt/memo/statistics_builder.go +++ b/pkg/sql/opt/memo/statistics_builder.go @@ -1189,7 +1189,7 @@ func (sb *statisticsBuilder) updateDistinctCountsFromConstraint( } startVal := sp.StartKey().Value(col) endVal := sp.EndKey().Value(col) - if startVal.Compare(sb.evalCtx, endVal) != 0 { + if tree.Distinct(sb.evalCtx, startVal, endVal) { // TODO(rytaft): are there other types we should handle here // besides int? if startVal.ResolvedType() == types.Int && endVal.ResolvedType() == types.Int { @@ -1211,7 +1211,7 @@ func (sb *statisticsBuilder) updateDistinctCountsFromConstraint( } } if i != 0 { - compare := startVal.Compare(sb.evalCtx, val) + compare := tree.TotalOrderCompare(sb.evalCtx, startVal, val) ascending := c.Columns.Get(col).Ascending() if (compare > 0 && ascending) || (compare < 0 && !ascending) { // This check is needed to ensure that we calculate the correct distinct diff --git a/pkg/sql/opt/xfunc/custom_funcs.go b/pkg/sql/opt/xfunc/custom_funcs.go index 96803e612072..e2fb87d19dcc 100644 --- a/pkg/sql/opt/xfunc/custom_funcs.go +++ b/pkg/sql/opt/xfunc/custom_funcs.go @@ -81,7 +81,7 @@ func (c *CustomFuncs) ConstructSortedUniqueList(list memo.ListID) (memo.ListID, // Remove duplicates from the list. n := 0 for i := 0; i < int(list.Length); i++ { - if i == 0 || ls.compare(i-1, i) < 0 { + if i == 0 || ls.less(i-1, i) { lb.items[n] = lb.items[i] n++ } diff --git a/pkg/sql/opt/xfunc/list_sorter.go b/pkg/sql/opt/xfunc/list_sorter.go index 3454dbca7dfa..17d867178b6c 100644 --- a/pkg/sql/opt/xfunc/list_sorter.go +++ b/pkg/sql/opt/xfunc/list_sorter.go @@ -16,6 +16,7 @@ package xfunc import ( "github.com/cockroachdb/cockroach/pkg/sql/opt/memo" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" ) // listSorter is a helper struct that implements the sort.Slice "less" @@ -27,37 +28,30 @@ type listSorter struct { // less returns true if item i in the list compares less than item j. // sort.Slice uses this method to sort the list. -func (s listSorter) less(i, j int) bool { - return s.compare(i, j) < 0 -} - -// compare returns -1 if item i compares less than item j, 0 if they are equal, -// and 1 if item i compares greater. Constants sort before non-constants, and +// Constants sort before non-constants, and // are sorted and uniquified according to Datum comparison rules. Non-constants // are sorted and uniquified by GroupID (arbitrary, but stable). -func (s listSorter) compare(i, j int) int { +func (s listSorter) less(i, j int) bool { // If both are constant values, then use datum comparison. isLeftConst := s.cf.mem.NormExpr(s.list[i]).IsConstValue() isRightConst := s.cf.mem.NormExpr(s.list[j]).IsConstValue() if isLeftConst { if !isRightConst { // Constant always sorts before non-constant - return -1 + return true } leftD := memo.ExtractConstDatum(memo.MakeNormExprView(s.cf.mem, s.list[i])) rightD := memo.ExtractConstDatum(memo.MakeNormExprView(s.cf.mem, s.list[j])) - return leftD.Compare(s.cf.evalCtx, rightD) + return tree.TotalOrderLess(s.cf.evalCtx, leftD, rightD) } else if isRightConst { // Non-constant always sorts after constant. - return 1 + return false } // Arbitrarily order by GroupID. if s.list[i] < s.list[j] { - return -1 - } else if s.list[i] > s.list[j] { - return 1 + return true } - return 0 + return false } diff --git a/pkg/sql/pgwire/binary_test.go b/pkg/sql/pgwire/binary_test.go index 2b3af7a3b472..347ae4b6641e 100644 --- a/pkg/sql/pgwire/binary_test.go +++ b/pkg/sql/pgwire/binary_test.go @@ -90,7 +90,7 @@ func testBinaryDatumType(t *testing.T, typ string, datumConstructor func(val str t.Fatalf("unable to decode %v: %s", got[4:], err) } - if d.Compare(evalCtx, datum) != 0 { + if tree.Distinct(evalCtx, d, datum) { t.Errorf("expected %s, got %s", d, datum) } }() @@ -249,7 +249,7 @@ func TestBinaryIntArray(t *testing.T) { } evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) defer evalCtx.Stop(context.Background()) - if got.Compare(evalCtx, d) != 0 { + if tree.Distinct(evalCtx, got, d) { t.Fatalf("expected %s, got %s", d, got) } } @@ -318,7 +318,7 @@ func TestRandomBinaryDecimal(t *testing.T) { oid.T_numeric, pgwirebase.FormatBinary, got[4:], ); err != nil { t.Errorf("%q: unable to decode %v: %s", test.In, got[4:], err) - } else if dec.Compare(evalCtx, datum) != 0 { + } else if tree.Distinct(evalCtx, dec, datum) { t.Errorf("%q: expected %s, got %s", test.In, dec, datum) } evalCtx.Stop(context.Background()) diff --git a/pkg/sql/pgwire/types_test.go b/pkg/sql/pgwire/types_test.go index 205f5e52478c..aad4e04133fc 100644 --- a/pkg/sql/pgwire/types_test.go +++ b/pkg/sql/pgwire/types_test.go @@ -139,7 +139,7 @@ func TestIntArrayRoundTrip(t *testing.T) { } evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) defer evalCtx.Stop(context.Background()) - if got.Compare(evalCtx, d) != 0 { + if tree.Distinct(evalCtx, got, d) { t.Fatalf("expected %s, got %s", d, got) } } @@ -177,7 +177,7 @@ func TestByteArrayRoundTrip(t *testing.T) { } evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) defer evalCtx.Stop(context.Background()) - if got.Compare(evalCtx, d) != 0 { + if tree.Distinct(evalCtx, got, d) { t.Fatalf("expected %s, got %s", d, got) } }) @@ -427,7 +427,7 @@ func BenchmarkDecodeBinaryDecimal(b *testing.B) { defer evalCtx.Stop(context.Background()) if err != nil { b.Fatal(err) - } else if got.Compare(evalCtx, expected) != 0 { + } else if tree.Distinct(evalCtx, got, expected) { b.Fatalf("expected %s, got %s", expected, got) } } diff --git a/pkg/sql/sem/builtins/aggregate_builtins.go b/pkg/sql/sem/builtins/aggregate_builtins.go index 2be923089e78..2215d7174a1f 100644 --- a/pkg/sql/sem/builtins/aggregate_builtins.go +++ b/pkg/sql/sem/builtins/aggregate_builtins.go @@ -325,13 +325,13 @@ func makeAggOverloadWithReturnType( case *MinAggregate: min := &slidingWindowFunc{} min.sw = makeSlidingWindow(evalCtx, func(evalCtx *tree.EvalContext, a, b tree.Datum) int { - return -a.Compare(evalCtx, b) + return -tree.TotalOrderCompare(evalCtx, a, b) }) return min case *MaxAggregate: max := &slidingWindowFunc{} max.sw = makeSlidingWindow(evalCtx, func(evalCtx *tree.EvalContext, a, b tree.Datum) int { - return a.Compare(evalCtx, b) + return tree.TotalOrderCompare(evalCtx, a, b) }) return max case *intSumAggregate: @@ -681,8 +681,7 @@ func (a *MaxAggregate) Add(_ context.Context, datum tree.Datum, _ ...tree.Datum) a.max = datum return nil } - c := a.max.Compare(a.evalCtx, datum) - if c < 0 { + if tree.TotalOrderLess(a.evalCtx, a.max, datum) { a.max = datum } return nil @@ -718,8 +717,7 @@ func (a *MinAggregate) Add(_ context.Context, datum tree.Datum, _ ...tree.Datum) a.min = datum return nil } - c := a.min.Compare(a.evalCtx, datum) - if c > 0 { + if tree.TotalOrderLess(a.evalCtx, datum, a.min) { a.min = datum } return nil diff --git a/pkg/sql/sem/builtins/builtins.go b/pkg/sql/sem/builtins/builtins.go index d73fe20864f3..5c235a066061 100644 --- a/pkg/sql/sem/builtins/builtins.go +++ b/pkg/sql/sem/builtins/builtins.go @@ -2464,7 +2464,7 @@ may increase either contention or retry errors, or both.`, } result := tree.NewDArray(typ) for _, e := range tree.MustBeDArray(args[0]).Array { - if e.Compare(ctx, args[1]) != 0 { + if tree.Distinct(ctx, e, args[1]) { if err := result.Append(e); err != nil { return nil, err } @@ -2486,7 +2486,7 @@ may increase either contention or retry errors, or both.`, } result := tree.NewDArray(typ) for _, e := range tree.MustBeDArray(args[0]).Array { - if e.Compare(ctx, args[1]) == 0 { + if !tree.Distinct(ctx, e, args[1]) { if err := result.Append(args[2]); err != nil { return nil, err } @@ -2511,7 +2511,7 @@ may increase either contention or retry errors, or both.`, return tree.DNull, nil } for i, e := range tree.MustBeDArray(args[0]).Array { - if e.Compare(ctx, args[1]) == 0 { + if !tree.Distinct(ctx, e, args[1]) { return tree.NewDInt(tree.DInt(i + 1)), nil } } @@ -2531,7 +2531,7 @@ may increase either contention or retry errors, or both.`, } result := tree.NewDArray(types.Int) for i, e := range tree.MustBeDArray(args[0]).Array { - if e.Compare(ctx, args[1]) == 0 { + if !tree.Distinct(ctx, e, args[1]) { if err := result.Append(tree.NewDInt(tree.DInt(i + 1))); err != nil { return nil, err } diff --git a/pkg/sql/sem/builtins/builtins_test.go b/pkg/sql/sem/builtins/builtins_test.go index bf2eb54fb6b5..1e740e128426 100644 --- a/pkg/sql/sem/builtins/builtins_test.go +++ b/pkg/sql/sem/builtins/builtins_test.go @@ -107,7 +107,7 @@ func TestStringToArrayAndBack(t *testing.T) { } evalContext := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) - if result.Compare(evalContext, expectedArray) != 0 { + if tree.Distinct(evalContext, result, expectedArray) { t.Errorf("expected %v, got %v", tc.expected, result) } diff --git a/pkg/sql/sem/builtins/pg_builtins.go b/pkg/sql/sem/builtins/pg_builtins.go index 39c05b7009cb..be16ed7dd4c3 100644 --- a/pkg/sql/sem/builtins/pg_builtins.go +++ b/pkg/sql/sem/builtins/pg_builtins.go @@ -533,7 +533,7 @@ var pgBuiltins = map[string]builtinDefinition{ }, ReturnType: tree.FixedReturnType(types.String), Fn: func(ctx *tree.EvalContext, args tree.Datums) (tree.Datum, error) { - if args[0].Compare(ctx, DatEncodingUTFId) == 0 { + if !tree.Distinct(ctx, args[0], DatEncodingUTFId) { return datEncodingUTF8ShortName, nil } return tree.DNull, nil diff --git a/pkg/sql/sem/builtins/window_frame_builtins_test.go b/pkg/sql/sem/builtins/window_frame_builtins_test.go index 94c5872067a8..9e04e5e02611 100644 --- a/pkg/sql/sem/builtins/window_frame_builtins_test.go +++ b/pkg/sql/sem/builtins/window_frame_builtins_test.go @@ -51,7 +51,7 @@ func testMin(t *testing.T, evalCtx *tree.EvalContext, wfr *tree.WindowFrameRun) wfr.EndBoundOffset = offset min := &slidingWindowFunc{} min.sw = makeSlidingWindow(evalCtx, func(evalCtx *tree.EvalContext, a, b tree.Datum) int { - return -a.Compare(evalCtx, b) + return -tree.TotalOrderCompare(evalCtx, a, b) }) for wfr.RowIdx = 0; wfr.RowIdx < wfr.PartitionSize(); wfr.RowIdx++ { res, err := min.Compute(evalCtx.Ctx(), evalCtx, wfr) @@ -86,7 +86,7 @@ func testMax(t *testing.T, evalCtx *tree.EvalContext, wfr *tree.WindowFrameRun) wfr.EndBoundOffset = offset max := &slidingWindowFunc{} max.sw = makeSlidingWindow(evalCtx, func(evalCtx *tree.EvalContext, a, b tree.Datum) int { - return a.Compare(evalCtx, b) + return tree.TotalOrderCompare(evalCtx, a, b) }) for wfr.RowIdx = 0; wfr.RowIdx < wfr.PartitionSize(); wfr.RowIdx++ { res, err := max.Compute(evalCtx.Ctx(), evalCtx, wfr) @@ -221,7 +221,7 @@ func testRingBuffer(t *testing.T, count int) { for pos, iv := range naiveBuffer { res := ring.get(pos) - if res.idx != iv.idx || res.value.Compare(evalCtx, iv.value) != 0 { + if res.idx != iv.idx || tree.Distinct(evalCtx, res.value, iv.value) { t.Errorf("Ring buffer returned incorrect value: expected %+v, found %+v", iv, res) panic("") } diff --git a/pkg/sql/sem/tree/compare.go b/pkg/sql/sem/tree/compare.go new file mode 100644 index 000000000000..61b98b7657dc --- /dev/null +++ b/pkg/sql/sem/tree/compare.go @@ -0,0 +1,141 @@ +// Copyright 2018 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package tree + +import "fmt" + +// This file implements comparisons between tuples. +// +// This implements the comparators for three separate relations. +// +// - a total ordering for the purpose of sorting and searching +// values in indexes. In that relation, NULL sorts at the +// same location as itself and before other values. +// +// Functions: TotalOrderLess(), TotalOrderCompare(). +// +// - the logical SQL scalar partial ordering, where non-NULL +// values can be compared to each others but NULL comparisons +// themselves produce a NULL result. +// +// Function: ScalarCompare() +// +// - the IS [NOT] DISTINCT relation, in which every value can +// be compared to every other, NULLs are distinct from every +// non-NULL value but not distinct from each other. +// +// Function: Distinct(). +// +// Due to the way the SQL language semantics are constructed, it is +// the case Distinct() returns true if and only if +// TotalOrderCompare() returns nonzero. However, one should be +// careful when using this methods to properly convey *intent* to the +// reader of the code: +// +// - the functions related to the total order for sorting should only +// be used in contexts that are about sorting values. +// +// - Distinct() and ScalarCompare() should be used everywhere else. +// +// Besides, separating Distinct() from TotalOrderCompare() enables +// later performance optimizations of the former by specializing the +// code. This is currently done for e.g. EncDatums. + +// TotalOrderLess returns true if and only if a sorts before b. NULLs +// are considered to sort first. +func TotalOrderLess(ctx *EvalContext, a, b Datum) bool { + return doCompare(ctx, true /* orderedNULLs */, a, b) < 0 +} + +// TotalOrderCompare returns -1 if a sorts before b, +1 if a sorts +// after b, and 0 if a and be are considered equal for the purpose of +// sorting. +// This function is only suitable for index span computations and +// should not be used to test equality. Consider Distinct() +// and ScalarCompare() instead. +func TotalOrderCompare(ctx *EvalContext, a, b Datum) int { + return doCompare(ctx, true /* orderedNULLs */, a, b) +} + +// Distinct returns true if and only if a and b are distinct +// from each other. NULLs are considered to not be distinct from each +// other but are distinct from every other value. +func Distinct(ctx *EvalContext, a, b Datum) bool { + return doCompare(ctx, true /* orderedNULLs */, a, b) != 0 +} + +// Distinct checks to see if two slices of datums are distinct +// from each other. Any change in value is considered distinct, +// however, a NULL value is NOT considered disctinct from another NULL +// value. +func (d Datums) Distinct(evalCtx *EvalContext, other Datums) bool { + if len(d) != len(other) { + return true + } + for i, val := range d { + if Distinct(evalCtx, val, other[i]) { + return true + } + } + return false +} + +// ScalarCompare returns a SQL value for the given comparison. +// +// It properly returns NULL if the comparison is not "Distinct" and +// requires comparing NULL values against anything else. +// +// The operator op and two operands must have already undergone +// normalization via foldComparisonExpr: the only supported operators +// here are EQ, LT, LE and IsNotDistinctFrom. +func ScalarCompare(ctx *EvalContext, a, b Datum, op ComparisonOperator) Datum { + if op == IsNotDistinctFrom { + return MakeDBool(DBool(!Distinct(ctx, a, b))) + } + + cmp := doCompare(ctx, false /* orderedNULLs */, a, b) + if cmp == -2 { + return DNull + } + + switch op { + case EQ: + return MakeDBool(cmp == 0) + case LT: + return MakeDBool(cmp < 0) + case LE: + return MakeDBool(cmp <= 0) + default: + panic(fmt.Sprintf("unexpected ComparisonOperator in boolFromCmp: %v", op)) + } +} + +// doCompare is the main function. +func doCompare(ctx *EvalContext, orderedNULLs bool, a, b Datum) int { + if a == DNull || b == DNull { + if orderedNULLs { + if b != DNull { + return -1 + } + if a != DNull { + return 1 + } + return 0 + } + return -2 + } + // TODO(knz): this should not use internalCompare any more. + return a.internalCompare(ctx, b) +} diff --git a/pkg/sql/sem/tree/constant_test.go b/pkg/sql/sem/tree/constant_test.go index 0d666796f432..f978c0e2ca14 100644 --- a/pkg/sql/sem/tree/constant_test.go +++ b/pkg/sql/sem/tree/constant_test.go @@ -334,7 +334,7 @@ func TestStringConstantResolveAvailableTypes(t *testing.T) { expectedDatum := parseFuncs[availType](t, test.c.RawString()) evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) defer evalCtx.Stop(context.Background()) - if res.Compare(evalCtx, expectedDatum) != 0 { + if tree.Distinct(evalCtx, res, expectedDatum) { t.Errorf("%d: type %s expected to be resolved from the tree.StrVal %v to tree.Datum %v"+ ", found %v", i, availType, test.c, expectedDatum, res) diff --git a/pkg/sql/sem/tree/datum.go b/pkg/sql/sem/tree/datum.go index f54cbd959dc7..82034a9dd3fa 100644 --- a/pkg/sql/sem/tree/datum.go +++ b/pkg/sql/sem/tree/datum.go @@ -76,9 +76,13 @@ type Datum interface { // fmtFlags.disambiguateDatumTypes. AmbiguousFormat() bool - // Compare returns -1 if the receiver is less than other, 0 if receiver is - // equal to other and +1 if receiver is greater than other. - Compare(ctx *EvalContext, other Datum) int + // internalCompare returns -1 if the receiver sorts before other, 0 + // if receiver is not distinct to other, +1 if receiver sorts after + // other. + // This interface is broken! Do not use. Use instead the functions + // in compare.go. + // TODO(knz): Fix this. + internalCompare(ctx *EvalContext, other Datum) int // Prev returns the previous datum and true, if one exists, or nil and false. // The previous datum satisfies the following definition: if the receiver is @@ -89,9 +93,9 @@ type Datum interface { // // TODO(#12022): for DTuple, the contract is actually that "x < b" (SQL order, // where NULL < x is unknown for all x) is true only if "x <= a" - // (.Compare/encoding order, where NULL <= x is true for all x) is true. This + // (TotalOrderCompare/encoding order, where NULL <= x is true for all x) is true. This // is okay for now: the returned datum is used only to construct a span, which - // uses .Compare/encoding order and is guaranteed to be large enough by this + // uses TotalOrderCompare/encoding order and is guaranteed to be large enough by this // weaker contract. The original filter expression is left in place to catch // false positives. Prev(ctx *EvalContext) (Datum, bool) @@ -109,9 +113,9 @@ type Datum interface { // // TODO(#12022): for DTuple, the contract is actually that "x > a" (SQL order, // where x > NULL is unknown for all x) is true only if "x >= b" - // (.Compare/encoding order, where x >= NULL is true for all x) is true. This + // (TotalOrderCompare/encoding order, where x >= NULL is true for all x) is true. This // is okay for now: the returned datum is used only to construct a span, which - // uses .Compare/encoding order and is guaranteed to be large enough by this + // uses TotalOrderCompare/encoding order and is guaranteed to be large enough by this // weaker contract. The original filter expression is left in place to catch // false positives. Next(ctx *EvalContext) (Datum, bool) @@ -161,27 +165,6 @@ func (d *Datums) Format(ctx *FmtCtx) { ctx.WriteByte(')') } -// IsDistinctFrom checks to see if two datums are distinct from each other. Any -// change in value is considered distinct, however, a NULL value is NOT -// considered disctinct from another NULL value. -func (d Datums) IsDistinctFrom(evalCtx *EvalContext, other Datums) bool { - if len(d) != len(other) { - return true - } - for i, val := range d { - if val == DNull { - if other[i] != DNull { - return true - } - } else { - if val.Compare(evalCtx, other[i]) != 0 { - return true - } - } - } - return false -} - // CompositeDatum is a Datum that may require composite encoding in // indexes. Any Datum implementing this interface must also add itself to // sqlbase/HasCompositeKeyEncoding. @@ -355,8 +338,8 @@ func (*DBool) ResolvedType() types.T { return types.Bool } -// Compare implements the Datum interface. -func (d *DBool) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DBool) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -464,8 +447,8 @@ func (*DInt) ResolvedType() types.T { return types.Int } -// Compare implements the Datum interface. -func (d *DInt) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DInt) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -475,7 +458,7 @@ func (d *DInt) Compare(ctx *EvalContext, other Datum) int { case *DInt: v = *t case *DFloat, *DDecimal: - return -t.Compare(ctx, d) + return -t.internalCompare(ctx, d) default: panic(makeUnsupportedComparisonMessage(d, other)) } @@ -567,8 +550,8 @@ func (*DFloat) ResolvedType() types.T { return types.Float } -// Compare implements the Datum interface. -func (d *DFloat) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DFloat) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -580,7 +563,7 @@ func (d *DFloat) Compare(ctx *EvalContext, other Datum) int { case *DInt: v = DFloat(MustBeDInt(t)) case *DDecimal: - return -t.Compare(ctx, d) + return -t.internalCompare(ctx, d) default: panic(makeUnsupportedComparisonMessage(d, other)) } @@ -720,8 +703,8 @@ func (*DDecimal) ResolvedType() types.T { return types.Decimal } -// Compare implements the Datum interface. -func (d *DDecimal) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DDecimal) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -861,8 +844,8 @@ func (*DString) ResolvedType() types.T { return types.String } -// Compare implements the Datum interface. -func (d *DString) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DString) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -1003,8 +986,8 @@ func (d *DCollatedString) ResolvedType() types.T { return types.TCollatedString{Locale: d.Locale} } -// Compare implements the Datum interface. -func (d *DCollatedString) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DCollatedString) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -1090,8 +1073,8 @@ func (*DBytes) ResolvedType() types.T { return types.Bytes } -// Compare implements the Datum interface. -func (d *DBytes) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DBytes) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -1192,8 +1175,8 @@ func (*DUuid) ResolvedType() types.T { return types.UUID } -// Compare implements the Datum interface. -func (d *DUuid) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DUuid) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -1308,8 +1291,8 @@ func (*DIPAddr) ResolvedType() types.T { return types.INet } -// Compare implements the Datum interface. -func (d *DIPAddr) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DIPAddr) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -1462,8 +1445,8 @@ func (*DDate) ResolvedType() types.T { return types.Date } -// Compare implements the Datum interface. -func (d *DDate) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DDate) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -1564,8 +1547,8 @@ func (*DTime) ResolvedType() types.T { return types.Time } -// Compare implements the Datum interface. -func (d *DTime) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DTime) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -1667,8 +1650,8 @@ func (*DTimeTZ) ResolvedType() types.T { return types.TimeTZ } -// Compare implements the Datum interface. -func (d *DTimeTZ) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DTimeTZ) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -1914,8 +1897,8 @@ func compareTimestamps(ctx *EvalContext, l Datum, r Datum) int { return 0 } -// Compare implements the Datum interface. -func (d *DTimestamp) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DTimestamp) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -2013,8 +1996,8 @@ func (*DTimestampTZ) ResolvedType() types.T { return types.TimestampTZ } -// Compare implements the Datum interface. -func (d *DTimestampTZ) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DTimestampTZ) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -2206,8 +2189,8 @@ func (*DInterval) ResolvedType() types.T { return types.Interval } -// Compare implements the Datum interface. -func (d *DInterval) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DInterval) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -2398,8 +2381,8 @@ func (*DJSON) ResolvedType() types.T { return types.JSON } -// Compare implements the Datum interface. -func (d *DJSON) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DJSON) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -2410,7 +2393,7 @@ func (d *DJSON) Compare(ctx *EvalContext, other Datum) int { } // No avenue for us to pass up this error here at the moment, but Compare // only errors for invalid encoded data. - // TODO(justin): modify Compare to allow passing up errors. + // TODO(justin): modify internalCompare to allow passing up errors. c, err := d.JSON.Compare(v.JSON) if err != nil { panic(err) @@ -2530,8 +2513,8 @@ func (d *DTuple) ResolvedType() types.T { return d.typ } -// Compare implements the Datum interface. -func (d *DTuple) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DTuple) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -2545,7 +2528,7 @@ func (d *DTuple) Compare(ctx *EvalContext, other Datum) int { n = len(v.D) } for i := 0; i < n; i++ { - c := d.D[i].Compare(ctx, v.D[i]) + c := d.D[i].internalCompare(ctx, v.D[i]) if c != 0 { return c } @@ -2711,16 +2694,10 @@ func (d *DTuple) AssertSorted() { // be func (d *DTuple) SearchSorted(ctx *EvalContext, target Datum) (int, bool) { d.AssertSorted() - if target == DNull { - panic(fmt.Sprintf("NULL target (d: %s)", d)) - } - if t, ok := target.(*DTuple); ok && t.ContainsNull() { - panic(fmt.Sprintf("target containing NULLs: %#v (d: %s)", target, d)) - } i := sort.Search(len(d.D), func(i int) bool { - return d.D[i].Compare(ctx, target) >= 0 + return !TotalOrderLess(ctx, d.D[i], target) }) - found := i < len(d.D) && d.D[i].Compare(ctx, target) == 0 + found := i < len(d.D) && !Distinct(ctx, d.D[i], target) return i, found } @@ -2733,7 +2710,7 @@ func (d *DTuple) Normalize(ctx *EvalContext) { func (d *DTuple) sort(ctx *EvalContext) { if !d.sorted { sort.Slice(d.D, func(i, j int) bool { - return d.D[i].Compare(ctx, d.D[j]) < 0 + return TotalOrderLess(ctx, d.D[i], d.D[j]) }) d.SetSorted() } @@ -2742,7 +2719,7 @@ func (d *DTuple) sort(ctx *EvalContext) { func (d *DTuple) makeUnique(ctx *EvalContext) { n := 0 for i := 0; i < len(d.D); i++ { - if n == 0 || d.D[n-1].Compare(ctx, d.D[i]) < 0 { + if n == 0 || TotalOrderLess(ctx, d.D[n-1], d.D[i]) { d.D[n] = d.D[i] n++ } @@ -2786,8 +2763,8 @@ func (dNull) ResolvedType() types.T { return types.Unknown } -// Compare implements the Datum interface. -func (d dNull) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d dNull) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { return 0 } @@ -2881,8 +2858,8 @@ func (d *DArray) ResolvedType() types.T { return types.TArray{Typ: d.ParamTyp} } -// Compare implements the Datum interface. -func (d *DArray) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DArray) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -2896,7 +2873,7 @@ func (d *DArray) Compare(ctx *EvalContext, other Datum) int { n = v.Len() } for i := 0; i < n; i++ { - c := d.Array[i].Compare(ctx, v.Array[i]) + c := d.Array[i].internalCompare(ctx, v.Array[i]) if c != 0 { return c } @@ -3060,8 +3037,8 @@ func (d *DOid) AsRegProc(name string) *DOid { // AmbiguousFormat implements the Datum interface. func (*DOid) AmbiguousFormat() bool { return true } -// Compare implements the Datum interface. -func (d *DOid) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DOid) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 @@ -3205,16 +3182,16 @@ func (d *DOidWrapper) ResolvedType() types.T { return types.WrapTypeWithOid(d.Wrapped.ResolvedType(), d.Oid) } -// Compare implements the Datum interface. -func (d *DOidWrapper) Compare(ctx *EvalContext, other Datum) int { +// internalCompare implements the Datum interface. +func (d *DOidWrapper) internalCompare(ctx *EvalContext, other Datum) int { if other == DNull { // NULL is less than any non-NULL value. return 1 } if v, ok := other.(*DOidWrapper); ok { - return d.Wrapped.Compare(ctx, v.Wrapped) + return d.Wrapped.internalCompare(ctx, v.Wrapped) } - return d.Wrapped.Compare(ctx, other) + return d.Wrapped.internalCompare(ctx, other) } // Prev implements the Datum interface. @@ -3284,9 +3261,9 @@ func (d *Placeholder) mustGetValue(ctx *EvalContext) Datum { return out } -// Compare implements the Datum interface. -func (d *Placeholder) Compare(ctx *EvalContext, other Datum) int { - return d.mustGetValue(ctx).Compare(ctx, other) +// internalCompare implements the Datum interface. +func (d *Placeholder) internalCompare(ctx *EvalContext, other Datum) int { + return d.mustGetValue(ctx).internalCompare(ctx, other) } // Prev implements the Datum interface. diff --git a/pkg/sql/sem/tree/datum_test.go b/pkg/sql/sem/tree/datum_test.go index 3dcda192f7ff..7fce5b7f3e6a 100644 --- a/pkg/sql/sem/tree/datum_test.go +++ b/pkg/sql/sem/tree/datum_test.go @@ -317,7 +317,7 @@ func TestDFloatCompare(t *testing.T) { } evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) defer evalCtx.Stop(context.Background()) - got := x.Compare(evalCtx, y) + got := tree.TotalOrderCompare(evalCtx, x, y) if got != expected { t.Errorf("comparing DFloats %s and %s: expected %d, got %d", x, y, expected, got) } @@ -367,7 +367,7 @@ func TestParseDIntervalWithField(t *testing.T) { } evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) defer evalCtx.Stop(context.Background()) - if expected.Compare(evalCtx, actual) != 0 { + if tree.Distinct(evalCtx, expected, actual) { t.Errorf("INTERVAL %s %v: got %s, expected %s", td.str, td.field, actual, expected) } } @@ -408,7 +408,7 @@ func TestParseDDate(t *testing.T) { } evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) defer evalCtx.Stop(context.Background()) - if expected.Compare(evalCtx, actual) != 0 { + if tree.Distinct(evalCtx, expected, actual) { t.Errorf("DATE %s: got %s, expected %s", td.str, actual, expected) } } @@ -619,12 +619,13 @@ func TestMakeDJSON(t *testing.T) { if err != nil { t.Fatal(err) } - if j1.Compare(tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()), j2) != -1 { + evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) + if !tree.TotalOrderLess(evalCtx, j1, j2) { t.Fatal("expected JSON 1 < 2") } } -func TestIsDistinctFrom(t *testing.T) { +func TestDistinct(t *testing.T) { testData := []struct { a string // comma separated list of strings, `NULL` is converted to a NULL b string // same as a @@ -709,7 +710,7 @@ func TestIsDistinctFrom(t *testing.T) { t.Run(fmt.Sprintf("%s to %s", td.a, td.b), func(t *testing.T) { datumsA := convert(td.a) datumsB := convert(td.b) - if e, a := td.expected, datumsA.IsDistinctFrom(&tree.EvalContext{}, datumsB); e != a { + if e, a := td.expected, datumsA.Distinct(&tree.EvalContext{}, datumsB); e != a { if e { t.Errorf("expected %s to be distinct from %s, but got %t", datumsA, datumsB, e) } else { diff --git a/pkg/sql/sem/tree/eval.go b/pkg/sql/sem/tree/eval.go index e85c8d55fc7a..04f874630567 100644 --- a/pkg/sql/sem/tree/eval.go +++ b/pkg/sql/sem/tree/eval.go @@ -2039,6 +2039,19 @@ func init() { } } +func cmpOpScalarEQFn(ctx *EvalContext, left, right Datum) (Datum, error) { + return ScalarCompare(ctx, left, right, EQ), nil +} +func cmpOpScalarLTFn(ctx *EvalContext, left, right Datum) (Datum, error) { + return ScalarCompare(ctx, left, right, LT), nil +} +func cmpOpScalarLEFn(ctx *EvalContext, left, right Datum) (Datum, error) { + return ScalarCompare(ctx, left, right, LE), nil +} +func cmpOpScalarIsFn(ctx *EvalContext, left, right Datum) (Datum, error) { + return ScalarCompare(ctx, left, right, IsNotDistinctFrom), nil +} + func boolFromCmp(cmp int, op ComparisonOperator) *DBool { switch op { case EQ, IsNotDistinctFrom: @@ -2052,37 +2065,6 @@ func boolFromCmp(cmp int, op ComparisonOperator) *DBool { } } -func cmpOpScalarFn(ctx *EvalContext, left, right Datum, op ComparisonOperator) Datum { - // Before deferring to the Datum.Compare method, check for values that should - // be handled differently during SQL comparison evaluation than they should when - // ordering Datum values. - if left == DNull || right == DNull { - switch op { - case IsNotDistinctFrom: - return MakeDBool((left == DNull) == (right == DNull)) - - default: - // If either Datum is NULL, the result of the comparison is NULL. - return DNull - } - } - cmp := left.Compare(ctx, right) - return boolFromCmp(cmp, op) -} - -func cmpOpScalarEQFn(ctx *EvalContext, left, right Datum) (Datum, error) { - return cmpOpScalarFn(ctx, left, right, EQ), nil -} -func cmpOpScalarLTFn(ctx *EvalContext, left, right Datum) (Datum, error) { - return cmpOpScalarFn(ctx, left, right, LT), nil -} -func cmpOpScalarLEFn(ctx *EvalContext, left, right Datum) (Datum, error) { - return cmpOpScalarFn(ctx, left, right, LE), nil -} -func cmpOpScalarIsFn(ctx *EvalContext, left, right Datum) (Datum, error) { - return cmpOpScalarFn(ctx, left, right, IsNotDistinctFrom), nil -} - func cmpOpTupleFn(ctx *EvalContext, left, right DTuple, op ComparisonOperator) Datum { cmp := 0 sawNull := false @@ -2115,7 +2097,7 @@ func cmpOpTupleFn(ctx *EvalContext, left, right DTuple, op ComparisonOperator) D return DNull } } else { - cmp = leftElem.Compare(ctx, rightElem) + cmp = leftElem.internalCompare(ctx, rightElem) if cmp != 0 { break } @@ -2166,7 +2148,7 @@ func makeEvalTupleIn(typ types.T) CmpOp { for _, val := range vtuple.D { if val == DNull { sawNull = true - } else if val.Compare(ctx, arg) == 0 { + } else if val.internalCompare(ctx, arg) == 0 { return DBoolTrue, nil } } @@ -4030,7 +4012,7 @@ func foldComparisonExpr( // NotRegIMatch(left, right) is implemented as !RegIMatch(left, right) return RegIMatch, left, right, false, true case IsDistinctFrom: - // IsDistinctFrom(left, right) is implemented as !IsNotDistinctFrom(left, right) + // Distinct(left, right) is implemented as !IsNotDistinctFrom(left, right) // Note: this seems backwards, but IS NOT DISTINCT FROM is an extended // version of IS and IS DISTINCT FROM is an extended version of IS NOT. return IsNotDistinctFrom, left, right, false, true diff --git a/pkg/sql/sem/tree/indexed_vars_test.go b/pkg/sql/sem/tree/indexed_vars_test.go index 4b58293e2277..778182b710af 100644 --- a/pkg/sql/sem/tree/indexed_vars_test.go +++ b/pkg/sql/sem/tree/indexed_vars_test.go @@ -101,7 +101,7 @@ func TestIndexedVars(t *testing.T) { if err != nil { t.Fatal(err) } - if d.Compare(evalCtx, NewDInt(3+5*6)) != 0 { + if Distinct(evalCtx, d, NewDInt(3+5*6)) { t.Errorf("invalid result %s (expected %d)", d, 3+5*6) } } diff --git a/pkg/sql/sem/tree/normalize.go b/pkg/sql/sem/tree/normalize.go index c9307726ae4d..6aa8d232dd70 100644 --- a/pkg/sql/sem/tree/normalize.go +++ b/pkg/sql/sem/tree/normalize.go @@ -315,7 +315,7 @@ func (expr *ComparisonExpr) normalize(v *NormalizeVisitor) TypedExpr { v.err = err return expr } - if divisor.Compare(v.ctx, DZero) < 0 { + if TotalOrderLess(v.ctx, divisor, DZero) { if !exprCopied { exprCopy := *expr expr = &exprCopy diff --git a/pkg/sql/sem/tree/parse_array_test.go b/pkg/sql/sem/tree/parse_array_test.go index 5006a8697244..7fbf5a894034 100644 --- a/pkg/sql/sem/tree/parse_array_test.go +++ b/pkg/sql/sem/tree/parse_array_test.go @@ -94,7 +94,7 @@ lo}`, coltypes.String, Datums{NewDString(`hel`), NewDString(`lo`)}}, if err != nil { t.Fatalf("ARRAY %s: got error %s, expected %s", td.str, err.Error(), expected) } - if actual.Compare(evalContext, expected) != 0 { + if Distinct(evalContext, actual, expected) { t.Fatalf("ARRAY %s: got %s, expected %s", td.str, actual, expected) } }) diff --git a/pkg/sql/sort.go b/pkg/sql/sort.go index ef3eec5c7543..518e85b9483a 100644 --- a/pkg/sql/sort.go +++ b/pkg/sql/sort.go @@ -646,7 +646,7 @@ func (ss *sortTopKStrategy) Add(ctx context.Context, values tree.Datums) error { if err := ss.vNode.PushValues(ctx, values); err != nil { return err } - case ss.vNode.ValuesLess(values, ss.vNode.rows.At(0)): + case sqlbase.LessDatums(ss.vNode.ordering, false, ss.vNode.evalCtx, values, ss.vNode.rows.At(0)): // Once the heap is full, only replace the top // value if a new value is less than it. If so // replace and fix the heap. @@ -756,19 +756,13 @@ func (n *sortValues) Len() int { return n.rows.Len() - n.rowsPopped } -// ValuesLess returns the comparison result between the two provided -// Datums slices in the context of the sortValues ordering. -func (n *sortValues) ValuesLess(ra, rb tree.Datums) bool { - return sqlbase.CompareDatums(n.ordering, n.evalCtx, ra, rb) < 0 -} - // Less implements the sort.Interface interface. func (n *sortValues) Less(i, j int) bool { // TODO(pmattis): An alternative to this type of field-based comparison would // be to construct a sort-key per row using encodeTableKey(). Using a // sort-key approach would likely fit better with a disk-based sort. ra, rb := n.rows.At(i), n.rows.At(j) - return n.invertSorting != n.ValuesLess(ra, rb) + return sqlbase.LessDatums(n.ordering, n.invertSorting, n.evalCtx, ra, rb) } // Swap implements the sort.Interface interface. diff --git a/pkg/sql/sqlbase/cascader.go b/pkg/sql/sqlbase/cascader.go index 1f237621c189..60bc0a6f2017 100644 --- a/pkg/sql/sqlbase/cascader.go +++ b/pkg/sql/sqlbase/cascader.go @@ -833,7 +833,7 @@ func (c *cascader) updateRows( } // Is there something to update? If not, skip it. - if !rowToUpdate.IsDistinctFrom(c.evalCtx, updateRow) { + if !rowToUpdate.Distinct(c.evalCtx, updateRow) { continue } @@ -1166,7 +1166,7 @@ func (c *cascader) cascadeAll( if _, exists := skipList[j]; exists { continue } - if !originalRows.At(j).IsDistinctFrom(c.evalCtx, finalRow) { + if !originalRows.At(j).Distinct(c.evalCtx, finalRow) { // The row has been updated again. finalRow = updatedRows.At(j) skipList[j] = struct{}{} diff --git a/pkg/sql/sqlbase/encoded_datum.go b/pkg/sql/sqlbase/encoded_datum.go index 24e05b7ecc41..dc1a6913a197 100644 --- a/pkg/sql/sqlbase/encoded_datum.go +++ b/pkg/sql/sqlbase/encoded_datum.go @@ -267,11 +267,11 @@ func (ed *EncDatum) Encode( } } -// Compare returns: -// -1 if the receiver is less than rhs, -// 0 if the receiver is equal to rhs, -// +1 if the receiver is greater than rhs. -func (ed *EncDatum) Compare( +// TotalOrderCompare returns: +// -1 if the receiver sorts before rhs, +// 0 if the receiver sorts at the same place as rhs, +// +1 if the receiver sorts after rhs. +func (ed *EncDatum) TotalOrderCompare( typ *ColumnType, a *DatumAlloc, evalCtx *tree.EvalContext, rhs *EncDatum, ) (int, error) { // TODO(radu): if we have both the Datum and a key encoding available, which @@ -290,7 +290,26 @@ func (ed *EncDatum) Compare( if err := rhs.EnsureDecoded(typ, a); err != nil { return 0, err } - return ed.Datum.Compare(evalCtx, rhs.Datum), nil + return tree.TotalOrderCompare(evalCtx, ed.Datum, rhs.Datum), nil +} + +// Distinct returns true if and only if the receiver is distinct +// from rhs. +func (ed *EncDatum) Distinct( + typ *ColumnType, a *DatumAlloc, evalCtx *tree.EvalContext, rhs *EncDatum, +) (bool, error) { + // TODO(radu): if we have both the Datum and a key encoding available, which + // one would be faster to use? + if ed.encoding == rhs.encoding && ed.encoded != nil && rhs.encoded != nil { + return !bytes.Equal(ed.encoded, rhs.encoded), nil + } + if err := ed.EnsureDecoded(typ, a); err != nil { + return false, err + } + if err := rhs.EnsureDecoded(typ, a); err != nil { + return false, err + } + return tree.Distinct(evalCtx, ed.Datum, rhs.Datum), nil } // GetInt decodes an EncDatum that is known to be of integer type and returns @@ -409,7 +428,34 @@ func EncDatumRowToDatums( return nil } -// Compare returns the relative ordering of two EncDatumRows according to a +// Distinct determines whether the two EncDatumRows are distinct +// according to a ColumnOrdering. +// +// Note that a return value of true does not (in general) imply that +// the rows are distinct (or the other way around); for example, rows +// [1 1 5] and [1 1 6] are non-distinct when compared against ordering +// {{0, asc}, {1, asc}} (i.e. ordered by first column and then by +// second column, but the 3rd column doesn't matter). +func (r EncDatumRow) Distinct( + types []ColumnType, + a *DatumAlloc, + ordering ColumnOrdering, + evalCtx *tree.EvalContext, + rhs EncDatumRow, +) (bool, error) { + if len(r) != len(types) || len(rhs) != len(types) { + panic(fmt.Sprintf("length mismatch: %d types, %d lhs, %d rhs\n%+v\n%+v\n%+v", len(types), len(r), len(rhs), types, r, rhs)) + } + for _, c := range ordering { + cmp, err := r[c.ColIdx].Distinct(&types[c.ColIdx], a, evalCtx, &rhs[c.ColIdx]) + if cmp || err != nil { + return cmp, err + } + } + return false, nil +} + +// TotalOrderCompare returns the relative ordering of two EncDatumRows according to a // ColumnOrdering: // -1 if the receiver comes before the rhs in the ordering, // +1 if the receiver comes after the rhs in the ordering, @@ -420,7 +466,7 @@ func EncDatumRowToDatums( // equal; for example, rows [1 1 5] and [1 1 6] when compared against ordering // {{0, asc}, {1, asc}} (i.e. ordered by first column and then by second // column). -func (r EncDatumRow) Compare( +func (r EncDatumRow) TotalOrderCompare( types []ColumnType, a *DatumAlloc, ordering ColumnOrdering, @@ -431,7 +477,7 @@ func (r EncDatumRow) Compare( panic(fmt.Sprintf("length mismatch: %d types, %d lhs, %d rhs\n%+v\n%+v\n%+v", len(types), len(r), len(rhs), types, r, rhs)) } for _, c := range ordering { - cmp, err := r[c.ColIdx].Compare(&types[c.ColIdx], a, evalCtx, &rhs[c.ColIdx]) + cmp, err := r[c.ColIdx].TotalOrderCompare(&types[c.ColIdx], a, evalCtx, &rhs[c.ColIdx]) if err != nil { return 0, err } @@ -445,27 +491,27 @@ func (r EncDatumRow) Compare( return 0, nil } -// CompareToDatums is a version of Compare which compares against decoded Datums. -func (r EncDatumRow) CompareToDatums( +// LessThanDatums is a specialization of TotalOrderCompare compares for +// strict order inequality against decoded Datums. +func (r EncDatumRow) LessThanDatums( types []ColumnType, a *DatumAlloc, ordering ColumnOrdering, evalCtx *tree.EvalContext, rhs tree.Datums, -) (int, error) { +) (bool, error) { for _, c := range ordering { if err := r[c.ColIdx].EnsureDecoded(&types[c.ColIdx], a); err != nil { - return 0, err + return false, err } - cmp := r[c.ColIdx].Datum.Compare(evalCtx, rhs[c.ColIdx]) - if cmp != 0 { + if cmp := tree.TotalOrderCompare(evalCtx, r[c.ColIdx].Datum, rhs[c.ColIdx]); cmp != 0 { if c.Direction == encoding.Descending { cmp = -cmp } - return cmp, nil + return cmp < 0, nil } } - return 0, nil + return false, nil } // EncDatumRows is a slice of EncDatumRows having the same schema. diff --git a/pkg/sql/sqlbase/encoded_datum_test.go b/pkg/sql/sqlbase/encoded_datum_test.go index 89fde9c719e7..7edd8b034118 100644 --- a/pkg/sql/sqlbase/encoded_datum_test.go +++ b/pkg/sql/sqlbase/encoded_datum_test.go @@ -81,8 +81,8 @@ func TestEncDatum(t *testing.T) { if err != nil { t.Fatal(err) } - if cmp := y.Datum.Compare(evalCtx, x.Datum); cmp != 0 { - t.Errorf("Datums should be equal, cmp = %d", cmp) + if tree.Distinct(evalCtx, x.Datum, y.Datum) { + t.Errorf("Datums should not be distinct: %v vs %v", x.Datum, y.Datum) } enc2, err := y.Encode(&typeInt, a, DatumEncoding_DESCENDING_KEY, nil) @@ -107,8 +107,8 @@ func TestEncDatum(t *testing.T) { if err != nil { t.Fatal(err) } - if cmp := y.Datum.Compare(evalCtx, z.Datum); cmp != 0 { - t.Errorf("Datums should be equal, cmp = %d", cmp) + if tree.Distinct(evalCtx, y.Datum, z.Datum) { + t.Errorf("Datums should not be distinct: %v vs %v", y.Datum, z.Datum) } y.UnsetDatum() if !y.IsUnset() { @@ -182,7 +182,7 @@ func checkEncDatumCmp( evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) defer evalCtx.Stop(context.Background()) - if val, err := dec1.Compare(&typ, a, evalCtx, &dec2); err != nil { + if val, err := dec1.TotalOrderCompare(&typ, a, evalCtx, &dec2); err != nil { t.Fatal(err) } else if val != expectedCmp { t.Errorf("comparing %s (%s), %s (%s) resulted in %d, expected %d", @@ -231,14 +231,14 @@ func TestEncDatumCompare(t *testing.T) { for { d1 = RandDatum(rng, typ, false) d2 = RandDatum(rng, typ, false) - if cmp := d1.Compare(evalCtx, d2); cmp < 0 { + if tree.TotalOrderLess(evalCtx, d1, d2) { break } } v1 := DatumToEncDatum(typ, d1) v2 := DatumToEncDatum(typ, d2) - if val, err := v1.Compare(&typ, a, evalCtx, &v2); err != nil { + if val, err := v1.TotalOrderCompare(&typ, a, evalCtx, &v2); err != nil { t.Fatal(err) } else if val != -1 { t.Errorf("compare(1, 2) = %d", val) @@ -318,8 +318,8 @@ func TestEncDatumFromBuffer(t *testing.T) { if err != nil { t.Fatal(err) } - if decoded.Datum.Compare(evalCtx, ed[i].Datum) != 0 { - t.Errorf("decoded datum %s doesn't equal original %s", decoded.Datum, ed[i].Datum) + if tree.Distinct(evalCtx, decoded.Datum, ed[i].Datum) { + t.Errorf("decoded datum %s is distinct from original %s", decoded.Datum, ed[i].Datum) } } if len(b) != 0 { @@ -433,7 +433,7 @@ func TestEncDatumRowCompare(t *testing.T) { for i := range types { types[i] = typeInt } - cmp, err := c.row1.Compare(types, a, c.ord, evalCtx, c.row2) + cmp, err := c.row1.TotalOrderCompare(types, a, c.ord, evalCtx, c.row2) if err != nil { t.Error(err) } else if cmp != c.cmp { @@ -479,8 +479,8 @@ func TestEncDatumRowAlloc(t *testing.T) { } for i := 0; i < rows; i++ { for j := 0; j < cols; j++ { - if a, b := in[i][j].Datum, out[i][j].Datum; a.Compare(evalCtx, b) != 0 { - t.Errorf("copied datum %s doesn't equal original %s", b, a) + if a, b := in[i][j].Datum, out[i][j].Datum; tree.Distinct(evalCtx, a, b) { + t.Errorf("copied datum %s distinct from original %s", b, a) } } } @@ -504,8 +504,9 @@ func TestValueEncodeDecodeTuple(t *testing.T) { tests[i] = RandDatum(rng, colTypes[i], true) } - for i, test := range tests { + evalCtx := &tree.EvalContext{} + for i, test := range tests { switch typedTest := test.(type) { case *tree.DTuple: @@ -527,8 +528,8 @@ func TestValueEncodeDecodeTuple(t *testing.T) { seed, test, colTypes[i], testTyp, len(buf)) } - if cmp := decodedTuple.Compare(&tree.EvalContext{}, test); cmp != 0 { - t.Fatalf("seed %d: encoded %+v, decoded %+v, expected equal, received comparison: %d", seed, test, decodedTuple, cmp) + if tree.Distinct(evalCtx, decodedTuple, test) { + t.Fatalf("seed %d: encoded %+v, decoded %+v are distinct", seed, test, decodedTuple) } default: if test == tree.DNull { diff --git a/pkg/sql/sqlbase/ordering.go b/pkg/sql/sqlbase/ordering.go index ceaded26dd4b..cfc667e34f70 100644 --- a/pkg/sql/sqlbase/ordering.go +++ b/pkg/sql/sqlbase/ordering.go @@ -46,22 +46,19 @@ func (a ColumnOrdering) IsPrefixOf(b ColumnOrdering) bool { return true } -// CompareDatums compares two datum rows according to a column ordering. Returns: -// - 0 if lhs and rhs are equal on the ordering columns; -// - less than 0 if lhs comes first; -// - greater than 0 if rhs comes first. -func CompareDatums(ordering ColumnOrdering, evalCtx *tree.EvalContext, lhs, rhs tree.Datums) int { +// LessDatums compares two datum rows according to a column +// ordering. Returns true if lhs sorts before rhs, false +// otherwise. The ordering is inverted if invertSort is true. +func LessDatums( + ordering ColumnOrdering, invertSort bool, evalCtx *tree.EvalContext, lhs, rhs tree.Datums, +) bool { for _, c := range ordering { - // TODO(pmattis): This is assuming that the datum types are compatible. I'm - // not sure this always holds as `CASE` expressions can return different - // types for a column for different rows. Investigate how other RDBMs - // handle this. - if cmp := lhs[c.ColIdx].Compare(evalCtx, rhs[c.ColIdx]); cmp != 0 { - if c.Direction == encoding.Descending { + if cmp := tree.TotalOrderCompare(evalCtx, lhs[c.ColIdx], rhs[c.ColIdx]); cmp != 0 { + if !invertSort == (c.Direction == encoding.Descending) { cmp = -cmp } - return cmp + return cmp < 0 } } - return 0 + return false } diff --git a/pkg/sql/sqlbase/rowfetcher.go b/pkg/sql/sqlbase/rowfetcher.go index 5ec5e3784a9e..3c38a1f843a9 100644 --- a/pkg/sql/sqlbase/rowfetcher.go +++ b/pkg/sql/sqlbase/rowfetcher.go @@ -1196,7 +1196,9 @@ func (rf *RowFetcher) checkKeyOrdering(ctx context.Context) error { evalCtx := tree.EvalContext{} for i, id := range rf.rowReadyTable.index.ColumnIDs { idx := rf.rowReadyTable.colIdxMap[id] - result := rf.rowReadyTable.decodedRow[idx].Compare(&evalCtx, rf.rowReadyTable.lastDatums[idx]) + result := tree.TotalOrderCompare(&evalCtx, + rf.rowReadyTable.decodedRow[idx], + rf.rowReadyTable.lastDatums[idx]) expectedDirection := rf.rowReadyTable.index.ColumnDirections[i] if rf.reverse && expectedDirection == IndexDescriptor_ASC { expectedDirection = IndexDescriptor_DESC diff --git a/pkg/sql/sqlbase/table_test.go b/pkg/sql/sqlbase/table_test.go index 5986988a459b..770af0ded746 100644 --- a/pkg/sql/sqlbase/table_test.go +++ b/pkg/sql/sqlbase/table_test.go @@ -234,7 +234,7 @@ func TestIndexKey(t *testing.T) { for j, value := range values { testValue := testValues[colMap[index.ColumnIDs[j]]] - if value.Compare(evalCtx, testValue) != 0 { + if tree.Distinct(evalCtx, value, testValue) { t.Fatalf("%d: value %d got %q but expected %q", i, j, value, testValue) } } @@ -365,7 +365,7 @@ func TestArrayEncoding(t *testing.T) { t.Fatal(err) } evalContext := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) - if d.Compare(evalContext, &test.datum) != 0 { + if tree.Distinct(evalContext, d, &test.datum) { t.Fatalf("expected %v to decode to %s, got %s", enc, test.datum.String(), d.String()) } }) @@ -1463,6 +1463,7 @@ func TestAdjustEndKeyForInterleave(t *testing.T) { func TestDecodeTableValue(t *testing.T) { a := &DatumAlloc{} + evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) for _, tc := range []struct { in tree.Datum typ types.T @@ -1489,7 +1490,7 @@ func TestDecodeTableValue(t *testing.T) { } else if err != nil { return } - if tc.in.Compare(tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()), d) != 0 { + if tree.Distinct(evalCtx, tc.in, d) { t.Fatalf("decoded datum %[1]v (%[1]T) does not match encoded datum %[2]v (%[2]T)", d, tc.in) } }) diff --git a/pkg/sql/stats/histogram.go b/pkg/sql/stats/histogram.go index ffabd1c25c33..779386a6140b 100644 --- a/pkg/sql/stats/histogram.go +++ b/pkg/sql/stats/histogram.go @@ -44,7 +44,7 @@ func EquiDepthHistogram( } } sort.Slice(samples, func(i, j int) bool { - return samples[i].Compare(evalCtx, samples[j]) < 0 + return tree.TotalOrderLess(evalCtx, samples[i], samples[j]) }) numBuckets := maxBuckets if maxBuckets > numSamples { @@ -69,15 +69,15 @@ func EquiDepthHistogram( // numLess is the number of samples less than upper (in this bucket). numLess := 0 for ; numLess < num-1; numLess++ { - if c := samples[i+numLess].Compare(evalCtx, upper); c == 0 { + if c := tree.TotalOrderCompare(evalCtx, samples[i+numLess], upper); c == 0 { break } else if c > 0 { panic("samples not sorted") } } - // Advance the boundary of the bucket to cover all samples equal to upper. + // Advance the boundary of the bucket to cover all samples order-equal to upper. for ; i+num < numSamples; num++ { - if samples[i+num].Compare(evalCtx, upper) != 0 { + if tree.Distinct(evalCtx, samples[i+num], upper) { break } } diff --git a/pkg/sql/window.go b/pkg/sql/window.go index 1abfcb6e9636..7ed619fbff2d 100644 --- a/pkg/sql/window.go +++ b/pkg/sql/window.go @@ -559,27 +559,36 @@ type partitionSorter struct { } // partitionSorter implements the sort.Interface interface. -func (n *partitionSorter) Len() int { return len(n.rows) } -func (n *partitionSorter) Swap(i, j int) { n.rows[i], n.rows[j] = n.rows[j], n.rows[i] } -func (n *partitionSorter) Less(i, j int) bool { return n.Compare(i, j) < 0 } +func (n *partitionSorter) Len() int { return len(n.rows) } +func (n *partitionSorter) Swap(i, j int) { n.rows[i], n.rows[j] = n.rows[j], n.rows[i] } +func (n *partitionSorter) Less(i, j int) bool { + ra, rb := n.rows[i], n.rows[j] + defa, defb := n.windowDefVals.At(ra.Idx), n.windowDefVals.At(rb.Idx) + for _, o := range n.ordering { + da := defa[o.ColIdx] + db := defb[o.ColIdx] + if cmp := tree.TotalOrderCompare(n.evalCtx, da, db); cmp != 0 { + if o.Direction == encoding.Descending { + cmp = -cmp + } + return cmp < 0 + } + } + return false +} // partitionSorter implements the peerGroupChecker interface. -func (n *partitionSorter) InSameGroup(i, j int) bool { return n.Compare(i, j) == 0 } - -func (n *partitionSorter) Compare(i, j int) int { +func (n *partitionSorter) InSameGroup(i, j int) bool { ra, rb := n.rows[i], n.rows[j] defa, defb := n.windowDefVals.At(ra.Idx), n.windowDefVals.At(rb.Idx) for _, o := range n.ordering { da := defa[o.ColIdx] db := defb[o.ColIdx] - if c := da.Compare(n.evalCtx, db); c != 0 { - if o.Direction != encoding.Ascending { - return -c - } - return c + if tree.Distinct(n.evalCtx, da, db) { + return false } } - return 0 + return true } type allPeers struct{}