diff --git a/pinot-connectors/prestodb-pinot-dependencies/pinot-spi-jdk8/pom.xml b/pinot-connectors/prestodb-pinot-dependencies/pinot-spi-jdk8/pom.xml index 6ac4e5514abb..5d29caec14b5 100644 --- a/pinot-connectors/prestodb-pinot-dependencies/pinot-spi-jdk8/pom.xml +++ b/pinot-connectors/prestodb-pinot-dependencies/pinot-spi-jdk8/pom.xml @@ -129,5 +129,10 @@ org.reflections reflections + + com.clearspring.analytics + stream + 2.7.0 + diff --git a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/DistinctCountHLLStarTreeV2Test.java b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/DistinctCountHLLStarTreeV2Test.java index d903f505b413..9bc103ffd247 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/DistinctCountHLLStarTreeV2Test.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/DistinctCountHLLStarTreeV2Test.java @@ -19,6 +19,7 @@ package org.apache.pinot.core.startree.v2; import com.clearspring.analytics.stream.cardinality.HyperLogLog; +import java.util.Collections; import java.util.Random; import org.apache.pinot.segment.local.aggregator.DistinctCountHLLValueAggregator; import org.apache.pinot.segment.local.aggregator.ValueAggregator; @@ -31,7 +32,7 @@ public class DistinctCountHLLStarTreeV2Test extends BaseStarTreeV2Test getValueAggregator() { - return new DistinctCountHLLValueAggregator(); + return new DistinctCountHLLValueAggregator(Collections.emptyList()); } @Override diff --git a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/PreAggregatedDistinctCountHLLStarTreeV2Test.java b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/PreAggregatedDistinctCountHLLStarTreeV2Test.java index 354497b9cb75..531348c77157 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/PreAggregatedDistinctCountHLLStarTreeV2Test.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/PreAggregatedDistinctCountHLLStarTreeV2Test.java @@ -19,6 +19,7 @@ package org.apache.pinot.core.startree.v2; import com.clearspring.analytics.stream.cardinality.HyperLogLog; +import java.util.Collections; import java.util.Random; import org.apache.pinot.core.common.ObjectSerDeUtils; import org.apache.pinot.segment.local.aggregator.DistinctCountHLLValueAggregator; @@ -34,7 +35,7 @@ public class PreAggregatedDistinctCountHLLStarTreeV2Test extends BaseStarTreeV2T @Override ValueAggregator getValueAggregator() { - return new DistinctCountHLLValueAggregator(); + return new DistinctCountHLLValueAggregator(Collections.emptyList()); } @Override diff --git a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/SumPrecisionStarTreeV2Test.java b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/SumPrecisionStarTreeV2Test.java index f99aa3df1f0f..87c74aa59152 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/SumPrecisionStarTreeV2Test.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/SumPrecisionStarTreeV2Test.java @@ -19,6 +19,7 @@ package org.apache.pinot.core.startree.v2; import java.math.BigDecimal; +import java.util.Collections; import java.util.Random; import org.apache.pinot.segment.local.aggregator.SumPrecisionValueAggregator; import org.apache.pinot.segment.local.aggregator.ValueAggregator; @@ -31,7 +32,7 @@ public class SumPrecisionStarTreeV2Test extends BaseStarTreeV2Test getValueAggregator() { - return new SumPrecisionValueAggregator(); + return new SumPrecisionValueAggregator(Collections.emptyList()); } @Override diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/DistinctCountHLLValueAggregator.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/DistinctCountHLLValueAggregator.java index 2cf26a3ac14f..ada9d5a170cc 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/DistinctCountHLLValueAggregator.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/DistinctCountHLLValueAggregator.java @@ -20,7 +20,10 @@ import com.clearspring.analytics.stream.cardinality.CardinalityMergeException; import com.clearspring.analytics.stream.cardinality.HyperLogLog; +import java.util.List; +import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.segment.local.utils.CustomSerDeUtils; +import org.apache.pinot.segment.local.utils.HyperLogLogUtils; import org.apache.pinot.segment.spi.AggregationFunctionType; import org.apache.pinot.spi.data.FieldSpec.DataType; import org.apache.pinot.spi.utils.CommonConstants; @@ -28,11 +31,21 @@ public class DistinctCountHLLValueAggregator implements ValueAggregator { public static final DataType AGGREGATED_VALUE_TYPE = DataType.BYTES; - private static final int DEFAULT_LOG2M_BYTE_SIZE = 180; + + private final int _log2m; // Byte size won't change once we get the initial aggregated value private int _maxByteSize; + public DistinctCountHLLValueAggregator(List arguments) { + // length 1 means we use the default _log2m of 8 + if (arguments.size() <= 1) { + _log2m = CommonConstants.Helix.DEFAULT_HYPERLOGLOG_LOG2M; + } else { + _log2m = arguments.get(1).getLiteral().getIntValue(); + } + } + @Override public AggregationFunctionType getAggregationType() { return AggregationFunctionType.DISTINCTCOUNTHLL; @@ -49,12 +62,11 @@ public HyperLogLog getInitialAggregatedValue(Object rawValue) { if (rawValue instanceof byte[]) { byte[] bytes = (byte[]) rawValue; initialValue = deserializeAggregatedValue(bytes); - _maxByteSize = Math.max(_maxByteSize, bytes.length); + _maxByteSize = bytes.length; } else { - // TODO: Handle configurable log2m for StarTreeBuilder - initialValue = new HyperLogLog(CommonConstants.Helix.DEFAULT_HYPERLOGLOG_LOG2M); + initialValue = new HyperLogLog(_log2m); initialValue.offer(rawValue); - _maxByteSize = Math.max(_maxByteSize, DEFAULT_LOG2M_BYTE_SIZE); + _maxByteSize = HyperLogLogUtils.byteSize(initialValue); } return initialValue; } @@ -90,7 +102,9 @@ public HyperLogLog cloneAggregatedValue(HyperLogLog value) { @Override public int getMaxAggregatedValueByteSize() { - return _maxByteSize; + // NOTE: For aggregated metrics, initial aggregated value might have not been generated. Returns the byte size + // based on log2m. + return _maxByteSize > 0 ? _maxByteSize : HyperLogLogUtils.byteSize(_log2m); } @Override diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/SumPrecisionValueAggregator.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/SumPrecisionValueAggregator.java index 2aab0deaad93..0257057e24f2 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/SumPrecisionValueAggregator.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/SumPrecisionValueAggregator.java @@ -18,7 +18,10 @@ */ package org.apache.pinot.segment.local.aggregator; +import com.google.common.base.Preconditions; import java.math.BigDecimal; +import java.util.List; +import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.segment.spi.AggregationFunctionType; import org.apache.pinot.spi.data.FieldSpec.DataType; import org.apache.pinot.spi.utils.BigDecimalUtils; @@ -27,8 +30,23 @@ public class SumPrecisionValueAggregator implements ValueAggregator { public static final DataType AGGREGATED_VALUE_TYPE = DataType.BYTES; + private final int _fixedSize; + private int _maxByteSize; + /** + * Optional second argument is the maximum precision. Scale is always stored as 2 bytes. During query time, the + * optional scale parameter can be provided, but during ingestion, we don't limit it. + */ + public SumPrecisionValueAggregator(List arguments) { + // length 1 means we don't have any caps on maximum precision nor do we have a fixed size then + if (arguments.size() <= 1) { + _fixedSize = -1; + } else { + _fixedSize = BigDecimalUtils.byteSizeForFixedPrecision(arguments.get(1).getLiteral().getIntValue()); + } + } + @Override public AggregationFunctionType getAggregationType() { return AggregationFunctionType.SUMPRECISION; @@ -42,14 +60,18 @@ public DataType getAggregatedValueType() { @Override public BigDecimal getInitialAggregatedValue(Object rawValue) { BigDecimal initialValue = toBigDecimal(rawValue); - _maxByteSize = Math.max(_maxByteSize, BigDecimalUtils.byteSize(initialValue)); + if (_fixedSize < 0) { + _maxByteSize = Math.max(_maxByteSize, BigDecimalUtils.byteSize(initialValue)); + } return initialValue; } @Override public BigDecimal applyRawValue(BigDecimal value, Object rawValue) { value = value.add(toBigDecimal(rawValue)); - _maxByteSize = Math.max(_maxByteSize, BigDecimalUtils.byteSize(value)); + if (_fixedSize < 0) { + _maxByteSize = Math.max(_maxByteSize, BigDecimalUtils.byteSize(value)); + } return value; } @@ -66,7 +88,9 @@ private static BigDecimal toBigDecimal(Object rawValue) { @Override public BigDecimal applyAggregatedValue(BigDecimal value, BigDecimal aggregatedValue) { value = value.add(aggregatedValue); - _maxByteSize = Math.max(_maxByteSize, BigDecimalUtils.byteSize(value)); + if (_fixedSize < 0) { + _maxByteSize = Math.max(_maxByteSize, BigDecimalUtils.byteSize(value)); + } return value; } @@ -78,12 +102,14 @@ public BigDecimal cloneAggregatedValue(BigDecimal value) { @Override public int getMaxAggregatedValueByteSize() { - return _maxByteSize; + Preconditions.checkState(_fixedSize > 0 || _maxByteSize > 0, + "Unknown max aggregated value byte size, please provide maximum precision as the second argument"); + return _fixedSize > 0 ? _fixedSize : _maxByteSize; } @Override public byte[] serializeAggregatedValue(BigDecimal value) { - return BigDecimalUtils.serialize(value); + return _fixedSize > 0 ? BigDecimalUtils.serializeWithSize(value, _fixedSize) : BigDecimalUtils.serialize(value); } @Override diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/ValueAggregatorFactory.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/ValueAggregatorFactory.java index b4f90c4952de..b348b1ff4c87 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/ValueAggregatorFactory.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/aggregator/ValueAggregatorFactory.java @@ -18,7 +18,9 @@ */ package org.apache.pinot.segment.local.aggregator; +import java.util.List; import org.apache.datasketches.tuple.aninteger.IntegerSummary; +import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.segment.spi.AggregationFunctionType; import org.apache.pinot.spi.data.FieldSpec.DataType; @@ -37,7 +39,8 @@ private ValueAggregatorFactory() { * @param aggregationType Aggregation type * @return Value aggregator */ - public static ValueAggregator getValueAggregator(AggregationFunctionType aggregationType) { + public static ValueAggregator getValueAggregator(AggregationFunctionType aggregationType, + List arguments) { switch (aggregationType) { case COUNT: return new CountValueAggregator(); @@ -48,7 +51,7 @@ public static ValueAggregator getValueAggregator(AggregationFunctionType aggrega case SUM: return new SumValueAggregator(); case SUMPRECISION: - return new SumPrecisionValueAggregator(); + return new SumPrecisionValueAggregator(arguments); case AVG: return new AvgValueAggregator(); case MINMAXRANGE: @@ -57,7 +60,7 @@ public static ValueAggregator getValueAggregator(AggregationFunctionType aggrega return new DistinctCountBitmapValueAggregator(); case DISTINCTCOUNTHLL: case DISTINCTCOUNTRAWHLL: - return new DistinctCountHLLValueAggregator(); + return new DistinctCountHLLValueAggregator(arguments); case PERCENTILEEST: case PERCENTILERAWEST: return new PercentileEstValueAggregator(); diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImpl.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImpl.java index 2c6885ace5b8..f798b62783f9 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImpl.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImpl.java @@ -251,15 +251,27 @@ public boolean isMutableSegment() { metricsAggregators = getMetricsAggregators(config); } - Set specialIndexes = Sets.newHashSet( - StandardIndexes.dictionary(), // dictionaries implement other contract - StandardIndexes.nullValueVector()); // null value vector implement other contract + Set specialIndexes = + Sets.newHashSet(StandardIndexes.dictionary(), // dictionaries implement other contract + StandardIndexes.nullValueVector()); // null value vector implement other contract // Initialize for each column for (FieldSpec fieldSpec : _physicalFieldSpecs) { String column = fieldSpec.getName(); - FieldIndexConfigs indexConfigs = Optional.ofNullable(config.getIndexConfigByCol().get(column)) - .orElse(FieldIndexConfigs.EMPTY); + + int fixedByteSize = -1; + DataType dataType = fieldSpec.getDataType(); + DataType storedType = dataType.getStoredType(); + if (!storedType.isFixedWidth()) { + // For aggregated metrics, we need to store values with fixed byte size so that in-place replacement is possible + Pair aggregatorPair = metricsAggregators.get(column); + if (aggregatorPair != null) { + fixedByteSize = aggregatorPair.getRight().getMaxAggregatedValueByteSize(); + } + } + + FieldIndexConfigs indexConfigs = + Optional.ofNullable(config.getIndexConfigByCol().get(column)).orElse(FieldIndexConfigs.EMPTY); boolean isDictionary = !isNoDictionaryColumn(indexConfigs, fieldSpec, column); MutableIndexContext context = MutableIndexContext.builder().withFieldSpec(fieldSpec).withMemoryManager(_memoryManager) @@ -268,7 +280,7 @@ public boolean isMutableSegment() { .withEstimatedColSize(_statsHistory.getEstimatedAvgColSize(column)) .withAvgNumMultiValues(_statsHistory.getEstimatedAvgColSize(column)) .withConsumerDir(config.getConsumerDir() != null ? new File(config.getConsumerDir()) : null) - .build(); + .withFixedLengthBytes(fixedByteSize).build(); // Partition info PartitionFunction partitionFunction = null; @@ -306,8 +318,7 @@ public boolean isMutableSegment() { dictionary = null; if (!fieldSpec.isSingleValueField()) { // Raw MV columns - DataType dataType = fieldSpec.getDataType().getStoredType(); - switch (dataType) { + switch (storedType) { case INT: case LONG: case FLOAT: @@ -416,15 +427,14 @@ private boolean isNoDictionaryColumn(FieldIndexConfigs indexConfigs, FieldSpec f // if the column is part of noDictionary set from table config if (fieldSpec instanceof DimensionFieldSpec && isAggregateMetricsEnabled() && (dataType == STRING || dataType == BYTES)) { - _logger.info( - "Aggregate metrics is enabled. Will create dictionary in consuming segment for column {} of type {}", + _logger.info("Aggregate metrics is enabled. Will create dictionary in consuming segment for column {} of type {}", column, dataType); return false; } // So don't create dictionary if the column (1) is member of noDictionary, and (2) is single-value or multi-value // with a fixed-width field, and (3) doesn't have an inverted index - return (fieldSpec.isSingleValueField() || fieldSpec.getDataType().isFixedWidth()) - && indexConfigs.getConfig(StandardIndexes.inverted()).isDisabled(); + return (fieldSpec.isSingleValueField() || fieldSpec.getDataType().isFixedWidth()) && indexConfigs.getConfig( + StandardIndexes.inverted()).isDisabled(); } public SegmentPartitionConfig getSegmentPartitionConfig() { @@ -603,10 +613,7 @@ private void addNewRow(int docId, GenericRow row) { DataType dataType = fieldSpec.getDataType(); value = indexContainer._valueAggregator.getInitialAggregatedValue(value); - // aggregator value has to be numeric, but can be a different type of Number from the one expected on the column - // therefore we need to do some value changes here. - // TODO: Precision may change from one value to other, so we may need to study if this is actually what we want - // to do + // BIG_DECIMAL is actually stored as byte[] and hence can be supported here. switch (dataType.getStoredType()) { case INT: forwardIndex.add(((Number) value).intValue(), -1, docId); @@ -620,6 +627,10 @@ private void addNewRow(int docId, GenericRow row) { case DOUBLE: forwardIndex.add(((Number) value).doubleValue(), -1, docId); break; + case BIG_DECIMAL: + case BYTES: + forwardIndex.add(indexContainer._valueAggregator.serializeAggregatedValue(value), -1, docId); + break; default: throw new UnsupportedOperationException( "Unsupported data type: " + dataType + " for aggregation: " + column); @@ -796,6 +807,11 @@ private void aggregateMetrics(GenericRow row, int docId) { valueAggregator.getAggregatedValueType(), valueAggregator.getAggregationType(), dataType)); } break; + case BYTES: + Object oldValue = valueAggregator.deserializeAggregatedValue(forwardIndex.getBytes(docId)); + Object newValue = valueAggregator.applyRawValue(oldValue, value); + forwardIndex.setBytes(docId, valueAggregator.serializeAggregatedValue(newValue)); + break; default: throw new UnsupportedOperationException( String.format("Aggregation type %s of %s not supported for %s", valueAggregator.getAggregatedValueType(), @@ -1198,8 +1214,8 @@ private static Map> fromAggregateMetrics(R Map> columnNameToAggregator = new HashMap<>(); for (String metricName : segmentConfig.getSchema().getMetricNames()) { - columnNameToAggregator.put(metricName, - Pair.of(metricName, ValueAggregatorFactory.getValueAggregator(AggregationFunctionType.SUM))); + columnNameToAggregator.put(metricName, Pair.of(metricName, + ValueAggregatorFactory.getValueAggregator(AggregationFunctionType.SUM, Collections.emptyList()))); } return columnNameToAggregator; } @@ -1215,18 +1231,15 @@ private static Map> fromAggregationConfig( Preconditions.checkState(expressionContext.getType() == ExpressionContext.Type.FUNCTION, "aggregation function must be a function: %s", config); FunctionContext functionContext = expressionContext.getFunction(); - TableConfigUtils.validateIngestionAggregation(functionContext.getFunctionName()); - Preconditions.checkState(functionContext.getArguments().size() == 1, - "aggregation function can only have one argument: %s", config); + AggregationFunctionType functionType = + AggregationFunctionType.getAggregationFunctionType(functionContext.getFunctionName()); + TableConfigUtils.validateIngestionAggregation(functionType); ExpressionContext argument = functionContext.getArguments().get(0); Preconditions.checkState(argument.getType() == ExpressionContext.Type.IDENTIFIER, "aggregator function argument must be a identifier: %s", config); - AggregationFunctionType functionType = - AggregationFunctionType.getAggregationFunctionType(functionContext.getFunctionName()); - - columnNameToAggregator.put(config.getColumnName(), - Pair.of(argument.getIdentifier(), ValueAggregatorFactory.getValueAggregator(functionType))); + columnNameToAggregator.put(config.getColumnName(), Pair.of(argument.getIdentifier(), + ValueAggregatorFactory.getValueAggregator(functionType, functionContext.getArguments()))); } return columnNameToAggregator; @@ -1290,8 +1303,8 @@ public void close() { closeable.close(); } } catch (Exception e) { - _logger.error("Caught exception while closing {} index for column: {}, continuing with error", - indexType, column, e); + _logger.error("Caught exception while closing {} index for column: {}, continuing with error", indexType, + column, e); } }; diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/realtime/impl/forward/FixedByteSVMutableForwardIndex.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/realtime/impl/forward/FixedByteSVMutableForwardIndex.java index 13fce48a21f6..529e80ef5ba6 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/realtime/impl/forward/FixedByteSVMutableForwardIndex.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/realtime/impl/forward/FixedByteSVMutableForwardIndex.java @@ -18,6 +18,7 @@ */ package org.apache.pinot.segment.local.realtime.impl.forward; +import com.google.common.base.Preconditions; import java.io.Closeable; import java.io.IOException; import java.math.BigDecimal; @@ -63,15 +64,21 @@ public class FixedByteSVMutableForwardIndex implements MutableForwardIndex { /** * @param storedType Data type of the values + * @param fixedLength Fixed length of values if known: only used for BYTES field (HyperLogLog and BigDecimal storage) * @param numRowsPerChunk Number of rows to pack in one chunk before a new chunk is created. * @param memoryManager Memory manager to be used for allocating memory. * @param allocationContext Allocation allocationContext. */ - public FixedByteSVMutableForwardIndex(boolean dictionaryEncoded, DataType storedType, int numRowsPerChunk, - PinotDataBufferMemoryManager memoryManager, String allocationContext) { + public FixedByteSVMutableForwardIndex(boolean dictionaryEncoded, DataType storedType, int fixedLength, + int numRowsPerChunk, PinotDataBufferMemoryManager memoryManager, String allocationContext) { _dictionaryEncoded = dictionaryEncoded; _storedType = storedType; - _valueSizeInBytes = storedType.size(); + if (!storedType.isFixedWidth()) { + Preconditions.checkState(fixedLength > 0, "Fixed length must be provided for type: %s", storedType); + _valueSizeInBytes = fixedLength; + } else { + _valueSizeInBytes = storedType.size(); + } _numRowsPerChunk = numRowsPerChunk; _chunkSizeInBytes = numRowsPerChunk * _valueSizeInBytes; _memoryManager = memoryManager; @@ -79,6 +86,11 @@ public FixedByteSVMutableForwardIndex(boolean dictionaryEncoded, DataType stored addBuffer(); } + public FixedByteSVMutableForwardIndex(boolean dictionaryEncoded, DataType valueType, int numRowsPerChunk, + PinotDataBufferMemoryManager memoryManager, String allocationContext) { + this(dictionaryEncoded, valueType, -1, numRowsPerChunk, memoryManager, allocationContext); + } + @Override public boolean isDictionaryEncoded() { return _dictionaryEncoded; @@ -195,6 +207,21 @@ public void setDouble(int docId, double value) { getWriterForRow(docId).setDouble(docId, value); } + @Override + public byte[] getBytes(int docId) { + int bufferId = getBufferId(docId); + return _readers.get(bufferId).getBytes(docId); + } + + @Override + public void setBytes(int docId, byte[] value) { + Preconditions.checkArgument(value.length == _valueSizeInBytes, "Expected value size to be: %s but got: %s ", + _valueSizeInBytes, value.length); + + addBufferIfNeeded(docId); + getWriterForRow(docId).setBytes(docId, value); + } + private WriterWithOffset getWriterForRow(int row) { return _writers.get(getBufferId(row)); } @@ -267,6 +294,10 @@ public void setFloat(int row, float value) { public void setDouble(int row, double value) { _writer.setDouble(row - _startRowId, 0, value); } + + public void setBytes(int row, byte[] value) { + _writer.setBytes(row - _startRowId, 0, value); + } } /** @@ -307,6 +338,10 @@ public BigDecimal getBigDecimal(int row) { return BigDecimalUtils.deserialize(_reader.getBytes(row - _startRowId, 0)); } + public byte[] getBytes(int row) { + return _reader.getBytes(row - _startRowId, 0); + } + public FixedByteSingleValueMultiColReader getReader() { return _reader; } diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/forward/ForwardIndexType.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/forward/ForwardIndexType.java index e45cca536e91..365d553b5886 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/forward/ForwardIndexType.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/forward/ForwardIndexType.java @@ -59,8 +59,7 @@ import org.apache.pinot.spi.data.Schema; -public class ForwardIndexType - extends AbstractIndexType +public class ForwardIndexType extends AbstractIndexType implements ConfigurableFromIndexLoadingConfig { public static final String INDEX_DISPLAY_NAME = "forward"; // For multi-valued column, forward-index. @@ -269,13 +268,15 @@ public MutableIndex createMutableIndex(MutableIndexContext context, ForwardIndex String column = context.getFieldSpec().getName(); String segmentName = context.getSegmentName(); FieldSpec.DataType storedType = context.getFieldSpec().getDataType().getStoredType(); + int fixedLengthBytes = context.getFixedLengthBytes(); boolean isSingleValue = context.getFieldSpec().isSingleValueField(); if (!context.hasDictionary()) { if (isSingleValue) { - String allocationContext = IndexUtil.buildAllocationContext(context.getSegmentName(), - context.getFieldSpec().getName(), V1Constants.Indexes.RAW_SV_FORWARD_INDEX_FILE_EXTENSION); - if (storedType.isFixedWidth()) { - return new FixedByteSVMutableForwardIndex(false, storedType, context.getCapacity(), + String allocationContext = + IndexUtil.buildAllocationContext(context.getSegmentName(), context.getFieldSpec().getName(), + V1Constants.Indexes.RAW_SV_FORWARD_INDEX_FILE_EXTENSION); + if (storedType.isFixedWidth() || fixedLengthBytes > 0) { + return new FixedByteSVMutableForwardIndex(false, storedType, fixedLengthBytes, context.getCapacity(), context.getMemoryManager(), allocationContext); } else { // RealtimeSegmentStatsHistory does not have the stats for no-dictionary columns from previous consuming diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/startree/v2/builder/BaseSingleTreeBuilder.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/startree/v2/builder/BaseSingleTreeBuilder.java index 1c1b03832e79..54da52c07040 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/startree/v2/builder/BaseSingleTreeBuilder.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/startree/v2/builder/BaseSingleTreeBuilder.java @@ -22,6 +22,7 @@ import java.io.File; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -142,7 +143,9 @@ static class Record { for (AggregationFunctionColumnPair functionColumnPair : functionColumnPairs) { _metrics[index] = functionColumnPair.toColumnName(); _functionColumnPairs[index] = functionColumnPair; - _valueAggregators[index] = ValueAggregatorFactory.getValueAggregator(functionColumnPair.getFunctionType()); + // TODO: Allow extra arguments in star-tree (e.g. log2m, precision) + _valueAggregators[index] = + ValueAggregatorFactory.getValueAggregator(functionColumnPair.getFunctionType(), Collections.emptyList()); // Ignore the column for COUNT aggregation function if (_valueAggregators[index].getAggregationType() != AggregationFunctionType.COUNT) { diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/utils/HyperLogLogUtils.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/utils/HyperLogLogUtils.java new file mode 100644 index 000000000000..15058e9639c1 --- /dev/null +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/utils/HyperLogLogUtils.java @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.segment.local.utils; + +import com.clearspring.analytics.stream.cardinality.HyperLogLog; +import com.clearspring.analytics.stream.cardinality.RegisterSet; + + +public class HyperLogLogUtils { + private HyperLogLogUtils() { + } + + /** + * Returns the byte size of the given HyperLogLog. + */ + public static int byteSize(HyperLogLog value) { + // 8 bytes header (log2m & register set size) & register set data + return value.sizeof() + 2 * Integer.BYTES; + } + + /** + * Returns the byte size of HyperLogLog of a given log2m. + */ + public static int byteSize(int log2m) { + // 8 bytes header (log2m & register set size) & register set data + return (RegisterSet.getSizeForCount(1 << log2m) + 2) * Integer.BYTES; + } +} diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/utils/TableConfigUtils.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/utils/TableConfigUtils.java index a64170c60f55..0557d592195a 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/utils/TableConfigUtils.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/utils/TableConfigUtils.java @@ -106,7 +106,8 @@ private TableConfigUtils() { // hardcode the value here to avoid pulling the entire pinot-kinesis module as dependency. private static final String KINESIS_STREAM_TYPE = "kinesis"; private static final EnumSet SUPPORTED_INGESTION_AGGREGATIONS = - EnumSet.of(SUM, MIN, MAX, COUNT); + EnumSet.of(SUM, MIN, MAX, COUNT, DISTINCTCOUNTHLL, SUMPRECISION); + private static final Set UPSERT_DEDUP_ALLOWED_ROUTING_STRATEGIES = ImmutableSet.of(RoutingConfig.STRICT_REPLICA_GROUP_INSTANCE_SELECTOR_TYPE, RoutingConfig.MULTI_STAGE_REPLICA_GROUP_SELECTOR_TYPE); @@ -357,8 +358,9 @@ public static void validateIngestionConfig(TableConfig tableConfig, @Nullable Sc "columnName/aggregationFunction cannot be null in AggregationConfig " + aggregationConfig); } + FieldSpec fieldSpec = null; if (schema != null) { - FieldSpec fieldSpec = schema.getFieldSpecFor(columnName); + fieldSpec = schema.getFieldSpecFor(columnName); Preconditions.checkState(fieldSpec != null, "The destination column '" + columnName + "' of the aggregation function must be present in the schema"); Preconditions.checkState(fieldSpec.getFieldType() == FieldSpec.FieldType.METRIC, @@ -379,15 +381,52 @@ public static void validateIngestionConfig(TableConfig tableConfig, @Nullable Sc "aggregation function must be a function for: %s", aggregationConfig); FunctionContext functionContext = expressionContext.getFunction(); - validateIngestionAggregation(functionContext.getFunctionName()); - Preconditions.checkState(functionContext.getArguments().size() == 1, - "aggregation function can only have one argument: %s", aggregationConfig); - - ExpressionContext argument = functionContext.getArguments().get(0); - Preconditions.checkState(argument.getType() == ExpressionContext.Type.IDENTIFIER, - "aggregator function argument must be a identifier: %s", aggregationConfig); + AggregationFunctionType functionType = + AggregationFunctionType.getAggregationFunctionType(functionContext.getFunctionName()); + validateIngestionAggregation(functionType); + + List arguments = functionContext.getArguments(); + int numArguments = arguments.size(); + if (functionType == DISTINCTCOUNTHLL) { + Preconditions.checkState(numArguments >= 1 && numArguments <= 2, + "DISTINCT_COUNT_HLL can have at most two arguments: %s", aggregationConfig); + if (numArguments == 2) { + ExpressionContext secondArgument = arguments.get(1); + Preconditions.checkState(secondArgument.getType() == ExpressionContext.Type.LITERAL, + "Second argument of DISTINCT_COUNT_HLL must be literal: %s", aggregationConfig); + String literal = secondArgument.getLiteral().getStringValue(); + Preconditions.checkState(StringUtils.isNumeric(literal), + "Second argument of DISTINCT_COUNT_HLL must be a number: %s", aggregationConfig); + } + if (fieldSpec != null) { + DataType dataType = fieldSpec.getDataType(); + Preconditions.checkState(dataType == DataType.BYTES, + "Result type for DISTINCT_COUNT_HLL must be BYTES: %s", aggregationConfig); + } + } else if (functionType == SUMPRECISION) { + Preconditions.checkState(numArguments >= 2 && numArguments <= 3, + "SUM_PRECISION must specify precision (required), scale (optional): %s", aggregationConfig); + ExpressionContext secondArgument = arguments.get(1); + Preconditions.checkState(secondArgument.getType() == ExpressionContext.Type.LITERAL, + "Second argument of SUM_PRECISION must be literal: %s", aggregationConfig); + String literal = secondArgument.getLiteral().getStringValue(); + Preconditions.checkState(StringUtils.isNumeric(literal), + "Second argument of SUM_PRECISION must be a number: %s", aggregationConfig); + if (fieldSpec != null) { + DataType dataType = fieldSpec.getDataType(); + Preconditions.checkState(dataType == DataType.BIG_DECIMAL || dataType == DataType.BYTES, + "Result type for DISTINCT_COUNT_HLL must be BIG_DECIMAL or BYTES: %s", aggregationConfig); + } + } else { + Preconditions.checkState(numArguments == 1, "%s can only have one argument: %s", functionType, + aggregationConfig); + } + ExpressionContext firstArgument = arguments.get(0); + Preconditions.checkState(firstArgument.getType() == ExpressionContext.Type.IDENTIFIER, + "First argument of aggregation function: %s must be identifier, got: %s", functionType, + firstArgument.getType()); - aggregationSourceColumns.add(argument.getIdentifier()); + aggregationSourceColumns.add(firstArgument.getIdentifier()); } if (schema != null) { Preconditions.checkState(new HashSet<>(schema.getMetricNames()).equals(aggregationColumns), @@ -455,21 +494,9 @@ public static void validateIngestionConfig(TableConfig tableConfig, @Nullable Sc } } - /** - * Currently only, ValueAggregators with fixed width types are allowed, so MIN, MAX, SUM, and COUNT. The reason - * is that only the {@link org.apache.pinot.segment.local.realtime.impl.forward.FixedByteSVMutableForwardIndex} - * supports random inserts and lookups. The - * {@link org.apache.pinot.segment.local.realtime.impl.forward.VarByteSVMutableForwardIndex only supports - * sequential inserts. - */ - public static void validateIngestionAggregation(String name) { - for (AggregationFunctionType functionType : SUPPORTED_INGESTION_AGGREGATIONS) { - if (functionType.getName().equals(name)) { - return; - } - } - throw new IllegalStateException( - String.format("aggregation function %s must be one of %s", name, SUPPORTED_INGESTION_AGGREGATIONS)); + public static void validateIngestionAggregation(AggregationFunctionType functionType) { + Preconditions.checkState(SUPPORTED_INGESTION_AGGREGATIONS.contains(functionType), + "Aggregation function: %s must be one of: %s", functionType, SUPPORTED_INGESTION_AGGREGATIONS); } @VisibleForTesting diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplIngestionAggregationTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplIngestionAggregationTest.java index e0aea45373c7..5e048520c0ed 100644 --- a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplIngestionAggregationTest.java +++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplIngestionAggregationTest.java @@ -18,6 +18,9 @@ */ package org.apache.pinot.segment.local.indexsegment.mutable; +import com.clearspring.analytics.stream.cardinality.CardinalityMergeException; +import com.clearspring.analytics.stream.cardinality.HyperLogLog; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -27,11 +30,15 @@ import java.util.Map; import java.util.Random; import java.util.Set; +import org.apache.pinot.common.request.Literal; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.segment.local.aggregator.DistinctCountHLLValueAggregator; import org.apache.pinot.spi.config.table.ingestion.AggregationConfig; import org.apache.pinot.spi.data.FieldSpec; import org.apache.pinot.spi.data.Schema; import org.apache.pinot.spi.data.readers.GenericRow; import org.apache.pinot.spi.stream.StreamMessageMetadata; +import org.apache.pinot.spi.utils.BigDecimalUtils; import org.testng.Assert; import org.testng.annotations.Test; @@ -48,7 +55,7 @@ public class MutableSegmentImplIngestionAggregationTest { private static final String KEY_SEPARATOR = "\t\t"; private static final int NUM_ROWS = 10001; - private static final Schema.SchemaBuilder getSchemaBuilder() { + private static Schema.SchemaBuilder getSchemaBuilder() { return new Schema.SchemaBuilder().setSchemaName("testSchema") .addSingleValueDimension(DIMENSION_1, FieldSpec.DataType.INT) .addSingleValueDimension(DIMENSION_2, FieldSpec.DataType.STRING) @@ -81,10 +88,10 @@ public void testSameSrcDifferentAggregations() for (List metrics : addRows(1, mutableSegmentImpl)) { expectedMin.put(metrics.get(0).getKey(), Math.min(expectedMin.getOrDefault(metrics.get(0).getKey(), Double.POSITIVE_INFINITY), - metrics.get(0).getValue())); + (Integer) metrics.get(0).getValue())); expectedMax.put(metrics.get(0).getKey(), Math.max(expectedMax.getOrDefault(metrics.get(0).getKey(), Double.NEGATIVE_INFINITY), - metrics.get(0).getValue())); + (Integer) metrics.get(0).getValue())); } GenericRow reuse = new GenericRow(); @@ -115,9 +122,9 @@ public void testSameAggregationDifferentSrc() Map expectedSum2 = new HashMap<>(); for (List metrics : addRows(2, mutableSegmentImpl)) { expectedSum1.put(metrics.get(0).getKey(), - expectedSum1.getOrDefault(metrics.get(0).getKey(), 0) + metrics.get(0).getValue()); + expectedSum1.getOrDefault(metrics.get(0).getKey(), 0) + (Integer) (metrics.get(0).getValue())); expectedSum2.put(metrics.get(1).getKey(), - expectedSum2.getOrDefault(metrics.get(1).getKey(), 0L) + metrics.get(1).getValue().longValue()); + expectedSum2.getOrDefault(metrics.get(1).getKey(), 0L) + ((Integer) metrics.get(1).getValue()).longValue()); } GenericRow reuse = new GenericRow(); @@ -132,7 +139,95 @@ public void testSameAggregationDifferentSrc() } @Test - public void testCOUNT() + public void testValuesAreNullThrowsException() + throws Exception { + String m1 = "sum1"; + + Schema schema = getSchemaBuilder().addMetric(m1, FieldSpec.DataType.INT).build(); + MutableSegmentImpl mutableSegmentImpl = + MutableSegmentImplTestUtils.createMutableSegmentImpl(schema, Collections.singleton(m1), VAR_LENGTH_SET, + INVERTED_INDEX_SET, Collections.singletonList(new AggregationConfig(m1, "SUM(metric)"))); + + long seed = 2; + Random random = new Random(seed); + StreamMessageMetadata defaultMetadata = new StreamMessageMetadata(System.currentTimeMillis(), null); + + // Generate random int to prevent overflow + GenericRow row = getRow(random, 1); + row.putValue(METRIC, null); + try { + mutableSegmentImpl.index(row, defaultMetadata); + Assert.fail(); + } catch (NullPointerException e) { + // expected + } + + mutableSegmentImpl.destroy(); + } + + @Test + public void testDistinctCountHLL() + throws Exception { + String m1 = "hll1"; + + Schema schema = getSchemaBuilder().addMetric(m1, FieldSpec.DataType.BYTES).build(); + MutableSegmentImpl mutableSegmentImpl = + MutableSegmentImplTestUtils.createMutableSegmentImpl(schema, Collections.singleton(m1), VAR_LENGTH_SET, + INVERTED_INDEX_SET, Collections.singletonList(new AggregationConfig(m1, "distinctCountHLL(metric, 12)"))); + + Map expected = new HashMap<>(); + List metrics = addRowsDistinctCountHLL(998, mutableSegmentImpl); + for (Metric metric : metrics) { + expected.put(metric.getKey(), (HLLTestData) metric.getValue()); + } + + List arguments = Arrays.asList(ExpressionContext.forIdentifier("metric"), + ExpressionContext.forLiteralContext(Literal.stringValue("12"))); + DistinctCountHLLValueAggregator valueAggregator = new DistinctCountHLLValueAggregator(arguments); + + Set integers = new HashSet<>(); + + // Assert that the distinct count is within an error margin. We assert on the cardinality of the HLL in the docID + // and the HLL we made, but also on the cardinality of the HLL in the docID and the actual cardinality from the set + // of integers. + GenericRow reuse = new GenericRow(); + for (int docId = 0; docId < expected.size(); docId++) { + GenericRow row = mutableSegmentImpl.getRecord(docId, reuse); + String key = buildKey(row); + + integers.addAll(expected.get(key)._integers); + + HyperLogLog expectedHLL = expected.get(key)._hll; + HyperLogLog actualHLL = valueAggregator.deserializeAggregatedValue((byte[]) row.getValue(m1)); + + Assert.assertEquals(actualHLL.cardinality(), expectedHLL.cardinality(), (int) (expectedHLL.cardinality() * 0.04), + "The HLL cardinality from the index is within a tolerable error margin (4%) of the cardinality of the " + + "expected HLL."); + Assert.assertEquals(actualHLL.cardinality(), expected.get(key)._integers.size(), + expected.get(key)._integers.size() * 0.04, + "The HLL cardinality from the index is within a tolerable error margin (4%) of the actual cardinality of " + + "the integers."); + } + + // Assert that the aggregated HyperLogLog is also within the error margin + HyperLogLog togetherHLL = new HyperLogLog(12); + expected.forEach((key, value) -> { + try { + togetherHLL.addAll(value._hll); + } catch (CardinalityMergeException e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + }); + + Assert.assertEquals(togetherHLL.cardinality(), integers.size(), (int) (integers.size() * 0.04), + "The aggregated HLL cardinality is within a tolerable error margin (4%) of the actual cardinality of the " + + "integers."); + mutableSegmentImpl.destroy(); + } + + @Test + public void testCount() throws Exception { String m1 = "count1"; String m2 = "count2"; @@ -146,8 +241,7 @@ public void testCOUNT() Map expectedCount = new HashMap<>(); for (List metrics : addRows(3, mutableSegmentImpl)) { - expectedCount.put(metrics.get(0).getKey(), - expectedCount.getOrDefault(metrics.get(0).getKey(), 0L) + 1L); + expectedCount.put(metrics.get(0).getKey(), expectedCount.getOrDefault(metrics.get(0).getKey(), 0L) + 1L); } GenericRow reuse = new GenericRow(); @@ -166,22 +260,40 @@ private String buildKey(GenericRow row) { TIME_COLUMN1) + KEY_SEPARATOR + row.getValue(TIME_COLUMN2); } - private GenericRow getRow(Random random) { + private GenericRow getRow(Random random, Integer multiplicationFactor) { GenericRow row = new GenericRow(); - row.putValue(DIMENSION_1, random.nextInt(10)); + row.putValue(DIMENSION_1, random.nextInt(2 * multiplicationFactor)); row.putValue(DIMENSION_2, STRING_VALUES.get(random.nextInt(STRING_VALUES.size()))); - row.putValue(TIME_COLUMN1, random.nextInt(10)); - row.putValue(TIME_COLUMN2, random.nextInt(5)); + row.putValue(TIME_COLUMN1, random.nextInt(2 * multiplicationFactor)); + row.putValue(TIME_COLUMN2, random.nextInt(2 * multiplicationFactor)); return row; } + private class HLLTestData { + private HyperLogLog _hll; + private Set _integers; + + public HLLTestData(HyperLogLog hll, Set integers) { + _hll = hll; + _integers = integers; + } + + public HyperLogLog getHll() { + return _hll; + } + + public Set getIntegers() { + return _integers; + } + } + private class Metric { private final String _key; - private final Integer _value; + private final Object _value; - Metric(String key, Integer value) { + Metric(String key, Object value) { _key = key; _value = value; } @@ -190,23 +302,110 @@ public String getKey() { return _key; } - public Integer getValue() { + public Object getValue() { return _value; } } + private List addRowsDistinctCountHLL(long seed, MutableSegmentImpl mutableSegmentImpl) + throws Exception { + List metrics = new ArrayList<>(); + + Random random = new Random(seed); + StreamMessageMetadata defaultMetadata = new StreamMessageMetadata(System.currentTimeMillis(), null); + + HashMap hllMap = new HashMap<>(); + HashMap> distinctMap = new HashMap<>(); + + Integer rows = 500000; + + for (int i = 0; i < (rows); i++) { + GenericRow row = getRow(random, 1); + String key = buildKey(row); + + int metricValue = random.nextInt(5000000); + row.putValue(METRIC, metricValue); + + if (hllMap.containsKey(key)) { + hllMap.get(key).offer(row.getValue(METRIC)); + distinctMap.get(key).add(metricValue); + } else { + HyperLogLog hll = new HyperLogLog(12); + hll.offer(row.getValue(METRIC)); + hllMap.put(key, hll); + distinctMap.put(key, new HashSet<>(metricValue)); + } + + mutableSegmentImpl.index(row, defaultMetadata); + } + + distinctMap.forEach( + (key, value) -> metrics.add(new Metric(key, new HLLTestData(hllMap.get(key), distinctMap.get(key))))); + + int numDocsIndexed = mutableSegmentImpl.getNumDocsIndexed(); + Assert.assertEquals(numDocsIndexed, hllMap.keySet().size()); + + // Assert that aggregation happened. + Assert.assertTrue(numDocsIndexed < NUM_ROWS); + + return metrics; + } + + private List addRowsSumPrecision(long seed, MutableSegmentImpl mutableSegmentImpl) + throws Exception { + List metrics = new ArrayList<>(); + + Random random = new Random(seed); + StreamMessageMetadata defaultMetadata = new StreamMessageMetadata(System.currentTimeMillis(), null); + + HashMap bdMap = new HashMap<>(); + HashMap> bdIndividualMap = new HashMap<>(); + + int numRows = 50000; + for (int i = 0; i < numRows; i++) { + GenericRow row = getRow(random, 1); + String key = buildKey(row); + + BigDecimal metricValue = generateRandomBigDecimal(random, 5, 6); + row.putValue(METRIC, metricValue.toString()); + + if (bdMap.containsKey(key)) { + bdMap.put(key, bdMap.get(key).add(metricValue)); + bdIndividualMap.get(key).add(metricValue); + } else { + bdMap.put(key, metricValue); + ArrayList bdList = new ArrayList<>(); + bdList.add(metricValue); + bdIndividualMap.put(key, bdList); + } + + mutableSegmentImpl.index(row, defaultMetadata); + } + + for (String key : bdMap.keySet()) { + metrics.add(new Metric(key, bdMap.get(key))); + } + + int numDocsIndexed = mutableSegmentImpl.getNumDocsIndexed(); + Assert.assertEquals(numDocsIndexed, bdMap.keySet().size()); + + // Assert that aggregation happened. + Assert.assertTrue(numDocsIndexed < NUM_ROWS); + + return metrics; + } + private List> addRows(long seed, MutableSegmentImpl mutableSegmentImpl) throws Exception { List> metrics = new ArrayList<>(); Set keys = new HashSet<>(); - Random random = new Random(seed); StreamMessageMetadata defaultMetadata = new StreamMessageMetadata(System.currentTimeMillis(), new GenericRow()); for (int i = 0; i < NUM_ROWS; i++) { - GenericRow row = getRow(random); - // This needs to be relatively low since it will tend to overflow with the Int-to-Double conversion. + // Generate random int to prevent overflow + GenericRow row = getRow(random, 1); Integer metricValue = random.nextInt(10000); Integer metric2Value = random.nextInt(); row.putValue(METRIC, metricValue); @@ -227,4 +426,74 @@ private List> addRows(long seed, MutableSegmentImpl mutableSegmentI return metrics; } + + @Test + public void testSumPrecision() + throws Exception { + String m1 = "sumPrecision1"; + Schema schema = getSchemaBuilder().addMetric(m1, FieldSpec.DataType.BIG_DECIMAL).build(); + + MutableSegmentImpl mutableSegmentImpl = + MutableSegmentImplTestUtils.createMutableSegmentImpl(schema, Collections.singleton(m1), VAR_LENGTH_SET, + INVERTED_INDEX_SET, + // Setting precision to 38 in the arguments for SUM_PRECISION + Collections.singletonList(new AggregationConfig(m1, "SUM_PRECISION(metric, 38)"))); + + Map expected = new HashMap<>(); + List metrics = addRowsSumPrecision(998, mutableSegmentImpl); + for (Metric metric : metrics) { + expected.put(metric.getKey(), (BigDecimal) metric.getValue()); + } + + // Assert that the aggregated values are correct + GenericRow reuse = new GenericRow(); + for (int docId = 0; docId < expected.size(); docId++) { + GenericRow row = mutableSegmentImpl.getRecord(docId, reuse); + String key = buildKey(row); + + BigDecimal expectedBigDecimal = expected.get(key); + BigDecimal actualBigDecimal = (BigDecimal) row.getValue(m1); + + Assert.assertEquals(actualBigDecimal, expectedBigDecimal, "The aggregated SUM does not match the expected SUM"); + } + mutableSegmentImpl.destroy(); + } + + @Test + public void testBigDecimalTooBig() { + String m1 = "sumPrecision1"; + Schema schema = getSchemaBuilder().addMetric(m1, FieldSpec.DataType.BIG_DECIMAL).build(); + + int seed = 1; + Random random = new Random(seed); + StreamMessageMetadata defaultMetadata = new StreamMessageMetadata(System.currentTimeMillis(), null); + + MutableSegmentImpl mutableSegmentImpl = + MutableSegmentImplTestUtils.createMutableSegmentImpl(schema, Collections.singleton(m1), VAR_LENGTH_SET, + INVERTED_INDEX_SET, Collections.singletonList(new AggregationConfig(m1, "SUM_PRECISION(metric, 3)"))); + + // Make a big decimal larger than 3 precision and try to index it + BigDecimal large = BigDecimalUtils.generateMaximumNumberWithPrecision(5); + GenericRow row = getRow(random, 1); + + row.putValue("metric", large); + Assert.assertThrows(IllegalArgumentException.class, () -> { + mutableSegmentImpl.index(row, defaultMetadata); + }); + } + + private BigDecimal generateRandomBigDecimal(Random random, int maxPrecision, int scale) { + int precision = 1 + random.nextInt(maxPrecision); + + String s = ""; + for (int i = 0; i < precision; i++) { + s = s + (1 + random.nextInt(9)); + } + + if ((1 + random.nextInt(2)) == 1) { + return (new BigDecimal(s).setScale(scale)).negate(); + } else { + return new BigDecimal(s).setScale(scale); + } + } } diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/segment/index/forward/mutable/FixedByteSVMutableForwardIndexTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/segment/index/forward/mutable/FixedByteSVMutableForwardIndexTest.java index fbdcefcdb7c1..6b17e906a55a 100644 --- a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/segment/index/forward/mutable/FixedByteSVMutableForwardIndexTest.java +++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/segment/index/forward/mutable/FixedByteSVMutableForwardIndexTest.java @@ -18,6 +18,7 @@ */ package org.apache.pinot.segment.local.segment.index.forward.mutable; +import com.clearspring.analytics.stream.cardinality.HyperLogLog; import java.io.IOException; import java.util.Arrays; import java.util.Random; @@ -113,6 +114,77 @@ private void testDictId(final Random random, final int rows, final int div) readerWriter.close(); } + @Test + public void testBytes() + throws IOException { + int rows = 10; + Random r = new Random(); + final long seed = r.nextLong(); + r = new Random(seed); + for (int div = 1; div <= rows / 2; div++) { + testBytes(r, rows, div); + } + } + + private void testBytes(final Random random, final int rows, final int div) + throws IOException { + int hllLog2M12Size = 2740; + int log2m = 12; + + FixedByteSVMutableForwardIndex readerWriter; + readerWriter = + new FixedByteSVMutableForwardIndex(false, DataType.BYTES, hllLog2M12Size, rows / div, _memoryManager, "Long"); + byte[][] data = new byte[rows][]; + + for (int i = 0; i < rows; i++) { + HyperLogLog hll = new HyperLogLog(log2m); + hll.offer(random.nextLong()); + data[i] = hll.getBytes(); + Assert.assertEquals(data[i].length, hllLog2M12Size); + readerWriter.setBytes(i, data[i]); + Assert.assertEquals(readerWriter.getBytes(i).length, data[i].length); + Assert.assertEquals(readerWriter.getBytes(i), data[i]); + } + for (int i = 0; i < rows; i++) { + Assert.assertEquals(readerWriter.getBytes(i), data[i]); + } + + // Test mutability by overwriting randomly selected rows. + for (int i = 0; i < rows; i++) { + if (random.nextFloat() >= 0.5) { + HyperLogLog hll = new HyperLogLog(log2m); + hll.offer(random.nextLong()); + data[i] = hll.getBytes(); + readerWriter.setBytes(i, data[i]); + } + } + for (int i = 0; i < rows; i++) { + Assert.assertEquals(readerWriter.getBytes(i), data[i]); + } + + // Write to a large enough row index to ensure multiple chunks are correctly allocated. + int start = rows * 4; + for (int i = 0; i < rows; i++) { + HyperLogLog hll = new HyperLogLog(log2m); + hll.offer(random.nextLong()); + data[i] = hll.getBytes(); + readerWriter.setBytes(start + i, data[i]); + } + + for (int i = 0; i < rows; i++) { + Assert.assertEquals(readerWriter.getBytes(start + i), data[i]); + } + + // Ensure that rows not written default to an empty byte array. + byte[] emptyBytes = new byte[hllLog2M12Size]; + start = rows * 2; + for (int i = 0; i < 2 * rows; i++) { + byte[] bytes = readerWriter.getBytes(start + i); + Assert.assertEquals(bytes, emptyBytes); + } + readerWriter.close(); + } + @Test public void testLong() throws IOException { diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/utils/HyperLogLogUtilsTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/utils/HyperLogLogUtilsTest.java new file mode 100644 index 000000000000..44f7b741e580 --- /dev/null +++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/utils/HyperLogLogUtilsTest.java @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.segment.local.utils; + +import com.clearspring.analytics.stream.cardinality.HyperLogLog; +import java.io.IOException; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + + +public class HyperLogLogUtilsTest { + + @Test + public void testByteSize() + throws IOException { + for (int log2m = 0; log2m < 16; log2m++) { + HyperLogLog hll = new HyperLogLog(log2m); + int expectedByteSize = hll.getBytes().length; + assertEquals(HyperLogLogUtils.byteSize(log2m), expectedByteSize); + assertEquals(HyperLogLogUtils.byteSize(hll), expectedByteSize); + } + } +} diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/utils/TableConfigUtilsTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/utils/TableConfigUtilsTest.java index 9d58b42161bb..4ec499d58be3 100644 --- a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/utils/TableConfigUtilsTest.java +++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/utils/TableConfigUtilsTest.java @@ -472,44 +472,125 @@ public void ingestionAggregationConfigsTest() { // expected } - ingestionConfig.setAggregationConfigs( - Collections.singletonList(new AggregationConfig("m1", "DISTINCTCOUNTHLL(s1)"))); + ingestionConfig.setAggregationConfigs(Collections.singletonList(new AggregationConfig("m1", "SUM(s1 - s2)"))); try { TableConfigUtils.validateIngestionConfig(tableConfig, schema); - Assert.fail("Should fail due to not supported aggregation function"); + Assert.fail("Should fail due to inner value not being a column"); } catch (IllegalStateException e) { // expected } - ingestionConfig.setAggregationConfigs(Collections.singletonList(new AggregationConfig("m1", "s1 + s2"))); + ingestionConfig.setAggregationConfigs(Collections.singletonList(new AggregationConfig("m1", "SUM(m1)"))); + TableConfigUtils.validateIngestionConfig(tableConfig, schema); + + ingestionConfig.setAggregationConfigs(Collections.singletonList(new AggregationConfig("m1", "SUM(s1)"))); + TableConfigUtils.validateIngestionConfig(tableConfig, schema); + + schema.addField(new MetricFieldSpec("m2", FieldSpec.DataType.DOUBLE)); try { TableConfigUtils.validateIngestionConfig(tableConfig, schema); - Assert.fail("Should fail due to multiple arguments"); + Assert.fail("Should fail due to one metric column not being aggregated"); } catch (IllegalStateException e) { // expected } - ingestionConfig.setAggregationConfigs(Collections.singletonList(new AggregationConfig("m1", "SUM(s1 - s2)"))); + schema = new Schema.SchemaBuilder().setSchemaName(TABLE_NAME).addMetric("d1", FieldSpec.DataType.BYTES).build(); + // distinctcounthllmv is not supported, we expect this to not validate + List aggregationConfigs = Arrays.asList(new AggregationConfig("d1", "DISTINCTCOUNTHLLMV(s1)")); + ingestionConfig.setAggregationConfigs(aggregationConfigs); + tableConfig = + new TableConfigBuilder(TableType.REALTIME).setTableName("myTable_REALTIME").setTimeColumnName("timeColumn") + .setIngestionConfig(ingestionConfig).build(); + try { TableConfigUtils.validateIngestionConfig(tableConfig, schema); - Assert.fail("Should fail due to inner value not being a column"); + Assert.fail("Should fail due to not supported aggregation function"); } catch (IllegalStateException e) { // expected } - ingestionConfig.setAggregationConfigs(Collections.singletonList(new AggregationConfig("m1", "SUM(m1)"))); - TableConfigUtils.validateIngestionConfig(tableConfig, schema); + // distinctcounthll, expect that the function name in various forms (with and without underscores) still validates + schema = new Schema.SchemaBuilder().setSchemaName(TABLE_NAME).addMetric("d1", FieldSpec.DataType.BYTES) + .addMetric("d2", FieldSpec.DataType.BYTES).addMetric("d3", FieldSpec.DataType.BYTES) + .addMetric("d4", FieldSpec.DataType.BYTES).addMetric("d5", FieldSpec.DataType.BYTES).build(); - ingestionConfig.setAggregationConfigs(Collections.singletonList(new AggregationConfig("m1", "SUM(s1)"))); - TableConfigUtils.validateIngestionConfig(tableConfig, schema); + aggregationConfigs = Arrays.asList(new AggregationConfig("d1", "distinct_count_hll(s1)"), + new AggregationConfig("d2", "DISTINCTCOUNTHLL(s1)"), new AggregationConfig("d3", "distinctcounthll(s1)"), + new AggregationConfig("d4", "DISTINCTCOUNT_HLL(s1)"), new AggregationConfig("d5", "DISTINCT_COUNT_HLL(s1)")); + + ingestionConfig.setAggregationConfigs(aggregationConfigs); + tableConfig = + new TableConfigBuilder(TableType.REALTIME).setTableName("myTable_REALTIME").setTimeColumnName("timeColumn") + .setIngestionConfig(ingestionConfig).build(); + + try { + TableConfigUtils.validateIngestionConfig(tableConfig, schema); + } catch (IllegalStateException e) { + Assert.fail("Should not fail due to valid aggregation function", e); + } + + // distinctcounthll, expect not specified log2m argument to default to 8 + schema = new Schema.SchemaBuilder().setSchemaName(TABLE_NAME).addMetric("d1", FieldSpec.DataType.BYTES).build(); + + aggregationConfigs = Arrays.asList(new AggregationConfig("d1", "DISTINCTCOUNTHLL(s1)")); + ingestionConfig.setAggregationConfigs(aggregationConfigs); + tableConfig = + new TableConfigBuilder(TableType.REALTIME).setTableName("myTable_REALTIME").setTimeColumnName("timeColumn") + .setIngestionConfig(ingestionConfig).build(); - schema.addField(new MetricFieldSpec("m2", FieldSpec.DataType.DOUBLE)); try { TableConfigUtils.validateIngestionConfig(tableConfig, schema); - Assert.fail("Should fail due to one metric column not being aggregated"); } catch (IllegalStateException e) { + Assert.fail("Log2m should have defaulted to 8", e); + } + + aggregationConfigs = Arrays.asList(new AggregationConfig("d1", "s1 + s2")); + ingestionConfig.setAggregationConfigs(aggregationConfigs); + tableConfig = + new TableConfigBuilder(TableType.REALTIME).setTableName("myTable_REALTIME").setTimeColumnName("timeColumn") + .setIngestionConfig(ingestionConfig).build(); + + try { + TableConfigUtils.validateIngestionConfig(tableConfig, schema); + Assert.fail("Should fail due to multiple arguments"); + } catch (IllegalArgumentException e) { // expected } + + // sumprecision, expect that the function name in various forms (with and without underscores) still validates + schema = new Schema.SchemaBuilder().setSchemaName(TABLE_NAME) + .addSingleValueDimension("s1", FieldSpec.DataType.BIG_DECIMAL).addMetric("d1", FieldSpec.DataType.BIG_DECIMAL) + .addMetric("d2", FieldSpec.DataType.BIG_DECIMAL).addMetric("d3", FieldSpec.DataType.BIG_DECIMAL) + .addMetric("d4", FieldSpec.DataType.BIG_DECIMAL).build(); + + aggregationConfigs = Arrays.asList(new AggregationConfig("d1", "sum_precision(s1, 10, 32)"), + new AggregationConfig("d2", "SUM_PRECISION(s1, 1)"), new AggregationConfig("d3", "sumprecision(s1, 2)"), + new AggregationConfig("d4", "SUMPRECISION(s1, 10, 99)")); + + ingestionConfig.setAggregationConfigs(aggregationConfigs); + tableConfig = + new TableConfigBuilder(TableType.REALTIME).setTableName("myTable_REALTIME").setTimeColumnName("timeColumn") + .setIngestionConfig(ingestionConfig).build(); + TableConfigUtils.validateIngestionConfig(tableConfig, schema); + + // with too many arguments should fail + schema = new Schema.SchemaBuilder().setSchemaName(TABLE_NAME) + .addSingleValueDimension("s1", FieldSpec.DataType.BIG_DECIMAL).addMetric("d1", FieldSpec.DataType.BIG_DECIMAL) + .build(); + + aggregationConfigs = Arrays.asList(new AggregationConfig("d1", "sum_precision(s1, 10, 32, 99)")); + + ingestionConfig.setAggregationConfigs(aggregationConfigs); + tableConfig = + new TableConfigBuilder(TableType.REALTIME).setTableName("myTable_REALTIME").setTimeColumnName("timeColumn") + .setIngestionConfig(ingestionConfig).build(); + + try { + TableConfigUtils.validateIngestionConfig(tableConfig, schema); + Assert.fail("Should have failed with too many arguments but didn't"); + } catch (IllegalStateException e) { + // Expected + } } @Test diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/index/mutable/MutableForwardIndex.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/index/mutable/MutableForwardIndex.java index 5b63c989d438..b1fd1f14ae58 100644 --- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/index/mutable/MutableForwardIndex.java +++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/index/mutable/MutableForwardIndex.java @@ -50,7 +50,17 @@ default void add(@Nonnull Object value, int dictId, int docId) { setDouble(docId, (double) value); break; case BIG_DECIMAL: - setBigDecimal(docId, (BigDecimal) value); + // If the Big Decimal is already serialized as byte[], use it directly. + // This is only possible when the Big Decimal is generated from a realtime pre-aggregation + // where SumPrecisionValueAggregator uses BigDecimalUtils.serializeWithSize() to serialize the value + // instead of the normal BigDecimalUtils.serialize(). + // setBigDecimal() underlying calls BigDecimalUtils.serialize() which is not the intention + // when the Big Decimal is already serialized. + if (value instanceof byte[]) { + setBytes(docId, (byte[]) value); + } else { + setBigDecimal(docId, (BigDecimal) value); + } break; case STRING: setString(docId, (String) value); diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/index/mutable/provider/MutableIndexContext.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/index/mutable/provider/MutableIndexContext.java index f6ffdc128e51..ab152c0fd285 100644 --- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/index/mutable/provider/MutableIndexContext.java +++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/index/mutable/provider/MutableIndexContext.java @@ -28,6 +28,7 @@ public class MutableIndexContext { private final int _capacity; private final FieldSpec _fieldSpec; + private final int _fixedLengthBytes; private final boolean _hasDictionary; private final boolean _offHeap; private final int _estimatedColSize; @@ -37,10 +38,11 @@ public class MutableIndexContext { private final PinotDataBufferMemoryManager _memoryManager; private final File _consumerDir; - public MutableIndexContext(FieldSpec fieldSpec, boolean hasDictionary, String segmentName, + public MutableIndexContext(FieldSpec fieldSpec, int fixedLengthBytes, boolean hasDictionary, String segmentName, PinotDataBufferMemoryManager memoryManager, int capacity, boolean offHeap, int estimatedColSize, int estimatedCardinality, int avgNumMultiValues, File consumerDir) { _fieldSpec = fieldSpec; + _fixedLengthBytes = fixedLengthBytes; _hasDictionary = hasDictionary; _segmentName = segmentName; _memoryManager = memoryManager; @@ -64,6 +66,10 @@ public FieldSpec getFieldSpec() { return _fieldSpec; } + public int getFixedLengthBytes() { + return _fixedLengthBytes; + } + public boolean hasDictionary() { return _hasDictionary; } @@ -99,6 +105,7 @@ public static Builder builder() { public static class Builder { private FieldSpec _fieldSpec; + private int _fixedLengthBytes; private String _segmentName; private boolean _hasDictionary = true; private boolean _offHeap = true; @@ -160,8 +167,13 @@ public Builder withConsumerDir(File consumerDir) { return this; } + public Builder withFixedLengthBytes(int fixedLengthBytes) { + _fixedLengthBytes = fixedLengthBytes; + return this; + } + public MutableIndexContext build() { - return new MutableIndexContext(Objects.requireNonNull(_fieldSpec), _hasDictionary, + return new MutableIndexContext(Objects.requireNonNull(_fieldSpec), _fixedLengthBytes, _hasDictionary, Objects.requireNonNull(_segmentName), Objects.requireNonNull(_memoryManager), _capacity, _offHeap, _estimatedColSize, _estimatedCardinality, _avgNumMultiValues, _consumerDir); } diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/data/FieldSpec.java b/pinot-spi/src/main/java/org/apache/pinot/spi/data/FieldSpec.java index 3c931204a925..978862e6006f 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/data/FieldSpec.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/data/FieldSpec.java @@ -100,7 +100,7 @@ public abstract class FieldSpec implements Comparable, Serializable { protected DataType _dataType; protected boolean _isSingleValueField = true; - // NOTE: for STRING column, this is the max number of characters; for BYTES column, this is the max number of bytes + // NOTE: This only applies to STRING column, which is the max number of characters private int _maxLength = DEFAULT_MAX_LENGTH; protected Object _defaultNullValue; diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/BigDecimalUtils.java b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/BigDecimalUtils.java index 6827a7c860a2..c479acee0af5 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/BigDecimalUtils.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/BigDecimalUtils.java @@ -35,6 +35,15 @@ public static int byteSize(BigDecimal value) { return (unscaledValue.bitLength() >>> 3) + 3; } + /** + * This gets the expected byte size of a big decimal with a specific precision. + * It is equal to (ceil(log2(10^precision) - 1) + */ + public static int byteSizeForFixedPrecision(int precision) { + BigDecimal bd = generateMaximumNumberWithPrecision(precision); + return byteSize(bd); + } + /** * Serializes a big decimal to a byte array. */ @@ -49,6 +58,34 @@ public static byte[] serialize(BigDecimal value) { return valueBytes; } + public static byte[] serializeWithSize(BigDecimal value, int fixedSize) { + int scale = value.scale(); + BigInteger unscaledValue = value.unscaledValue(); + byte[] unscaledValueBytes = unscaledValue.toByteArray(); + + int unscaledBytesStartingIndex = fixedSize - unscaledValueBytes.length; + if (unscaledValueBytes.length > (fixedSize - 2)) { + throw new IllegalArgumentException("Big decimal of size " + (unscaledValueBytes.length + 2) + + " is too big to serialize into a fixed size of " + fixedSize + " bytes"); + } + + byte[] valueBytes = new byte[fixedSize]; + valueBytes[0] = (byte) (scale >> 8); + valueBytes[1] = (byte) scale; + + byte paddingByte = 0; + if (value.signum() < 0) { + paddingByte = -1; + } + + for (int i = 2; i < unscaledBytesStartingIndex; i++) { + valueBytes[i] = paddingByte; + } + + System.arraycopy(unscaledValueBytes, 0, valueBytes, unscaledBytesStartingIndex, unscaledValueBytes.length); + return valueBytes; + } + /** * Deserializes a big decimal from a byte array. */ @@ -75,4 +112,8 @@ public static BigDecimal deserialize(ByteBuffer byteBuffer) { byteBuffer.get(bytes); return deserialize(bytes); } + + public static BigDecimal generateMaximumNumberWithPrecision(int precision) { + return (new BigDecimal("10")).pow(precision).subtract(new BigDecimal("1")); + } } diff --git a/pinot-spi/src/test/java/org/apache/pinot/spi/utils/BigDecimalUtilsTest.java b/pinot-spi/src/test/java/org/apache/pinot/spi/utils/BigDecimalUtilsTest.java index fe9387a14f90..f06afeda3682 100644 --- a/pinot-spi/src/test/java/org/apache/pinot/spi/utils/BigDecimalUtilsTest.java +++ b/pinot-spi/src/test/java/org/apache/pinot/spi/utils/BigDecimalUtilsTest.java @@ -20,6 +20,8 @@ import java.math.BigDecimal; import java.math.RoundingMode; +import java.util.ArrayList; +import org.testng.Assert; import org.testng.annotations.Test; import static org.testng.Assert.assertEquals; @@ -29,24 +31,96 @@ public class BigDecimalUtilsTest { @Test public void testBigDecimal() { - BigDecimal value = new BigDecimal("123456789.0123456789"); - byte[] serializedValue = BigDecimalUtils.serialize(value); - assertEquals(BigDecimalUtils.byteSize(value), serializedValue.length); - BigDecimal deserializedValue = BigDecimalUtils.deserialize(serializedValue); - assertEquals(deserializedValue, value); - - // Set the scale to a negative value - value = value.setScale(-1, RoundingMode.HALF_UP); - serializedValue = BigDecimalUtils.serialize(value); - assertEquals(BigDecimalUtils.byteSize(value), serializedValue.length); - deserializedValue = BigDecimalUtils.deserialize(serializedValue); - assertEquals(deserializedValue, value); - - // Set the scale to a negative value in byte - value = value.setScale(128, RoundingMode.HALF_UP); - serializedValue = BigDecimalUtils.serialize(value); - assertEquals(BigDecimalUtils.byteSize(value), serializedValue.length); - deserializedValue = BigDecimalUtils.deserialize(serializedValue); - assertEquals(deserializedValue, value); + BigDecimal[] testCases = { + new BigDecimal("0.123456789"), + new BigDecimal("-0.123456789"), + new BigDecimal("123456789"), + new BigDecimal("-123456789"), + new BigDecimal("123456789.0123456789"), + new BigDecimal("-123456789.0123456789"), + // Set the scale to a negative value + new BigDecimal("123456789.0123456789").setScale(-1, RoundingMode.HALF_UP), + new BigDecimal("-123456789.0123456789").setScale(-1, RoundingMode.HALF_UP), + // Set the scale to a negative value in byte + new BigDecimal("123456789.0123456789").setScale(128, RoundingMode.HALF_UP), + new BigDecimal("-123456789.0123456789").setScale(128, RoundingMode.HALF_UP) + }; + for (BigDecimal value : testCases) { + byte[] serializedValue = BigDecimalUtils.serialize(value); + assertEquals(BigDecimalUtils.byteSize(value), serializedValue.length); + BigDecimal deserializedValue = BigDecimalUtils.deserialize(serializedValue); + assertEquals(deserializedValue, value); + } + } + + @Test + public void testBigDecimalSerializeWithSize() { + BigDecimal[] testCases = { + new BigDecimal("0.123456789"), + new BigDecimal("-0.123456789"), + new BigDecimal("123456789"), + new BigDecimal("-123456789"), + new BigDecimal("123456789.0123456789"), + new BigDecimal("-123456789.0123456789"), + new BigDecimal("123456789.0123456789").setScale(-1, RoundingMode.HALF_UP), + new BigDecimal("-123456789.0123456789").setScale(-1, RoundingMode.HALF_UP), + new BigDecimal("123456789.0123456789").setScale(128, RoundingMode.HALF_UP), + new BigDecimal("-123456789.0123456789").setScale(128, RoundingMode.HALF_UP) + }; + // One case of serialization with and without padding + int[] sizes = {0, 4}; + for (BigDecimal value : testCases) { + int actualSize = BigDecimalUtils.byteSize(value); + for (int size : sizes) { + byte[] serializedValue = BigDecimalUtils.serializeWithSize(value, actualSize + size); + assertEquals(actualSize + size, serializedValue.length); + BigDecimal deserializedValue = BigDecimalUtils.deserialize(serializedValue); + assertEquals(deserializedValue, value); + } + } + } + + @Test + public void testGenerateMaximumNumberWithPrecision() { + int[] testCases = { 1, 3, 10, 38, 128 }; + for (int precision : testCases) { + BigDecimal bd = BigDecimalUtils.generateMaximumNumberWithPrecision(precision); + assertEquals(bd.precision(), precision); + assertEquals(bd.add(new BigDecimal("1")).precision(), precision + 1); + } + } + + @Test + public void testBigDecimalWithMaximumPrecisionSizeInBytes() { + Assert.assertEquals(BigDecimalUtils.byteSizeForFixedPrecision(18), 10); + Assert.assertEquals(BigDecimalUtils.byteSizeForFixedPrecision(32), 16); + Assert.assertEquals(BigDecimalUtils.byteSizeForFixedPrecision(38), 18); + } + + @Test + public void testBigDecimalSerializationWithSize() { + ArrayList bigDecimals = new ArrayList<>(); + bigDecimals.add(new BigDecimal("1000.123456")); + bigDecimals.add(new BigDecimal("1237663")); + bigDecimals.add(new BigDecimal("0.114141622")); + + for (BigDecimal bigDecimal : bigDecimals) { + int bytesNeeded = BigDecimalUtils.byteSize(bigDecimal); + + // Serialize big decimal equal to the target size + byte[] bytes = BigDecimalUtils.serializeWithSize(bigDecimal, bytesNeeded); + BigDecimal bigDecimalDeserialized = BigDecimalUtils.deserialize(bytes); + Assert.assertEquals(bigDecimalDeserialized, bigDecimal); + + // Serialize big decimal smaller than the target size + bytes = BigDecimalUtils.serializeWithSize(bigDecimal, bytesNeeded + 2); + bigDecimalDeserialized = BigDecimalUtils.deserialize(bytes); + Assert.assertEquals(bigDecimalDeserialized, bigDecimal); + + // Raise exception when trying to serialize a big decimal larger than target size + Assert.assertThrows(IllegalArgumentException.class, () -> { + BigDecimalUtils.serializeWithSize(bigDecimal, bytesNeeded - 4); + }); + } } }