diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java index cca56b46..601c3a69 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java @@ -69,19 +69,24 @@ private static String buildFieldWriteLines( private static String generateFieldWriteLines( final Field field, final String modelClassName, String getValueCode, boolean skipDefault) { final String fieldDef = Common.camelToUpperSnake(field.name()); - String prefix = "// [%d] - %s%n".formatted(field.fieldNumber(), field.name()); + String prefix = "// ["+field.fieldNumber()+"] - "+field.name(); + prefix += "\n"; + String postFix = ""; + int indent = 0; if (field.parent() != null) { final OneOfField oneOfField = field.parent(); - final String oneOfType = "%s.%sOneOfType".formatted(modelClassName, oneOfField.nameCamelFirstUpper()); - getValueCode = "data.%s().as()".formatted(oneOfField.nameCamelFirstLower()); - prefix += "if (data.%s().kind() == %s.%s)%n" - .formatted(oneOfField.nameCamelFirstLower(), oneOfType, Common.camelToUpperSnake(field.name())); + final String oneOfType = modelClassName+"."+oneOfField.nameCamelFirstUpper()+"OneOfType"; + getValueCode = "data."+oneOfField.nameCamelFirstLower()+"().as()"; + prefix += "if (data."+oneOfField.nameCamelFirstLower()+"().kind() == "+ oneOfType +"."+ + Common.camelToUpperSnake(field.name())+") {\n"; + postFix += "}\n"; + indent ++; } // spotless:off final String writeMethodName = field.methodNameType(); if (field.optionalValueType()) { - return prefix + switch (field.messageType()) { + return prefix + (switch (field.messageType()) { case "StringValue" -> "writeOptionalString(out, %s, %s);" .formatted(fieldDef,getValueCode); case "BoolValue" -> "writeOptionalBoolean(out, %s, %s);" @@ -96,24 +101,17 @@ private static String generateFieldWriteLines( .formatted(fieldDef, getValueCode); case "BytesValue" -> "writeOptionalBytes(out, %s, %s);" .formatted(fieldDef, getValueCode); - default -> throw new UnsupportedOperationException( - "Unhandled optional message type:%s".formatted(field.messageType())); - }; - } else { - String codecReference = ""; - if (Field.FieldType.MESSAGE.equals(field.type())) { - codecReference = "%s.%s.PROTOBUF".formatted(((SingleField) field).messageTypeModelPackage(), - Common.capitalizeFirstLetter(field.messageType())); - } - if (field.repeated()) { - return prefix + switch(field.type()) { + default -> throw new UnsupportedOperationException("Unhandled optional message type:"+field.messageType()); + }).indent(indent) + postFix; + } else if (field.repeated()) { + return prefix + (switch(field.type()) { case ENUM -> "writeEnumList(out, %s, %s);" .formatted(fieldDef, getValueCode); case MESSAGE -> "writeMessageList(out, %s, %s, %s);" .formatted(fieldDef, getValueCode, codecReference); default -> "write%sList(out, %s, %s);" .formatted(writeMethodName, fieldDef, getValueCode); - }; + }).indent(indent) + postFix; } else if (field.type() == Field.FieldType.MAP) { // https://protobuf.dev/programming-guides/proto3/#maps // On the wire, a map is equivalent to: @@ -160,22 +158,165 @@ private static String generateFieldWriteLines( .replace("$K", mapField.keyField().type().boxedType) .replace("$V", mapField.valueField().type() == Field.FieldType.MESSAGE ? ((SingleField)mapField.valueField()).messageType() : mapField.valueField().type().boxedType) .replace("$fieldWriteLines", fieldWriteLines.indent(DEFAULT_INDENT)) - .replace("$fieldSizeOfLines", fieldSizeOfLines.indent(DEFAULT_INDENT)); + .replace("$fieldSizeOfLines", fieldSizeOfLines.indent(DEFAULT_INDENT)) + .indent(indent) + postFix + ; } else { - return prefix + switch(field.type()) { + return prefix + (switch(field.type()) { case ENUM -> "writeEnum(out, %s, %s);" .formatted(fieldDef, getValueCode); case STRING -> "writeString(out, %s, %s, %s);" .formatted(fieldDef, getValueCode, skipDefault); - case MESSAGE -> "writeMessage(out, %s, %s, %s);" - .formatted(fieldDef, getValueCode, codecReference); + case MESSAGE -> writeMessageCode(field, fieldDef, getValueCode); case BOOL -> "writeBoolean(out, %s, %s, %s);" .formatted(fieldDef, getValueCode, skipDefault); - case INT32, UINT32, SINT32, FIXED32, SFIXED32, INT64, SINT64, UINT64, FIXED64, SFIXED64, BYTES -> - "write%s(out, %s, %s, %s);".formatted(writeMethodName, fieldDef, getValueCode, skipDefault); + case INT32, UINT32, SINT32, FIXED32, SFIXED32, INT64, SINT64, UINT64, FIXED64, SFIXED64 -> + writeNumberCode(field, getValueCode, skipDefault); + case BYTES -> + "write%s(out, %s, %s, %s);" + .formatted(writeMethodName, fieldDef, getValueCode, skipDefault); default -> "write%s(out, %s, %s);" .formatted(writeMethodName, fieldDef, getValueCode); - }; + }).indent(indent) + postFix; + } + } + + private static String writeMessageCode(final Field field, final String fieldDef, final String getValueCode) { + String code = ""; + // When not a oneOf don't write default value + if (field.parent() != null) { + code += "if (%s == null) {\n".formatted(getValueCode); + code += writeTagCode(field, ProtoConstants.WIRE_TYPE_DELIMITED).indent(DEFAULT_INDENT); + code += "out.writeByte((byte)0);\n".indent(DEFAULT_INDENT); + code += "}\n"; + } + code += "if (%s != null) {\n".formatted(getValueCode); + code += writeTagCode(field, ProtoConstants.WIRE_TYPE_DELIMITED).indent(DEFAULT_INDENT); + if(field.parent() != null) { + code += "final int msgSize = ((%s)%s).protobufSize();\n".formatted(field.messageType(), getValueCode) + .indent(DEFAULT_INDENT); + } else { + code += "final int msgSize = %s.protobufSize();\n".formatted(getValueCode).indent(DEFAULT_INDENT); + } + code += "out.writeVarInt(msgSize, false);\n".indent(DEFAULT_INDENT); + code += "if (msgSize > 0) %s.write(%s, out);\n".formatted(field.messageType()+".PROTOBUF", getValueCode).indent(DEFAULT_INDENT); + code += "}\n"; + return code; + } + + private static String writeNumberCode(final Field field, final String getValueCode, final boolean skipDefault) { + assert !field.repeated() : "Use write***List methods with repeated types"; + final String objectCastName = switch(field.type()) { + case INT32, UINT32, SINT32, FIXED32, SFIXED32 -> "Integer"; + case INT64, UINT64, SINT64, FIXED64, SFIXED64 -> "Long"; + default -> throw new RuntimeException("Unsupported field type. Bug in ProtoOutputStream, shouldn't happen."); + }; + String code = ""; + int indent = 0; + if (skipDefault) { + if(field.parent() == null) { + code += "if (%s != 0) {\n".formatted(getValueCode); + } else { + code += "if ((%s)%s != 0) {\n".formatted(objectCastName, getValueCode); + } + indent ++; + } + String writeCode = switch (field.type()) { + case INT32, INT64, UINT64 -> + writeTagCode(field, ProtoConstants.WIRE_TYPE_VARINT_OR_ZIGZAG) + + "out.writeVarLong("+getValueCode+", false);\n"; + case UINT32 -> + writeTagCode(field, ProtoConstants.WIRE_TYPE_VARINT_OR_ZIGZAG) + + "out.writeVarLong(Integer.toUnsignedLong("+getValueCode+"), false);\n"; + case SINT32, SINT64 -> + writeTagCode(field, ProtoConstants.WIRE_TYPE_VARINT_OR_ZIGZAG) + + "out.writeVarLong("+getValueCode+", true);\n"; + case SFIXED32, FIXED32 -> + // The bytes in protobuf are in little-endian order -- backwards for Java. + // Smallest byte first. + writeTagCode(field, ProtoConstants.WIRE_TYPE_VARINT_OR_ZIGZAG) + + "out.writeInt("+getValueCode+", ByteOrder.LITTLE_ENDIAN);\n"; + case SFIXED64, FIXED64 -> + // The bytes in protobuf are in little-endian order -- backwards for Java. + // Smallest byte first. + writeTagCode(field, ProtoConstants.WIRE_TYPE_VARINT_OR_ZIGZAG) + + "out.writeLong("+getValueCode+", ByteOrder.LITTLE_ENDIAN);\n"; + default -> throw new RuntimeException("Unsupported field type. Bug in ProtoOutputStream, shouldn't happen."); + }; + code += writeCode.indent(DEFAULT_INDENT * indent); + + if (skipDefault) { + indent --; + code += "}\n".indent(DEFAULT_INDENT * indent); + } + return code; + } + + /** + * Generate manually inlined code to write tag for field + * + * @param field The field to generate tag for + * @param wireType The wire type for the field + * @return java code to write tag for field + */ + private static String writeTagCode(final Field field, final ProtoConstants wireType) { + return writeVarLongCode(((long)field.fieldNumber() << TAG_TYPE_BITS) | wireType.ordinal(), false); + } + + /** + * Generate manually inlined code to write varint + * + * @param value The value to write + * @param zigZag If true, use zigzag encoding + * @return java code to write varint + */ + private static String writeVarLongCode(long value, final boolean zigZag) { + if (zigZag) { + value = (value << 1) ^ (value >> 63); + } + StringBuilder code = new StringBuilder(); + while (true) { + if ((value & ~0x7FL) == 0) { + code.append("out.writeByte((byte) 0x%08X);\n".formatted(value)); + break; + } else { + code.append("out.writeByte((byte) 0x%08X );\n".formatted((byte) (((int) value & 0x7F) | 0x80))); + value >>>= 7; + } + } + return code.toString(); + } + + /** The number of leading bits of the tag that are used to store field type, the rest is field number */ + private static final int TAG_TYPE_BITS = 3; + + /** + * Protobuf field types + */ + private enum ProtoConstants { + /** On wire encoded type for varint */ + WIRE_TYPE_VARINT_OR_ZIGZAG, + /** On wire encoded type for fixed 64bit */ + WIRE_TYPE_FIXED_64_BIT, + /** On wire encoded type for length delimited */ + WIRE_TYPE_DELIMITED, + /** On wire encoded type for group start, deprecated */ + WIRE_TYPE_GROUP_START, + /** On wire encoded type for group end, deprecated */ + WIRE_TYPE_GROUP_END, + /** On wire encoded type for fixed 32bit */ + WIRE_TYPE_FIXED_32_BIT; + + // values() seems to allocate a new array on each call, so let's cache it here + private static final ProtoConstants[] values = values(); + + /** + * Mask used to extract the wire type from the "tag" byte + */ + public static final int TAG_WIRE_TYPE_MASK = 0b0000_0111; + + public static ProtoConstants get(int ordinal) { + return values[ordinal]; } } // spotless:on diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/JsonBench.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/JsonBench.java new file mode 100644 index 00000000..e41de404 --- /dev/null +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/JsonBench.java @@ -0,0 +1,155 @@ +package com.hedera.pbj.intergration.jmh; + +import com.google.protobuf.GeneratedMessage; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.util.JsonFormat; +import com.hedera.hapi.node.base.Timestamp; +import com.hedera.hapi.node.token.AccountDetails; +import com.hedera.pbj.integration.AccountDetailsPbj; +import com.hedera.pbj.integration.EverythingTestData; +import com.hedera.pbj.runtime.Codec; +import com.hedera.pbj.runtime.JsonCodec; +import com.hedera.pbj.runtime.ParseException; +import com.hedera.pbj.runtime.io.buffer.BufferedData; +import com.hedera.pbj.test.proto.pbj.Everything; +import com.hederahashgraph.api.proto.java.GetAccountDetailsResponse; +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +@SuppressWarnings("unused") +@Fork(1) +@Warmup(iterations = 2, time = 2) +@Measurement(iterations = 5, time = 2) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@BenchmarkMode(Mode.AverageTime) +public abstract class JsonBench
{ + + @SuppressWarnings("rawtypes") + @State(Scope.Benchmark) + public static class JsonBenchmarkState
{ + private JsonCodec
pbjJsonCodec;
+ private Supplier pbjProtoCodec, JsonCodec pbjJsonCodec,
+ ProtobufObjectBench.ProtobufParseFunction benchmarkState, Blackhole blackhole) throws ParseException {
+ benchmarkState.jsonDataBuffer.position(0);
+ blackhole.consume(benchmarkState.pbjJsonCodec.parse(benchmarkState.jsonDataBuffer));
+ }
+
+ @Benchmark
+ public void parseProtoC(JsonBenchmarkState benchmarkState, Blackhole blackhole) throws IOException {
+ var builder = benchmarkState.builderSupplier.get();
+ JsonFormat.parser().merge(benchmarkState.jsonString, builder);
+ blackhole.consume(builder.build());
+ }
+
+ /** Same as writePbjByteBuffer because DataBuffer.wrap(byte[]) uses ByteBuffer today, added this because makes result plotting easier */
+ @Benchmark
+ public void writePbj(JsonBenchmarkState benchmarkState, Blackhole blackhole) throws IOException {
+ benchmarkState.outDataBuffer.reset();
+ benchmarkState.pbjJsonCodec.write(benchmarkState.pbjModelObject, benchmarkState.outDataBuffer);
+ blackhole.consume(benchmarkState.outDataBuffer);
+ }
+
+ @Benchmark
+ public void writeProtoC(JsonBenchmarkState benchmarkState, Blackhole blackhole) throws InvalidProtocolBufferException {
+ blackhole.consume(JsonFormat.printer().print(benchmarkState.googleModelObject));
+ }
+
+ /** Custom interface for method references as java.util.Function does not throw IOException */
+ public interface ProtobufParseFunction {
+ /** we repeat all operations 1000 times so that measured times are nig enough */
+ private static final int OPERATION_COUNT = 1000;
+
+ @State(Scope.Benchmark)
+ public static class BenchmarkState {
+ private Codec pbjCodec;
+ private ProtobufParseFunction pbjCodec,
+ ProtobufParseFunction benchmarkState, Blackhole blackhole) throws ParseException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ benchmarkState.protobufDataBuffer.resetPosition();
+ blackhole.consume(benchmarkState.pbjCodec.parse(benchmarkState.protobufDataBuffer));
+ }
+ }
+
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void parsePbjByteBuffer(BenchmarkState benchmarkState, Blackhole blackhole) throws ParseException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ benchmarkState.protobufDataBuffer.resetPosition();
+ blackhole.consume(benchmarkState.pbjCodec.parse(benchmarkState.protobufDataBuffer));
+ }
+ }
+
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void parsePbjByteBufferDirect(BenchmarkState benchmarkState, Blackhole blackhole)
+ throws ParseException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ benchmarkState.protobufDataBufferDirect.resetPosition();
+ blackhole.consume(benchmarkState.pbjCodec.parse(benchmarkState.protobufDataBufferDirect));
+ }
+ }
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void parsePbjInputStream(BenchmarkState benchmarkState, Blackhole blackhole) throws ParseException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ benchmarkState.bin.resetPosition();
+ blackhole.consume(benchmarkState.pbjCodec.parse(new ReadableStreamingData(benchmarkState.bin)));
+ }
+ }
+
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void parseProtoCByteArray(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ blackhole.consume(benchmarkState.googleByteArrayParseMethod.parse(benchmarkState.protobuf));
+ }
+ }
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void parseProtoCByteBufferDirect(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ benchmarkState.protobufByteBufferDirect.position(0);
+ blackhole.consume(benchmarkState.googleByteBufferParseMethod.parse(benchmarkState.protobufByteBufferDirect));
+ }
+ }
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void parseProtoCByteBuffer(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ blackhole.consume(benchmarkState.googleByteBufferParseMethod.parse(benchmarkState.protobufByteBuffer));
+ }
+ }
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void parseProtoCInputStream(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ benchmarkState.bin.resetPosition();
+ blackhole.consume(benchmarkState.googleInputStreamParseMethod.parse(benchmarkState.bin));
+ }
+ }
+
+ /** Same as writePbjByteBuffer because DataBuffer.wrap(byte[]) uses ByteBuffer today, added this because makes result plotting easier */
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void writePbjByteArray(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ benchmarkState.outDataBuffer.reset();
+ benchmarkState.pbjCodec.write(benchmarkState.pbjModelObject, benchmarkState.outDataBuffer);
+ blackhole.consume(benchmarkState.outDataBuffer);
+ }
+ }
+
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void writePbjByteBuffer(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ benchmarkState.outDataBuffer.reset();
+ benchmarkState.pbjCodec.write(benchmarkState.pbjModelObject, benchmarkState.outDataBuffer);
+ blackhole.consume(benchmarkState.outDataBuffer);
+ }
+ }
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void writePbjByteDirect(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ benchmarkState.outDataBufferDirect.reset();
+ benchmarkState.pbjCodec.write(benchmarkState.pbjModelObject, benchmarkState.outDataBufferDirect);
+ blackhole.consume(benchmarkState.outDataBufferDirect);
+ }
+ }
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void writePbjOutputStream(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ benchmarkState.bout.reset();
+ benchmarkState.pbjCodec.write(benchmarkState.pbjModelObject, new WritableStreamingData(benchmarkState.bout));
+ blackhole.consume(benchmarkState.bout.toByteArray());
+ }
+ }
+
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void writeProtoCByteArray(BenchmarkState benchmarkState, Blackhole blackhole) {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ blackhole.consume(benchmarkState.googleModelObject.toByteArray());
+ }
+ }
+
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void writeProtoCByteBuffer(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ CodedOutputStream cout = CodedOutputStream.newInstance(benchmarkState.bbout);
+ benchmarkState.googleModelObject.writeTo(cout);
+ blackhole.consume(benchmarkState.bbout);
+ }
+ }
+
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void writeProtoCByteBufferDirect(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ CodedOutputStream cout = CodedOutputStream.newInstance(benchmarkState.bboutDirect);
+ benchmarkState.googleModelObject.writeTo(cout);
+ blackhole.consume(benchmarkState.bbout);
+ }
+ }
+
+ @Benchmark
+ @OperationsPerInvocation(OPERATION_COUNT)
+ public void writeProtoCOutputStream(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException {
+ for (int i = 0; i < OPERATION_COUNT; i++) {
+ benchmarkState.bout.reset();
+ benchmarkState.googleModelObject.writeTo(benchmarkState.bout);
+ blackhole.consume(benchmarkState.bout.toByteArray());
+ }
+ }
+
+ /** Custom interface for method references as java.util.Function does not throw IOException */
+ public interface ProtobufParseFunction