Skip to content

Commit

Permalink
Merge pull request #395 from nasa/feature/split_model
Browse files Browse the repository at this point in the history
  • Loading branch information
teubert authored Sep 13, 2022
2 parents 6948891 + 23929b2 commit 44c4524
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 22 deletions.
2 changes: 1 addition & 1 deletion examples/lstm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def future_loading3(t, x = None):
inputs = input_data,
outputs = output_data,
window=12,
epochs=3,
epochs=5,
units=64, # Additional units given the increased complexity of the system
input_keys = ['i', 'dt'],
output_keys = ['t', 'v'])
Expand Down
75 changes: 56 additions & 19 deletions src/prog_models/data_models/lstm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class LSTMStateTransitionModel(DataModel):
Most users will use the `LSTMStateTransitionModel.from_data` method to create a model, but the model can be created by passing in a model directly into the constructor. The LSTM model in this method maps from [u_t-n+1, z_t-n, ..., u_t, z_t-1] to z_t. Past :term:`input` are stored in the :term:`model` internal :term:`state`. Actual calculation of :term:`output` is performed when :py:func`LSTMStateTransitionModel.output` is called. When using in simulation that may not be until the simulation results are accessed.
Args:
model (keras.Model): Keras model to use for state transition
output_model (keras.Model): If a state model is present, maps from the state_model outputs to model :term:`output`. Otherwise, maps from model inputs to model :term:`output`
state_model (keras.Model, optional): Keras model to use for state transition
Keyword Args:
input_keys (list[str]): List of input keys
Expand All @@ -42,12 +43,15 @@ class LSTMStateTransitionModel(DataModel):
'measurement_noise': 0, # Default 0 noise
}

def __init__(self, model, **kwargs):
# Setup inputs, outputs, states
self.outputs = kwargs.get('output_keys', [f'z{i}' for i in range(model.output.shape[1])])
def __init__(self, output_model, state_model = None, **kwargs):
n_outputs = output_model.output.shape[1]
n_internal = 0 if state_model is None else state_model.output.shape[1]
input_shape = output_model.input.shape if state_model is None else state_model.input.shape
n_inputs = input_shape[-1]-n_outputs

input_shape = model.input.shape
input_keys = kwargs.get('input_keys', [f'u{i}' for i in range(input_shape[2]-len(self.outputs))])
# Setup inputs, outputs, states
self.outputs = kwargs.get('output_keys', [f'z{i}' for i in range(n_outputs)])
input_keys = kwargs.get('input_keys', [f'u{i}' for i in range(n_inputs)])
self.inputs = input_keys.copy()
# Outputs from the last step are part of input
self.inputs.extend([f'{z_key}_t-1' for z_key in self.outputs])
Expand All @@ -57,14 +61,16 @@ def __init__(self, model, **kwargs):
for j in range(input_shape[1]-1, -1, -1):
self.states.extend([f'{input_i}_t-{j}' for input_i in input_keys])
self.states.extend([f'{output_i}_t-{j+1}' for output_i in self.outputs])
self.states.extend([f'_model_output{i}' for i in range(n_internal)])

kwargs['window'] = input_shape[1]
kwargs['model'] = model # Putting it in the parameters dictionary simplifies pickling
kwargs['state_model'] = state_model
kwargs['output_model'] = output_model
# Putting it in the parameters dictionary simplifies pickling

super().__init__(**kwargs)

# Save Model
self.model = model
self.history = kwargs.get('history', None)

def __getstate__(self):
Expand Down Expand Up @@ -95,21 +101,37 @@ def initialize(self, u=None, z=None):
def next_state(self, x, u, _):
# Rotate new input into state
input_data = u.matrix

if self.parameters['state_model'] is None:
states = x.matrix[len(input_data):]
return self.StateContainer(np.vstack((states, input_data)))

states = x.matrix[len(input_data):]
return self.StateContainer(np.vstack((states, input_data)))
states = x.matrix[len(input_data):-self.parameters['state_model'].output_shape[1]]
states = np.vstack((states, input_data))

if states[0,0] is None:
return self.StateContainer(np.vstack((states, x.matrix[-self.parameters['state_model'].output_shape[1]:])))
else:
# Enough data has been received to calculate output
# Format input into np array with shape (1, window, num_inputs)
m_input = states.reshape(1, self.parameters['window'], len(self.inputs))
m_input = np.array(m_input, dtype=np.float)
internal_states = self.parameters['state_model'](m_input).numpy().T
return self.StateContainer(np.vstack((states, internal_states)))

def output(self, x):
if x.matrix[0,0] is None:
warn(f"Output estimation is not available until at least {1+self.parameters['window']} timesteps have passed.")
return self.OutputContainer(np.array([[None] for _ in self.outputs]))

