Skip to content

Commit

Permalink
Optimize ser/de to avoid using output stream (#9278)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackie-Jiang authored Aug 25, 2022
1 parent f6e26c2 commit d778df2
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.pinot.core.common;

import com.clearspring.analytics.stream.cardinality.HyperLogLog;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Longs;
import com.tdunning.math.stats.MergingDigest;
import com.tdunning.math.stats.TDigest;
Expand All @@ -45,14 +46,11 @@
import it.unimi.dsi.fastutil.longs.LongSet;
import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;
import it.unimi.dsi.fastutil.objects.ObjectSet;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -540,39 +538,39 @@ public byte[] serialize(Map<Object, Object> map) {
return new byte[Integer.BYTES];
}

// No need to close these 2 streams
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);

try {
// Write the size of the map
dataOutputStream.writeInt(size);

// Write the serialized key-value pairs
Iterator<Map.Entry<Object, Object>> iterator = map.entrySet().iterator();
// First write the key type and value type
Map.Entry<Object, Object> firstEntry = iterator.next();
Object firstKey = firstEntry.getKey();
Object firstValue = firstEntry.getValue();
int keyTypeValue = ObjectType.getObjectType(firstKey).getValue();
int valueTypeValue = ObjectType.getObjectType(firstValue).getValue();
dataOutputStream.writeInt(keyTypeValue);
dataOutputStream.writeInt(valueTypeValue);
// Then write each key-value pair
for (Map.Entry<Object, Object> entry : map.entrySet()) {
byte[] keyBytes = ObjectSerDeUtils.serialize(entry.getKey(), keyTypeValue);
dataOutputStream.writeInt(keyBytes.length);
dataOutputStream.write(keyBytes);

byte[] valueBytes = ObjectSerDeUtils.serialize(entry.getValue(), valueTypeValue);
dataOutputStream.writeInt(valueBytes.length);
dataOutputStream.write(valueBytes);
}
} catch (IOException e) {
throw new RuntimeException("Caught exception while serializing Map", e);
// Besides the value bytes, we store: size, key type, value type, length for each key, length for each value
long bufferSize = (3 + 2 * (long) size) * Integer.BYTES;
byte[][] keyBytesArray = new byte[size][];
byte[][] valueBytesArray = new byte[size][];
Map.Entry<Object, Object> firstEntry = map.entrySet().iterator().next();
int keyTypeValue = ObjectType.getObjectType(firstEntry.getKey()).getValue();
int valueTypeValue = ObjectType.getObjectType(firstEntry.getValue()).getValue();
ObjectSerDe keySerDe = SER_DES[keyTypeValue];
ObjectSerDe valueSerDe = SER_DES[valueTypeValue];
int index = 0;
for (Map.Entry<Object, Object> entry : map.entrySet()) {
byte[] keyBytes = keySerDe.serialize(entry.getKey());
bufferSize += keyBytes.length;
keyBytesArray[index] = keyBytes;
byte[] valueBytes = valueSerDe.serialize(entry.getValue());
bufferSize += valueBytes.length;
valueBytesArray[index++] = valueBytes;
}

return byteArrayOutputStream.toByteArray();
Preconditions.checkState(bufferSize <= Integer.MAX_VALUE, "Buffer size exceeds 2GB");
byte[] bytes = new byte[(int) bufferSize];
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
byteBuffer.putInt(size);
byteBuffer.putInt(keyTypeValue);
byteBuffer.putInt(valueTypeValue);
for (int i = 0; i < index; i++) {
byte[] keyBytes = keyBytesArray[i];
byteBuffer.putInt(keyBytes.length);
byteBuffer.put(keyBytes);
byte[] valueBytes = valueBytesArray[i];
byteBuffer.putInt(valueBytes.length);
byteBuffer.put(valueBytes);
}
return bytes;
}

