Skip to content

Commit

Permalink
Check consistency of sizes of collections in StreamOutput
Browse files Browse the repository at this point in the history
Today, we serialize collections prefixed by their lengths. If the serialized
length is inconsistent with the number of objects in the collection then the
serialization succeeds but any subsequent deserialization will (almost
certainly) fail, reporting something unexpected at some later point in the
stream. This can happen, for instance, because of a concurrent modification
(e.g. elastic#28718) to the collection in between getting its length and iterating through
its objects.

This is a royal pain to debug, because the failure is reported a long way from
where it actually occurs. See also e.g. elastic#25323.

This change adds checks to StreamOutput to verify that it wrote as many objects
as it said it would, in order to catch any similar problems in future closer to
their sources.
  • Loading branch information
DaveCTurner committed Feb 21, 2018
1 parent 86e5e38 commit 2abace5
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -459,13 +459,19 @@ public void writeMapWithConsistentOrder(@Nullable Map<String, ? extends Object>
}
assert false == (map instanceof LinkedHashMap);
this.writeByte((byte) 10);
this.writeVInt(map.size());
final int size = map.size();
this.writeVInt(size);
Iterator<? extends Map.Entry<String, ?>> iterator =
map.entrySet().stream().sorted((a, b) -> a.getKey().compareTo(b.getKey())).iterator();
int count = 0;
while (iterator.hasNext()) {
Map.Entry<String, ?> next = iterator.next();
this.writeString(next.getKey());
this.writeGenericValue(next.getValue());
count++;
}
if (size != count) {
throw new IllegalStateException("Serialized size header does not match number of objects written.");
}
}

