Skip to content

Commit

Permalink
Audits eval package (#1709)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnedquinn authored Jan 13, 2025
1 parent 5ecd7bd commit a7b772e
Show file tree
Hide file tree
Showing 17 changed files with 91 additions and 58 deletions.
4 changes: 3 additions & 1 deletion partiql-eval/api/partiql-eval.api
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ public class org/partiql/eval/Mode {
}

public class org/partiql/eval/Row {
public final field values [Lorg/partiql/spi/value/Datum;
public fun <init> ()V
public fun <init> ([Lorg/partiql/spi/value/Datum;)V
public fun concat (Lorg/partiql/eval/Row;)Lorg/partiql/eval/Row;
public fun equals (Ljava/lang/Object;)Z
public fun get (I)Lorg/partiql/spi/value/Datum;
public fun getSize ()I
public fun hashCode ()I
public static fun of ([Lorg/partiql/spi/value/Datum;)Lorg/partiql/eval/Row;
public fun set (ILorg/partiql/spi/value/Datum;)V
public fun toString ()Ljava/lang/String;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public Environment push(Row row) {
*/
public Datum get(int depth, int offset) {
try {
return stack[depth].values[offset];
return stack[depth].get(offset);
} catch (IndexOutOfBoundsException ex) {
throw new RuntimeException("Invalid variable reference [$depth:$offset]\n$this");
}
Expand Down
26 changes: 22 additions & 4 deletions partiql-eval/src/main/java/org/partiql/eval/Row.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,32 @@
*/
public class Row {

private final Datum[] values;

/**
* TODO internalize values.
* @param index the requested index
* @return the value at the given index
*/
public final Datum[] values;
public Datum get(int index) {
return values[index];
}

/**
* @param index the requested index
* @param value the value to insert
*/
public void set(int index, Datum value) {
values[index] = value;
}

/**
* @return the number of values in the record
*/
public int getSize() {
return values.length;
}

/**
* TODO keep ??
*
* @param values the values
* @return the record
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import org.partiql.plan.Operand;

/**
* Match represents a subtree match to be sent to the
* Match represents a subtree match to be sent to the {@link Strategy}.
*/
public class Match {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ default Statement prepare(@NotNull Plan plan, @NotNull Mode mode) {
return prepare(plan, mode, Context.standard());
}

/**
* Prepares the given plan into an executable PartiQL statement.
*
* @param plan The plan to compile.
* @param mode The mode to execute in.
* @param ctx The shared context object.
* @return The prepared statement.
*/
@NotNull
public Statement prepare(@NotNull Plan plan, @NotNull Mode mode, @NotNull Context ctx);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ protected Pattern(@NotNull Class<? extends Operator> clazz, @Nullable Predicate<
this.predicate = predicate;
}

/**
* @param operator the operator to match against.
* @return whether the operator matches the pattern.
*/
public boolean matches(Operator operator) {
if (!clazz.isInstance(operator)) {
return false;
}
return predicate == null || predicate.test(operator);
}

}
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
package org.partiql.eval.internal.helpers

import org.partiql.eval.Row
import org.partiql.spi.value.Datum

internal object DatumArrayComparator : Comparator<Array<Datum>> {
internal object DatumArrayComparator : Comparator<Row> {
private val delegate = Datum.comparator(false)
override fun compare(o1: Array<Datum>, o2: Array<Datum>): Int {
if (o1.size < o2.size) {
override fun compare(o1: Row, o2: Row): Int {
val o1Size = o1.size
val o2Size = o2.size
if (o1Size < o2Size) {
return -1
}
if (o1.size > o2.size) {
if (o1Size > o2Size) {
return 1
}
for (index in 0..o2.lastIndex) {
for (index in 0 until o2.size) {
val element1 = o1[index]
val element2 = o2[index]
val compared = delegate.compare(element1, element2)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.partiql.eval.internal.helpers

import org.partiql.eval.Row
import org.partiql.spi.value.Datum

internal object RecordUtility {
Expand All @@ -8,8 +9,8 @@ internal object RecordUtility {
* (treats null and missing as the same value) and we need to deterministically return a value. Here we use coerce
* to null to follow the PartiQL spec's grouping function.
*/
fun Array<Datum>.coerceMissing() {
for (i in indices) {
fun Row.coerceMissing() {
for (i in 0..getSize()) {
if (this[i].isMissing) {
this[i] = Datum.nullValue()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ internal class RelOpAggregate(

private lateinit var records: Iterator<Row>

private val aggregationMap = TreeMap<Array<Datum>, List<AccumulatorWrapper>>(DatumArrayComparator)
private val aggregationMap = TreeMap<Row, List<AccumulatorWrapper>>(DatumArrayComparator)

/**
* Wraps an [Aggregation.Accumulator] to help with filtering distinct values.
Expand All @@ -30,7 +30,7 @@ internal class RelOpAggregate(
class AccumulatorWrapper(
val delegate: Aggregation.Accumulator,
val args: List<ExprValue>,
val seen: TreeSet<Array<Datum>>?
val seen: TreeSet<Row>?
)

override fun open(env: Environment) {
Expand All @@ -50,7 +50,7 @@ internal class RelOpAggregate(
// TODO IT DOES NOT MATTER NOW, BUT SqlCompiler SHOULD HANDLE GET THE ARGUMENT TYPES FOR .getAccumulator
val args: Array<PType> = emptyArray()

val accumulators = aggregationMap.getOrPut(evaluatedGroupByKeys) {
val accumulators = aggregationMap.getOrPut(Row(evaluatedGroupByKeys)) {
aggregates.map {
AccumulatorWrapper(
delegate = it.agg.getAccumulator(args),
Expand All @@ -71,7 +71,7 @@ internal class RelOpAggregate(
argument
}
// Skip over aggregation if DISTINCT and SEEN
if (function.seen != null && (function.seen.add(arguments).not())) {
if (function.seen != null && (function.seen.add(Row(arguments)).not())) {
return@forEachIndexed
}
accumulators[index].delegate.next(arguments)
Expand All @@ -93,8 +93,8 @@ internal class RelOpAggregate(

records = iterator {
aggregationMap.forEach { (keysEvaluated, accumulators) ->
val recordValues = accumulators.map { acc -> acc.delegate.value() } + keysEvaluated
yield(Row(recordValues.toTypedArray()))
val recordValues = Row(accumulators.map { acc -> acc.delegate.value() }.toTypedArray()).concat(keysEvaluated)
yield(recordValues)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ internal class RelOpDistinct(private val input: ExprRelation) : RelOpPeeking() {

override fun peek(): Row? {
for (next in input) {
val transformed = Array(next.values.size) { next.values[it] }
val transformed = next
if (seen.contains(transformed).not()) {
seen.add(transformed)
return next
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@ import org.partiql.eval.ExprRelation
import org.partiql.eval.Row
import org.partiql.eval.internal.helpers.DatumArrayComparator
import org.partiql.eval.internal.helpers.RecordUtility.coerceMissing
import org.partiql.spi.value.Datum
import java.util.TreeMap

internal class RelOpExceptAll(
private val lhs: ExprRelation,
private val rhs: ExprRelation,
) : RelOpPeeking() {

private val seen = TreeMap<Array<Datum>, Int>(DatumArrayComparator)
private val seen = TreeMap<Row, Int>(DatumArrayComparator)
private var init: Boolean = false

override fun openPeeking(env: Environment) {
Expand All @@ -28,13 +27,13 @@ internal class RelOpExceptAll(
seed()
}
for (row in lhs) {
row.values.coerceMissing()
val remaining = seen[row.values] ?: 0
row.coerceMissing()
val remaining = seen[row] ?: 0
if (remaining > 0) {
seen[row.values] = remaining - 1
seen[row] = remaining - 1
continue
}
return Row(row.values)
return row
}
return null
}
Expand All @@ -51,9 +50,9 @@ internal class RelOpExceptAll(
private fun seed() {
init = true
for (row in rhs) {
row.values.coerceMissing()
val n = seen[row.values] ?: 0
seen[row.values] = n + 1
row.coerceMissing()
val n = seen[row] ?: 0
seen[row] = n + 1
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ internal class RelOpExceptDistinct(
seed()
}
for (row in lhs) {
row.values.coerceMissing()
if (!seen.contains(row.values)) {
return Row(row.values)
row.coerceMissing()
if (!seen.contains(row)) {
return row
}
}
return null
Expand All @@ -52,8 +52,8 @@ internal class RelOpExceptDistinct(
private fun seed() {
init = true
for (row in rhs) {
row.values.coerceMissing()
seen.add(row.values)
row.coerceMissing()
seen.add(row)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ internal class RelOpExclude(
exclusions.forEach { exclusion ->
// TODO memoize offsets and steps (i.e. don't call getVar(), getOffset(), and getItems() every time).
val o = exclusion.getVar().getOffset()
val value = record.values[o]
record.values[o] = value.exclude(exclusion.getItems())
val value = record[o]
record[o] = value.exclude(exclusion.getItems())
}
return record
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@ import org.partiql.eval.ExprRelation
import org.partiql.eval.Row
import org.partiql.eval.internal.helpers.DatumArrayComparator
import org.partiql.eval.internal.helpers.RecordUtility.coerceMissing
import org.partiql.spi.value.Datum
import java.util.TreeMap

internal class RelOpIntersectAll(
private val lhs: ExprRelation,
private val rhs: ExprRelation,
) : RelOpPeeking() {

private val seen = TreeMap<Array<Datum>, Int>(DatumArrayComparator)
private val seen = TreeMap<Row, Int>(DatumArrayComparator)
private var init: Boolean = false

override fun openPeeking(env: Environment) {
Expand All @@ -28,11 +27,11 @@ internal class RelOpIntersectAll(
seed()
}
for (row in rhs) {
row.values.coerceMissing()
val remaining = seen[row.values] ?: 0
row.coerceMissing()
val remaining = seen[row] ?: 0
if (remaining > 0) {
seen[row.values] = remaining - 1
return Row(row.values)
seen[row] = remaining - 1
return row
}
}
return null
Expand All @@ -50,9 +49,9 @@ internal class RelOpIntersectAll(
private fun seed() {
init = true
for (row in lhs) {
row.values.coerceMissing()
val alreadySeen = seen[row.values] ?: 0
seen[row.values] = alreadySeen + 1
row.coerceMissing()
val alreadySeen = seen[row] ?: 0
seen[row] = alreadySeen + 1
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ internal class RelOpIntersectDistinct(
seed()
}
for (row in rhs) {
row.values.coerceMissing()
if (seen.remove(row.values)) {
return Row(row.values)
row.coerceMissing()
if (seen.remove(row)) {
return row
}
}
return null
Expand All @@ -47,8 +47,8 @@ internal class RelOpIntersectDistinct(
private fun seed() {
init = true
for (row in lhs) {
row.values.coerceMissing()
seen.add(row.values)
row.coerceMissing()
seen.add(row)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ internal class RelOpUnionAll(
return when (lhs.hasNext()) {
true -> {
val record = lhs.next()
record.values.coerceMissing()
record.coerceMissing()
record
}
false -> {
val record = rhs.next()
record.values.coerceMissing()
record.coerceMissing()
record
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ internal class RelOpUnionDistinct(

override fun peek(): Row? {
for (record in input) {
record.values.coerceMissing()
if (!seen.contains(record.values)) {
seen.add(record.values)
return Row(record.values)
record.coerceMissing()
if (!seen.contains(record)) {
seen.add(record)
return record
}
}
return null
Expand Down

0 comments on commit a7b772e

Please sign in to comment.