@Override
Expand Down Expand Up @@ -736,20 +734,24 @@ public DoubleOpenHashSet deserialize(ByteBuffer byteBuffer) {
@Override
public byte[] serialize(Set<String> stringSet) {
int size = stringSet.size();
// NOTE: No need to close the ByteArrayOutputStream.
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
try {
dataOutputStream.writeInt(size);
for (String value : stringSet) {
byte[] bytes = value.getBytes(UTF_8);
dataOutputStream.writeInt(bytes.length);
dataOutputStream.write(bytes);
}
} catch (IOException e) {
throw new RuntimeException("Caught exception while serializing Set<String>", e);
// Besides the value bytes, we store: size, length for each value
long bufferSize = (1 + (long) size) * Integer.BYTES;
byte[][] valueBytesArray = new byte[size][];
int index = 0;
for (String value : stringSet) {
byte[] valueBytes = value.getBytes(UTF_8);
bufferSize += valueBytes.length;
valueBytesArray[index++] = valueBytes;
}
return byteArrayOutputStream.toByteArray();
Preconditions.checkState(bufferSize <= Integer.MAX_VALUE, "Buffer size exceeds 2GB");
byte[] bytes = new byte[(int) bufferSize];
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
byteBuffer.putInt(size);
for (byte[] valueBytes : valueBytesArray) {
byteBuffer.putInt(valueBytes.length);
byteBuffer.put(valueBytes);
}
return bytes;
}

@Override
Expand All @@ -776,20 +778,21 @@ public ObjectOpenHashSet<String> deserialize(ByteBuffer byteBuffer) {
@Override
public byte[] serialize(Set<ByteArray> bytesSet) {
int size = bytesSet.size();
// NOTE: No need to close the ByteArrayOutputStream.
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
try {
dataOutputStream.writeInt(size);
for (ByteArray value : bytesSet) {
byte[] bytes = value.getBytes();
dataOutputStream.writeInt(bytes.length);
dataOutputStream.write(bytes);
}
} catch (IOException e) {
throw new RuntimeException("Caught exception while serializing Set<ByteArray>", e);
// Besides the value bytes, we store: size, length for each value
long bufferSize = (1 + (long) size) * Integer.BYTES;
for (ByteArray value : bytesSet) {
bufferSize += value.length();
}
Preconditions.checkState(bufferSize <= Integer.MAX_VALUE, "Buffer size exceeds 2GB");
byte[] bytes = new byte[(int) bufferSize];
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
byteBuffer.putInt(size);
for (ByteArray value : bytesSet) {
byte[] valueBytes = value.getBytes();
byteBuffer.putInt(valueBytes.length);
byteBuffer.put(valueBytes);
}
return byteArrayOutputStream.toByteArray();
return bytes;
}

@Override
Expand Down Expand Up @@ -941,30 +944,27 @@ public byte[] serialize(List<Object> list) {
return new byte[Integer.BYTES];
}

// No need to close these 2 streams (close() is no-op)
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);

try {
// Write the size of the list
dataOutputStream.writeInt(size);

// Write the value type
Object firstValue = list.get(0);
int valueType = ObjectType.getObjectType(firstValue).getValue();
dataOutputStream.writeInt(valueType);

// Write the serialized values
for (Object value : list) {
byte[] bytes = ObjectSerDeUtils.serialize(value, valueType);
dataOutputStream.writeInt(bytes.length);
dataOutputStream.write(bytes);
}
} catch (IOException e) {
throw new RuntimeException("Caught exception while serializing List", e);
// Besides the value bytes, we store: size, value type, length for each value
long bufferSize = (2 + (long) size) * Integer.BYTES;
byte[][] valueBytesArray = new byte[size][];
int valueType = ObjectType.getObjectType(list.get(0)).getValue();
ObjectSerDe serDe = SER_DES[valueType];
int index = 0;
for (Object value : list) {
byte[] valueBytes = serDe.serialize(value);
bufferSize += valueBytes.length;
valueBytesArray[index++] = valueBytes;
}

return byteArrayOutputStream.toByteArray();
Preconditions.checkState(bufferSize <= Integer.MAX_VALUE, "Buffer size exceeds 2GB");
byte[] bytes = new byte[(int) bufferSize];
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
byteBuffer.putInt(size);
byteBuffer.putInt(valueType);
for (byte[] valueBytes : valueBytesArray) {
byteBuffer.putInt(valueBytes.length);
byteBuffer.put(valueBytes);
}
return bytes;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.perf;

import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.pinot.common.utils.HashUtil;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.spi.utils.ByteArray;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.options.OptionsBuilder;

import static java.nio.charset.StandardCharsets.UTF_8;


@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Fork(1)
@Warmup(iterations = 3, time = 10)
@Measurement(iterations = 5, time = 10)
@State(Scope.Benchmark)
public class BenchmarkObjectSerDe {
private static final int NUM_VALUES = 5_000_000;

List<String> _stringList = new ArrayList<>(NUM_VALUES);
Set<String> _stringSet = new ObjectOpenHashSet<>(NUM_VALUES);
Map<String, String> _stringToStringMap = new HashMap<>(HashUtil.getHashMapCapacity(NUM_VALUES));
Set<ByteArray> _bytesSet = new ObjectOpenHashSet<>(NUM_VALUES);

@Setup
public void setUp()
throws IOException {
for (int i = 0; i < NUM_VALUES; i++) {
String stringValue = RandomStringUtils.randomAlphanumeric(10, 201);
_stringList.add(stringValue);
_stringSet.add(stringValue);
_stringToStringMap.put(stringValue, stringValue);
_bytesSet.add(new ByteArray(stringValue.getBytes(UTF_8)));
}
}

@Benchmark
public int stringList() {
return ObjectSerDeUtils.serialize(_stringList).length;
}

@Benchmark
public int stringSet() {
return ObjectSerDeUtils.serialize(_stringSet).length;
}

@Benchmark
public int stringToStringMap() {
return ObjectSerDeUtils.serialize(_stringToStringMap).length;
}

@Benchmark
public int bytesSet() {
return ObjectSerDeUtils.serialize(_bytesSet).length;
}

public static void main(String[] args)
throws Exception {
new Runner(new OptionsBuilder().include(BenchmarkObjectSerDe.class.getSimpleName()).build()).run();
}
}

0 comments on commit d778df2

Please sign in to comment.