From 732b067b29577faef393ee85d3cd0a0054e387f6 Mon Sep 17 00:00:00 2001 From: wrighe3 Date: Tue, 26 Apr 2016 19:00:02 -0500 Subject: [PATCH 1/6] add KryoSerializer, remove ClassifierResult2 --- .../examples/common/Parameters.scala | 6 +- flink-htm-streaming-java/build.gradle | 2 +- .../flink/serialization/KryoSerializer.java | 78 ++++++++ .../streaming/api/ClassifierResult2.java | 178 ------------------ .../nupic/flink/streaming/api/HTM.java | 4 + .../nupic/flink/streaming/api/Inference2.java | 13 +- .../streaming/api/HTMIntegrationTest.java | 12 +- .../flink/streaming/api/TestHarness.java | 2 +- .../nupic/encoders/scala/Encoders.scala | 1 - 9 files changed, 102 insertions(+), 194 deletions(-) create mode 100644 flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java delete mode 100644 flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/ClassifierResult2.java diff --git a/flink-htm-examples/src/main/scala/org/numenta/nupic/flink/streaming/examples/common/Parameters.scala b/flink-htm-examples/src/main/scala/org/numenta/nupic/flink/streaming/examples/common/Parameters.scala index 754390d..43b22d8 100644 --- a/flink-htm-examples/src/main/scala/org/numenta/nupic/flink/streaming/examples/common/Parameters.scala +++ b/flink-htm-examples/src/main/scala/org/numenta/nupic/flink/streaming/examples/common/Parameters.scala @@ -25,7 +25,7 @@ object NetworkDemoParameters { //SpatialPooler specific POTENTIAL_RADIUS -> 12, //3 POTENTIAL_PCT -> 0.5, //0.5 - GLOBAL_INHIBITIONS -> false, + GLOBAL_INHIBITION -> false, LOCAL_AREA_DENSITY -> -1.0, NUM_ACTIVE_COLUMNS_PER_INH_AREA -> 5.0, STIMULUS_THRESHOLD -> 1.0, @@ -49,7 +49,7 @@ object NetworkDemoParameters { PERMANENCE_DECREMENT -> 0.05, ACTIVATION_THRESHOLD -> 4)) .union(Parameters( - GLOBAL_INHIBITIONS -> true, + GLOBAL_INHIBITION -> true, COLUMN_DIMENSIONS -> Array(2048), CELLS_PER_COLUMN -> 32, NUM_ACTIVE_COLUMNS_PER_INH_AREA -> 40.0, @@ -80,7 +80,7 @@ trait WorkshopAnomalyParameters { // spParams POTENTIAL_PCT -> 0.8, COLUMN_DIMENSIONS -> Array(2048), - GLOBAL_INHIBITIONS -> true, + GLOBAL_INHIBITION -> true, /* inputWidth */ MAX_BOOST -> 1.0, NUM_ACTIVE_COLUMNS_PER_INH_AREA -> 40, diff --git a/flink-htm-streaming-java/build.gradle b/flink-htm-streaming-java/build.gradle index 44eaecc..fcf7a30 100644 --- a/flink-htm-streaming-java/build.gradle +++ b/flink-htm-streaming-java/build.gradle @@ -19,7 +19,7 @@ dependencies { compile 'org.slf4j:slf4j-api:1.7.13' // htm.java - compile 'org.numenta:htm.java:0.6.5' + compile 'org.numenta:htm.java:0.6.7' // flink compile 'org.apache.flink:flink-java:1.0.0' diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java new file mode 100644 index 0000000..5d70567 --- /dev/null +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java @@ -0,0 +1,78 @@ +package org.numenta.nupic.flink.serialization; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoException; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import org.apache.flink.api.common.ExecutionConfig; +import org.numenta.nupic.Persistable; +import org.numenta.nupic.network.Network; +import org.numenta.nupic.serialize.HTMObjectInput; +import org.numenta.nupic.serialize.HTMObjectOutput; +import org.numenta.nupic.serialize.SerialConfig; +import org.numenta.nupic.serialize.SerializerCore; + +import java.io.IOException; +import java.io.Serializable; + +/** + * Kryo serializer for HTM network and related objects. + * + */ +public class KryoSerializer extends Serializer implements Serializable { + + private final SerializerCore serializer = new SerializerCore(SerialConfig.DEFAULT_REGISTERED_TYPES); + + /** + * + * @param kryo instance of {@link Kryo} object + * @param output a Kryo {@link Output} object + * @param t instance to serialize + */ + @Override + public void write(Kryo kryo, Output output, T t) { + try { + HTMObjectOutput writer = serializer.getObjectOutput(output); + writer.writeObject(t, t.getClass()); + writer.flush(); + } + catch(IOException e) { + throw new KryoException(e); + } + } + + /** + * + * @param kryo instance of {@link Kryo} object + * @param input a Kryo {@link Input} + * @param aClass The class of the object to be read in. + * @return an instance of type <T> + */ + @Override + public T read(Kryo kryo, Input input, Class aClass) { + try { + HTMObjectInput reader = serializer.getObjectInput(input); + T t = (T) reader.readObject(aClass); + return t; + } + catch(Exception e) { + throw new KryoException(e); + } + } + + /** + * Register the HTM types with the Kryo serializer. + * @param config + */ + public static void registerTypes(ExecutionConfig config) { + for(Class c : SerialConfig.DEFAULT_REGISTERED_TYPES) { + config.registerTypeWithKryoSerializer(c, (Class>) (Class) KryoSerializer.class); + } + for(Class c : KryoSerializer.ADDITIONAL_REGISTERED_TYPES) { + config.registerTypeWithKryoSerializer(c, (Class>) (Class) KryoSerializer.class); + } + } + + static final Class[] ADDITIONAL_REGISTERED_TYPES = { Network.class }; +} \ No newline at end of file diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/ClassifierResult2.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/ClassifierResult2.java deleted file mode 100644 index 8cec7d5..0000000 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/ClassifierResult2.java +++ /dev/null @@ -1,178 +0,0 @@ -package org.numenta.nupic.flink.streaming.api; - -import org.numenta.nupic.algorithms.ClassifierResult; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; - -/** - * Container for the results of a classification computation. - */ -public class ClassifierResult2 implements Serializable { - /** Array of actual values */ - private final T[] actualValues; - - /** Map of step count -to- probabilities */ - private final Map probabilities; - - private ClassifierResult2(T[] actualValues, Map probabilities) { - this.actualValues = actualValues; - this.probabilities = probabilities; - } - - /** - * Returns the actual value for the specified bucket index - * - * @param bucketIndex - * @return - */ - public T getActualValue(int bucketIndex) { - if(actualValues == null || actualValues.length < bucketIndex + 1) { - return null; - } - return (T)actualValues[bucketIndex]; - } - - /** - * Returns all actual values entered - * - * @return array of type <T> - */ - public T[] getActualValues() { - return actualValues; - } - - /** - * Returns a count of actual values entered - * @return - */ - public int getActualValueCount() { - return actualValues.length; - } - - /** - * Returns the probability at the specified index for the given step - * @param step - * @param bucketIndex - * @return - */ - public double getStat(int step, int bucketIndex) { - return probabilities.get(step)[bucketIndex]; - } - - /** - * Returns the probabilities for the specified step - * @param step - * @return - */ - public double[] getStats(int step) { - return probabilities.get(step); - } - - /** - * Returns the input value corresponding with the highest probability - * for the specified step. - * - * @param step the step key under which the most probable value will be returned. - * @return - */ - public T getMostProbableValue(int step) { - int idx = -1; - if(probabilities.get(step) == null || (idx = getMostProbableBucketIndex(step)) == -1) { - return null; - } - return getActualValue(idx); - } - - /** - * Returns the bucket index corresponding with the highest probability - * for the specified step. - * - * @param step the step key under which the most probable index will be returned. - * @return -1 if there is no such entry - */ - public int getMostProbableBucketIndex(int step) { - if(probabilities.get(step) == null) return -1; - - double max = 0; - int bucketIdx = -1; - int i = 0; - for(double d : probabilities.get(step)) { - if(d > max) { - max = d; - bucketIdx = i; - } - ++i; - } - return bucketIdx; - } - - /** - * Returns the count of steps - * @return - */ - public int getStepCount() { - return probabilities.size(); - } - - /** - * Returns the count of probabilities for the specified step - * @param step the step indexing the probability values - * @return - */ - public int getStatCount(int step) { - return probabilities.get(step).length; - } - - /** - * Returns a set of steps being recorded. - * @return - */ - public int[] stepSet() { - int[] set = new int[probabilities.size()]; - int i = 0; - for(Integer key : probabilities.keySet()) set[i++] = key; - return set; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + Arrays.hashCode(actualValues); - result = prime * result + ((probabilities == null) ? 0 : probabilities.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if(this == obj) - return true; - if(obj == null) - return false; - if(getClass() != obj.getClass()) - return false; - @SuppressWarnings("rawtypes") - ClassifierResult2 other = (ClassifierResult2)obj; - if(!Arrays.equals(actualValues, other.actualValues)) - return false; - if(probabilities == null) { - if(other.probabilities != null) - return false; - } else if(!probabilities.equals(other.probabilities)) - return false; - return true; - } - - public static ClassifierResult2 fromClassifierResult(ClassifierResult result) { - Map probabilities = new HashMap<>(); - - for(int i = 0; i <= result.getStepCount(); i++) { - probabilities.put(i, result.getStats(i)); - } - - return new ClassifierResult2(result.getActualValues(), probabilities); - } -} diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTM.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTM.java index aaed1d6..71f5764 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTM.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTM.java @@ -4,6 +4,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.numenta.nupic.flink.serialization.KryoSerializer; import org.numenta.nupic.flink.streaming.api.operator.GlobalHTMInferenceOperator; import org.numenta.nupic.flink.streaming.api.operator.KeyedHTMInferenceOperator; import org.apache.flink.streaming.api.TimeCharacteristic; @@ -29,6 +30,9 @@ public class HTM { * @return Resulting HTM stream */ public static HTMStream learn(DataStream input, NetworkFactory networkFactory) { + + KryoSerializer.registerTypes(input.getExecutionConfig()); + final boolean isProcessingTime = input.getExecutionEnvironment().getStreamTimeCharacteristic() == TimeCharacteristic.ProcessingTime; final DataStream> inferenceStream; diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/Inference2.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/Inference2.java index 5cc8d5c..b7babd3 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/Inference2.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/Inference2.java @@ -1,5 +1,6 @@ package org.numenta.nupic.flink.streaming.api; +import org.numenta.nupic.algorithms.Classification; import org.numenta.nupic.network.Inference; import org.numenta.nupic.network.Layer; @@ -19,9 +20,9 @@ public class Inference2 implements Serializable { private final IN input; - private final Map> classifications; + private final Map> classifications; - public Inference2(IN input, double anomalyScore, Map> classifications) { + public Inference2(IN input, double anomalyScore, Map> classifications) { this.input = input; this.anomalyScore = anomalyScore; this.classifications = classifications; @@ -32,20 +33,20 @@ public Inference2(IN input, double anomalyScore, Map getClassification(String fieldName) { + public Classification getClassification(String fieldName) { if(classifications == null) throw new IllegalStateException("no classification results are available"); return classifications.get(fieldName); } public static Inference2 fromInference(IN input, Inference i) { - Map> classifications = new HashMap<>(); + Map> classifications = new HashMap<>(); for(String field : i.getClassifiers().keys()) { - classifications.put(field, ClassifierResult2.fromClassifierResult(i.getClassification(field))); + classifications.put(field, i.getClassification(field)); } return new Inference2(input, i.getAnomalyScore(), classifications); } diff --git a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java index 26ab319..231a590 100644 --- a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java +++ b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java @@ -1,6 +1,7 @@ package org.numenta.nupic.flink.streaming.api; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -48,12 +49,15 @@ public void testSimple() throws Exception { DataStream input = env.fromCollection(records); - DataStream> result = HTM + DataStream> result = HTM .learn(input, new TestHarness.DayDemoNetworkFactory()) - .select(new InferenceSelectFunction>() { + .select(new InferenceSelectFunction>() { @Override - public Tuple2 select(Inference2 inference) throws Exception { - return new Tuple2(inference.getInput().dayOfWeek, inference.getAnomalyScore()); + public Tuple3 select(Inference2 inference) throws Exception { + return new Tuple3( + inference.getInput().dayOfWeek, + (Double) inference.getClassification("dayOfWeek").getMostProbableValue(1), + inference.getAnomalyScore()); } }); diff --git a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/TestHarness.java b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/TestHarness.java index 7ec02a0..a6e8d6c 100644 --- a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/TestHarness.java +++ b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/TestHarness.java @@ -93,7 +93,7 @@ public static Parameters getParameters() { //SpatialPooler specific parameters.setParameterByKey(Parameters.KEY.POTENTIAL_RADIUS, 12);//3 parameters.setParameterByKey(Parameters.KEY.POTENTIAL_PCT, 0.5);//0.5 - parameters.setParameterByKey(Parameters.KEY.GLOBAL_INHIBITIONS, false); + parameters.setParameterByKey(Parameters.KEY.GLOBAL_INHIBITION, false); parameters.setParameterByKey(Parameters.KEY.LOCAL_AREA_DENSITY, -1.0); parameters.setParameterByKey(Parameters.KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 5.0); parameters.setParameterByKey(Parameters.KEY.STIMULUS_THRESHOLD, 1.0); diff --git a/flink-htm-streaming-scala/src/main/scala/org/numenta/nupic/encoders/scala/Encoders.scala b/flink-htm-streaming-scala/src/main/scala/org/numenta/nupic/encoders/scala/Encoders.scala index 051a97b..694759a 100644 --- a/flink-htm-streaming-scala/src/main/scala/org/numenta/nupic/encoders/scala/Encoders.scala +++ b/flink-htm-streaming-scala/src/main/scala/org/numenta/nupic/encoders/scala/Encoders.scala @@ -21,7 +21,6 @@ object DateEncoder { def apply(): jencoders.DateEncoder.Builder = { jencoders.DateEncoder.builder() .defaults - .formatter(ISODateTimeFormat.dateTime()) } } From 088da9b94905eab022c901fb742d58ed19537862 Mon Sep 17 00:00:00 2001 From: wrighe3 Date: Tue, 26 Apr 2016 21:15:14 -0500 Subject: [PATCH 2/6] expose nupic Inference object --- .../streaming/examples/hotgym/HotGym.scala | 8 ++++---- .../streaming/examples/traffic/Traffic.scala | 2 +- .../flink/serialization/KryoSerializer.java | 9 +++++++++ .../nupic/flink/streaming/api/HTM.java | 14 +++++++++---- .../nupic/flink/streaming/api/HTMStream.java | 10 ++++++---- .../api/InferenceSelectFunction.java | 4 +++- ...{Inference2.java => NetworkInference.java} | 20 +++++++------------ .../AbstractHTMInferenceOperator.java | 12 +++++------ .../streaming/api/HTMIntegrationTest.java | 9 +++++---- .../nupic/flink/streaming/api/scala/HTM.scala | 8 ++++---- 10 files changed, 54 insertions(+), 42 deletions(-) rename flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/{Inference2.java => NetworkInference.java} (64%) diff --git a/flink-htm-examples/src/main/scala/org/numenta/nupic/flink/streaming/examples/hotgym/HotGym.scala b/flink-htm-examples/src/main/scala/org/numenta/nupic/flink/streaming/examples/hotgym/HotGym.scala index 46a6762..6ba76e1 100644 --- a/flink-htm-examples/src/main/scala/org/numenta/nupic/flink/streaming/examples/hotgym/HotGym.scala +++ b/flink-htm-examples/src/main/scala/org/numenta/nupic/flink/streaming/examples/hotgym/HotGym.scala @@ -77,17 +77,17 @@ object Demo extends HotGymModel { .mapWithState { (inference, state: Option[Double]) => val prediction = Prediction( - inference.getInput.timestamp.toString(LOOSE_DATE_TIME), - inference.getInput.consumption, + inference._1.timestamp.toString(LOOSE_DATE_TIME), + inference._1.consumption, state match { case Some(prediction) => prediction case None => 0.0 }, - inference.getAnomalyScore) + inference._2.getAnomalyScore) // store the prediction about the next value as state for the next iteration, // so that actual vs predicted is a meaningful comparison - val predictedConsumption = inference.getClassification("consumption").getMostProbableValue(1).asInstanceOf[Any] match { + val predictedConsumption = inference._2.getClassification("consumption").getMostProbableValue(1).asInstanceOf[Any] match { case value: Double if value != 0.0 => value case _ => state.getOrElse(0.0) } diff --git a/flink-htm-examples/src/main/scala/org/numenta/nupic/flink/streaming/examples/traffic/Traffic.scala b/flink-htm-examples/src/main/scala/org/numenta/nupic/flink/streaming/examples/traffic/Traffic.scala index 0e7acf1..c716dae 100644 --- a/flink-htm-examples/src/main/scala/org/numenta/nupic/flink/streaming/examples/traffic/Traffic.scala +++ b/flink-htm-examples/src/main/scala/org/numenta/nupic/flink/streaming/examples/traffic/Traffic.scala @@ -67,7 +67,7 @@ object Demo extends TrafficModel { .filter { report => report.datetime.isBefore(investigationInterval.getEnd) } .keyBy("streamId") .learn(network) - .select(inference => (inference.getInput, inference.getAnomalyScore)) + .select(inference => (inference._1, inference._2.getAnomalyScore)) val anomalousRoutes = anomalyScores .filter { anomaly => investigationInterval.contains(anomaly._1.datetime) } diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java index 5d70567..a50a9dd 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java @@ -3,6 +3,7 @@ import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.KryoException; import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.ByteBufferOutput; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; import org.apache.flink.api.common.ExecutionConfig; @@ -12,6 +13,8 @@ import org.numenta.nupic.serialize.HTMObjectOutput; import org.numenta.nupic.serialize.SerialConfig; import org.numenta.nupic.serialize.SerializerCore; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.Serializable; @@ -22,6 +25,8 @@ */ public class KryoSerializer extends Serializer implements Serializable { + protected static final Logger LOGGER = LoggerFactory.getLogger(KryoSerializer.class); + private final SerializerCore serializer = new SerializerCore(SerialConfig.DEFAULT_REGISTERED_TYPES); /** @@ -33,9 +38,13 @@ public class KryoSerializer extends Serializer implements Serializable { @Override public void write(Kryo kryo, Output output, T t) { try { + long total = output.total(); + HTMObjectOutput writer = serializer.getObjectOutput(output); writer.writeObject(t, t.getClass()); writer.flush(); + + LOGGER.debug("wrote {} bytes", output.total() - total); } catch(IOException e) { throw new KryoException(e); diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTM.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTM.java index 71f5764..642f46d 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTM.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTM.java @@ -3,6 +3,8 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.numenta.nupic.flink.serialization.KryoSerializer; import org.numenta.nupic.flink.streaming.api.operator.GlobalHTMInferenceOperator; @@ -10,6 +12,7 @@ import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.KeyedStream; +import org.numenta.nupic.network.Inference; /** * Utility class for stream processing with hierarchical temporal memory (HTM). @@ -35,7 +38,10 @@ public static HTMStream learn(DataStream input, NetworkFactory n final boolean isProcessingTime = input.getExecutionEnvironment().getStreamTimeCharacteristic() == TimeCharacteristic.ProcessingTime; - final DataStream> inferenceStream; + final TypeInformation inferenceTypeInfo = TypeExtractor.getForClass(Inference.class); + final TypeInformation> inferenceStreamTypeInfo = new TupleTypeInfo<>(input.getType(), inferenceTypeInfo); + + final DataStream> inferenceStream; if (input instanceof KeyedStream) { // each key will be processed by a dedicated Network instance. @@ -46,7 +52,7 @@ public static HTMStream learn(DataStream input, NetworkFactory n inferenceStream = input.transform( INFERENCE_OPERATOR_NAME, - (TypeInformation>) (TypeInformation) TypeExtractor.getForClass(Inference2.class), + inferenceStreamTypeInfo, new KeyedHTMInferenceOperator<>(input.getExecutionConfig(), input.getType(), isProcessingTime, keySelector, keySerializer, networkFactory) ).name("Learn"); @@ -54,8 +60,8 @@ public static HTMStream learn(DataStream input, NetworkFactory n // all stream elements will be processed by a single Network instance, hence parallelism -> 1. inferenceStream = input.transform( INFERENCE_OPERATOR_NAME, - (TypeInformation>) (TypeInformation) TypeExtractor.getForClass(Inference2.class), - new GlobalHTMInferenceOperator(input.getExecutionConfig(), input.getType(), isProcessingTime, networkFactory) + inferenceStreamTypeInfo, + new GlobalHTMInferenceOperator<>(input.getExecutionConfig(), input.getType(), isProcessingTime, networkFactory) ).name("Learn").setParallelism(1); } diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTMStream.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTMStream.java index 25d45a1..63fe0ea 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTMStream.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTMStream.java @@ -2,9 +2,11 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.numenta.nupic.network.Inference; /** * Stream abstraction for HTM inference. An HTM stream is a stream which emits @@ -16,11 +18,11 @@ public class HTMStream { // underlying data stream - private final DataStream> inferenceStream; + private final DataStream> inferenceStream; //type information of input type T private final TypeInformation inputType; - HTMStream(final DataStream> inferenceStream, final TypeInformation inputType) { + HTMStream(final DataStream> inferenceStream, final TypeInformation inputType) { this.inferenceStream = inferenceStream; this.inputType = inputType; } @@ -62,7 +64,7 @@ public DataStream select(final InferenceSelectFunction inferenceSel ).returns(returnType); } - private static class InferenceSelectMapper implements MapFunction, R> { + private static class InferenceSelectMapper implements MapFunction, R> { private final InferenceSelectFunction inferenceSelectFunction; @@ -71,7 +73,7 @@ public InferenceSelectMapper(InferenceSelectFunction inferenceSelectFuncti } @Override - public R map(Inference2 value) throws Exception { + public R map(Tuple2 value) throws Exception { return inferenceSelectFunction.select(value); } } diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/InferenceSelectFunction.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/InferenceSelectFunction.java index edc32d8..49f8407 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/InferenceSelectFunction.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/InferenceSelectFunction.java @@ -1,6 +1,8 @@ package org.numenta.nupic.flink.streaming.api; import org.apache.flink.api.common.functions.Function; +import org.apache.flink.api.java.tuple.Tuple2; +import org.numenta.nupic.network.Inference; import java.io.Serializable; @@ -20,5 +22,5 @@ public interface InferenceSelectFunction extends Function, Serializable * @throws Exception This method may throw exceptions. Throwing an exception * will cause the operation to fail and may trigger recovery. */ - OUT select(Inference2 inference) throws Exception; + OUT select(Tuple2 inference) throws Exception; } diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/Inference2.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/NetworkInference.java similarity index 64% rename from flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/Inference2.java rename to flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/NetworkInference.java index b7babd3..5eba34d 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/Inference2.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/NetworkInference.java @@ -2,36 +2,30 @@ import org.numenta.nupic.algorithms.Classification; import org.numenta.nupic.network.Inference; -import org.numenta.nupic.network.Layer; +import org.numenta.nupic.network.Network; import java.io.Serializable; import java.util.HashMap; import java.util.Map; /** - * Container for output from a given {@link Layer}. Represents the + * Container for output from a {@link Network}. Represents the * result accumulated by the computation of a sequence of algorithms - * contained in a given Layer and contains information needed at - * various stages in the sequence of calculations a Layer may contain. + * contained in the network. */ -public class Inference2 implements Serializable { +public class NetworkInference implements Serializable { private final double anomalyScore; - private final IN input; - private final Map> classifications; - public Inference2(IN input, double anomalyScore, Map> classifications) { - this.input = input; + public NetworkInference(double anomalyScore, Map> classifications) { this.anomalyScore = anomalyScore; this.classifications = classifications; } public double getAnomalyScore() { return this.anomalyScore; } - public IN getInput() { return this.input; } - /** * Returns the most recent {@link Classification} * @@ -43,11 +37,11 @@ public Classification getClassification(String fieldName) { return classifications.get(fieldName); } - public static Inference2 fromInference(IN input, Inference i) { + public static NetworkInference fromInference(Inference i) { Map> classifications = new HashMap<>(); for(String field : i.getClassifiers().keys()) { classifications.put(field, i.getClassification(field)); } - return new Inference2(input, i.getAnomalyScore(), classifications); + return new NetworkInference(i.getAnomalyScore(), classifications); } } diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/operator/AbstractHTMInferenceOperator.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/operator/AbstractHTMInferenceOperator.java index fad3a07..805d995 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/operator/AbstractHTMInferenceOperator.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/operator/AbstractHTMInferenceOperator.java @@ -7,7 +7,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.numenta.nupic.flink.streaming.api.Inference2; +import org.apache.flink.api.java.tuple.Tuple2; import org.numenta.nupic.flink.streaming.api.NetworkFactory; import org.numenta.nupic.flink.streaming.api.codegen.GenerateEncoderInputFunction; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -33,8 +33,8 @@ * @author Eron Wright */ public abstract class AbstractHTMInferenceOperator - extends AbstractStreamOperator> - implements OneInputStreamOperator> { + extends AbstractStreamOperator> + implements OneInputStreamOperator> { protected static final Logger LOG = LoggerFactory.getLogger(AbstractHTMInferenceOperator.class); @@ -109,10 +109,8 @@ protected void processInput(Network network, IN record, long timestamp) { Inference inference = network.computeImmediate(input); if(inference != null) { - Inference2 outRecord = Inference2.fromInference(record, inference); - - StreamRecord> streamRecord = new StreamRecord<>( - outRecord, + StreamRecord> streamRecord = new StreamRecord<>( + new Tuple2(record, inference), timestamp); output.collect(streamRecord); } diff --git a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java index 231a590..7767516 100644 --- a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java +++ b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java @@ -11,6 +11,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.numenta.nupic.network.Inference; import java.util.List; import java.util.stream.Collectors; @@ -53,11 +54,11 @@ public void testSimple() throws Exception { .learn(input, new TestHarness.DayDemoNetworkFactory()) .select(new InferenceSelectFunction>() { @Override - public Tuple3 select(Inference2 inference) throws Exception { + public Tuple3 select(Tuple2 inference) throws Exception { return new Tuple3( - inference.getInput().dayOfWeek, - (Double) inference.getClassification("dayOfWeek").getMostProbableValue(1), - inference.getAnomalyScore()); + inference.f0.dayOfWeek, + (Double) inference.f1.getClassification("dayOfWeek").getMostProbableValue(1), + inference.f1.getAnomalyScore()); } }); diff --git a/flink-htm-streaming-scala/src/main/scala/org/numenta/nupic/flink/streaming/api/scala/HTM.scala b/flink-htm-streaming-scala/src/main/scala/org/numenta/nupic/flink/streaming/api/scala/HTM.scala index ecb0f3c..e688fef 100644 --- a/flink-htm-streaming-scala/src/main/scala/org/numenta/nupic/flink/streaming/api/scala/HTM.scala +++ b/flink-htm-streaming-scala/src/main/scala/org/numenta/nupic/flink/streaming/api/scala/HTM.scala @@ -5,8 +5,8 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala.ClosureCleaner import org.apache.flink.streaming.api.scala.DataStream import org.numenta.nupic.flink.streaming.{api => jnupic} -import org.numenta.nupic.network.Network - +import org.numenta.nupic.network.{Inference, Network} +import org.apache.flink.api.java.tuple.{Tuple2 => FlinkTuple2} import scala.reflect.ClassTag /** @@ -86,11 +86,11 @@ final class HTMStream[T: TypeInformation : ClassTag](jstream: jnupic.HTMStream[T * @tparam R the type of the output elements. * @return a new data stream. */ - def select[R: TypeInformation : ClassTag](fun: jnupic.Inference2[T] => R): DataStream[R] = { + def select[R: TypeInformation : ClassTag](fun: Tuple2[T,Inference] => R): DataStream[R] = { val outType : TypeInformation[R] = implicitly[TypeInformation[R]] val selector: jnupic.InferenceSelectFunction[T,R] = new jnupic.InferenceSelectFunction[T,R] { val cleanFun = clean(fun) - def select(in: jnupic.Inference2[T]): R = cleanFun(in) + def select(in: FlinkTuple2[T,Inference]): R = cleanFun(Tuple2(in.f0, in.f1)) } new DataStream[R](jstream.select(selector, outType)) } From b98b06e91a0462c43be9384938a37124a1f08e58 Mon Sep 17 00:00:00 2001 From: wrighe3 Date: Sat, 30 Apr 2016 19:26:37 -0700 Subject: [PATCH 3/6] introduce NetworkInference class - introduce NetworkInference because the Inference object is proving too big to use as a data flow element. --- .../flink/serialization/KryoSerializer.java | 76 ++++++++++++++++--- .../nupic/flink/streaming/api/HTM.java | 6 +- .../nupic/flink/streaming/api/HTMStream.java | 8 +- .../api/InferenceSelectFunction.java | 4 +- .../flink/streaming/api/NetworkInference.java | 27 ++++--- .../AbstractHTMInferenceOperator.java | 10 ++- .../serialization/KryoSerializerTest.java | 75 ++++++++++++++++++ .../streaming/api/HTMIntegrationTest.java | 2 +- .../nupic/flink/streaming/api/scala/HTM.scala | 9 ++- 9 files changed, 180 insertions(+), 37 deletions(-) create mode 100644 flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/serialization/KryoSerializerTest.java diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java index a50a9dd..15df554 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java @@ -3,7 +3,6 @@ import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.KryoException; import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.ByteBufferOutput; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; import org.apache.flink.api.common.ExecutionConfig; @@ -16,20 +15,20 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; -import java.io.Serializable; +import java.io.*; /** * Kryo serializer for HTM network and related objects. * */ -public class KryoSerializer extends Serializer implements Serializable { +public class KryoSerializer extends Serializer implements Serializable { protected static final Logger LOGGER = LoggerFactory.getLogger(KryoSerializer.class); private final SerializerCore serializer = new SerializerCore(SerialConfig.DEFAULT_REGISTERED_TYPES); /** + * Write the given instance to the given output. * * @param kryo instance of {@link Kryo} object * @param output a Kryo {@link Output} object @@ -38,13 +37,19 @@ public class KryoSerializer extends Serializer implements Serializable { @Override public void write(Kryo kryo, Output output, T t) { try { - long total = output.total(); + try(ByteArrayOutputStream stream = new ByteArrayOutputStream(4096)) { - HTMObjectOutput writer = serializer.getObjectOutput(output); - writer.writeObject(t, t.getClass()); - writer.flush(); + // write the object using the HTM serializer + HTMObjectOutput writer = serializer.getObjectOutput(stream); + writer.writeObject(t, t.getClass()); + writer.close(); - LOGGER.debug("wrote {} bytes", output.total() - total); + // write the serialized data + output.writeInt(stream.size()); + stream.writeTo(output); + + LOGGER.debug("wrote {} bytes", stream.size()); + } } catch(IOException e) { throw new KryoException(e); @@ -52,6 +57,7 @@ public void write(Kryo kryo, Output output, T t) { } /** + * Read an instance of the given class from the given input. * * @param kryo instance of {@link Kryo} object * @param input a Kryo {@link Input} @@ -60,16 +66,62 @@ public void write(Kryo kryo, Output output, T t) { */ @Override public T read(Kryo kryo, Input input, Class aClass) { + + // read the serialized data + byte[] data = new byte[input.readInt()]; + input.readBytes(data); + + try { + try(ByteArrayInputStream stream = new ByteArrayInputStream(data)) { + HTMObjectInput reader = serializer.getObjectInput(stream); + T t = (T) reader.readObject(aClass); + return t; + } + } + catch(Exception e) { + throw new KryoException(e); + } + } + + /** + * Copy the given instance. + * @param kryo instance of {@link Kryo} object + * @param original an object to copy. + * @return + */ + @Override + public T copy(Kryo kryo, T original) { try { - HTMObjectInput reader = serializer.getObjectInput(input); - T t = (T) reader.readObject(aClass); - return t; + try(CopyStream output = new CopyStream(4096)) { + HTMObjectOutput writer = serializer.getObjectOutput(output); + writer.writeObject(original, original.getClass()); + writer.close(); + + try(InputStream input = output.toInputStream()) { + HTMObjectInput reader = serializer.getObjectInput(input); + T t = (T) reader.readObject(original.getClass()); + return t; + } + } } catch(Exception e) { throw new KryoException(e); } } + static class CopyStream extends ByteArrayOutputStream { + public CopyStream(int size) { super(size); } + + /** + * Get an input stream based on the contents of this output stream. + * Do not use the output stream after calling this method. + * @return an {@link InputStream} + */ + public InputStream toInputStream() { + return new ByteArrayInputStream(this.buf, 0, this.count); + } + } + /** * Register the HTM types with the Kryo serializer. * @param config diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTM.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTM.java index 642f46d..db0eccc 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTM.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTM.java @@ -38,10 +38,10 @@ public static HTMStream learn(DataStream input, NetworkFactory n final boolean isProcessingTime = input.getExecutionEnvironment().getStreamTimeCharacteristic() == TimeCharacteristic.ProcessingTime; - final TypeInformation inferenceTypeInfo = TypeExtractor.getForClass(Inference.class); - final TypeInformation> inferenceStreamTypeInfo = new TupleTypeInfo<>(input.getType(), inferenceTypeInfo); + final TypeInformation inferenceTypeInfo = TypeExtractor.getForClass(NetworkInference.class); + final TypeInformation> inferenceStreamTypeInfo = new TupleTypeInfo<>(input.getType(), inferenceTypeInfo); - final DataStream> inferenceStream; + final DataStream> inferenceStream; if (input instanceof KeyedStream) { // each key will be processed by a dedicated Network instance. diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTMStream.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTMStream.java index 63fe0ea..bc5e8c5 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTMStream.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/HTMStream.java @@ -18,11 +18,11 @@ public class HTMStream { // underlying data stream - private final DataStream> inferenceStream; + private final DataStream> inferenceStream; //type information of input type T private final TypeInformation inputType; - HTMStream(final DataStream> inferenceStream, final TypeInformation inputType) { + HTMStream(final DataStream> inferenceStream, final TypeInformation inputType) { this.inferenceStream = inferenceStream; this.inputType = inputType; } @@ -64,7 +64,7 @@ public DataStream select(final InferenceSelectFunction inferenceSel ).returns(returnType); } - private static class InferenceSelectMapper implements MapFunction, R> { + private static class InferenceSelectMapper implements MapFunction, R> { private final InferenceSelectFunction inferenceSelectFunction; @@ -73,7 +73,7 @@ public InferenceSelectMapper(InferenceSelectFunction inferenceSelectFuncti } @Override - public R map(Tuple2 value) throws Exception { + public R map(Tuple2 value) throws Exception { return inferenceSelectFunction.select(value); } } diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/InferenceSelectFunction.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/InferenceSelectFunction.java index 49f8407..065418d 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/InferenceSelectFunction.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/InferenceSelectFunction.java @@ -17,10 +17,10 @@ public interface InferenceSelectFunction extends Function, Serializable /** * Generates a result from the given raw inference. * - * @param inference Inference emitted by the network + * @param inference A tuple combining the input and associated inference emitted by the network * @return resulting element * @throws Exception This method may throw exceptions. Throwing an exception * will cause the operation to fail and may trigger recovery. */ - OUT select(Tuple2 inference) throws Exception; + OUT select(Tuple2 inference) throws Exception; } diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/NetworkInference.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/NetworkInference.java index 5eba34d..ebf6522 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/NetworkInference.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/NetworkInference.java @@ -1,10 +1,13 @@ package org.numenta.nupic.flink.streaming.api; +import com.esotericsoftware.kryo.serializers.DefaultSerializers; +import com.esotericsoftware.kryo.serializers.FieldSerializer; +import com.esotericsoftware.kryo.serializers.MapSerializer; import org.numenta.nupic.algorithms.Classification; +import org.numenta.nupic.flink.serialization.KryoSerializer; import org.numenta.nupic.network.Inference; import org.numenta.nupic.network.Network; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; @@ -13,18 +16,24 @@ * result accumulated by the computation of a sequence of algorithms * contained in the network. */ -public class NetworkInference implements Serializable { +public class NetworkInference { + @FieldSerializer.Bind(DefaultSerializers.DoubleSerializer.class) private final double anomalyScore; - private final Map> classifications; + @MapSerializer.BindMap( + keySerializer = DefaultSerializers.StringSerializer.class, keyClass = String.class, + valueSerializer = KryoSerializer.class, valueClass = Classification.class) + private final Map> classifications; - public NetworkInference(double anomalyScore, Map> classifications) { + public NetworkInference(double anomalyScore, Map> classifications) { this.anomalyScore = anomalyScore; this.classifications = classifications; } - public double getAnomalyScore() { return this.anomalyScore; } + public double getAnomalyScore() { + return this.anomalyScore; + } /** * Returns the most recent {@link Classification} @@ -33,15 +42,15 @@ public NetworkInference(double anomalyScore, Map> * @return the classification result. */ public Classification getClassification(String fieldName) { - if(classifications == null) throw new IllegalStateException("no classification results are available"); + if (classifications == null) throw new IllegalStateException("no classification results are available"); return classifications.get(fieldName); } - public static NetworkInference fromInference(Inference i) { + public static NetworkInference fromInference(Inference i) { Map> classifications = new HashMap<>(); - for(String field : i.getClassifiers().keys()) { + for (String field : i.getClassifiers().keys()) { classifications.put(field, i.getClassification(field)); } return new NetworkInference(i.getAnomalyScore(), classifications); } -} +} \ No newline at end of file diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/operator/AbstractHTMInferenceOperator.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/operator/AbstractHTMInferenceOperator.java index 805d995..44e04da 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/operator/AbstractHTMInferenceOperator.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/operator/AbstractHTMInferenceOperator.java @@ -9,6 +9,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.numenta.nupic.flink.streaming.api.NetworkFactory; +import org.numenta.nupic.flink.streaming.api.NetworkInference; import org.numenta.nupic.flink.streaming.api.codegen.GenerateEncoderInputFunction; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; @@ -33,8 +34,8 @@ * @author Eron Wright */ public abstract class AbstractHTMInferenceOperator - extends AbstractStreamOperator> - implements OneInputStreamOperator> { + extends AbstractStreamOperator> + implements OneInputStreamOperator> { protected static final Logger LOG = LoggerFactory.getLogger(AbstractHTMInferenceOperator.class); @@ -109,8 +110,9 @@ protected void processInput(Network network, IN record, long timestamp) { Inference inference = network.computeImmediate(input); if(inference != null) { - StreamRecord> streamRecord = new StreamRecord<>( - new Tuple2(record, inference), + NetworkInference outputInference = NetworkInference.fromInference(inference); + StreamRecord> streamRecord = new StreamRecord<>( + new Tuple2(record, outputInference), timestamp); output.collect(streamRecord); } diff --git a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/serialization/KryoSerializerTest.java b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/serialization/KryoSerializerTest.java new file mode 100644 index 0000000..73cf491 --- /dev/null +++ b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/serialization/KryoSerializerTest.java @@ -0,0 +1,75 @@ +package org.numenta.nupic.flink.serialization; + +import static org.junit.Assert.*; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import org.junit.Before; +import org.junit.Test; +import org.numenta.nupic.util.Tuple; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; + +import static org.junit.Assert.*; + +/** + * Test the KryoSerializer. + */ +public class KryoSerializerTest { + + public Kryo createKryo() { + Kryo.DefaultInstantiatorStrategy initStrategy = new Kryo.DefaultInstantiatorStrategy(); + + // use Objenesis to create classes without calling the constructor (Flink's technique) + //initStrategy.setFallbackInstantiatorStrategy(new StdInstantiatorStrategy()); + + Kryo kryo = new Kryo(); + kryo.setInstantiatorStrategy(initStrategy); + return kryo; + } + + private Kryo kryo; + + @Before + public void before() { + kryo = createKryo(); + kryo.register(Tuple.class, new KryoSerializer<>()); + } + + @Test + public void testReadWrite() throws Exception { + + // write numerous objects to the stream, to verify that the read buffers + // aren't too greedy + + Tuple expected1 = new Tuple(42); + Tuple expected2 = new Tuple(101); + + ByteArrayOutputStream baout = new ByteArrayOutputStream(); + Output output = new Output(baout); + kryo.writeObject(output, expected1); + kryo.writeObject(output, expected2); + output.close(); + + ByteArrayInputStream bain = new ByteArrayInputStream(baout.toByteArray()); + Input input = new Input(bain); + + Tuple actual1 = kryo.readObject(input, Tuple.class); + assertNotSame(expected1, actual1); + assertEquals(expected1, actual1); + + Tuple actual2 = kryo.readObject(input, Tuple.class); + assertNotSame(expected2, actual2); + assertEquals(expected2, actual2); + } + + @Test + public void testCopy() throws Exception { + Tuple expected = new Tuple(42); + + Tuple actual = kryo.copy(expected); + assertNotSame(expected, actual); + assertEquals(expected, actual); + } +} \ No newline at end of file diff --git a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java index 7767516..0316897 100644 --- a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java +++ b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java @@ -54,7 +54,7 @@ public void testSimple() throws Exception { .learn(input, new TestHarness.DayDemoNetworkFactory()) .select(new InferenceSelectFunction>() { @Override - public Tuple3 select(Tuple2 inference) throws Exception { + public Tuple3 select(Tuple2 inference) throws Exception { return new Tuple3( inference.f0.dayOfWeek, (Double) inference.f1.getClassification("dayOfWeek").getMostProbableValue(1), diff --git a/flink-htm-streaming-scala/src/main/scala/org/numenta/nupic/flink/streaming/api/scala/HTM.scala b/flink-htm-streaming-scala/src/main/scala/org/numenta/nupic/flink/streaming/api/scala/HTM.scala index e688fef..fe9806b 100644 --- a/flink-htm-streaming-scala/src/main/scala/org/numenta/nupic/flink/streaming/api/scala/HTM.scala +++ b/flink-htm-streaming-scala/src/main/scala/org/numenta/nupic/flink/streaming/api/scala/HTM.scala @@ -4,6 +4,7 @@ import org.apache.flink.api.common.ExecutionConfig import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala.ClosureCleaner import org.apache.flink.streaming.api.scala.DataStream +import org.numenta.nupic.flink.streaming.api.NetworkInference import org.numenta.nupic.flink.streaming.{api => jnupic} import org.numenta.nupic.network.{Inference, Network} import org.apache.flink.api.java.tuple.{Tuple2 => FlinkTuple2} @@ -28,6 +29,7 @@ object HTM { /** * Create an HTM stream based on the current [[DataStream]]. + * * @param input the input data stream to model. * @param factory the factory to create the HTM network. * @tparam T the type of the input elements. @@ -42,6 +44,7 @@ object HTM { /** * Create an HTM stream based on the current [[DataStream]]. + * * @param input the input data stream to model. * @param fun the factory to create the HTM network. * @tparam T the type of the input elements. @@ -71,6 +74,7 @@ final class HTMStream[T: TypeInformation : ClassTag](jstream: jnupic.HTMStream[T /** * Select output elements from the HTM stream. + * * @param selector the select function. * @tparam R the type of the output elements. * @return a new data stream. @@ -82,15 +86,16 @@ final class HTMStream[T: TypeInformation : ClassTag](jstream: jnupic.HTMStream[T /** * Select output elements from the HTM stream. + * * @param fun the select function. * @tparam R the type of the output elements. * @return a new data stream. */ - def select[R: TypeInformation : ClassTag](fun: Tuple2[T,Inference] => R): DataStream[R] = { + def select[R: TypeInformation : ClassTag](fun: Tuple2[T, NetworkInference] => R): DataStream[R] = { val outType : TypeInformation[R] = implicitly[TypeInformation[R]] val selector: jnupic.InferenceSelectFunction[T,R] = new jnupic.InferenceSelectFunction[T,R] { val cleanFun = clean(fun) - def select(in: FlinkTuple2[T,Inference]): R = cleanFun(Tuple2(in.f0, in.f1)) + def select(in: FlinkTuple2[T,NetworkInference]): R = cleanFun(Tuple2(in.f0, in.f1)) } new DataStream[R](jstream.select(selector, outType)) } From cc2e554c9a9532db2dc3b9adde6304c9e1a9597d Mon Sep 17 00:00:00 2001 From: wrighe3 Date: Sun, 1 May 2016 17:09:32 -0700 Subject: [PATCH 4/6] checkpoint support --- .../flink/serialization/KryoSerializer.java | 77 +++++++++++++++- .../operator/GlobalHTMInferenceOperator.java | 34 ++++++++ .../streaming/api/HTMIntegrationTest.java | 66 +++++++++++++- .../streaming/api/TestSourceFunction.java | 87 +++++++++++++++++++ 4 files changed, 261 insertions(+), 3 deletions(-) create mode 100644 flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/TestSourceFunction.java diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java index 15df554..ef9a63f 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/serialization/KryoSerializer.java @@ -6,6 +6,8 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.streaming.util.FieldAccessor; import org.numenta.nupic.Persistable; import org.numenta.nupic.network.Network; import org.numenta.nupic.serialize.HTMObjectInput; @@ -16,6 +18,10 @@ import org.slf4j.LoggerFactory; import java.io.*; +import java.lang.reflect.Field; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * Kryo serializer for HTM network and related objects. @@ -37,6 +43,8 @@ public class KryoSerializer extends Serializer impleme @Override public void write(Kryo kryo, Output output, T t) { try { + preSerialize(t); + try(ByteArrayOutputStream stream = new ByteArrayOutputStream(4096)) { // write the object using the HTM serializer @@ -75,6 +83,9 @@ public T read(Kryo kryo, Input input, Class aClass) { try(ByteArrayInputStream stream = new ByteArrayInputStream(data)) { HTMObjectInput reader = serializer.getObjectInput(stream); T t = (T) reader.readObject(aClass); + + postDeSerialize(t); + return t; } } @@ -92,6 +103,8 @@ public T read(Kryo kryo, Input input, Class aClass) { @Override public T copy(Kryo kryo, T original) { try { + preSerialize(original); + try(CopyStream output = new CopyStream(4096)) { HTMObjectOutput writer = serializer.getObjectOutput(output); writer.writeObject(original, original.getClass()); @@ -100,6 +113,9 @@ public T copy(Kryo kryo, T original) { try(InputStream input = output.toInputStream()) { HTMObjectInput reader = serializer.getObjectInput(input); T t = (T) reader.readObject(original.getClass()); + + postDeSerialize(t); + return t; } } @@ -122,18 +138,75 @@ public InputStream toInputStream() { } } + /** + * The HTM serializer handles the Persistable callbacks automatically, but + * this method is for any additional actions to be taken. + * @param t the instance to be serialized. + */ + protected void preSerialize(T t) { + } + + /** + * The HTM serializer handles the Persistable callbacks automatically, but + * this method is for any additional actions to be taken. + * @param t the instance newly deserialized. + */ + protected void postDeSerialize(T t) { + } + /** * Register the HTM types with the Kryo serializer. * @param config */ public static void registerTypes(ExecutionConfig config) { for(Class c : SerialConfig.DEFAULT_REGISTERED_TYPES) { - config.registerTypeWithKryoSerializer(c, (Class>) (Class) KryoSerializer.class); + Class serializerClass = DEFAULT_SERIALIZERS.getOrDefault(c, (Class) KryoSerializer.class); + config.registerTypeWithKryoSerializer(c, (Class>) serializerClass); } for(Class c : KryoSerializer.ADDITIONAL_REGISTERED_TYPES) { - config.registerTypeWithKryoSerializer(c, (Class>) (Class) KryoSerializer.class); + Class serializerClass = DEFAULT_SERIALIZERS.getOrDefault(c, (Class) KryoSerializer.class); + config.registerTypeWithKryoSerializer(c, (Class>) serializerClass); } } static final Class[] ADDITIONAL_REGISTERED_TYPES = { Network.class }; + + /** + * A map of serializers for various classes. + */ + static final Map,Class> DEFAULT_SERIALIZERS = Stream.of( + new Tuple2<>(Network.class, NetworkSerializer.class) + ).collect(Collectors.toMap(kv -> kv.f0, kv -> kv.f1)); + + + public static class NetworkSerializer extends KryoSerializer { + + private final static Field shouldDoHaltField; + + static { + try { + shouldDoHaltField = Network.class.getDeclaredField("shouldDoHalt"); + shouldDoHaltField.setAccessible(true); + } catch (NoSuchFieldException e) { + throw new UnsupportedOperationException("unable to locate Network::shouldDoHalt", e); + } + + } + + @Override + protected void preSerialize(Network network) { + super.preSerialize(network); + try { + // issue: HTM.java #417 + shouldDoHaltField.set(network, false); + } catch (IllegalAccessException e) { + throw new UnsupportedOperationException("unable to set Network::shouldDoHalt", e); + } + } + + @Override + protected void postDeSerialize(Network network) { + super.postDeSerialize(network); + } + } } \ No newline at end of file diff --git a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/operator/GlobalHTMInferenceOperator.java b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/operator/GlobalHTMInferenceOperator.java index 5e70bc4..bd3f4e8 100644 --- a/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/operator/GlobalHTMInferenceOperator.java +++ b/flink-htm-streaming-java/src/main/java/org/numenta/nupic/flink/streaming/api/operator/GlobalHTMInferenceOperator.java @@ -1,7 +1,14 @@ package org.numenta.nupic.flink.streaming.api.operator; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.streaming.runtime.tasks.StreamTaskState; import org.numenta.nupic.flink.streaming.api.NetworkFactory; import org.numenta.nupic.network.Network; @@ -19,6 +26,9 @@ public class GlobalHTMInferenceOperator extends AbstractHTMInferenceOperator // global network for all elements transient private Network network; + // state serializer + private KryoSerializer stateSerializer; + public GlobalHTMInferenceOperator( final ExecutionConfig executionConfig, final TypeInformation inputType, @@ -27,6 +37,8 @@ public GlobalHTMInferenceOperator( super(executionConfig, inputType, isProcessingTime, networkFactory); this.networkFactory = networkFactory; + + stateSerializer = new KryoSerializer((Class) (Class) Network.class, executionConfig); } @Override @@ -43,4 +55,26 @@ public void open() throws Exception { protected Network getInputNetwork() throws Exception { return network; } + + @Override + public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp) throws Exception { + StreamTaskState taskState = super.snapshotOperatorState(checkpointId, timestamp); + + final AbstractStateBackend.CheckpointStateOutputView ov = + this.getStateBackend().createCheckpointStateOutputView(checkpointId, timestamp); + stateSerializer.serialize(network, ov); + taskState.setOperatorState(ov.closeAndGetHandle()); + + return taskState; + } + + @Override + @SuppressWarnings("unchecked") + public void restoreState(StreamTaskState state, long recoveryTimestamp) throws Exception { + super.restoreState(state, recoveryTimestamp); + + final StateHandle handle = (StateHandle) state.getOperatorState(); + final DataInputView iv = handle.getState(getUserCodeClassloader()); + network = stateSerializer.deserialize(iv); + } } diff --git a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java index 0316897..cd62764 100644 --- a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java +++ b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java @@ -1,5 +1,6 @@ package org.numenta.nupic.flink.streaming.api; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.core.fs.FileSystem; @@ -11,9 +12,9 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import org.numenta.nupic.network.Inference; import java.util.List; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -70,4 +71,67 @@ public Tuple3 select(Tuple2 source = env + .addSource(new DayDemoRecordSourceFunction(2, true)) + .broadcast(); + + DataStream> result = + HTM.learn(source, new TestHarness.DayDemoNetworkFactory()) + .select(new InferenceSelectFunction>() { + @Override + public Tuple3 select(Tuple2 inference) throws Exception { + return new Tuple3( + inference.f0.dayOfWeek, + (Double) inference.f1.getClassification("dayOfWeek").getMostProbableValue(1), + inference.f1.getAnomalyScore()); + } + }); + + result.print(); + + env.execute(); + } + + + private static class DayDemoRecordSourceFunction extends TestSourceFunction { + + private volatile int dayOfWeek = 0; + + public DayDemoRecordSourceFunction(int numCheckpoints, boolean failAfterCheckpoint) { + super(numCheckpoints, failAfterCheckpoint); + } + + @Override + protected Supplier generate() { + return new Supplier() { + @Override + public TestHarness.DayDemoRecord get() { + return new TestHarness.DayDemoRecord(dayOfWeek++ % 7); + } + }; + } + + @Override + public Long snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { + super.snapshotState(checkpointId, checkpointTimestamp); + return Long.valueOf(dayOfWeek); + } + + @Override + public void restoreState(Long state) throws Exception { + super.restoreState(state); + dayOfWeek = state.intValue(); + } + } } diff --git a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/TestSourceFunction.java b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/TestSourceFunction.java new file mode 100644 index 0000000..1b9aab5 --- /dev/null +++ b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/TestSourceFunction.java @@ -0,0 +1,87 @@ +package org.numenta.nupic.flink.streaming.api; + +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.streaming.api.checkpoint.Checkpointed; +import org.apache.flink.streaming.api.functions.source.SourceFunction; + +import java.util.Iterator; +import java.util.function.Supplier; +import java.util.stream.Stream; + +/** + * Emit records until exactly numCheckpoints occur. + * + * Checkpoints are based on stream barriers that flow with the records; + * by ceasing to emit records after checkpoint X, the sink is guaranteed to + * not receive any records after checkpoint X. This arrangement stabilizes the test. + */ +public abstract class TestSourceFunction implements SourceFunction, Checkpointed, CheckpointListener { + private static final long serialVersionUID = 1L; + + private volatile boolean running = true; + + private int numCheckpoints; + private boolean failAfterCheckpoint; + + private volatile boolean failOnNext = false; + + /** + * Create a test source function that runs until a given number of checkpoints. + * @param numCheckpoints the number of checkpoints to run for. + * @param failAfterCheckpoint indicates whether to simulate a failure after the first checkpoint. + */ + public TestSourceFunction(int numCheckpoints, boolean failAfterCheckpoint) { + this.numCheckpoints = numCheckpoints; + this.failAfterCheckpoint = failAfterCheckpoint; + } + + /** + * Generate an unbounded sequence of elements. + */ + protected abstract Supplier generate(); + + @Override + public void run(SourceContext ctx) throws Exception { + Iterator inputs = Stream.generate(generate()).iterator(); + while(running) { + synchronized (ctx.getCheckpointLock()) { + if(running) { + if(failOnNext) { + failOnNext = false; + throw new Exception("Artificial Failure"); + } + assert(inputs.hasNext()); + ctx.collect(inputs.next()); + Thread.sleep(10); + } + } + } + } + + @Override + public void cancel() { + running = false; + } + + @Override + public Long snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { + if(--numCheckpoints == 0) { + running = false; + } + return 0L; + } + + @Override + public void restoreState(Long state) throws Exception { + // do not cause repeated failures + failAfterCheckpoint = false; + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + if(running && failAfterCheckpoint) { + // any simulated failure should come after the entire checkpoint has been taken. + failOnNext = true; + } + } +} From 60aeed8a628d42abf871e07be8cb2a724d6aade3 Mon Sep 17 00:00:00 2001 From: wrighe3 Date: Sun, 1 May 2016 23:22:03 -0700 Subject: [PATCH 5/6] update version to 0.6.7 --- flink-htm-examples/build.gradle | 2 +- flink-htm-streaming-scala/build.gradle | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flink-htm-examples/build.gradle b/flink-htm-examples/build.gradle index cb10eda..519f171 100644 --- a/flink-htm-examples/build.gradle +++ b/flink-htm-examples/build.gradle @@ -24,7 +24,7 @@ dependencies { compile 'org.scala-lang:scala-library:2.11.7' // htm.java - compile 'org.numenta:htm.java:0.6.5' + compile 'org.numenta:htm.java:0.6.7' // flink compile 'org.apache.flink:flink-scala_2.11:1.0.0' diff --git a/flink-htm-streaming-scala/build.gradle b/flink-htm-streaming-scala/build.gradle index 9cc8bf5..1dc0ce2 100644 --- a/flink-htm-streaming-scala/build.gradle +++ b/flink-htm-streaming-scala/build.gradle @@ -21,7 +21,7 @@ dependencies { compile 'org.scala-lang:scala-library:2.11.7' // htm.java - compile 'org.numenta:htm.java:0.6.5' + compile 'org.numenta:htm.java:0.6.7' // flink compile 'org.apache.flink:flink-scala_2.11:1.0.0' From fd059b95f351abb9e55cabec30f4b72dde4ad0b1 Mon Sep 17 00:00:00 2001 From: wrighe3 Date: Sun, 1 May 2016 23:46:01 -0700 Subject: [PATCH 6/6] Keyed Stream serialization test --- .../streaming/api/HTMIntegrationTest.java | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java index cd62764..0a05e89 100644 --- a/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java +++ b/flink-htm-streaming-java/src/test/java/org/numenta/nupic/flink/streaming/api/HTMIntegrationTest.java @@ -103,6 +103,37 @@ public Tuple3 select(Tuple2 source = env + .addSource(new DayDemoRecordSourceFunction(2, true)) + .keyBy("dayOfWeek"); + + DataStream> result = + HTM.learn(source, new TestHarness.DayDemoNetworkFactory()) + .select(new InferenceSelectFunction>() { + @Override + public Tuple3 select(Tuple2 inference) throws Exception { + return new Tuple3( + inference.f0.dayOfWeek, + (Double) inference.f1.getClassification("dayOfWeek").getMostProbableValue(1), + inference.f1.getAnomalyScore()); + } + }); + + result.print(); + + env.execute(); + } private static class DayDemoRecordSourceFunction extends TestSourceFunction {