-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add fit_curve and predict_curve #139
Merged
+188
−2
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
76cf964
add prototype for fit_curve
40a7c43
add proper parameter parsing
9a66759
use partial function in test
e99dceb
add assert to test
f716b87
Merge branch 'main' into add-fit-curve-that-works
a55decd
cleanups in fit_curve
23a68b2
minor edits
5e421ce
Merge branch 'main' into add-fit-curve-that-works
991970a
progress with parameter passing
f4e68f8
revert changes to process decorator
7b903c6
add to tests
f626450
progress on predict
623626b
get predict to work
a40449e
remove comment
f42f812
fix up output datacube
688a3f1
add assertions
0b2b64a
preserve dimension order
c2e6bf2
Merge branch 'main' into add-fit-curve-that-works
1fbbdd5
add cast to datetime if appropriate
021e459
keep attrs
407f4fe
fix typo
53fbfbd
add ignore_nodata
20a8326
bump submodule
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .curve_fitting import * | ||
from .random_forest import * |
125 changes: 125 additions & 0 deletions
125
openeo_processes_dask/process_implementations/ml/curve_fitting.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from typing import Callable, Optional | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import xarray as xr | ||
from numpy.typing import ArrayLike | ||
|
||
from openeo_processes_dask.process_implementations.cubes import apply_dimension | ||
from openeo_processes_dask.process_implementations.data_model import RasterCube | ||
from openeo_processes_dask.process_implementations.exceptions import ( | ||
DimensionNotAvailable, | ||
) | ||
|
||
__all__ = ["fit_curve", "predict_curve"] | ||
|
||
|
||
def fit_curve( | ||
data: RasterCube, | ||
parameters: list, | ||
function: Callable, | ||
dimension: str, | ||
ignore_nodata: bool = True, | ||
): | ||
if dimension not in data.dims: | ||
raise DimensionNotAvailable( | ||
f"Provided dimension ({dimension}) not found in data.dims: {data.dims}" | ||
) | ||
|
||
dims_before = list(data.dims) | ||
|
||
# In the spec, parameters is a list, but xr.curvefit requires names for them, | ||
# so we do this to generate names locally | ||
parameters = {f"param_{i}": v for i, v in enumerate(parameters)} | ||
|
||
# The dimension along which to fit the curves cannot be chunked! | ||
rechunked_data = data.chunk({dimension: -1}) | ||
|
||
def wrapper(f): | ||
def _wrap(*args, **kwargs): | ||
return f( | ||
*args, | ||
**kwargs, | ||
positional_parameters={"x": 0, "parameters": slice(1, None)}, | ||
) | ||
|
||
return _wrap | ||
|
||
expected_dims_after = list(dims_before) | ||
expected_dims_after[dims_before.index(dimension)] = "param" | ||
|
||
# .curvefit returns some extra information that isn't required by the OpenEO process | ||
# so we simply drop these here. | ||
fit_result = ( | ||
rechunked_data.curvefit( | ||
dimension, | ||
wrapper(function), | ||
p0=parameters, | ||
param_names=list(parameters.keys()), | ||
skipna=ignore_nodata, | ||
) | ||
LukeWeidenwalker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.drop_dims(["cov_i", "cov_j"]) | ||
.to_array() | ||
.squeeze() | ||
.transpose(*expected_dims_after) | ||
) | ||
LukeWeidenwalker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
fit_result.attrs = data.attrs | ||
|
||
return fit_result | ||
LukeWeidenwalker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def predict_curve( | ||
parameters: RasterCube, | ||
function: Callable, | ||
dimension: str, | ||
labels: ArrayLike, | ||
): | ||
labels_were_datetime = False | ||
dims_before = list(parameters.dims) | ||
|
||
try: | ||
# Try parsing as datetime first | ||
labels = np.asarray(labels, dtype=np.datetime64) | ||
except ValueError: | ||
labels = np.asarray(labels) | ||
|
||
if np.issubdtype(labels.dtype, np.datetime64): | ||
LukeWeidenwalker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
labels = labels.astype(int) | ||
labels_were_datetime = True | ||
|
||
# This is necessary to pipe the arguments correctly through @process | ||
def wrapper(f): | ||
def _wrap(*args, **kwargs): | ||
return f( | ||
*args, | ||
positional_parameters={"parameters": 0}, | ||
named_parameters={"x": labels}, | ||
**kwargs, | ||
) | ||
|
||
return _wrap | ||
|
||
expected_dims_after = list(dims_before) | ||
expected_dims_after[dims_before.index("param")] = dimension | ||
|
||
predictions = xr.apply_ufunc( | ||
wrapper(function), | ||
parameters, | ||
vectorize=True, | ||
input_core_dims=[["param"]], | ||
output_core_dims=[[dimension]], | ||
dask="parallelized", | ||
output_dtypes=[np.float64], | ||
dask_gufunc_kwargs={ | ||
"allow_rechunk": True, | ||
"output_sizes": {dimension: len(labels)}, | ||
}, | ||
).transpose(*expected_dims_after) | ||
|
||
predictions = predictions.assign_coords({dimension: labels.data}) | ||
|
||
if labels_were_datetime: | ||
predictions[dimension] = pd.DatetimeIndex(predictions[dimension].values) | ||
|
||
return predictions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we do the same in
predict_curve
as well?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it matters for predict, because there each timestep can be inferenced for independently of the other steps!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LukeWeidenwalker that's true. But I was wondering if it would be faster when predicting on a datacube with many timesteps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm - not sure tbh, I haven't profiled
predict_curve
at all yet - I think I'll merge and deploy this now so we can start a training run at least, and revisit this if performance of inference turns out to be a problem!