Skip to content

Commit

Permalink
correct bad generate predictions with sampling period correction (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanJamesLew committed Feb 27, 2023
1 parent f70a5d3 commit bd9ecf1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
5 changes: 5 additions & 0 deletions autokoopman/autokoopman.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,12 @@ def _sanitize_training_data(
# convert the data to autokoopman trajectories
if isinstance(training_data, TrajectoriesData):
if not isinstance(training_data, UniformTimeTrajectoriesData):
print(f"resampling trajectories as they need to be uniform time (sampling period {sampling_period})")
training_data = training_data.interp_uniform_time(sampling_period)
else:
if not np.isclose(training_data.sampling_period, sampling_period):
print(f"resampling trajectories because the sampling periods differ (original {training_data.sampling_period}, new {sampling_period})")
training_data = training_data.interp_uniform_time(sampling_period)
else:
# figure out how to add inputs
training_iter = (
Expand Down
7 changes: 6 additions & 1 deletion autokoopman/core/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,18 @@ def generate_predictions(
preds = {}
# get the predictions
for k, v in holdout_data._trajs.items():
# ugh, this is a hack--sampling period is meaningful if the system is discrete
# TODO: make the behavior of sampling period less problematic
_kwargs = {}
if hasattr(v, "sampling_period"):
_kwargs["sampling_period"] = v.sampling_period
sivp_interp = trained_model.model.solve_ivp(
v.states[0],
(np.min(v.times), np.max(v.times)),
inputs=v.inputs,
teval=v.times,
**_kwargs
)
# sivp_interp = sivp_interp.interp1d(v.times)
preds[k] = sivp_interp
return TrajectoriesData(preds)

Expand Down

0 comments on commit bd9ecf1

Please sign in to comment.