diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java index 98a126e75e5c3..8726cf7651991 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java @@ -459,13 +459,19 @@ public void writeMapWithConsistentOrder(@Nullable Map } assert false == (map instanceof LinkedHashMap); this.writeByte((byte) 10); - this.writeVInt(map.size()); + final int size = map.size(); + this.writeVInt(size); Iterator> iterator = map.entrySet().stream().sorted((a, b) -> a.getKey().compareTo(b.getKey())).iterator(); + int count = 0; while (iterator.hasNext()) { Map.Entry next = iterator.next(); this.writeString(next.getKey()); this.writeGenericValue(next.getValue()); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } } @@ -482,9 +488,15 @@ public void writeMapWithConsistentOrder(@Nullable Map public final void writeMapOfLists(final Map> map, final Writer keyWriter, final Writer valueWriter) throws IOException { writeMap(map, keyWriter, (stream, list) -> { - writeVInt(list.size()); + final int size = list.size(); + writeVInt(size); + int count = 0; for (final V value : list) { valueWriter.write(this, value); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } }); } @@ -501,10 +513,16 @@ public final void writeMapOfLists(final Map> map, final Writer */ public final void writeMap(final Map map, final Writer keyWriter, final Writer valueWriter) throws IOException { - writeVInt(map.size()); + final int size = map.size(); + writeVInt(size); + int count = 0; for (final Map.Entry entry : map.entrySet()) { keyWriter.write(this, entry.getKey()); valueWriter.write(this, entry.getValue()); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } } @@ -545,9 +563,15 @@ public final void writeMap(final Map map, final Writer keyWriter writers.put(List.class, (o, v) -> { o.writeByte((byte) 7); final List list = (List) v; - o.writeVInt(list.size()); + final int size = list.size(); + o.writeVInt(size); + int count = 0; for (Object item : list) { o.writeGenericValue(item); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } }); writers.put(Object[].class, (o, v) -> { @@ -566,10 +590,16 @@ public final void writeMap(final Map map, final Writer keyWriter } @SuppressWarnings("unchecked") final Map map = (Map) v; - o.writeVInt(map.size()); + final int size = map.size(); + o.writeVInt(size); + int count = 0; for (Map.Entry entry : map.entrySet()) { o.writeString(entry.getKey()); o.writeGenericValue(entry.getValue()); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } }); writers.put(Byte.class, (o, v) -> { @@ -926,9 +956,15 @@ public void writeOptionalTimeZone(@Nullable DateTimeZone timeZone) throws IOExce * Writes a list of {@link Streamable} objects */ public void writeStreamableList(List list) throws IOException { - writeVInt(list.size()); + final int size = list.size(); + writeVInt(size); + int count = 0; for (Streamable obj: list) { obj.writeTo(this); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } } @@ -936,9 +972,15 @@ public void writeStreamableList(List list) throws IOExcept * Writes a list of {@link Writeable} objects */ public void writeList(List list) throws IOException { - writeVInt(list.size()); + final int size = list.size(); + writeVInt(size); + int count = 0; for (Writeable obj: list) { obj.writeTo(this); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } } @@ -946,9 +988,15 @@ public void writeList(List list) throws IOException { * Writes a list of strings */ public void writeStringList(List list) throws IOException { - writeVInt(list.size()); + final int size = list.size(); + writeVInt(size); + int count = 0; for (String string: list) { this.writeString(string); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } } @@ -956,9 +1004,15 @@ public void writeStringList(List list) throws IOException { * Writes a list of {@link NamedWriteable} objects. */ public void writeNamedWriteableList(List list) throws IOException { - writeVInt(list.size()); + final int size = list.size(); + writeVInt(size); + int count = 0; for (NamedWriteable obj: list) { writeNamedWriteable(obj); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } } diff --git a/server/src/test/java/org/elasticsearch/common/io/stream/BytesStreamsTests.java b/server/src/test/java/org/elasticsearch/common/io/stream/BytesStreamsTests.java index 27656e9bc092d..d7711ab78b1e4 100644 --- a/server/src/test/java/org/elasticsearch/common/io/stream/BytesStreamsTests.java +++ b/server/src/test/java/org/elasticsearch/common/io/stream/BytesStreamsTests.java @@ -21,7 +21,6 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.Constants; -import org.elasticsearch.Version; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.geo.GeoPoint; @@ -33,7 +32,6 @@ import java.io.EOFException; import java.io.IOException; import java.util.ArrayList; -import java.util.Base64; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; @@ -812,4 +810,142 @@ public void testInvalidEnum() throws IOException { } assertEquals(0, input.available()); } + + public void testWriteMapWithConsistentOrderSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + Map liesAboutItsSize = new HashMap(){ + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, + () -> output.writeMapWithConsistentOrder(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteMapOfListsSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + Map> liesAboutItsSize = new HashMap>(){ + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, + () -> output.writeMapOfLists(liesAboutItsSize, StreamOutput::writeString, StreamOutput::writeString)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteMapSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + Map liesAboutItsSize = new HashMap(){ + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, + () -> output.writeMap(liesAboutItsSize, StreamOutput::writeString, StreamOutput::writeString)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteGenericValueListSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + List liesAboutItsSize = new ArrayList() { + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, () -> output.writeGenericValue(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteGenericValueMapSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + Map liesAboutItsSize = new HashMap(){ + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, () -> output.writeGenericValue(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteStreamableListSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + List liesAboutItsSize = new ArrayList() { + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, () -> output.writeStreamableList(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteListSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + List liesAboutItsSize = new ArrayList() { + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, () -> output.writeList(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteStringListSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + List liesAboutItsSize = new ArrayList() { + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, () -> output.writeStringList(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteNamedWriteableListSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + List liesAboutItsSize = new ArrayList() { + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, () -> output.writeNamedWriteableList(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } }