Skip to content

Commit

Permalink
Merge pull request #55 from EthanJamesLew/bugbix/sp
Browse files Browse the repository at this point in the history
correct bad generate predictions with sampling period correction (#52)
  • Loading branch information
EthanJamesLew authored Feb 27, 2023
2 parents 890c222 + bd9ecf1 commit 09e1a82
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
5 changes: 5 additions & 0 deletions autokoopman/autokoopman.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,12 @@ def _sanitize_training_data(
# convert the data to autokoopman trajectories
if isinstance(training_data, TrajectoriesData):
if not isinstance(training_data, UniformTimeTrajectoriesData):
print(f"resampling trajectories as they need to be uniform time (sampling period {sampling_period})")
training_data = training_data.interp_uniform_time(sampling_period)
else:
if not np.isclose(training_data.sampling_period, sampling_period):
print(f"resampling trajectories because the sampling periods differ (original {training_data.sampling_period}, new {sampling_period})")
training_data = training_data.interp_uniform_time(sampling_period)
else:
# figure out how to add inputs
training_iter = (
Expand Down
7 changes: 6 additions & 1 deletion autokoopman/core/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,18 @@ def generate_predictions(
preds = {}
# get the predictions
for k, v in holdout_data._trajs.items():
# ugh, this is a hack--sampling period is meaningful if the system is discrete
# TODO: make the behavior of sampling period less problematic
_kwargs = {}
if hasattr(v, "sampling_period"):
_kwargs["sampling_period"] = v.sampling_period
sivp_interp = trained_model.model.solve_ivp(
v.states[0],
(np.min(v.times), np.max(v.times)),
inputs=v.inputs,
teval=v.times,
**_kwargs
)
# sivp_interp = sivp_interp.interp1d(v.times)
preds[k] = sivp_interp
return TrajectoriesData(preds)

Expand Down

0 comments on commit 09e1a82

Please sign in to comment.