From fceadd0fa8991bcf7ce60f7b2244748f114fc42a Mon Sep 17 00:00:00 2001 From: "Xiaotian (Jackie) Jiang" Date: Sat, 10 Feb 2024 01:35:46 -0800 Subject: [PATCH] [Multi-stage] Optimize group key generation --- ...ictionaryMultiColumnGroupKeyGenerator.java | 189 ++++++++++++------ .../groupby/utils/BaseValueToIdMap.java | 87 -------- .../groupby/utils/BytesToIdMap.java | 67 ------- .../groupby/utils/DoubleToIdMap.java | 33 +-- .../groupby/utils/FloatToIdMap.java | 33 +-- .../aggregation/groupby/utils/IntToIdMap.java | 33 +-- .../groupby/utils/LongToIdMap.java | 33 +-- ...{StringToIdMap.java => ObjectToIdMap.java} | 32 ++- .../groupby/utils/ValueToIdMap.java | 44 ++-- .../groupby/utils/ValueToIdMapFactory.java | 11 +- .../core/util/DataBlockExtractUtils.java | 14 +- .../operator/MultistageGroupByExecutor.java | 92 ++++----- .../operator/groupby/GroupIdGenerator.java | 50 +++++ .../groupby/GroupIdGeneratorFactory.java | 48 +++++ .../groupby/MultiKeysGroupIdGenerator.java | 106 ++++++++++ .../groupby/OneDoubleKeyGroupIdGenerator.java | 92 +++++++++ .../groupby/OneFloatKeyGroupIdGenerator.java | 90 +++++++++ .../groupby/OneIntKeyGroupIdGenerator.java | 91 +++++++++ .../groupby/OneLongKeyGroupIdGenerator.java | 91 +++++++++ .../groupby/OneObjectKeyGroupIdGenerator.java | 72 +++++++ .../groupby/TwoKeysGroupIdGenerator.java | 111 ++++++++++ .../operator/AggregateOperatorTest.java | 2 +- 22 files changed, 1033 insertions(+), 388 deletions(-) delete mode 100644 pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/BaseValueToIdMap.java delete mode 100644 pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/BytesToIdMap.java rename pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/{StringToIdMap.java => ObjectToIdMap.java} (64%) create mode 100644 pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/GroupIdGenerator.java create mode 100644 pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/GroupIdGeneratorFactory.java create mode 100644 pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/MultiKeysGroupIdGenerator.java create mode 100644 pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneDoubleKeyGroupIdGenerator.java create mode 100644 pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneFloatKeyGroupIdGenerator.java create mode 100644 pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneIntKeyGroupIdGenerator.java create mode 100644 pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneLongKeyGroupIdGenerator.java create mode 100644 pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneObjectKeyGroupIdGenerator.java create mode 100644 pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/TwoKeysGroupIdGenerator.java diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/NoDictionaryMultiColumnGroupKeyGenerator.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/NoDictionaryMultiColumnGroupKeyGenerator.java index 91becf3bbeeb..9c7cf193a662 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/NoDictionaryMultiColumnGroupKeyGenerator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/NoDictionaryMultiColumnGroupKeyGenerator.java @@ -57,11 +57,9 @@ public class NoDictionaryMultiColumnGroupKeyGenerator implements GroupKeyGenerat private final ValueToIdMap[] _onTheFlyDictionaries; private final Object2IntOpenHashMap _groupKeyMap; private final boolean[] _isSingleValueExpressions; - private final int _globalGroupIdUpperBound; + private final int _numGroupsLimit; private final boolean _nullHandlingEnabled; - private int _numGroups = 0; - public NoDictionaryMultiColumnGroupKeyGenerator(BaseProjectOperator projectOperator, ExpressionContext[] groupByExpressions, int numGroupsLimit, boolean nullHandlingEnabled) { _groupByExpressions = groupByExpressions; @@ -87,12 +85,12 @@ public NoDictionaryMultiColumnGroupKeyGenerator(BaseProjectOperator projectOp _groupKeyMap = new Object2IntOpenHashMap<>(); _groupKeyMap.defaultReturnValue(INVALID_ID); - _globalGroupIdUpperBound = numGroupsLimit; + _numGroupsLimit = numGroupsLimit; } @Override public int getGlobalGroupKeyUpperBound() { - return _globalGroupIdUpperBound; + return _numGroupsLimit; } @Override @@ -117,6 +115,9 @@ public void generateKeysForBlock(ValueBlock valueBlock, int[] groupKeys) { case DOUBLE: values[i] = blockValSet.getDoubleValuesSV(); break; + case BIG_DECIMAL: + values[i] = blockValSet.getBigDecimalValuesSV(); + break; case STRING: values[i] = blockValSet.getStringValuesSV(); break; @@ -137,53 +138,134 @@ public void generateKeysForBlock(ValueBlock valueBlock, int[] groupKeys) { nullBitmaps[i] = valueBlock.getBlockValueSet(_groupByExpressions[i]).getNullBitmap(); } for (int row = 0; row < numDocs; row++) { - for (int col = 0; col < _numGroupByExpressions; col++) { - if (nullBitmaps[col] != null && nullBitmaps[col].contains(row)) { - keyValues[col] = ID_FOR_NULL; - } else { + int numGroups = _groupKeyMap.size(); + boolean hasInvalidKeyValue = false; + if (numGroups < _numGroupsLimit) { + for (int col = 0; col < _numGroupByExpressions; col++) { + if (nullBitmaps[col] != null && nullBitmaps[col].contains(row)) { + keyValues[col] = ID_FOR_NULL; + } else { + Object columnValues = values[col]; + ValueToIdMap onTheFlyDictionary = _onTheFlyDictionaries[col]; + int keyValue; + if (columnValues instanceof int[]) { + keyValue = onTheFlyDictionary.put(((int[]) columnValues)[row]); + } else if (columnValues instanceof long[]) { + keyValue = onTheFlyDictionary.put(((long[]) columnValues)[row]); + } else if (columnValues instanceof float[]) { + keyValue = onTheFlyDictionary.put(((float[]) columnValues)[row]); + } else if (columnValues instanceof double[]) { + keyValue = onTheFlyDictionary.put(((double[]) columnValues)[row]); + } else if (columnValues instanceof byte[][]) { + keyValue = onTheFlyDictionary.put(new ByteArray(((byte[][]) columnValues)[row])); + } else { + keyValue = onTheFlyDictionary.put(((Object[]) columnValues)[row]); + } + keyValues[col] = keyValue; + } + } + } else { + for (int col = 0; col < _numGroupByExpressions; col++) { + if (nullBitmaps[col] != null && nullBitmaps[col].contains(row)) { + keyValues[col] = ID_FOR_NULL; + } else { + Object columnValues = values[col]; + ValueToIdMap onTheFlyDictionary = _onTheFlyDictionaries[col]; + int keyValue; + if (columnValues instanceof int[]) { + keyValue = onTheFlyDictionary.getId(((int[]) columnValues)[row]); + } else if (columnValues instanceof long[]) { + keyValue = onTheFlyDictionary.getId(((long[]) columnValues)[row]); + } else if (columnValues instanceof float[]) { + keyValue = onTheFlyDictionary.getId(((float[]) columnValues)[row]); + } else if (columnValues instanceof double[]) { + keyValue = onTheFlyDictionary.getId(((double[]) columnValues)[row]); + } else if (columnValues instanceof byte[][]) { + keyValue = onTheFlyDictionary.getId(new ByteArray(((byte[][]) columnValues)[row])); + } else { + keyValue = onTheFlyDictionary.getId(((Object[]) columnValues)[row]); + } + if (keyValue == INVALID_ID) { + hasInvalidKeyValue = true; + break; + } + } + } + } + if (hasInvalidKeyValue) { + groupKeys[row] = INVALID_ID; + } else { + int groupId = getGroupIdForKey(flyweightKey); + if (groupId == numGroups) { + // When a new group is added, create a new FixedIntArray + keyValues = new int[_numGroupByExpressions]; + flyweightKey = new FixedIntArray(keyValues); + } + groupKeys[row] = groupId; + } + } + } else { + for (int row = 0; row < numDocs; row++) { + int numGroups = _groupKeyMap.size(); + boolean hasInvalidKeyValue = false; + if (numGroups < _numGroupsLimit) { + for (int col = 0; col < _numGroupByExpressions; col++) { Object columnValues = values[col]; ValueToIdMap onTheFlyDictionary = _onTheFlyDictionaries[col]; + int keyValue; if (columnValues instanceof int[]) { - keyValues[col] = onTheFlyDictionary.put(((int[]) columnValues)[row]); + int columnValue = ((int[]) columnValues)[row]; + keyValue = onTheFlyDictionary != null ? onTheFlyDictionary.put(columnValue) : columnValue; } else if (columnValues instanceof long[]) { - keyValues[col] = onTheFlyDictionary.put(((long[]) columnValues)[row]); + keyValue = onTheFlyDictionary.put(((long[]) columnValues)[row]); } else if (columnValues instanceof float[]) { - keyValues[col] = onTheFlyDictionary.put(((float[]) columnValues)[row]); + keyValue = onTheFlyDictionary.put(((float[]) columnValues)[row]); } else if (columnValues instanceof double[]) { - keyValues[col] = onTheFlyDictionary.put(((double[]) columnValues)[row]); - } else if (columnValues instanceof String[]) { - keyValues[col] = onTheFlyDictionary.put(((String[]) columnValues)[row]); + keyValue = onTheFlyDictionary.put(((double[]) columnValues)[row]); } else if (columnValues instanceof byte[][]) { - keyValues[col] = onTheFlyDictionary.put(new ByteArray(((byte[][]) columnValues)[row])); + keyValue = onTheFlyDictionary.put(new ByteArray(((byte[][]) columnValues)[row])); + } else { + keyValue = onTheFlyDictionary.put(((Object[]) columnValues)[row]); } + keyValues[col] = keyValue; } - } - groupKeys[row] = getGroupIdForFlyweightKey(flyweightKey); - } - } else { - for (int row = 0; row < numDocs; row++) { - for (int col = 0; col < _numGroupByExpressions; col++) { - Object columnValues = values[col]; - ValueToIdMap onTheFlyDictionary = _onTheFlyDictionaries[col]; - if (columnValues instanceof int[]) { - if (onTheFlyDictionary == null) { - keyValues[col] = ((int[]) columnValues)[row]; + } else { + for (int col = 0; col < _numGroupByExpressions; col++) { + Object columnValues = values[col]; + ValueToIdMap onTheFlyDictionary = _onTheFlyDictionaries[col]; + int keyValue; + if (columnValues instanceof int[]) { + int columnValue = ((int[]) columnValues)[row]; + keyValue = onTheFlyDictionary != null ? onTheFlyDictionary.getId(columnValue) : columnValue; + } else if (columnValues instanceof long[]) { + keyValue = onTheFlyDictionary.getId(((long[]) columnValues)[row]); + } else if (columnValues instanceof float[]) { + keyValue = onTheFlyDictionary.getId(((float[]) columnValues)[row]); + } else if (columnValues instanceof double[]) { + keyValue = onTheFlyDictionary.getId(((double[]) columnValues)[row]); + } else if (columnValues instanceof byte[][]) { + keyValue = onTheFlyDictionary.getId(new ByteArray(((byte[][]) columnValues)[row])); } else { - keyValues[col] = onTheFlyDictionary.put(((int[]) columnValues)[row]); + keyValue = onTheFlyDictionary.getId(((Object[]) columnValues)[row]); + } + if (keyValue == INVALID_ID) { + hasInvalidKeyValue = true; + break; } - } else if (columnValues instanceof long[]) { - keyValues[col] = onTheFlyDictionary.put(((long[]) columnValues)[row]); - } else if (columnValues instanceof float[]) { - keyValues[col] = onTheFlyDictionary.put(((float[]) columnValues)[row]); - } else if (columnValues instanceof double[]) { - keyValues[col] = onTheFlyDictionary.put(((double[]) columnValues)[row]); - } else if (columnValues instanceof String[]) { - keyValues[col] = onTheFlyDictionary.put(((String[]) columnValues)[row]); - } else if (columnValues instanceof byte[][]) { - keyValues[col] = onTheFlyDictionary.put(new ByteArray(((byte[][]) columnValues)[row])); + keyValues[col] = keyValue; + } + } + if (hasInvalidKeyValue) { + groupKeys[row] = INVALID_ID; + } else { + int groupId = getGroupIdForKey(flyweightKey); + if (groupId == numGroups) { + // When a new group is added, create a new FixedIntArray + keyValues = new int[_numGroupByExpressions]; + flyweightKey = new FixedIntArray(keyValues); } + groupKeys[row] = groupId; } - groupKeys[row] = getGroupIdForFlyweightKey(flyweightKey); } } } @@ -329,23 +411,6 @@ public Iterator getGroupKeys() { return new GroupKeyIterator(); } - /** - * Helper method to get or create group-id for a group key. - * - * @param flyweight Group key, that is a list of objects to be grouped, will be cloned on first occurrence - * @return Group id - */ - private int getGroupIdForFlyweightKey(FixedIntArray flyweight) { - int groupId = _groupKeyMap.getInt(flyweight); - if (groupId == INVALID_ID) { - if (_numGroups < _globalGroupIdUpperBound) { - groupId = _numGroups; - _groupKeyMap.put(flyweight.clone(), _numGroups++); - } - } - return groupId; - } - /** * Helper method to get or create group-id for a group key. * @@ -353,14 +418,12 @@ private int getGroupIdForFlyweightKey(FixedIntArray flyweight) { * @return Group id */ private int getGroupIdForKey(FixedIntArray keyList) { - int groupId = _groupKeyMap.getInt(keyList); - if (groupId == INVALID_ID) { - if (_numGroups < _globalGroupIdUpperBound) { - groupId = _numGroups; - _groupKeyMap.put(keyList, _numGroups++); - } + int numGroups = _groupKeyMap.size(); + if (numGroups < _numGroupsLimit) { + return _groupKeyMap.computeIfAbsent(keyList, k -> numGroups); + } else { + return _groupKeyMap.getInt(keyList); } - return groupId; } /** diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/BaseValueToIdMap.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/BaseValueToIdMap.java deleted file mode 100644 index 6d889d3442d3..000000000000 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/BaseValueToIdMap.java +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.core.query.aggregation.groupby.utils; - -import org.apache.pinot.spi.utils.ByteArray; - - -/** - * Abstract base class for {@link ValueToIdMap} interface. - */ -public abstract class BaseValueToIdMap implements ValueToIdMap { - @Override - public int put(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public int put(long value) { - throw new UnsupportedOperationException(); - } - - @Override - public int put(float value) { - throw new UnsupportedOperationException(); - } - - @Override - public int put(double value) { - throw new UnsupportedOperationException(); - } - - @Override - public int put(String value) { - throw new UnsupportedOperationException(); - } - - @Override - public int put(ByteArray value) { - throw new UnsupportedOperationException(); - } - - @Override - public int getInt(int id) { - throw new UnsupportedOperationException(); - } - - @Override - public long getLong(int id) { - throw new UnsupportedOperationException(); - } - - @Override - public float getFloat(int id) { - throw new UnsupportedOperationException(); - } - - @Override - public double getDouble(int id) { - throw new UnsupportedOperationException(); - } - - @Override - public String getString(int id) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteArray getBytes(int id) { - throw new UnsupportedOperationException(); - } -} diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/BytesToIdMap.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/BytesToIdMap.java deleted file mode 100644 index 2f2fe6ec2e15..000000000000 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/BytesToIdMap.java +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.core.query.aggregation.groupby.utils; - -import it.unimi.dsi.fastutil.objects.Object2IntMap; -import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; -import it.unimi.dsi.fastutil.objects.ObjectArrayList; -import it.unimi.dsi.fastutil.objects.ObjectList; -import org.apache.pinot.spi.utils.ByteArray; - - -/** - * Implementation of {@link ValueToIdMap} for ByteArray. - */ -public class BytesToIdMap extends BaseValueToIdMap { - Object2IntMap _valueToIdMap; - ObjectList _idToValueMap; - - public BytesToIdMap() { - _valueToIdMap = new Object2IntOpenHashMap<>(); - _valueToIdMap.defaultReturnValue(INVALID_KEY); - _idToValueMap = new ObjectArrayList<>(); - } - - @Override - public int put(ByteArray value) { - int id = _valueToIdMap.getInt(value); - if (id == INVALID_KEY) { - id = _idToValueMap.size(); - _valueToIdMap.put(value, id); - _idToValueMap.add(value); - } - return id; - } - - @Override - public String getString(int id) { - return getBytes(id).toHexString(); - } - - @Override - public ByteArray getBytes(int id) { - assert id < _idToValueMap.size(); - return _idToValueMap.get(id); - } - - @Override - public Object get(int id) { - return getBytes(id); - } -} diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/DoubleToIdMap.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/DoubleToIdMap.java index 4bc5a1d8d13b..0b3754ee4fe5 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/DoubleToIdMap.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/DoubleToIdMap.java @@ -18,18 +18,16 @@ */ package org.apache.pinot.core.query.aggregation.groupby.utils; -import it.unimi.dsi.fastutil.doubles.Double2IntMap; import it.unimi.dsi.fastutil.doubles.Double2IntOpenHashMap; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; -import it.unimi.dsi.fastutil.doubles.DoubleList; /** * Implementation of {@link ValueToIdMap} for double. */ -public class DoubleToIdMap extends BaseValueToIdMap { - Double2IntMap _valueToIdMap; - DoubleList _idToValueMap; +public class DoubleToIdMap implements ValueToIdMap { + private final Double2IntOpenHashMap _valueToIdMap; + private final DoubleArrayList _idToValueMap; public DoubleToIdMap() { _valueToIdMap = new Double2IntOpenHashMap(); @@ -39,28 +37,31 @@ public DoubleToIdMap() { @Override public int put(double value) { - int id = _valueToIdMap.get(value); - if (id == INVALID_KEY) { - id = _idToValueMap.size(); - _valueToIdMap.put(value, id); + int numValues = _valueToIdMap.size(); + int id = _valueToIdMap.computeIfAbsent(value, k -> numValues); + if (id == numValues) { _idToValueMap.add(value); } return id; } @Override - public double getDouble(int id) { - assert id < _idToValueMap.size(); - return _idToValueMap.getDouble(id); + public int put(Object value) { + return put((double) value); } @Override - public String getString(int id) { - return Double.toString(getDouble(id)); + public int getId(double value) { + return _valueToIdMap.get(value); } @Override - public Object get(int id) { - return getDouble(id); + public int getId(Object value) { + return getId((double) value); + } + + @Override + public Double get(int id) { + return _idToValueMap.getDouble(id); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/FloatToIdMap.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/FloatToIdMap.java index a928c2de0b1f..8b4e41aba3a5 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/FloatToIdMap.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/FloatToIdMap.java @@ -18,18 +18,16 @@ */ package org.apache.pinot.core.query.aggregation.groupby.utils; -import it.unimi.dsi.fastutil.floats.Float2IntMap; import it.unimi.dsi.fastutil.floats.Float2IntOpenHashMap; import it.unimi.dsi.fastutil.floats.FloatArrayList; -import it.unimi.dsi.fastutil.floats.FloatList; /** * Implementation of {@link ValueToIdMap} for float. */ -public class FloatToIdMap extends BaseValueToIdMap { - Float2IntMap _valueToIdMap; - FloatList _idToValueMap; +public class FloatToIdMap implements ValueToIdMap { + private final Float2IntOpenHashMap _valueToIdMap; + private final FloatArrayList _idToValueMap; public FloatToIdMap() { _valueToIdMap = new Float2IntOpenHashMap(); @@ -39,28 +37,31 @@ public FloatToIdMap() { @Override public int put(float value) { - int id = _valueToIdMap.get(value); - if (id == INVALID_KEY) { - id = _idToValueMap.size(); - _valueToIdMap.put(value, id); + int numValues = _valueToIdMap.size(); + int id = _valueToIdMap.computeIfAbsent(value, k -> numValues); + if (id == numValues) { _idToValueMap.add(value); } return id; } @Override - public float getFloat(int id) { - assert id < _idToValueMap.size(); - return _idToValueMap.getFloat(id); + public int put(Object value) { + return put((float) value); } @Override - public String getString(int id) { - return Float.toString(getFloat(id)); + public int getId(float value) { + return _valueToIdMap.get(value); } @Override - public Object get(int id) { - return getFloat(id); + public int getId(Object value) { + return getId((float) value); + } + + @Override + public Float get(int id) { + return _idToValueMap.getFloat(id); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/IntToIdMap.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/IntToIdMap.java index d07e4d62dbf6..2a05d8a2a534 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/IntToIdMap.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/IntToIdMap.java @@ -18,18 +18,16 @@ */ package org.apache.pinot.core.query.aggregation.groupby.utils; -import it.unimi.dsi.fastutil.ints.Int2IntMap; import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap; import it.unimi.dsi.fastutil.ints.IntArrayList; -import it.unimi.dsi.fastutil.ints.IntList; /** * Implementation of {@link ValueToIdMap} for int. */ -public class IntToIdMap extends BaseValueToIdMap { - Int2IntMap _valueToIdMap; - IntList _idToValueMap; +public class IntToIdMap implements ValueToIdMap { + private final Int2IntOpenHashMap _valueToIdMap; + private final IntArrayList _idToValueMap; public IntToIdMap() { _valueToIdMap = new Int2IntOpenHashMap(); @@ -39,28 +37,31 @@ public IntToIdMap() { @Override public int put(int value) { - int id = _valueToIdMap.get(value); - if (id == INVALID_KEY) { - id = _idToValueMap.size(); - _valueToIdMap.put(value, id); + int numValues = _valueToIdMap.size(); + int id = _valueToIdMap.computeIfAbsent(value, k -> numValues); + if (id == numValues) { _idToValueMap.add(value); } return id; } @Override - public int getInt(int id) { - assert id < _idToValueMap.size(); - return _idToValueMap.getInt(id); + public int put(Object value) { + return put((int) value); } @Override - public String getString(int id) { - return Integer.toString(getInt(id)); + public int getId(int value) { + return _valueToIdMap.get(value); } @Override - public Object get(int id) { - return getInt(id); + public int getId(Object value) { + return getId((int) value); + } + + @Override + public Integer get(int id) { + return _idToValueMap.getInt(id); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/LongToIdMap.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/LongToIdMap.java index 06df55c2cb1d..cc2259d27892 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/LongToIdMap.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/LongToIdMap.java @@ -18,18 +18,16 @@ */ package org.apache.pinot.core.query.aggregation.groupby.utils; -import it.unimi.dsi.fastutil.longs.Long2IntMap; import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; -import it.unimi.dsi.fastutil.longs.LongList; /** * Implementation of {@link ValueToIdMap} for long. */ -public class LongToIdMap extends BaseValueToIdMap { - Long2IntMap _valueToIdMap; - LongList _idToValueMap; +public class LongToIdMap implements ValueToIdMap { + private final Long2IntOpenHashMap _valueToIdMap; + private final LongArrayList _idToValueMap; public LongToIdMap() { _valueToIdMap = new Long2IntOpenHashMap(); @@ -39,28 +37,31 @@ public LongToIdMap() { @Override public int put(long value) { - int id = _valueToIdMap.get(value); - if (id == INVALID_KEY) { - id = _idToValueMap.size(); - _valueToIdMap.put(value, id); + int numValues = _valueToIdMap.size(); + int id = _valueToIdMap.computeIfAbsent(value, k -> numValues); + if (id == numValues) { _idToValueMap.add(value); } return id; } @Override - public long getLong(int id) { - assert id < _idToValueMap.size(); - return _idToValueMap.getLong(id); + public int put(Object value) { + return put((long) value); } @Override - public String getString(int id) { - return Long.toString(getLong(id)); + public int getId(long value) { + return _valueToIdMap.get(value); } @Override - public Object get(int id) { - return getLong(id); + public int getId(Object value) { + return getId((long) value); + } + + @Override + public Long get(int id) { + return _idToValueMap.getLong(id); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/StringToIdMap.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/ObjectToIdMap.java similarity index 64% rename from pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/StringToIdMap.java rename to pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/ObjectToIdMap.java index 290eaddaf5a5..9a0734362dbc 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/StringToIdMap.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/ObjectToIdMap.java @@ -18,44 +18,40 @@ */ package org.apache.pinot.core.query.aggregation.groupby.utils; -import it.unimi.dsi.fastutil.objects.Object2IntMap; import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; -import it.unimi.dsi.fastutil.objects.ObjectArrayList; -import it.unimi.dsi.fastutil.objects.ObjectList; +import java.util.ArrayList; /** - * Implementation of {@link ValueToIdMap} for String. + * Implementation of {@link ValueToIdMap} for Object. */ -public class StringToIdMap extends BaseValueToIdMap { - Object2IntMap _valueToIdMap; - ObjectList _idToValueMap; +public class ObjectToIdMap implements ValueToIdMap { + private final Object2IntOpenHashMap _valueToIdMap; + private final ArrayList _idToValueMap; - public StringToIdMap() { + public ObjectToIdMap() { _valueToIdMap = new Object2IntOpenHashMap<>(); _valueToIdMap.defaultReturnValue(INVALID_KEY); - _idToValueMap = new ObjectArrayList<>(); + _idToValueMap = new ArrayList<>(); } @Override - public int put(String value) { - int id = _valueToIdMap.getInt(value); - if (id == INVALID_KEY) { - id = _idToValueMap.size(); - _valueToIdMap.put(value, id); + public int put(Object value) { + int numValues = _valueToIdMap.size(); + int id = _valueToIdMap.computeIntIfAbsent(value, k -> numValues); + if (id == numValues) { _idToValueMap.add(value); } return id; } @Override - public String getString(int id) { - assert id < _idToValueMap.size(); - return _idToValueMap.get(id); + public int getId(Object value) { + return _valueToIdMap.getInt(value); } @Override public Object get(int id) { - return getString(id); + return _idToValueMap.get(id); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/ValueToIdMap.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/ValueToIdMap.java index 858e814c5197..93d65383e7f4 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/ValueToIdMap.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/ValueToIdMap.java @@ -18,38 +18,47 @@ */ package org.apache.pinot.core.query.aggregation.groupby.utils; -import org.apache.pinot.spi.utils.ByteArray; - - /** * Interface for mapping primitive values to contiguous id's. */ public interface ValueToIdMap { int INVALID_KEY = -1; - int put(int value); - - int put(long value); - - int put(float value); + default int put(int value) { + throw new UnsupportedOperationException(); + } - int put(double value); + default int put(long value) { + throw new UnsupportedOperationException(); + } - int put(String value); + default int put(float value) { + throw new UnsupportedOperationException(); + } - int put(ByteArray value); + default int put(double value) { + throw new UnsupportedOperationException(); + } - int getInt(int id); + int put(Object value); - long getLong(int id); + default int getId(int value) { + throw new UnsupportedOperationException(); + } - float getFloat(int id); + default int getId(long value) { + throw new UnsupportedOperationException(); + } - double getDouble(int id); + default int getId(float value) { + throw new UnsupportedOperationException(); + } - String getString(int id); + default int getId(double value) { + throw new UnsupportedOperationException(); + } - ByteArray getBytes(int id); + int getId(Object value); /** * Returns the value for the given id. @@ -59,6 +68,7 @@ public interface ValueToIdMap { *
  • LONG -> Long
  • *
  • FLOAT -> Float
  • *
  • DOUBLE -> Double
  • + *
  • BIG_DECIMAL -> BigDecimal
  • *
  • STRING -> String
  • *
  • BYTES -> ByteArray
  • * diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/ValueToIdMapFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/ValueToIdMapFactory.java index 444899e4b73a..ee6222566020 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/ValueToIdMapFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/utils/ValueToIdMapFactory.java @@ -18,18 +18,17 @@ */ package org.apache.pinot.core.query.aggregation.groupby.utils; -import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.data.FieldSpec.DataType; /** * Factory for various implementations for {@link ValueToIdMap} */ public class ValueToIdMapFactory { - // Private constructor to prevent instantiating the class. private ValueToIdMapFactory() { } - public static ValueToIdMap get(FieldSpec.DataType dataType) { + public static ValueToIdMap get(DataType dataType) { switch (dataType) { case INT: return new IntToIdMap(); @@ -39,12 +38,8 @@ public static ValueToIdMap get(FieldSpec.DataType dataType) { return new FloatToIdMap(); case DOUBLE: return new DoubleToIdMap(); - case STRING: - return new StringToIdMap(); - case BYTES: - return new BytesToIdMap(); default: - throw new IllegalArgumentException("Illegal data type for ValueToIdMapFactory: " + dataType); + return new ObjectToIdMap(); } } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/util/DataBlockExtractUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/util/DataBlockExtractUtils.java index 823527102f7a..445881527788 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/util/DataBlockExtractUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/util/DataBlockExtractUtils.java @@ -28,7 +28,6 @@ import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.core.common.ObjectSerDeUtils; -import org.apache.pinot.core.data.table.Key; import org.apache.pinot.spi.data.FieldSpec.DataType; import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder; import org.roaringbitmap.PeekableIntIterator; @@ -105,7 +104,7 @@ private static Object extractValue(DataBlock dataBlock, ColumnDataType storedTyp } } - public static Key[] extractKeys(DataBlock dataBlock, int[] keyIds) { + public static Object[][] extractKeys(DataBlock dataBlock, int[] keyIds) { DataSchema dataSchema = dataBlock.getDataSchema(); int numKeys = keyIds.length; ColumnDataType[] storedTypes = new ColumnDataType[numKeys]; @@ -115,7 +114,7 @@ public static Key[] extractKeys(DataBlock dataBlock, int[] keyIds) { nullBitmaps[colId] = dataBlock.getNullRowIds(keyIds[colId]); } int numRows = dataBlock.getNumberOfRows(); - Key[] keys = new Key[numRows]; + Object[][] keys = new Object[numRows][]; for (int rowId = 0; rowId < numRows; rowId++) { Object[] values = new Object[numKeys]; for (int colId = 0; colId < numKeys; colId++) { @@ -124,12 +123,13 @@ public static Key[] extractKeys(DataBlock dataBlock, int[] keyIds) { values[colId] = extractValue(dataBlock, storedTypes[colId], rowId, keyIds[colId]); } } - keys[rowId] = new Key(values); + keys[rowId] = values; } return keys; } - public static Key[] extractKeys(DataBlock dataBlock, int[] keyIds, int numMatchedRows, RoaringBitmap matchedBitmap) { + public static Object[][] extractKeys(DataBlock dataBlock, int[] keyIds, int numMatchedRows, + RoaringBitmap matchedBitmap) { DataSchema dataSchema = dataBlock.getDataSchema(); int numKeys = keyIds.length; ColumnDataType[] storedTypes = new ColumnDataType[numKeys]; @@ -138,7 +138,7 @@ public static Key[] extractKeys(DataBlock dataBlock, int[] keyIds, int numMatche storedTypes[colId] = dataSchema.getColumnDataType(keyIds[colId]).getStoredType(); nullBitmaps[colId] = dataBlock.getNullRowIds(keyIds[colId]); } - Key[] keys = new Key[numMatchedRows]; + Object[][] keys = new Object[numMatchedRows][]; PeekableIntIterator iterator = matchedBitmap.getIntIterator(); for (int matchedRowId = 0; matchedRowId < numMatchedRows; matchedRowId++) { int rowId = iterator.next(); @@ -149,7 +149,7 @@ public static Key[] extractKeys(DataBlock dataBlock, int[] keyIds, int numMatche values[colId] = extractValue(dataBlock, storedTypes[colId], rowId, keyIds[colId]); } } - keys[matchedRowId] = new Key(values); + keys[matchedRowId] = values; } return keys; } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java index 7b4cb28071a4..a89125d048d1 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java @@ -18,10 +18,9 @@ */ package org.apache.pinot.query.runtime.operator; -import it.unimi.dsi.fastutil.objects.Object2IntMap; -import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; import java.util.ArrayList; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Map; import javax.annotation.Nullable; @@ -32,7 +31,6 @@ import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.common.utils.config.QueryOptionsUtils; import org.apache.pinot.core.common.BlockValSet; -import org.apache.pinot.core.data.table.Key; import org.apache.pinot.core.plan.maker.InstancePlanMakerImplV2; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; @@ -41,6 +39,8 @@ import org.apache.pinot.query.planner.plannode.AbstractPlanNode; import org.apache.pinot.query.planner.plannode.AggregateNode.AggType; import org.apache.pinot.query.runtime.blocks.TransferableBlock; +import org.apache.pinot.query.runtime.operator.groupby.GroupIdGenerator; +import org.apache.pinot.query.runtime.operator.groupby.GroupIdGeneratorFactory; import org.apache.pinot.query.runtime.operator.utils.TypeUtils; import org.roaringbitmap.PeekableIntIterator; import org.roaringbitmap.RoaringBitmap; @@ -65,7 +65,7 @@ public class MultistageGroupByExecutor { // Mapping from the row-key to a zero based integer index. This is used when we invoke the v1 aggregation functions // because they use the zero based integer indexes to store results. - private final Object2IntOpenHashMap _groupKeyToIdMap; + private final GroupIdGenerator _groupIdGenerator; public MultistageGroupByExecutor(int[] groupKeyIds, AggregationFunction[] aggFunctions, int[] filterArgIds, int maxFilterArgId, AggType aggType, DataSchema resultSchema, Map opChainMetadata, @@ -92,8 +92,9 @@ public MultistageGroupByExecutor(int[] groupKeyIds, AggregationFunction[] aggFun _aggregateResultHolders = null; } - _groupKeyToIdMap = new Object2IntOpenHashMap<>(); - _groupKeyToIdMap.defaultReturnValue(GroupKeyGenerator.INVALID_ID); + _groupIdGenerator = + GroupIdGeneratorFactory.getGroupIdGenerator(_resultSchema.getStoredColumnDataTypes(), groupKeyIds.length, + _numGroupsLimit); } private int getNumGroupsLimit(Map opChainMetadata, @Nullable AbstractPlanNode.NodeHint nodeHint) { @@ -146,39 +147,27 @@ public void processBlock(TransferableBlock block) { * Fetches the result. */ public List getResult() { - if (_groupKeyToIdMap.isEmpty()) { + int numGroups = _groupIdGenerator.getNumGroups(); + if (numGroups == 0) { return Collections.emptyList(); } - List rows = new ArrayList<>(_groupKeyToIdMap.size()); + List rows = new ArrayList<>(numGroups); int numKeys = _groupKeyIds.length; int numFunctions = _aggFunctions.length; - int numColumns = numKeys + numFunctions; ColumnDataType[] resultStoredTypes = _resultSchema.getStoredColumnDataTypes(); - if (numKeys == 1) { - for (Object2IntMap.Entry entry : _groupKeyToIdMap.object2IntEntrySet()) { - Object[] row = new Object[numColumns]; - row[0] = entry.getKey(); - int groupId = entry.getIntValue(); - for (int i = 0; i < numFunctions; i++) { - row[i + 1] = getResultValue(i, groupId); - } - // Convert the results from AggregationFunction to the desired type - TypeUtils.convertRow(row, resultStoredTypes); - rows.add(row); - } - } else { - for (Object2IntMap.Entry entry : _groupKeyToIdMap.object2IntEntrySet()) { - Object[] row = new Object[numColumns]; - Object[] keyValues = ((Key) entry.getKey()).getValues(); - System.arraycopy(keyValues, 0, row, 0, numKeys); - int groupId = entry.getIntValue(); - for (int i = 0; i < numFunctions; i++) { - row[numKeys + i] = getResultValue(i, groupId); - } - // Convert the results from AggregationFunction to the desired type - TypeUtils.convertRow(row, resultStoredTypes); - rows.add(row); + Iterator groupKeyIterator = + _groupIdGenerator.getGroupKeyIterator(numKeys + numFunctions); + while (groupKeyIterator.hasNext()) { + GroupIdGenerator.GroupKey groupKey = groupKeyIterator.next(); + int groupId = groupKey._groupId; + Object[] row = groupKey._row; + int columnId = numKeys; + for (int i = 0; i < numFunctions; i++) { + row[columnId++] = getResultValue(i, groupId); } + // Convert the results from AggregationFunction to the desired type + TypeUtils.convertRow(row, resultStoredTypes); + rows.add(row); } return rows; } @@ -201,7 +190,7 @@ private Object getResultValue(int functionId, int groupId) { } public boolean isNumGroupsLimitReached() { - return _groupKeyToIdMap.size() == _numGroupsLimit; + return _groupIdGenerator.getNumGroups() == _numGroupsLimit; } private void processAggregate(TransferableBlock block) { @@ -212,7 +201,7 @@ private void processAggregate(TransferableBlock block) { AggregationFunction aggFunction = _aggFunctions[i]; Map blockValSetMap = AggregateOperator.getBlockValSetMap(aggFunction, block); GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i]; - groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size()); + groupByResultHolder.ensureCapacity(_groupIdGenerator.getNumGroups()); aggFunction.aggregateGroupBySV(block.getNumRows(), intKeys, groupByResultHolder, blockValSetMap); } } else { @@ -231,7 +220,7 @@ private void processAggregate(TransferableBlock block) { } Map blockValSetMap = AggregateOperator.getBlockValSetMap(aggFunction, block); GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i]; - groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size()); + groupByResultHolder.ensureCapacity(_groupIdGenerator.getNumGroups()); aggFunction.aggregateGroupBySV(block.getNumRows(), intKeys, groupByResultHolder, blockValSetMap); } else { // Need to filter the block before aggregation @@ -248,7 +237,7 @@ private void processAggregate(TransferableBlock block) { Map blockValSetMap = AggregateOperator.getFilteredBlockValSetMap(aggFunction, block, numMatchedRows, matchedBitmap); GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i]; - groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size()); + groupByResultHolder.ensureCapacity(_groupIdGenerator.getNumGroups()); aggFunction.aggregateGroupBySV(numMatchedRows, filteredIntKeys, groupByResultHolder, blockValSetMap); } } @@ -308,16 +297,16 @@ private int[] generateGroupByKeys(List rows) { if (numKeys == 1) { int groupKeyId = _groupKeyIds[0]; for (int i = 0; i < numRows; i++) { - intKeys[i] = getGroupId(rows.get(i)[groupKeyId]); + intKeys[i] = _groupIdGenerator.getGroupId(rows.get(i)[groupKeyId]); } } else { + Object[] key = new Object[numKeys]; for (int i = 0; i < numRows; i++) { Object[] row = rows.get(i); - Object[] keyValues = new Object[numKeys]; for (int j = 0; j < numKeys; j++) { - keyValues[j] = row[_groupKeyIds[j]]; + key[j] = row[_groupKeyIds[j]]; } - intKeys[i] = getGroupId(new Key(keyValues)); + intKeys[i] = _groupIdGenerator.getGroupId(key); } } return intKeys; @@ -333,7 +322,7 @@ private int[] generateGroupByKeys(DataBlock dataBlock) { int numRows = keys.length; int[] intKeys = new int[numRows]; for (int i = 0; i < numRows; i++) { - intKeys[i] = getGroupId(keys[i]); + intKeys[i] = _groupIdGenerator.getGroupId(keys[i]); } return intKeys; } @@ -354,17 +343,17 @@ private int[] generateGroupByKeys(List rows, int numMatchedRows, Roari if (numKeys == 1) { int groupKeyId = _groupKeyIds[0]; for (int i = 0; i < numMatchedRows; i++) { - intKeys[i] = getGroupId(rows.get(iterator.next())[groupKeyId]); + intKeys[i] = _groupIdGenerator.getGroupId(rows.get(iterator.next())[groupKeyId]); } } else { + Object[] key = new Object[numKeys]; for (int i = 0; i < numMatchedRows; i++) { int rowId = iterator.next(); Object[] row = rows.get(rowId); - Object[] keyValues = new Object[numKeys]; for (int j = 0; j < numKeys; j++) { - keyValues[j] = row[_groupKeyIds[j]]; + key[j] = row[_groupKeyIds[j]]; } - intKeys[i] = getGroupId(new Key(keyValues)); + intKeys[i] = _groupIdGenerator.getGroupId(key); } } return intKeys; @@ -379,17 +368,8 @@ private int[] generateGroupByKeys(DataBlock dataBlock, int numMatchedRows, Roari } int[] intKeys = new int[numMatchedRows]; for (int i = 0; i < numMatchedRows; i++) { - intKeys[i] = getGroupId(keys[i]); + intKeys[i] = _groupIdGenerator.getGroupId(keys[i]); } return intKeys; } - - private int getGroupId(Object key) { - int numGroups = _groupKeyToIdMap.size(); - if (numGroups < _numGroupsLimit) { - return _groupKeyToIdMap.computeIntIfAbsent(key, k -> numGroups); - } else { - return _groupKeyToIdMap.getInt(key); - } - } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/GroupIdGenerator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/GroupIdGenerator.java new file mode 100644 index 000000000000..de95033ab28c --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/GroupIdGenerator.java @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.groupby; + +import java.util.Iterator; + + +public interface GroupIdGenerator { + int INVALID_ID = -1; + int NULL_ID = -2; + + /** + * Returns the group id for the given key. When a new key is encountered, it assigns a new group id to it before + * reaching the groups limit, or returns {@link #INVALID_ID} when the limit is reached. + * For single key column, the input is a single Object. For multi key columns, the input is an Object[] containing + * the values for each key column. + */ + int getGroupId(Object key); + + int getNumGroups(); + + Iterator getGroupKeyIterator(int numColumns); + + class GroupKey { + public final int _groupId; + // Row is pre-allocated for key and value columns, and is safe to be modified + public final Object[] _row; + + public GroupKey(int groupId, Object[] row) { + _groupId = groupId; + _row = row; + } + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/GroupIdGeneratorFactory.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/GroupIdGeneratorFactory.java new file mode 100644 index 000000000000..16be037f3813 --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/GroupIdGeneratorFactory.java @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.groupby; + +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; + + +public class GroupIdGeneratorFactory { + private GroupIdGeneratorFactory() { + } + + public static GroupIdGenerator getGroupIdGenerator(ColumnDataType[] keyTypes, int numKeyColumns, int numGroupsLimit) { + if (numKeyColumns == 1) { + switch (keyTypes[0]) { + case INT: + return new OneIntKeyGroupIdGenerator(numGroupsLimit); + case LONG: + return new OneLongKeyGroupIdGenerator(numGroupsLimit); + case FLOAT: + return new OneFloatKeyGroupIdGenerator(numGroupsLimit); + case DOUBLE: + return new OneDoubleKeyGroupIdGenerator(numGroupsLimit); + default: + return new OneObjectKeyGroupIdGenerator(numGroupsLimit); + } + } else if (numKeyColumns == 2) { + return new TwoKeysGroupIdGenerator(keyTypes[0], keyTypes[1], numGroupsLimit); + } else { + return new MultiKeysGroupIdGenerator(keyTypes, numKeyColumns, numGroupsLimit); + } + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/MultiKeysGroupIdGenerator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/MultiKeysGroupIdGenerator.java new file mode 100644 index 000000000000..30019746b468 --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/MultiKeysGroupIdGenerator.java @@ -0,0 +1,106 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.groupby; + +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectIterator; +import java.util.Iterator; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.core.query.aggregation.groupby.utils.ValueToIdMap; +import org.apache.pinot.core.query.aggregation.groupby.utils.ValueToIdMapFactory; +import org.apache.pinot.spi.utils.FixedIntArray; + + +public class MultiKeysGroupIdGenerator implements GroupIdGenerator { + private final Object2IntOpenHashMap _groupIdMap; + private final ValueToIdMap[] _keyToIdMaps; + private final int _numGroupsLimit; + + public MultiKeysGroupIdGenerator(ColumnDataType[] keyTypes, int numKeyColumns, int numGroupsLimit) { + _groupIdMap = new Object2IntOpenHashMap<>(); + _groupIdMap.defaultReturnValue(INVALID_ID); + _keyToIdMaps = new ValueToIdMap[numKeyColumns]; + for (int i = 0; i < numKeyColumns; i++) { + _keyToIdMaps[i] = ValueToIdMapFactory.get(keyTypes[i].toDataType()); + } + _numGroupsLimit = numGroupsLimit; + } + + @Override + public int getGroupId(Object key) { + Object[] keyValues = (Object[]) key; + int numKeyColumns = keyValues.length; + int[] keyIds = new int[numKeyColumns]; + int numGroups = _groupIdMap.size(); + if (numGroups < _numGroupsLimit) { + for (int i = 0; i < numKeyColumns; i++) { + Object keyValue = keyValues[i]; + keyIds[i] = keyValue != null ? _keyToIdMaps[i].put(keyValue) : NULL_ID; + } + return _groupIdMap.computeIntIfAbsent(new FixedIntArray(keyIds), k -> numGroups); + } else { + for (int i = 0; i < numKeyColumns; i++) { + Object keyValue = keyValues[i]; + if (keyValue == null) { + keyIds[i] = NULL_ID; + } else { + int keyId = _keyToIdMaps[i].getId(keyValue); + if (keyId == INVALID_ID) { + return INVALID_ID; + } + keyIds[i] = keyId; + } + } + return _groupIdMap.getInt(new FixedIntArray(keyIds)); + } + } + + @Override + public int getNumGroups() { + return _groupIdMap.size(); + } + + @Override + public Iterator getGroupKeyIterator(int numColumns) { + return new Iterator() { + final ObjectIterator> _entryIterator = + _groupIdMap.object2IntEntrySet().fastIterator(); + + @Override + public boolean hasNext() { + return _entryIterator.hasNext(); + } + + @Override + public GroupKey next() { + Object2IntOpenHashMap.Entry entry = _entryIterator.next(); + int[] keyIds = entry.getKey().elements(); + Object[] row = new Object[numColumns]; + int numKeyColumns = keyIds.length; + for (int i = 0; i < numKeyColumns; i++) { + int keyId = keyIds[i]; + if (keyId != NULL_ID) { + row[i] = _keyToIdMaps[i].get(keyId); + } + } + return new GroupKey(entry.getIntValue(), row); + } + }; + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneDoubleKeyGroupIdGenerator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneDoubleKeyGroupIdGenerator.java new file mode 100644 index 000000000000..cf3f920d22db --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneDoubleKeyGroupIdGenerator.java @@ -0,0 +1,92 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.groupby; + +import it.unimi.dsi.fastutil.doubles.Double2IntMap; +import it.unimi.dsi.fastutil.doubles.Double2IntOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectIterator; +import java.util.Iterator; + + +public class OneDoubleKeyGroupIdGenerator implements GroupIdGenerator { + private final Double2IntOpenHashMap _groupIdMap; + private final int _numGroupsLimit; + + private int _numGroups = 0; + private int _nullGroupId = INVALID_ID; + + public OneDoubleKeyGroupIdGenerator(int numGroupsLimit) { + _groupIdMap = new Double2IntOpenHashMap(); + _groupIdMap.defaultReturnValue(INVALID_ID); + _numGroupsLimit = numGroupsLimit; + } + + @Override + public int getGroupId(Object key) { + if (_numGroups < _numGroupsLimit) { + if (key == null) { + if (_nullGroupId == INVALID_ID) { + _nullGroupId = _numGroups++; + } + return _nullGroupId; + } + int groupId = _groupIdMap.computeIfAbsent((double) key, k -> _numGroups); + if (groupId == _numGroups) { + _numGroups++; + } + return groupId; + } else { + if (key == null) { + return _nullGroupId; + } + return _groupIdMap.get((double) key); + } + } + + @Override + public int getNumGroups() { + return _numGroups; + } + + @Override + public Iterator getGroupKeyIterator(int numColumns) { + return new Iterator() { + final ObjectIterator _entryIterator = + _groupIdMap.double2IntEntrySet().fastIterator(); + boolean _returnNull = _nullGroupId != INVALID_ID; + + @Override + public boolean hasNext() { + return _returnNull || _entryIterator.hasNext(); + } + + @Override + public GroupKey next() { + Object[] row = new Object[numColumns]; + if (_returnNull) { + _returnNull = false; + return new GroupKey(_nullGroupId, row); + } + Double2IntMap.Entry entry = _entryIterator.next(); + row[0] = entry.getDoubleKey(); + return new GroupKey(entry.getIntValue(), row); + } + }; + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneFloatKeyGroupIdGenerator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneFloatKeyGroupIdGenerator.java new file mode 100644 index 000000000000..5d3005dc6b68 --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneFloatKeyGroupIdGenerator.java @@ -0,0 +1,90 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.groupby; + +import it.unimi.dsi.fastutil.floats.Float2IntOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectIterator; +import java.util.Iterator; + + +public class OneFloatKeyGroupIdGenerator implements GroupIdGenerator { + private final Float2IntOpenHashMap _groupIdMap; + private final int _numGroupsLimit; + + private int _numGroups = 0; + private int _nullGroupId = INVALID_ID; + + public OneFloatKeyGroupIdGenerator(int numGroupsLimit) { + _groupIdMap = new Float2IntOpenHashMap(); + _groupIdMap.defaultReturnValue(INVALID_ID); + _numGroupsLimit = numGroupsLimit; + } + + @Override + public int getGroupId(Object key) { + if (_numGroups < _numGroupsLimit) { + if (key == null) { + if (_nullGroupId == INVALID_ID) { + _nullGroupId = _numGroups++; + } + return _nullGroupId; + } + int groupId = _groupIdMap.computeIfAbsent((float) key, k -> _numGroups); + if (groupId == _numGroups) { + _numGroups++; + } + return groupId; + } else { + if (key == null) { + return _nullGroupId; + } + return _groupIdMap.get((float) key); + } + } + + @Override + public int getNumGroups() { + return _numGroups; + } + + @Override + public Iterator getGroupKeyIterator(int numColumns) { + return new Iterator() { + final ObjectIterator _entryIterator = _groupIdMap.float2IntEntrySet().fastIterator(); + boolean _returnNull = _nullGroupId != INVALID_ID; + + @Override + public boolean hasNext() { + return _returnNull || _entryIterator.hasNext(); + } + + @Override + public GroupKey next() { + Object[] row = new Object[numColumns]; + if (_returnNull) { + _returnNull = false; + return new GroupKey(_nullGroupId, row); + } + Float2IntOpenHashMap.Entry entry = _entryIterator.next(); + row[0] = entry.getFloatKey(); + return new GroupKey(entry.getIntValue(), row); + } + }; + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneIntKeyGroupIdGenerator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneIntKeyGroupIdGenerator.java new file mode 100644 index 000000000000..77064f8b3e49 --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneIntKeyGroupIdGenerator.java @@ -0,0 +1,91 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.groupby; + +import it.unimi.dsi.fastutil.ints.Int2IntMap; +import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectIterator; +import java.util.Iterator; + + +public class OneIntKeyGroupIdGenerator implements GroupIdGenerator { + private final Int2IntOpenHashMap _groupIdMap; + private final int _numGroupsLimit; + + private int _numGroups = 0; + private int _nullGroupId = INVALID_ID; + + public OneIntKeyGroupIdGenerator(int numGroupsLimit) { + _groupIdMap = new Int2IntOpenHashMap(); + _groupIdMap.defaultReturnValue(INVALID_ID); + _numGroupsLimit = numGroupsLimit; + } + + @Override + public int getGroupId(Object key) { + if (_numGroups < _numGroupsLimit) { + if (key == null) { + if (_nullGroupId == INVALID_ID) { + _nullGroupId = _numGroups++; + } + return _nullGroupId; + } + int groupId = _groupIdMap.computeIfAbsent((int) key, k -> _numGroups); + if (groupId == _numGroups) { + _numGroups++; + } + return groupId; + } else { + if (key == null) { + return _nullGroupId; + } + return _groupIdMap.get((int) key); + } + } + + @Override + public int getNumGroups() { + return _numGroups; + } + + @Override + public Iterator getGroupKeyIterator(int numColumns) { + return new Iterator() { + final ObjectIterator _entryIterator = _groupIdMap.int2IntEntrySet().fastIterator(); + boolean _returnNull = _nullGroupId != INVALID_ID; + + @Override + public boolean hasNext() { + return _returnNull || _entryIterator.hasNext(); + } + + @Override + public GroupKey next() { + Object[] row = new Object[numColumns]; + if (_returnNull) { + _returnNull = false; + return new GroupKey(_nullGroupId, row); + } + Int2IntMap.Entry entry = _entryIterator.next(); + row[0] = entry.getIntKey(); + return new GroupKey(entry.getIntValue(), row); + } + }; + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneLongKeyGroupIdGenerator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneLongKeyGroupIdGenerator.java new file mode 100644 index 000000000000..5862df3defef --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneLongKeyGroupIdGenerator.java @@ -0,0 +1,91 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.groupby; + +import it.unimi.dsi.fastutil.longs.Long2IntMap; +import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectIterator; +import java.util.Iterator; + + +public class OneLongKeyGroupIdGenerator implements GroupIdGenerator { + private final Long2IntOpenHashMap _groupIdMap; + private final int _numGroupsLimit; + + private int _numGroups = 0; + private int _nullGroupId = INVALID_ID; + + public OneLongKeyGroupIdGenerator(int numGroupsLimit) { + _groupIdMap = new Long2IntOpenHashMap(); + _groupIdMap.defaultReturnValue(INVALID_ID); + _numGroupsLimit = numGroupsLimit; + } + + @Override + public int getGroupId(Object key) { + if (_numGroups < _numGroupsLimit) { + if (key == null) { + if (_nullGroupId == INVALID_ID) { + _nullGroupId = _numGroups++; + } + return _nullGroupId; + } + int groupId = _groupIdMap.computeIfAbsent((long) key, k -> _numGroups); + if (groupId == _numGroups) { + _numGroups++; + } + return groupId; + } else { + if (key == null) { + return _nullGroupId; + } + return _groupIdMap.get((long) key); + } + } + + @Override + public int getNumGroups() { + return _numGroups; + } + + @Override + public Iterator getGroupKeyIterator(int numColumns) { + return new Iterator() { + final ObjectIterator _entryIterator = _groupIdMap.long2IntEntrySet().fastIterator(); + boolean _returnNull = _nullGroupId != INVALID_ID; + + @Override + public boolean hasNext() { + return _returnNull || _entryIterator.hasNext(); + } + + @Override + public GroupKey next() { + Object[] row = new Object[numColumns]; + if (_returnNull) { + _returnNull = false; + return new GroupKey(_nullGroupId, row); + } + Long2IntMap.Entry entry = _entryIterator.next(); + row[0] = entry.getLongKey(); + return new GroupKey(entry.getIntValue(), row); + } + }; + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneObjectKeyGroupIdGenerator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneObjectKeyGroupIdGenerator.java new file mode 100644 index 000000000000..e7d7bc38153a --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/OneObjectKeyGroupIdGenerator.java @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.groupby; + +import it.unimi.dsi.fastutil.objects.Object2IntMap; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectIterator; +import java.util.Iterator; + + +public class OneObjectKeyGroupIdGenerator implements GroupIdGenerator { + private final Object2IntOpenHashMap _groupIdMap; + private final int _numGroupsLimit; + + public OneObjectKeyGroupIdGenerator(int numGroupsLimit) { + _groupIdMap = new Object2IntOpenHashMap<>(); + _groupIdMap.defaultReturnValue(INVALID_ID); + _numGroupsLimit = numGroupsLimit; + } + + @Override + public int getGroupId(Object key) { + int numGroups = _groupIdMap.size(); + if (numGroups < _numGroupsLimit) { + return _groupIdMap.computeIntIfAbsent(key, k -> numGroups); + } else { + return _groupIdMap.getInt(key); + } + } + + @Override + public int getNumGroups() { + return _groupIdMap.size(); + } + + @Override + public Iterator getGroupKeyIterator(int numColumns) { + return new Iterator() { + final ObjectIterator> _entryIterator = + _groupIdMap.object2IntEntrySet().fastIterator(); + + @Override + public boolean hasNext() { + return _entryIterator.hasNext(); + } + + @Override + public GroupKey next() { + Object2IntMap.Entry entry = _entryIterator.next(); + Object[] row = new Object[numColumns]; + row[0] = entry.getKey(); + return new GroupKey(entry.getIntValue(), row); + } + }; + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/TwoKeysGroupIdGenerator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/TwoKeysGroupIdGenerator.java new file mode 100644 index 000000000000..21e8fcf1448b --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/groupby/TwoKeysGroupIdGenerator.java @@ -0,0 +1,111 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.groupby; + +import it.unimi.dsi.fastutil.longs.Long2IntMap; +import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectIterator; +import java.util.Iterator; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.core.query.aggregation.groupby.utils.ValueToIdMap; +import org.apache.pinot.core.query.aggregation.groupby.utils.ValueToIdMapFactory; + + +public class TwoKeysGroupIdGenerator implements GroupIdGenerator { + private final Long2IntOpenHashMap _groupIdMap; + private final ValueToIdMap _firstKeyToIdMap; + private final ValueToIdMap _secondKeyToIdMap; + private final int _numGroupsLimit; + + public TwoKeysGroupIdGenerator(ColumnDataType firstKeyType, ColumnDataType secondKeyType, int numGroupsLimit) { + _groupIdMap = new Long2IntOpenHashMap(); + _groupIdMap.defaultReturnValue(INVALID_ID); + _firstKeyToIdMap = ValueToIdMapFactory.get(firstKeyType.toDataType()); + _secondKeyToIdMap = ValueToIdMapFactory.get(secondKeyType.toDataType()); + _numGroupsLimit = numGroupsLimit; + } + + @Override + public int getGroupId(Object key) { + Object[] keyValues = (Object[]) key; + Object firstKey = keyValues[0]; + Object secondKey = keyValues[1]; + int numGroups = _groupIdMap.size(); + if (numGroups < _numGroupsLimit) { + int firstKeyId = firstKey != null ? _firstKeyToIdMap.put(firstKey) : NULL_ID; + int secondKeyId = secondKey != null ? _secondKeyToIdMap.put(secondKey) : NULL_ID; + long longKey = ((long) firstKeyId << 32) | (secondKeyId & 0xFFFFFFFFL); + return _groupIdMap.computeIfAbsent(longKey, k -> numGroups); + } else { + int firstKeyId; + if (firstKey != null) { + firstKeyId = _firstKeyToIdMap.getId(firstKey); + if (firstKeyId == INVALID_ID) { + return INVALID_ID; + } + } else { + firstKeyId = NULL_ID; + } + int secondKeyId; + if (secondKey != null) { + secondKeyId = _secondKeyToIdMap.getId(secondKey); + if (secondKeyId == INVALID_ID) { + return INVALID_ID; + } + } else { + secondKeyId = NULL_ID; + } + long longKey = ((long) firstKeyId << 32) | (secondKeyId & 0xFFFFFFFFL); + return _groupIdMap.get(longKey); + } + } + + @Override + public int getNumGroups() { + return _groupIdMap.size(); + } + + @Override + public Iterator getGroupKeyIterator(int numColumns) { + return new Iterator() { + final ObjectIterator _entryIterator = _groupIdMap.long2IntEntrySet().fastIterator(); + + @Override + public boolean hasNext() { + return _entryIterator.hasNext(); + } + + @Override + public GroupKey next() { + Long2IntMap.Entry entry = _entryIterator.next(); + long longKey = entry.getLongKey(); + Object[] row = new Object[numColumns]; + int firstKeyId = (int) (longKey >>> 32); + int secondKeyId = (int) longKey; + if (firstKeyId != NULL_ID) { + row[0] = _firstKeyToIdMap.get(firstKeyId); + } + if (secondKeyId != NULL_ID) { + row[1] = _secondKeyToIdMap.get(secondKeyId); + } + return new GroupKey(entry.getIntValue(), row); + } + }; + } +} diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java index 93b65dad7f49..c1e5255f8576 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java @@ -266,7 +266,7 @@ public void shouldReturnErrorBlockOnUnexpectedInputType() { // Then: Assert.assertTrue(block.isErrorBlock(), "expected ERROR block from invalid computation"); - Assert.assertTrue(block.getExceptions().get(1000).contains("String cannot be cast to class"), + Assert.assertTrue(block.getExceptions().get(1000).contains("cannot be cast to class"), "expected it to fail with class cast exception"); }