Expand All @@ -482,9 +488,15 @@ public void writeMapWithConsistentOrder(@Nullable Map<String, ? extends Object>
public final <K, V> void writeMapOfLists(final Map<K, List<V>> map, final Writer<K> keyWriter, final Writer<V> valueWriter)
throws IOException {
writeMap(map, keyWriter, (stream, list) -> {
writeVInt(list.size());
final int size = list.size();
writeVInt(size);
int count = 0;
for (final V value : list) {
valueWriter.write(this, value);
count++;
}
if (size != count) {
throw new IllegalStateException("Serialized size header does not match number of objects written.");
}
});
}
Expand All @@ -501,10 +513,16 @@ public final <K, V> void writeMapOfLists(final Map<K, List<V>> map, final Writer
*/
public final <K, V> void writeMap(final Map<K, V> map, final Writer<K> keyWriter, final Writer<V> valueWriter)
throws IOException {
writeVInt(map.size());
final int size = map.size();
writeVInt(size);
int count = 0;
for (final Map.Entry<K, V> entry : map.entrySet()) {
keyWriter.write(this, entry.getKey());
valueWriter.write(this, entry.getValue());
count++;
}
if (size != count) {
throw new IllegalStateException("Serialized size header does not match number of objects written.");
}
}

Expand Down Expand Up @@ -545,9 +563,15 @@ public final <K, V> void writeMap(final Map<K, V> map, final Writer<K> keyWriter
writers.put(List.class, (o, v) -> {
o.writeByte((byte) 7);
final List list = (List) v;
o.writeVInt(list.size());
final int size = list.size();
o.writeVInt(size);
int count = 0;
for (Object item : list) {
o.writeGenericValue(item);
count++;
}
if (size != count) {
throw new IllegalStateException("Serialized size header does not match number of objects written.");
}
});
writers.put(Object[].class, (o, v) -> {
Expand All @@ -566,10 +590,16 @@ public final <K, V> void writeMap(final Map<K, V> map, final Writer<K> keyWriter
}
@SuppressWarnings("unchecked")
final Map<String, Object> map = (Map<String, Object>) v;
o.writeVInt(map.size());
final int size = map.size();
o.writeVInt(size);
int count = 0;
for (Map.Entry<String, Object> entry : map.entrySet()) {
o.writeString(entry.getKey());
o.writeGenericValue(entry.getValue());
count++;
}
if (size != count) {
throw new IllegalStateException("Serialized size header does not match number of objects written.");
}
});
writers.put(Byte.class, (o, v) -> {
Expand Down Expand Up @@ -926,39 +956,63 @@ public void writeOptionalTimeZone(@Nullable DateTimeZone timeZone) throws IOExce
* Writes a list of {@link Streamable} objects
*/
public void writeStreamableList(List<? extends Streamable> list) throws IOException {
writeVInt(list.size());
final int size = list.size();
writeVInt(size);
int count = 0;
for (Streamable obj: list) {
obj.writeTo(this);
count++;
}
if (size != count) {
throw new IllegalStateException("Serialized size header does not match number of objects written.");
}
}

/**
* Writes a list of {@link Writeable} objects
*/
public void writeList(List<? extends Writeable> list) throws IOException {
writeVInt(list.size());
final int size = list.size();
writeVInt(size);
int count = 0;
for (Writeable obj: list) {
obj.writeTo(this);
count++;
}
if (size != count) {
throw new IllegalStateException("Serialized size header does not match number of objects written.");
}
}

/**
* Writes a list of strings
*/
public void writeStringList(List<String> list) throws IOException {
writeVInt(list.size());
final int size = list.size();
writeVInt(size);
int count = 0;
for (String string: list) {
this.writeString(string);
count++;
}
if (size != count) {
throw new IllegalStateException("Serialized size header does not match number of objects written.");
}
}

/**
* Writes a list of {@link NamedWriteable} objects.
*/
public void writeNamedWriteableList(List<? extends NamedWriteable> list) throws IOException {
writeVInt(list.size());
final int size = list.size();
writeVInt(size);
int count = 0;
for (NamedWriteable obj: list) {
writeNamedWriteable(obj);
count++;
}
if (size != count) {
throw new IllegalStateException("Serialized size header does not match number of objects written.");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Constants;
import org.elasticsearch.Version;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.geo.GeoPoint;
Expand All @@ -33,7 +32,6 @@
import java.io.EOFException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -812,4 +810,142 @@ public void testInvalidEnum() throws IOException {
}
assertEquals(0, input.available());
}

public void testWriteMapWithConsistentOrderSizeCheck() {
try (BytesStreamOutput output = new BytesStreamOutput(1)) {
Map<String,String> liesAboutItsSize = new HashMap<String,String>(){
@Override
public int size() {
return 1;
}
};

final IllegalStateException illegalStateException
= expectThrows(IllegalStateException.class,
() -> output.writeMapWithConsistentOrder(liesAboutItsSize));
assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written."));
}
}

public void testWriteMapOfListsSizeCheck() {
try (BytesStreamOutput output = new BytesStreamOutput(1)) {
Map<String,List<String>> liesAboutItsSize = new HashMap<String,List<String>>(){
@Override
public int size() {
return 1;
}
};

final IllegalStateException illegalStateException
= expectThrows(IllegalStateException.class,
() -> output.writeMapOfLists(liesAboutItsSize, StreamOutput::writeString, StreamOutput::writeString));
assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written."));
}
}

public void testWriteMapSizeCheck() {
try (BytesStreamOutput output = new BytesStreamOutput(1)) {
Map<String,String> liesAboutItsSize = new HashMap<String,String>(){
@Override
public int size() {
return 1;
}
};

final IllegalStateException illegalStateException
= expectThrows(IllegalStateException.class,
() -> output.writeMap(liesAboutItsSize, StreamOutput::writeString, StreamOutput::writeString));
assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written."));
}
}

public void testWriteGenericValueListSizeCheck() {
try (BytesStreamOutput output = new BytesStreamOutput(1)) {
List<String> liesAboutItsSize = new ArrayList<String>() {
@Override
public int size() {
return 1;
}
};

final IllegalStateException illegalStateException
= expectThrows(IllegalStateException.class, () -> output.writeGenericValue(liesAboutItsSize));
assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written."));
}
}

public void testWriteGenericValueMapSizeCheck() {
try (BytesStreamOutput output = new BytesStreamOutput(1)) {
Map<String,String> liesAboutItsSize = new HashMap<String,String>(){
@Override
public int size() {
return 1;
}
};

final IllegalStateException illegalStateException
= expectThrows(IllegalStateException.class, () -> output.writeGenericValue(liesAboutItsSize));
assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written."));
}
}

public void testWriteStreamableListSizeCheck() {
try (BytesStreamOutput output = new BytesStreamOutput(1)) {
List<Streamable> liesAboutItsSize = new ArrayList<Streamable>() {
@Override
public int size() {
return 1;
}
};

final IllegalStateException illegalStateException
= expectThrows(IllegalStateException.class, () -> output.writeStreamableList(liesAboutItsSize));
assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written."));
}
}

public void testWriteListSizeCheck() {
try (BytesStreamOutput output = new BytesStreamOutput(1)) {
List<Writeable> liesAboutItsSize = new ArrayList<Writeable>() {
@Override
public int size() {
return 1;
}
};

final IllegalStateException illegalStateException
= expectThrows(IllegalStateException.class, () -> output.writeList(liesAboutItsSize));
assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written."));
}
}

public void testWriteStringListSizeCheck() {
try (BytesStreamOutput output = new BytesStreamOutput(1)) {
List<String> liesAboutItsSize = new ArrayList<String>() {
@Override
public int size() {
return 1;
}
};

final IllegalStateException illegalStateException
= expectThrows(IllegalStateException.class, () -> output.writeStringList(liesAboutItsSize));
assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written."));
}
}

public void testWriteNamedWriteableListSizeCheck() {
try (BytesStreamOutput output = new BytesStreamOutput(1)) {
List<NamedWriteable> liesAboutItsSize = new ArrayList<NamedWriteable>() {
@Override
public int size() {
return 1;
}
};

final IllegalStateException illegalStateException
= expectThrows(IllegalStateException.class, () -> output.writeNamedWriteableList(liesAboutItsSize));
assertThat(illegalStateException.getMessage(), is("Serialized size header does not match number of objects written."));
}
}
}

0 comments on commit 2abace5

Please sign in to comment.