# Enough data has been received to calculate output
# Format input into np array with shape (1, window, num_inputs)
m_input = x.matrix[:self.parameters['window']*len(self.inputs)].reshape(1, self.parameters['window'], len(self.inputs))

# Pass into model to calculate output
m_output = self.model(m_input)
# Pass internal states into model to calculate output
if self.parameters['state_model'] is None:
m_input = x.matrix.reshape(1, self.parameters['window'], len(self.inputs))
internal_states = np.array(m_input, dtype=np.float)
else:
internal_states = x.matrix[-self.parameters['state_model'].output_shape[1]:].T
m_output = self.parameters['output_model'](internal_states)

if 'normalization' in self.parameters:
m_output *= self.parameters['normalization'][1]
Expand All @@ -122,7 +144,13 @@ def summary(self, file= sys.stdout, expand_nested=False, show_trainable=False):
print("Inputs: ", self.inputs, file = file)
print("Outputs: ", self.outputs, file = file)
print("Window_size: ", self.parameters['window'], file = file)
self.model.summary(print_fn= file.write, expand_nested = expand_nested, show_trainable = show_trainable)
if self.parameters['state_model'] is not None:
print('\nState Model: ', file = file)
self.parameters['state_model'].summary(print_fn= file.write, expand_nested = expand_nested, show_trainable = show_trainable)

print('\nOutput Model: ', file = file)
self.parameters['output_model'].summary(print_fn= file.write, expand_nested = expand_nested, show_trainable = show_trainable)


@staticmethod
def pre_process_data(inputs, outputs, window, **kwargs):
Expand Down Expand Up @@ -197,7 +225,7 @@ def pre_process_data(inputs, outputs, window, **kwargs):
n_outputs = len(z[0])
z_i = [[z[i][k] for k in range(n_outputs)] for i in range(window+1, len(z))]
else:
raise TypeError(f"Unsupported input type: {type(z)} for internal element (data[0][i]")
raise TypeError(f"Unsupported input type: {type(z)} for internal element (output[i])")

# Also add to input (past outputs are part of input)
if len(u_i) == 0:
Expand Down Expand Up @@ -364,14 +392,23 @@ def from_data(cls, inputs, outputs, event_states = None, thresh_met = None, **kw
# Dropout prevents overfitting
x = layers.Dropout(params['dropout'])(x)

x = layers.Dense(z_all.shape[1] if z_all.ndim == 2 else 1)(x)
x = layers.Dense(z_all.shape[1] if z_all.ndim == 2 else 1, name='output')(x)
model = keras.Model(inputs, x)
model.compile(optimizer="rmsprop", loss="mse", metrics=["mae"])

# Train model
history = model.fit(u_all, z_all, epochs=params['epochs'], callbacks = callbacks, validation_split = params['validation_split'])

return cls(keras.models.load_model("best_model.keras"), history = history, **params)
model = keras.models.load_model("best_model.keras")

# Split model into separate models
n_state_layers = params['layers'] + 1 + (params['dropout'] > 0) + (params['normalize'])
output_layer_input = layers.Input(model.layers[n_state_layers-1].output.shape[1:])
output_layer = model.get_layer('output')(output_layer_input)
state_model = keras.Model(model.input, model.layers[n_state_layers-1].output)
output_model = keras.Model(output_layer_input, output_layer)

return cls(output_model, state_model, history = history, **params)

def simulate_to_threshold(self, future_loading_eqn, first_output = None, threshold_keys = None, **kwargs):
t = kwargs.get('t0', 0)
Expand Down
6 changes: 4 additions & 2 deletions tests/test_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,12 @@ def test_lstm_simple(self):
m = self._test_simple_case(LSTMStateTransitionModel, window=5, epochs=20, max_error=3)
self.assertListEqual(m.inputs, ['x_t-1'])
# Use set below so there's no issue with ordering
self.assertSetEqual(set(m.states), set(['x_t-1', 'x_t-2', 'x_t-3', 'x_t-4', 'x_t-5']))
keys = ['x_t-1', 'x_t-2', 'x_t-3', 'x_t-4', 'x_t-5']
keys.extend([f'_model_output{i}' for i in range(16)])
self.assertSetEqual(set(m.states), set(keys))

# Create from model
LSTMStateTransitionModel(m.model, output_keys = ['x'])
LSTMStateTransitionModel(m.parameters['output_model'], m.parameters['state_model'], output_keys = ['x'])
try:
# Test pickling model m
with self.assertWarns(RuntimeWarning):
Expand Down

0 comments on commit 44c4524

Please sign in to comment.