Skip to content

Commit

Permalink
Merge branch 'vectorized_features' of github.com:predict-idlab/tsflex…
Browse files Browse the repository at this point in the history
… into vectorized_features
  • Loading branch information
jonasvdd committed Mar 22, 2022
2 parents 3467631 + a4862b1 commit 9eb7b06
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
8 changes: 7 additions & 1 deletion tsflex/features/segmenter/strided_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,13 @@ def apply_func(self, func: FuncWrapper) -> pd.DataFrame:
views.append(
_sliding_strided_window_1d(sc.values, windows[0], strides[0])
)
out = np.asarray(func(*views)).T # .T to comply with expected output format
out = func(*views)

out_type = type(out)
out = np.asarray(out)
# When multiple outputs are returned (= tuple) they should be transposed
# when combining into an array
out = out.T if out_type is tuple else out

else:
# Sequential function execution (default)
Expand Down
37 changes: 31 additions & 6 deletions tsflex/features/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,35 @@ def _get_name(func: Callable) -> str:
except:
return type(func).__name__

def _get_funcwrapper_func_and_kwargs(func: FuncWrapper) -> Tuple[Callable, dict]:
"""Extract the function and keyword arguments from the given FuncWrapper.
Parameters
----------
func: FuncWrapper
The FuncWrapper to extract the function and kwargs from.
Returns
-------
Tuple[Callable, dict]
Tuple of 1st the function of the FuncWrapper (is a Callable) and 2nd the keyword
arguments of the FuncWrapper.
"""
assert isinstance(func, FuncWrapper)

# Extract the function (is a Callable)
function = func.func

# Extract the keyword arguments
func_wrapper_kwargs = {}
func_wrapper_kwargs["output_names"] = func.output_names
func_wrapper_kwargs["input_type"] = func.input_type
func_wrapper_kwargs["vectorized"] = func.vectorized
func_wrapper_kwargs.update(func.kwargs)

return function, func_wrapper_kwargs


def _make_single_func_robust(
func: Union[Callable, FuncWrapper],
Expand Down Expand Up @@ -105,14 +134,10 @@ def _make_single_func_robust(
"""
assert isinstance(func, FuncWrapper) or isinstance(func, Callable)

# Extract the keyword arguments from the function wrapper
func_wrapper_kwargs = {}
if isinstance(func, FuncWrapper):
_func = func
func = _func.func
func_wrapper_kwargs["output_names"] = _func.output_names
func_wrapper_kwargs["input_type"] = _func.input_type
func_wrapper_kwargs.update(_func.kwargs)
# Extract the function and keyword arguments from the function wrapper
func, func_wrapper_kwargs = _get_funcwrapper_func_and_kwargs(func)

output_names = func_wrapper_kwargs.get("output_names")

Expand Down

0 comments on commit 9eb7b06

Please sign in to comment.