Skip to content
This repository was archived by the owner on Dec 31, 2020. It is now read-only.

Commit

Permalink
Merge pull request #15 from nupic-community/update-to-0.6.7
Browse files Browse the repository at this point in the history
Update to 0.6.7
  • Loading branch information
EronWright committed May 2, 2016
2 parents 3fc91f3 + fd059b9 commit b18a5f2
Show file tree
Hide file tree
Showing 21 changed files with 620 additions and 268 deletions.
2 changes: 1 addition & 1 deletion flink-htm-examples/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies {
compile 'org.scala-lang:scala-library:2.11.7'

// htm.java
compile 'org.numenta:htm.java:0.6.5'
compile 'org.numenta:htm.java:0.6.7'

// flink
compile 'org.apache.flink:flink-scala_2.11:1.0.0'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object NetworkDemoParameters {
//SpatialPooler specific
POTENTIAL_RADIUS -> 12, //3
POTENTIAL_PCT -> 0.5, //0.5
GLOBAL_INHIBITIONS -> false,
GLOBAL_INHIBITION -> false,
LOCAL_AREA_DENSITY -> -1.0,
NUM_ACTIVE_COLUMNS_PER_INH_AREA -> 5.0,
STIMULUS_THRESHOLD -> 1.0,
Expand All @@ -49,7 +49,7 @@ object NetworkDemoParameters {
PERMANENCE_DECREMENT -> 0.05,
ACTIVATION_THRESHOLD -> 4))
.union(Parameters(
GLOBAL_INHIBITIONS -> true,
GLOBAL_INHIBITION -> true,
COLUMN_DIMENSIONS -> Array(2048),
CELLS_PER_COLUMN -> 32,
NUM_ACTIVE_COLUMNS_PER_INH_AREA -> 40.0,
Expand Down Expand Up @@ -80,7 +80,7 @@ trait WorkshopAnomalyParameters {
// spParams
POTENTIAL_PCT -> 0.8,
COLUMN_DIMENSIONS -> Array(2048),
GLOBAL_INHIBITIONS -> true,
GLOBAL_INHIBITION -> true,
/* inputWidth */
MAX_BOOST -> 1.0,
NUM_ACTIVE_COLUMNS_PER_INH_AREA -> 40,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,17 @@ object Demo extends HotGymModel {
.mapWithState { (inference, state: Option[Double]) =>

val prediction = Prediction(
inference.getInput.timestamp.toString(LOOSE_DATE_TIME),
inference.getInput.consumption,
inference._1.timestamp.toString(LOOSE_DATE_TIME),
inference._1.consumption,
state match {
case Some(prediction) => prediction
case None => 0.0
},
inference.getAnomalyScore)
inference._2.getAnomalyScore)

// store the prediction about the next value as state for the next iteration,
// so that actual vs predicted is a meaningful comparison
val predictedConsumption = inference.getClassification("consumption").getMostProbableValue(1).asInstanceOf[Any] match {
val predictedConsumption = inference._2.getClassification("consumption").getMostProbableValue(1).asInstanceOf[Any] match {
case value: Double if value != 0.0 => value
case _ => state.getOrElse(0.0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object Demo extends TrafficModel {
.filter { report => report.datetime.isBefore(investigationInterval.getEnd) }
.keyBy("streamId")
.learn(network)
.select(inference => (inference.getInput, inference.getAnomalyScore))
.select(inference => (inference._1, inference._2.getAnomalyScore))

val anomalousRoutes = anomalyScores
.filter { anomaly => investigationInterval.contains(anomaly._1.datetime) }
Expand Down
2 changes: 1 addition & 1 deletion flink-htm-streaming-java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies {
compile 'org.slf4j:slf4j-api:1.7.13'

// htm.java
compile 'org.numenta:htm.java:0.6.5'
compile 'org.numenta:htm.java:0.6.7'

// flink
compile 'org.apache.flink:flink-java:1.0.0'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package org.numenta.nupic.flink.serialization;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoException;
import com.esotericsoftware.kryo.Serializer;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.util.FieldAccessor;
import org.numenta.nupic.Persistable;
import org.numenta.nupic.network.Network;
import org.numenta.nupic.serialize.HTMObjectInput;
import org.numenta.nupic.serialize.HTMObjectOutput;
import org.numenta.nupic.serialize.SerialConfig;
import org.numenta.nupic.serialize.SerializerCore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.lang.reflect.Field;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Kryo serializer for HTM network and related objects.
*
*/
public class KryoSerializer<T extends Persistable> extends Serializer<T> implements Serializable {

protected static final Logger LOGGER = LoggerFactory.getLogger(KryoSerializer.class);

private final SerializerCore serializer = new SerializerCore(SerialConfig.DEFAULT_REGISTERED_TYPES);

/**
* Write the given instance to the given output.
*
* @param kryo instance of {@link Kryo} object
* @param output a Kryo {@link Output} object
* @param t instance to serialize
*/
@Override
public void write(Kryo kryo, Output output, T t) {
try {
preSerialize(t);

try(ByteArrayOutputStream stream = new ByteArrayOutputStream(4096)) {

// write the object using the HTM serializer
HTMObjectOutput writer = serializer.getObjectOutput(stream);
writer.writeObject(t, t.getClass());
writer.close();

// write the serialized data
output.writeInt(stream.size());
stream.writeTo(output);

LOGGER.debug("wrote {} bytes", stream.size());
}
}
catch(IOException e) {
throw new KryoException(e);
}
}

/**
* Read an instance of the given class from the given input.
*
* @param kryo instance of {@link Kryo} object
* @param input a Kryo {@link Input}
* @param aClass The class of the object to be read in.
* @return an instance of type &lt;T&gt;
*/
@Override
public T read(Kryo kryo, Input input, Class<T> aClass) {

// read the serialized data
byte[] data = new byte[input.readInt()];
input.readBytes(data);

try {
try(ByteArrayInputStream stream = new ByteArrayInputStream(data)) {
HTMObjectInput reader = serializer.getObjectInput(stream);
T t = (T) reader.readObject(aClass);

postDeSerialize(t);

return t;
}
}
catch(Exception e) {
throw new KryoException(e);
}
}

/**
* Copy the given instance.
* @param kryo instance of {@link Kryo} object
* @param original an object to copy.
* @return
*/
@Override
public T copy(Kryo kryo, T original) {
try {
preSerialize(original);

try(CopyStream output = new CopyStream(4096)) {
HTMObjectOutput writer = serializer.getObjectOutput(output);
writer.writeObject(original, original.getClass());
writer.close();

try(InputStream input = output.toInputStream()) {
HTMObjectInput reader = serializer.getObjectInput(input);
T t = (T) reader.readObject(original.getClass());

postDeSerialize(t);

return t;
}
}
}
catch(Exception e) {
throw new KryoException(e);
}
}

static class CopyStream extends ByteArrayOutputStream {
public CopyStream(int size) { super(size); }

/**
* Get an input stream based on the contents of this output stream.
* Do not use the output stream after calling this method.
* @return an {@link InputStream}
*/
public InputStream toInputStream() {
return new ByteArrayInputStream(this.buf, 0, this.count);
}
}

/**
* The HTM serializer handles the Persistable callbacks automatically, but
* this method is for any additional actions to be taken.
* @param t the instance to be serialized.
*/
protected void preSerialize(T t) {
}

/**
* The HTM serializer handles the Persistable callbacks automatically, but
* this method is for any additional actions to be taken.
* @param t the instance newly deserialized.
*/
protected void postDeSerialize(T t) {
}

/**
* Register the HTM types with the Kryo serializer.
* @param config
*/
public static void registerTypes(ExecutionConfig config) {
for(Class<?> c : SerialConfig.DEFAULT_REGISTERED_TYPES) {
Class<?> serializerClass = DEFAULT_SERIALIZERS.getOrDefault(c, (Class<?>) KryoSerializer.class);
config.registerTypeWithKryoSerializer(c, (Class<? extends Serializer<?>>) serializerClass);
}
for(Class<?> c : KryoSerializer.ADDITIONAL_REGISTERED_TYPES) {
Class<?> serializerClass = DEFAULT_SERIALIZERS.getOrDefault(c, (Class<?>) KryoSerializer.class);
config.registerTypeWithKryoSerializer(c, (Class<? extends Serializer<?>>) serializerClass);
}
}

static final Class<?>[] ADDITIONAL_REGISTERED_TYPES = { Network.class };

/**
* A map of serializers for various classes.
*/
static final Map<Class<?>,Class<?>> DEFAULT_SERIALIZERS = Stream.of(
new Tuple2<>(Network.class, NetworkSerializer.class)
).collect(Collectors.toMap(kv -> kv.f0, kv -> kv.f1));


public static class NetworkSerializer extends KryoSerializer<Network> {

private final static Field shouldDoHaltField;

static {
try {
shouldDoHaltField = Network.class.getDeclaredField("shouldDoHalt");
shouldDoHaltField.setAccessible(true);
} catch (NoSuchFieldException e) {
throw new UnsupportedOperationException("unable to locate Network::shouldDoHalt", e);
}

}

@Override
protected void preSerialize(Network network) {
super.preSerialize(network);
try {
// issue: HTM.java #417
shouldDoHaltField.set(network, false);
} catch (IllegalAccessException e) {
throw new UnsupportedOperationException("unable to set Network::shouldDoHalt", e);
}
}

@Override
protected void postDeSerialize(Network network) {
super.postDeSerialize(network);
}
}
}
Loading

0 comments on commit b18a5f2

Please sign in to comment.