Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
gviejo committed May 7, 2024
1 parent 7282d0d commit 95d078a
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 64 deletions.
14 changes: 4 additions & 10 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@
_get_terminal_size,
_split_tsd,
_TsdFrameSliceHelper,
convert_to_jax_array,
convert_to_numpy_array,
get_backend,
convert_to_array,
is_array_like,
)

Expand Down Expand Up @@ -73,11 +71,7 @@ class BaseTsd(Base, NDArrayOperatorsMixin, abc.ABC):
def __init__(self, t, d, time_units="s", time_support=None):
super().__init__(t, time_units, time_support)

# Check if jax backend
if get_backend() == "jax":
self.values = convert_to_jax_array(d, "d")
else:
self.values = convert_to_numpy_array(d, "d")
self.values = convert_to_array(d, "d")

assert len(self.index) == len(
self.values
Expand Down Expand Up @@ -455,7 +449,7 @@ def dropna(self, update_time_support=True):
if hasattr(self, "columns"):
kwargs["columns"] = self.columns

return self.__class__(t=t, d=d, time_support=ep)
return self.__class__(t=t, d=d, time_support=ep, **kwargs)

def convolve(self, array, ep=None, trim="both"):
"""Return the discrete linear convolution of the time series with a one dimensional sequence.
Expand All @@ -472,7 +466,7 @@ def convolve(self, array, ep=None, trim="both"):
----------
array : array-like
One dimensional input array-like.
ep : None, optional
The epochs to apply the convolution
trim : str, optional
Expand Down
50 changes: 10 additions & 40 deletions pynapple/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
from .config import nap_config


def convert_to_array(array, array_name):
# Check if jax backend
if get_backend() == "jax":
from pynajax.utils import convert_to_jax_array

return convert_to_jax_array(array, array_name)
else:
return convert_to_numpy_array(array, array_name)


def convert_to_numpy_array(array, array_name):
"""Convert any array like object to numpy ndarray.
Expand Down Expand Up @@ -48,46 +58,6 @@ def convert_to_numpy_array(array, array_name):
)


def convert_to_jax_array(array, array_name):
"""Convert any array like object to jax Array.
Parameters
----------
array : ArrayLike
array_name : str
Array name if RuntimeError is raised or object is casted to numpy
Returns
-------
jax.Array
Jax array object
Raises
------
RuntimeError
If input can't be converted to jax array
"""
import jax.numpy as jnp

if isinstance(array, Number):
return jnp.array([array])
elif isinstance(array, (list, tuple)):
return jnp.array(array)
elif isinstance(array, jnp.ndarray):
return array
elif isinstance(array, np.ndarray):
return cast_to_jax(array, array_name)
elif is_array_like(array):
return cast_to_jax(array, array_name)
else:
raise RuntimeError(
"Unknown format for {}. Accepted formats are numpy.ndarray, list, tuple or any array-like objects.".format(
array_name
)
)


def get_backend():
"""
Return the current backend of pynapple. Possible backends are
Expand Down
6 changes: 0 additions & 6 deletions tests/test_correlograms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
# -*- coding: utf-8 -*-
# @Author: gviejo
# @Date: 2022-03-30 11:16:22
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2024-04-04 10:46:31
#!/usr/bin/env python

"""Tests of correlograms for `pynapple` package."""

Expand Down
8 changes: 0 additions & 8 deletions tests/test_numpy_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
# -*- coding: utf-8 -*-
# @Author: Guillaume Viejo
# @Date: 2023-09-18 18:11:24
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2024-04-04 11:34:35



import pynapple as nap
import numpy as np
import pytest
Expand Down

0 comments on commit 95d078a

Please sign in to comment.