Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write perf experiments #394

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,24 @@ private static String buildFieldWriteLines(
private static String generateFieldWriteLines(
final Field field, final String modelClassName, String getValueCode, boolean skipDefault) {
final String fieldDef = Common.camelToUpperSnake(field.name());
String prefix = "// [%d] - %s%n".formatted(field.fieldNumber(), field.name());
String prefix = "// ["+field.fieldNumber()+"] - "+field.name();
prefix += "\n";
String postFix = "";
int indent = 0;

if (field.parent() != null) {
final OneOfField oneOfField = field.parent();
final String oneOfType = "%s.%sOneOfType".formatted(modelClassName, oneOfField.nameCamelFirstUpper());
getValueCode = "data.%s().as()".formatted(oneOfField.nameCamelFirstLower());
prefix += "if (data.%s().kind() == %s.%s)%n"
.formatted(oneOfField.nameCamelFirstLower(), oneOfType, Common.camelToUpperSnake(field.name()));
final String oneOfType = modelClassName+"."+oneOfField.nameCamelFirstUpper()+"OneOfType";
getValueCode = "data."+oneOfField.nameCamelFirstLower()+"().as()";
prefix += "if (data."+oneOfField.nameCamelFirstLower()+"().kind() == "+ oneOfType +"."+
Common.camelToUpperSnake(field.name())+") {\n";
postFix += "}\n";
indent ++;
}
// spotless:off
final String writeMethodName = field.methodNameType();
if (field.optionalValueType()) {
return prefix + switch (field.messageType()) {
return prefix + (switch (field.messageType()) {
case "StringValue" -> "writeOptionalString(out, %s, %s);"
.formatted(fieldDef,getValueCode);
case "BoolValue" -> "writeOptionalBoolean(out, %s, %s);"
Expand All @@ -96,24 +101,17 @@ private static String generateFieldWriteLines(
.formatted(fieldDef, getValueCode);
case "BytesValue" -> "writeOptionalBytes(out, %s, %s);"
.formatted(fieldDef, getValueCode);
default -> throw new UnsupportedOperationException(
"Unhandled optional message type:%s".formatted(field.messageType()));
};
} else {
String codecReference = "";
if (Field.FieldType.MESSAGE.equals(field.type())) {
codecReference = "%s.%s.PROTOBUF".formatted(((SingleField) field).messageTypeModelPackage(),
Common.capitalizeFirstLetter(field.messageType()));
}
if (field.repeated()) {
return prefix + switch(field.type()) {
default -> throw new UnsupportedOperationException("Unhandled optional message type:"+field.messageType());
}).indent(indent) + postFix;
} else if (field.repeated()) {
return prefix + (switch(field.type()) {
case ENUM -> "writeEnumList(out, %s, %s);"
.formatted(fieldDef, getValueCode);
case MESSAGE -> "writeMessageList(out, %s, %s, %s);"
.formatted(fieldDef, getValueCode, codecReference);
default -> "write%sList(out, %s, %s);"
.formatted(writeMethodName, fieldDef, getValueCode);
};
}).indent(indent) + postFix;
} else if (field.type() == Field.FieldType.MAP) {
// https://protobuf.dev/programming-guides/proto3/#maps
// On the wire, a map is equivalent to:
Expand Down Expand Up @@ -160,22 +158,165 @@ private static String generateFieldWriteLines(
.replace("$K", mapField.keyField().type().boxedType)
.replace("$V", mapField.valueField().type() == Field.FieldType.MESSAGE ? ((SingleField)mapField.valueField()).messageType() : mapField.valueField().type().boxedType)
.replace("$fieldWriteLines", fieldWriteLines.indent(DEFAULT_INDENT))
.replace("$fieldSizeOfLines", fieldSizeOfLines.indent(DEFAULT_INDENT));
.replace("$fieldSizeOfLines", fieldSizeOfLines.indent(DEFAULT_INDENT))
.indent(indent) + postFix
;
} else {
return prefix + switch(field.type()) {
return prefix + (switch(field.type()) {
case ENUM -> "writeEnum(out, %s, %s);"
.formatted(fieldDef, getValueCode);
case STRING -> "writeString(out, %s, %s, %s);"
.formatted(fieldDef, getValueCode, skipDefault);
case MESSAGE -> "writeMessage(out, %s, %s, %s);"
.formatted(fieldDef, getValueCode, codecReference);
case MESSAGE -> writeMessageCode(field, fieldDef, getValueCode);
case BOOL -> "writeBoolean(out, %s, %s, %s);"
.formatted(fieldDef, getValueCode, skipDefault);
case INT32, UINT32, SINT32, FIXED32, SFIXED32, INT64, SINT64, UINT64, FIXED64, SFIXED64, BYTES ->
"write%s(out, %s, %s, %s);".formatted(writeMethodName, fieldDef, getValueCode, skipDefault);
case INT32, UINT32, SINT32, FIXED32, SFIXED32, INT64, SINT64, UINT64, FIXED64, SFIXED64 ->
writeNumberCode(field, getValueCode, skipDefault);
case BYTES ->
"write%s(out, %s, %s, %s);"
.formatted(writeMethodName, fieldDef, getValueCode, skipDefault);
default -> "write%s(out, %s, %s);"
.formatted(writeMethodName, fieldDef, getValueCode);
};
}).indent(indent) + postFix;
}
}

private static String writeMessageCode(final Field field, final String fieldDef, final String getValueCode) {
String code = "";
// When not a oneOf don't write default value
if (field.parent() != null) {
code += "if (%s == null) {\n".formatted(getValueCode);
code += writeTagCode(field, ProtoConstants.WIRE_TYPE_DELIMITED).indent(DEFAULT_INDENT);
code += "out.writeByte((byte)0);\n".indent(DEFAULT_INDENT);
code += "}\n";
}
code += "if (%s != null) {\n".formatted(getValueCode);
code += writeTagCode(field, ProtoConstants.WIRE_TYPE_DELIMITED).indent(DEFAULT_INDENT);
if(field.parent() != null) {
code += "final int msgSize = ((%s)%s).protobufSize();\n".formatted(field.messageType(), getValueCode)
.indent(DEFAULT_INDENT);
} else {
code += "final int msgSize = %s.protobufSize();\n".formatted(getValueCode).indent(DEFAULT_INDENT);
}
code += "out.writeVarInt(msgSize, false);\n".indent(DEFAULT_INDENT);
code += "if (msgSize > 0) %s.write(%s, out);\n".formatted(field.messageType()+".PROTOBUF", getValueCode).indent(DEFAULT_INDENT);
code += "}\n";
return code;
}

private static String writeNumberCode(final Field field, final String getValueCode, final boolean skipDefault) {
assert !field.repeated() : "Use write***List methods with repeated types";
final String objectCastName = switch(field.type()) {
case INT32, UINT32, SINT32, FIXED32, SFIXED32 -> "Integer";
case INT64, UINT64, SINT64, FIXED64, SFIXED64 -> "Long";
default -> throw new RuntimeException("Unsupported field type. Bug in ProtoOutputStream, shouldn't happen.");
};
String code = "";
int indent = 0;
if (skipDefault) {
if(field.parent() == null) {
code += "if (%s != 0) {\n".formatted(getValueCode);
} else {
code += "if ((%s)%s != 0) {\n".formatted(objectCastName, getValueCode);
}
indent ++;
}
String writeCode = switch (field.type()) {
case INT32, INT64, UINT64 ->
writeTagCode(field, ProtoConstants.WIRE_TYPE_VARINT_OR_ZIGZAG) +
"out.writeVarLong("+getValueCode+", false);\n";
case UINT32 ->
writeTagCode(field, ProtoConstants.WIRE_TYPE_VARINT_OR_ZIGZAG) +
"out.writeVarLong(Integer.toUnsignedLong("+getValueCode+"), false);\n";
case SINT32, SINT64 ->
writeTagCode(field, ProtoConstants.WIRE_TYPE_VARINT_OR_ZIGZAG) +
"out.writeVarLong("+getValueCode+", true);\n";
case SFIXED32, FIXED32 ->
// The bytes in protobuf are in little-endian order -- backwards for Java.
// Smallest byte first.
writeTagCode(field, ProtoConstants.WIRE_TYPE_VARINT_OR_ZIGZAG) +
"out.writeInt("+getValueCode+", ByteOrder.LITTLE_ENDIAN);\n";
case SFIXED64, FIXED64 ->
// The bytes in protobuf are in little-endian order -- backwards for Java.
// Smallest byte first.
writeTagCode(field, ProtoConstants.WIRE_TYPE_VARINT_OR_ZIGZAG) +
"out.writeLong("+getValueCode+", ByteOrder.LITTLE_ENDIAN);\n";
default -> throw new RuntimeException("Unsupported field type. Bug in ProtoOutputStream, shouldn't happen.");
};
code += writeCode.indent(DEFAULT_INDENT * indent);

if (skipDefault) {
indent --;
code += "}\n".indent(DEFAULT_INDENT * indent);
}
return code;
}

/**
* Generate manually inlined code to write tag for field
*
* @param field The field to generate tag for
* @param wireType The wire type for the field
* @return java code to write tag for field
*/
private static String writeTagCode(final Field field, final ProtoConstants wireType) {
return writeVarLongCode(((long)field.fieldNumber() << TAG_TYPE_BITS) | wireType.ordinal(), false);
}

/**
* Generate manually inlined code to write varint
*
* @param value The value to write
* @param zigZag If true, use zigzag encoding
* @return java code to write varint
*/
private static String writeVarLongCode(long value, final boolean zigZag) {
if (zigZag) {
value = (value << 1) ^ (value >> 63);
}
StringBuilder code = new StringBuilder();
while (true) {
if ((value & ~0x7FL) == 0) {
code.append("out.writeByte((byte) 0x%08X);\n".formatted(value));
break;
} else {
code.append("out.writeByte((byte) 0x%08X );\n".formatted((byte) (((int) value & 0x7F) | 0x80)));
value >>>= 7;
}
}
return code.toString();
}

/** The number of leading bits of the tag that are used to store field type, the rest is field number */
private static final int TAG_TYPE_BITS = 3;

/**
* Protobuf field types
*/
private enum ProtoConstants {
/** On wire encoded type for varint */
WIRE_TYPE_VARINT_OR_ZIGZAG,
/** On wire encoded type for fixed 64bit */
WIRE_TYPE_FIXED_64_BIT,
/** On wire encoded type for length delimited */
WIRE_TYPE_DELIMITED,
/** On wire encoded type for group start, deprecated */
WIRE_TYPE_GROUP_START,
/** On wire encoded type for group end, deprecated */
WIRE_TYPE_GROUP_END,
/** On wire encoded type for fixed 32bit */
WIRE_TYPE_FIXED_32_BIT;

// values() seems to allocate a new array on each call, so let's cache it here
private static final ProtoConstants[] values = values();

/**
* Mask used to extract the wire type from the "tag" byte
*/
public static final int TAG_WIRE_TYPE_MASK = 0b0000_0111;

public static ProtoConstants get(int ordinal) {
return values[ordinal];
}
}
// spotless:on
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package com.hedera.pbj.intergration.jmh;

import com.google.protobuf.GeneratedMessage;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import com.hedera.hapi.node.base.Timestamp;
import com.hedera.hapi.node.token.AccountDetails;
import com.hedera.pbj.integration.AccountDetailsPbj;
import com.hedera.pbj.integration.EverythingTestData;
import com.hedera.pbj.runtime.Codec;
import com.hedera.pbj.runtime.JsonCodec;
import com.hedera.pbj.runtime.ParseException;
import com.hedera.pbj.runtime.io.buffer.BufferedData;
import com.hedera.pbj.test.proto.pbj.Everything;
import com.hederahashgraph.api.proto.java.GetAccountDetailsResponse;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

@SuppressWarnings("unused")
@Fork(1)
@Warmup(iterations = 2, time = 2)
@Measurement(iterations = 5, time = 2)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@BenchmarkMode(Mode.AverageTime)
public abstract class JsonBench<P,G extends GeneratedMessage> {

@SuppressWarnings("rawtypes")
@State(Scope.Benchmark)
public static class JsonBenchmarkState<P,G extends GeneratedMessage> {
private JsonCodec<P> pbjJsonCodec;
private Supplier<GeneratedMessage.Builder> builderSupplier;
// input objects
private P pbjModelObject;
private G googleModelObject;

// input bytes
private BufferedData jsonDataBuffer;
private String jsonString;

// output buffers
private BufferedData outDataBuffer;
public void configure(P pbjModelObject, Codec<P> pbjProtoCodec, JsonCodec<P> pbjJsonCodec,
ProtobufObjectBench.ProtobufParseFunction<byte[],G> googleByteArrayParseMethod,
Supplier<GeneratedMessage.Builder> builderSupplier) {
try {
this.pbjModelObject = pbjModelObject;
this.pbjJsonCodec = pbjJsonCodec;
this.builderSupplier = builderSupplier;
// write to JSON for parse tests
jsonDataBuffer = BufferedData.allocate(5 * 1024 * 1024);
pbjJsonCodec.write(pbjModelObject, jsonDataBuffer);
jsonDataBuffer.flip();
// get as string for parse tests
jsonString = jsonDataBuffer.asUtf8String();

// write to temp data buffer and then read into byte array
BufferedData tempDataBuffer = BufferedData.allocate(5 * 1024 * 1024);
pbjProtoCodec.write(pbjModelObject, tempDataBuffer);
tempDataBuffer.flip();
byte[] protoBytes = new byte[(int)tempDataBuffer.length()];
tempDataBuffer.getBytes(0,protoBytes);
// convert to protobuf
googleModelObject = googleByteArrayParseMethod.parse(protoBytes);

// input buffers
// output buffers
this.outDataBuffer = BufferedData.allocate(jsonString.length());
} catch (IOException e) {
e.getStackTrace();
System.err.flush();
throw new RuntimeException(e);
}
}
}

/** Same as parsePbjByteBuffer because DataBuffer.wrap(byte[]) uses ByteBuffer today, added this because makes result plotting easier */
@Benchmark
public void parsePbj(JsonBenchmarkState<P,G> benchmarkState, Blackhole blackhole) throws ParseException {
benchmarkState.jsonDataBuffer.position(0);
blackhole.consume(benchmarkState.pbjJsonCodec.parse(benchmarkState.jsonDataBuffer));
}

@Benchmark
public void parseProtoC(JsonBenchmarkState<P,G> benchmarkState, Blackhole blackhole) throws IOException {
var builder = benchmarkState.builderSupplier.get();
JsonFormat.parser().merge(benchmarkState.jsonString, builder);
blackhole.consume(builder.build());
}

/** Same as writePbjByteBuffer because DataBuffer.wrap(byte[]) uses ByteBuffer today, added this because makes result plotting easier */
@Benchmark
public void writePbj(JsonBenchmarkState<P,G> benchmarkState, Blackhole blackhole) throws IOException {
benchmarkState.outDataBuffer.reset();
benchmarkState.pbjJsonCodec.write(benchmarkState.pbjModelObject, benchmarkState.outDataBuffer);
blackhole.consume(benchmarkState.outDataBuffer);
}

@Benchmark
public void writeProtoC(JsonBenchmarkState<P,G> benchmarkState, Blackhole blackhole) throws InvalidProtocolBufferException {
blackhole.consume(JsonFormat.printer().print(benchmarkState.googleModelObject));
}

/** Custom interface for method references as java.util.Function does not throw IOException */
public interface ProtobufParseFunction<D, G> {
G parse(D data) throws IOException;
}

@State(Scope.Benchmark)
public static class EverythingBench extends JsonBench<Everything, com.hedera.pbj.test.proto.java.Everything> {
@Setup
public void setup(JsonBenchmarkState<Everything, com.hedera.pbj.test.proto.java.Everything> benchmarkState) {
benchmarkState.configure(EverythingTestData.EVERYTHING,
Everything.PROTOBUF,
Everything.JSON,
com.hedera.pbj.test.proto.java.Everything::parseFrom,
com.hedera.pbj.test.proto.java.Everything::newBuilder);
}
}

@State(Scope.Benchmark)
public static class TimeStampBench extends JsonBench<Timestamp , com.hederahashgraph.api.proto.java.Timestamp> {
@Setup
public void setup(JsonBenchmarkState<Timestamp , com.hederahashgraph.api.proto.java.Timestamp> benchmarkState) {
benchmarkState.configure(new Timestamp(5678L, 1234),
Timestamp.PROTOBUF,
Timestamp.JSON,
com.hederahashgraph.api.proto.java.Timestamp::parseFrom,
com.hederahashgraph.api.proto.java.Timestamp::newBuilder);
}
}

@State(Scope.Benchmark)
public static class AccountDetailsBench extends JsonBench<com.hedera.hapi.node.token.AccountDetails, GetAccountDetailsResponse.AccountDetails> {
@Setup
public void setup(JsonBenchmarkState<com.hedera.hapi.node.token.AccountDetails, GetAccountDetailsResponse.AccountDetails> benchmarkState) {
benchmarkState.configure(AccountDetailsPbj.ACCOUNT_DETAILS,
AccountDetails.PROTOBUF,
AccountDetails.JSON,
GetAccountDetailsResponse.AccountDetails::parseFrom,
GetAccountDetailsResponse.AccountDetails::newBuilder);
}
}
}
Loading
Loading