From 2abace58b7a6a09519a5ba0c2fbc2b7b43c92a17 Mon Sep 17 00:00:00 2001 From: David Turner Date: Wed, 21 Feb 2018 07:36:06 +0000 Subject: [PATCH] Check consistency of sizes of collections in StreamOutput Today, we serialize collections prefixed by their lengths. If the serialized length is inconsistent with the number of objects in the collection then the serialization succeeds but any subsequent deserialization will (almost certainly) fail, reporting something unexpected at some later point in the stream. This can happen, for instance, because of a concurrent modification (e.g. #28718) to the collection in between getting its length and iterating through its objects. This is a royal pain to debug, because the failure is reported a long way from where it actually occurs. See also e.g. #25323. This change adds checks to StreamOutput to verify that it wrote as many objects as it said it would, in order to catch any similar problems in future closer to their sources. --- .../common/io/stream/StreamOutput.java | 72 +++++++-- .../common/io/stream/BytesStreamsTests.java | 140 +++++++++++++++++- 2 files changed, 201 insertions(+), 11 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java index 98a126e75e5c3..8726cf7651991 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java @@ -459,13 +459,19 @@ public void writeMapWithConsistentOrder(@Nullable Map } assert false == (map instanceof LinkedHashMap); this.writeByte((byte) 10); - this.writeVInt(map.size()); + final int size = map.size(); + this.writeVInt(size); Iterator> iterator = map.entrySet().stream().sorted((a, b) -> a.getKey().compareTo(b.getKey())).iterator(); + int count = 0; while (iterator.hasNext()) { Map.Entry next = iterator.next(); this.writeString(next.getKey()); this.writeGenericValue(next.getValue()); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } } @@ -482,9 +488,15 @@ public void writeMapWithConsistentOrder(@Nullable Map public final void writeMapOfLists(final Map> map, final Writer keyWriter, final Writer valueWriter) throws IOException { writeMap(map, keyWriter, (stream, list) -> { - writeVInt(list.size()); + final int size = list.size(); + writeVInt(size); + int count = 0; for (final V value : list) { valueWriter.write(this, value); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } }); } @@ -501,10 +513,16 @@ public final void writeMapOfLists(final Map> map, final Writer */ public final void writeMap(final Map map, final Writer keyWriter, final Writer valueWriter) throws IOException { - writeVInt(map.size()); + final int size = map.size(); + writeVInt(size); + int count = 0; for (final Map.Entry entry : map.entrySet()) { keyWriter.write(this, entry.getKey()); valueWriter.write(this, entry.getValue()); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } } @@ -545,9 +563,15 @@ public final void writeMap(final Map map, final Writer keyWriter writers.put(List.class, (o, v) -> { o.writeByte((byte) 7); final List list = (List) v; - o.writeVInt(list.size()); + final int size = list.size(); + o.writeVInt(size); + int count = 0; for (Object item : list) { o.writeGenericValue(item); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } }); writers.put(Object[].class, (o, v) -> { @@ -566,10 +590,16 @@ public final void writeMap(final Map map, final Writer keyWriter } @SuppressWarnings("unchecked") final Map map = (Map) v; - o.writeVInt(map.size()); + final int size = map.size(); + o.writeVInt(size); + int count = 0; for (Map.Entry entry : map.entrySet()) { o.writeString(entry.getKey()); o.writeGenericValue(entry.getValue()); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } }); writers.put(Byte.class, (o, v) -> { @@ -926,9 +956,15 @@ public void writeOptionalTimeZone(@Nullable DateTimeZone timeZone) throws IOExce * Writes a list of {@link Streamable} objects */ public void writeStreamableList(List list) throws IOException { - writeVInt(list.size()); + final int size = list.size(); + writeVInt(size); + int count = 0; for (Streamable obj: list) { obj.writeTo(this); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } } @@ -936,9 +972,15 @@ public void writeStreamableList(List list) throws IOExcept * Writes a list of {@link Writeable} objects */ public void writeList(List list) throws IOException { - writeVInt(list.size()); + final int size = list.size(); + writeVInt(size); + int count = 0; for (Writeable obj: list) { obj.writeTo(this); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } } @@ -946,9 +988,15 @@ public void writeList(List list) throws IOException { * Writes a list of strings */ public void writeStringList(List list) throws IOException { - writeVInt(list.size()); + final int size = list.size(); + writeVInt(size); + int count = 0; for (String string: list) { this.writeString(string); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } } @@ -956,9 +1004,15 @@ public void writeStringList(List list) throws IOException { * Writes a list of {@link NamedWriteable} objects. */ public void writeNamedWriteableList(List list) throws IOException { - writeVInt(list.size()); + final int size = list.size(); + writeVInt(size); + int count = 0; for (NamedWriteable obj: list) { writeNamedWriteable(obj); + count++; + } + if (size != count) { + throw new IllegalStateException("Serialized size header does not match number of objects written."); } } diff --git a/server/src/test/java/org/elasticsearch/common/io/stream/BytesStreamsTests.java b/server/src/test/java/org/elasticsearch/common/io/stream/BytesStreamsTests.java index 27656e9bc092d..d7711ab78b1e4 100644 --- a/server/src/test/java/org/elasticsearch/common/io/stream/BytesStreamsTests.java +++ b/server/src/test/java/org/elasticsearch/common/io/stream/BytesStreamsTests.java @@ -21,7 +21,6 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.Constants; -import org.elasticsearch.Version; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.geo.GeoPoint; @@ -33,7 +32,6 @@ import java.io.EOFException; import java.io.IOException; import java.util.ArrayList; -import java.util.Base64; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; @@ -812,4 +810,142 @@ public void testInvalidEnum() throws IOException { } assertEquals(0, input.available()); } + + public void testWriteMapWithConsistentOrderSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + Map liesAboutItsSize = new HashMap(){ + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, + () -> output.writeMapWithConsistentOrder(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteMapOfListsSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + Map> liesAboutItsSize = new HashMap>(){ + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, + () -> output.writeMapOfLists(liesAboutItsSize, StreamOutput::writeString, StreamOutput::writeString)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteMapSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + Map liesAboutItsSize = new HashMap(){ + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, + () -> output.writeMap(liesAboutItsSize, StreamOutput::writeString, StreamOutput::writeString)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteGenericValueListSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + List liesAboutItsSize = new ArrayList() { + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, () -> output.writeGenericValue(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteGenericValueMapSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + Map liesAboutItsSize = new HashMap(){ + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, () -> output.writeGenericValue(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteStreamableListSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + List liesAboutItsSize = new ArrayList() { + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, () -> output.writeStreamableList(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteListSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + List liesAboutItsSize = new ArrayList() { + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, () -> output.writeList(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteStringListSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + List liesAboutItsSize = new ArrayList() { + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, () -> output.writeStringList(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } + + public void testWriteNamedWriteableListSizeCheck() { + try (BytesStreamOutput output = new BytesStreamOutput(1)) { + List liesAboutItsSize = new ArrayList() { + @Override + public int size() { + return 1; + } + }; + + final IllegalStateException illegalStateException + = expectThrows(IllegalStateException.class, () -> output.writeNamedWriteableList(liesAboutItsSize)); + assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written.")); + } + } }