Skip to content

Commit

Permalink
feat: Sum and Average aggregations (#1387)
Browse files Browse the repository at this point in the history
* feat: SUM / AVG (#1218)

* feat: SUM / AVG (real files)

* Remove aggregate field duplicates (if any).

* Clean up and fixes.

* Clean up comments, and add Nonnull where possible.

* Add more public docs.

* More cleanup.

* Update hashCode and equals for AggregateQuery.

* Address code review comments. more to come.

* fix test name.

* Better comment.

* Fix alias encoding.

* Remove TODO.

* Revert the way alias is constructed.

* Backport test updates.

fix format.

fix import stmt.

* feat: Add long alias support for aggregations. (#1267)

* feat: Add long alias support for aggregations.

* address comments.

* Better method name and replace hardcoded "count" with "aggregate_0".

* Remove duplicate aggregations.

* add static import.

* Fix tests. All tests pass.

* Address comments.

* Improve the documentation to match strongly typed languages.

* Do not use wildcard import.

* Do not use wildcard import (2).

* Do not use wildcard import (3).

* Do not use wildcard import (4).

* Fix the javadoc.

* Add license header, and remove unused test code.

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* better regex.

* Add more tests for cursors.

---------

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
ehsannas and gcf-owl-bot[bot] authored Oct 9, 2023
1 parent 5c6dcde commit afa5c01
Show file tree
Hide file tree
Showing 13 changed files with 1,833 additions and 153 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.firestore;

import java.util.Objects;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

/** Represents an aggregation that can be performed by Firestore. */
public abstract class AggregateField {
/**
* Create a {@link CountAggregateField} object that can be used to compute the count of documents
* in the result set of a query.
*
* <p>The result of a count operation will always be a 64-bit integer value.
*
* @return The `CountAggregateField` object that can be used to compute the count of documents in
* the result set of a query.
*/
@Nonnull
public static CountAggregateField count() {
return new CountAggregateField();
}

/**
* Create a {@link SumAggregateField} object that can be used to compute the sum of a specified
* field over a range of documents in the result set of a query.
*
* <p>The result of a sum operation will always be a 64-bit integer value, a double, or NaN.
*
* <ul>
* <li>Summing over zero documents or fields will result in 0L.
* <li>Summing over NaN will result in a double value representing NaN.
* <li>A sum that overflows the maximum representable 64-bit integer value will result in a
* double return value. This may result in lost precision of the result.
* <li>A sum that overflows the maximum representable double value will result in a double
* return value representing infinity.
* </ul>
*
* @param field Specifies the field to sum across the result set.
* @return The `SumAggregateField` object that can be used to compute the sum of a specified field
* over a range of documents in the result set of a query.
*/
@Nonnull
public static SumAggregateField sum(@Nonnull String field) {
return new SumAggregateField(FieldPath.fromDotSeparatedString(field));
}

/**
* Create a {@link SumAggregateField} object that can be used to compute the sum of a specified
* field over a range of documents in the result set of a query.
*
* <p>The result of a sum operation will always be a 64-bit integer value, a double, or NaN.
*
* <ul>
* <li>Summing over zero documents or fields will result in 0L.
* <li>Summing over NaN will result in a double value representing NaN.
* <li>A sum that overflows the maximum representable 64-bit integer value will result in a
* double return value. This may result in lost precision of the result.
* <li>A sum that overflows the maximum representable double value will result in a double
* return value representing infinity.
* </ul>
*
* @param fieldPath Specifies the field to sum across the result set.
* @return The `SumAggregateField` object that can be used to compute the sum of a specified field
* over a range of documents in the result set of a query.
*/
@Nonnull
public static SumAggregateField sum(@Nonnull FieldPath fieldPath) {
return new SumAggregateField(fieldPath);
}

/**
* Create an {@link AverageAggregateField} object that can be used to compute the average of a
* specified field over a range of documents in the result set of a query.
*
* <p>The result of an average operation will always be a double or NaN.
*
* <ul>
* <li>Averaging over zero documents or fields will result in a double value representing NaN.
* <li>Averaging over NaN will result in a double value representing NaN.
* </ul>
*
* @param field Specifies the field to average across the result set.
* @return The `AverageAggregateField` object that can be used to compute the average of a
* specified field over a range of documents in the result set of a query.
*/
@Nonnull
public static AverageAggregateField average(@Nonnull String field) {
return new AverageAggregateField(FieldPath.fromDotSeparatedString(field));
}

/**
* Create an {@link AverageAggregateField} object that can be used to compute the average of a
* specified field over a range of documents in the result set of a query.
*
* <p>The result of an average operation will always be a double or NaN.
*
* <ul>
* <li>Averaging over zero documents or fields will result in a double value representing NaN.
* <li>Averaging over NaN will result in a double value representing NaN.
* </ul>
*
* @param fieldPath Specifies the field to average across the result set.
* @return The `AverageAggregateField` object that can be used to compute the average of a
* specified field over a range of documents in the result set of a query.
*/
@Nonnull
public static AverageAggregateField average(@Nonnull FieldPath fieldPath) {
return new AverageAggregateField(fieldPath);
}

/** The field over which the aggregation is performed. */
@Nullable FieldPath fieldPath;

/** Returns the alias used internally for this aggregate field. */
@Nonnull
String getAlias() {
// Use $operator_$field format if it's an aggregation of a specific field. For example: sum_foo.
// Use $operator format if there's no field. For example: count.
return getOperator() + (fieldPath == null ? "" : "_" + fieldPath.getEncodedPath());
}

/**
* Returns the field on which the aggregation takes place. Returns an empty string if there's no
* field (e.g. for count).
*/
@Nonnull
String getFieldPath() {
return fieldPath == null ? "" : fieldPath.getEncodedPath();
}

/** Returns a string representation of this aggregation's operator. For example: "sum" */
abstract @Nonnull String getOperator();

/**
* Returns true if the given object is equal to this object. Two `AggregateField` objects are
* considered equal if they have the same operator and operate on the same field.
*/
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof AggregateField)) {
return false;
}
AggregateField otherAggregateField = (AggregateField) other;
return getOperator().equals(otherAggregateField.getOperator())
&& getFieldPath().equals(otherAggregateField.getFieldPath());
}

/** Calculates and returns the hash code for this object. */
@Override
public int hashCode() {
return Objects.hash(getOperator(), getFieldPath());
}

/** Represents a "sum" aggregation that can be performed by Firestore. */
public static class SumAggregateField extends AggregateField {
private SumAggregateField(@Nonnull FieldPath field) {
fieldPath = field;
}

@Override
@Nonnull
public String getOperator() {
return "sum";
}
}

/** Represents an "average" aggregation that can be performed by Firestore. */
public static class AverageAggregateField extends AggregateField {
private AverageAggregateField(@Nonnull FieldPath field) {
fieldPath = field;
}

@Override
@Nonnull
public String getOperator() {
return "average";
}
}

/** Represents a "count" aggregation that can be performed by Firestore. */
public static class CountAggregateField extends AggregateField {
private CountAggregateField() {}

@Override
@Nonnull
public String getOperator() {
return "count";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,16 @@
import com.google.firestore.v1.RunAggregationQueryResponse;
import com.google.firestore.v1.RunQueryRequest;
import com.google.firestore.v1.StructuredAggregationQuery;
import com.google.firestore.v1.StructuredAggregationQuery.Aggregation;
import com.google.firestore.v1.StructuredQuery;
import com.google.firestore.v1.Value;
import com.google.protobuf.ByteString;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nonnull;
Expand All @@ -39,18 +47,16 @@
/** A query that calculates aggregations over an underlying query. */
@InternalExtensionOnly
public class AggregateQuery {
@Nonnull private final Query query;

/**
* The "alias" to specify in the {@link RunAggregationQueryRequest} proto when running a count
* query. The actual value is not meaningful, but will be used to get the count out of the {@link
* RunAggregationQueryResponse}.
*/
private static final String ALIAS_COUNT = "count";
@Nonnull private List<AggregateField> aggregateFieldList;

@Nonnull private final Query query;
@Nonnull private Map<String, String> aliasMap;

AggregateQuery(@Nonnull Query query) {
AggregateQuery(@Nonnull Query query, @Nonnull List<AggregateField> aggregateFields) {
this.query = query;
this.aggregateFieldList = aggregateFields;
this.aliasMap = new HashMap<>();
}

/** Returns the query whose aggregations will be calculated by this object. */
Expand Down Expand Up @@ -112,9 +118,11 @@ long getStartTimeNanos() {
return startTimeNanos;
}

void deliverResult(long count, Timestamp readTime) {
void deliverResult(@Nonnull Map<String, Value> data, Timestamp readTime) {
if (isFutureCompleted.compareAndSet(false, true)) {
future.set(new AggregateQuerySnapshot(AggregateQuery.this, readTime, count));
Map<String, Value> mappedData = new HashMap<>();
data.forEach((serverAlias, value) -> mappedData.put(aliasMap.get(serverAlias), value));
future.set(new AggregateQuerySnapshot(AggregateQuery.this, readTime, mappedData));
}
}

Expand Down Expand Up @@ -145,26 +153,13 @@ public void onResponse(RunAggregationQueryResponse response) {
// Close the stream to avoid it dangling, since we're not expecting any more responses.
streamController.cancel();

// Extract the count and read time from the RunAggregationQueryResponse.
// Extract the aggregations and read time from the RunAggregationQueryResponse.
Timestamp readTime = Timestamp.fromProto(response.getReadTime());
Value value = response.getResult().getAggregateFieldsMap().get(ALIAS_COUNT);
if (value == null) {
throw new IllegalArgumentException(
"RunAggregationQueryResponse is missing required alias: " + ALIAS_COUNT);
} else if (value.getValueTypeCase() != Value.ValueTypeCase.INTEGER_VALUE) {
throw new IllegalArgumentException(
"RunAggregationQueryResponse alias "
+ ALIAS_COUNT
+ " has incorrect type: "
+ value.getValueTypeCase());
}
long count = value.getIntegerValue();

// Deliver the result; even though the `RunAggregationQuery` RPC is a "streaming" RPC, meaning
// that `onResponse()` can be called multiple times, it _should_ only be called once for count
// queries. But even if it is called more than once, `responseDeliverer` will drop superfluous
// results.
responseDeliverer.deliverResult(count, readTime);
// that `onResponse()` can be called multiple times, it _should_ only be called once. But even
// if it is called more than once, `responseDeliverer` will drop superfluous results.
responseDeliverer.deliverResult(response.getResult().getAggregateFieldsMap(), readTime);
}

@Override
Expand Down Expand Up @@ -215,12 +210,45 @@ RunAggregationQueryRequest toProto(@Nullable final ByteString transactionId) {
request.getStructuredAggregationQueryBuilder();
structuredAggregationQuery.setStructuredQuery(runQueryRequest.getStructuredQuery());

StructuredAggregationQuery.Aggregation.Builder aggregation =
StructuredAggregationQuery.Aggregation.newBuilder();
aggregation.setCount(StructuredAggregationQuery.Aggregation.Count.getDefaultInstance());
aggregation.setAlias(ALIAS_COUNT);
structuredAggregationQuery.addAggregations(aggregation);
// We use this set to remove duplicate aggregates. e.g. `aggregate(sum("foo"), sum("foo"))`
HashSet<String> uniqueAggregates = new HashSet<>();
List<StructuredAggregationQuery.Aggregation> aggregations = new ArrayList<>();
int aggregationNum = 0;
for (AggregateField aggregateField : aggregateFieldList) {
// `getAlias()` provides a unique representation of an AggregateField.
boolean isNewAggregateField = uniqueAggregates.add(aggregateField.getAlias());
if (!isNewAggregateField) {
// This is a duplicate AggregateField. We don't need to include it in the request.
continue;
}

// If there's a field for this aggregation, build its proto.
StructuredQuery.FieldReference field = null;
if (!aggregateField.getFieldPath().isEmpty()) {
field =
StructuredQuery.FieldReference.newBuilder()
.setFieldPath(aggregateField.getFieldPath())
.build();
}
// Build the aggregation proto.
Aggregation.Builder aggregation = Aggregation.newBuilder();
if (aggregateField instanceof AggregateField.CountAggregateField) {
aggregation.setCount(Aggregation.Count.getDefaultInstance());
} else if (aggregateField instanceof AggregateField.SumAggregateField) {
aggregation.setSum(Aggregation.Sum.newBuilder().setField(field).build());
} else if (aggregateField instanceof AggregateField.AverageAggregateField) {
aggregation.setAvg(Aggregation.Avg.newBuilder().setField(field).build());
} else {
throw new RuntimeException("Unsupported aggregation");
}
// Map all client-side aliases to a unique short-form alias.
// This avoids issues with client-side aliases that exceed the 1500-byte string size limit.
String serverAlias = "aggregate_" + aggregationNum++;
aliasMap.put(serverAlias, aggregateField.getAlias());
aggregation.setAlias(serverAlias);
aggregations.add(aggregation.build());
}
structuredAggregationQuery.addAllAggregations(aggregations);
return request.build();
}

Expand All @@ -243,7 +271,23 @@ public static AggregateQuery fromProto(Firestore firestore, RunAggregationQueryR
.setStructuredQuery(proto.getStructuredAggregationQuery().getStructuredQuery())
.build();
Query query = Query.fromProto(firestore, runQueryRequest);
return new AggregateQuery(query);

List<AggregateField> aggregateFields = new ArrayList<>();
List<Aggregation> aggregations = proto.getStructuredAggregationQuery().getAggregationsList();
aggregations.forEach(
aggregation -> {
if (aggregation.hasCount()) {
aggregateFields.add(AggregateField.count());
} else if (aggregation.hasAvg()) {
aggregateFields.add(
AggregateField.average(aggregation.getAvg().getField().getFieldPath()));
} else if (aggregation.hasSum()) {
aggregateFields.add(AggregateField.sum(aggregation.getSum().getField().getFieldPath()));
} else {
throw new RuntimeException("Unsupported aggregation.");
}
});
return new AggregateQuery(query, aggregateFields);
}

/**
Expand All @@ -253,7 +297,7 @@ public static AggregateQuery fromProto(Firestore firestore, RunAggregationQueryR
*/
@Override
public int hashCode() {
return query.hashCode();
return Objects.hash(query, aggregateFieldList);
}

/**
Expand All @@ -280,6 +324,6 @@ public boolean equals(Object object) {
return false;
}
AggregateQuery other = (AggregateQuery) object;
return query.equals(other.query);
return query.equals(other.query) && aggregateFieldList.equals(other.aggregateFieldList);
}
}
Loading

0 comments on commit afa5c01

Please sign in to comment.