Skip to content

Commit

Permalink
changing tests for pynajax
Browse files Browse the repository at this point in the history
  • Loading branch information
gviejo committed May 6, 2024
1 parent d86965c commit bf866f0
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 12 deletions.
37 changes: 34 additions & 3 deletions draft_pynapple_fastplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def get_memory_map(filepath, nChannels, frequency=20000):



fig = fpl.Figure(shape=(2,1))
fig = fpl.Figure(canvas="glfw", shape=(2,1))
fig[0,0].add_line(data=lfp2, thickness=1, cmap="autumn")
fig[1,0].add_scatter(tmp)
fig.show()

fig.show(maintain_aspect=False)
# fpl.run()



Expand All @@ -68,6 +68,37 @@ def get_memory_map(filepath, nChannels, frequency=20000):
# grid_plot['lfp'].add_line(lfp.t, lfp[:,14].d)


import numpy as np
import fastplotlib as fpl

fig = fpl.Figure(canvas="glfw")#, shape=(2,1), controller_ids="sync")
fig[0,0].add_line(data=np.random.randn(1000))
fig.show(maintain_aspect=False)

fig2 = fpl.Figure(canvas="glfw", controllers=fig.controllers)#, shape=(2,1), controller_ids="sync")
fig2[0,0].add_line(data=np.random.randn(1000)*1000)
fig2.show(maintain_aspect=False)



# Not sure about this :
fig[1,0].controller.controls["mouse1"] = "pan", "drag", (1.0, 0.0)

fig[1,0].controller.controls.pop("mouse2")
fig[1,0].controller.controls.pop("mouse4")
fig[1,0].controller.controls.pop("wheel")

import pygfx

controller = pygfx.PanZoomController()
controller.controls.pop("mouse1")
controller.add_camera(fig[0, 0].camera)
controller.register_events(fig[0, 0].viewport)

controller2 = pygfx.PanZoomController()
controller2.add_camera(fig[1, 0].camera)
controller2.controls.pop("mouse1")
controller2.register_events(fig[1, 0].viewport)



Expand Down
2 changes: 2 additions & 0 deletions pynapple/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def convert_to_jax_array(array, array_name):
return array
elif isinstance(array, np.ndarray):
return cast_to_jax(array, array_name)
elif is_array_like(array):
return cast_to_jax(array, array_name)
else:
raise RuntimeError(
"Unknown format for {}. Accepted formats are numpy.ndarray, list, tuple or any array-like objects.".format(
Expand Down
24 changes: 15 additions & 9 deletions tests/test_non_numpy_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def test_tsd_type_d(self, time, data, expectation):
"""Verify that the data attribute 'd' of a Tsd object is stored as a numpy.ndarray."""
with expectation:
ts = nap.Tsd(t=time, d=data)
assert isinstance(ts.d, np.ndarray)
if nap.nap_config.backend == "numba":
assert isinstance(ts.d, np.ndarray)

@pytest.mark.filterwarnings("ignore")
@pytest.mark.parametrize(
Expand Down Expand Up @@ -132,8 +133,9 @@ def test_tsd_type_t(self, time, data, expectation):
)
def test_tsd_warn(self, data, expectation):
"""Check for warnings when the data attribute 'd' is automatically converted to numpy.ndarray."""
with expectation:
nap.Tsd(t=np.array(data), d=data)
if nap.nap_config.backend == "numba":
with expectation:
nap.Tsd(t=np.array(data), d=data)


class TestTsdFrameArray:
Expand Down Expand Up @@ -169,7 +171,8 @@ def test_tsdframe_type(self, time, data, expectation):
"""Verify that the data attribute 'd' of a TsdFrame object is stored as a numpy.ndarray."""
with expectation:
ts = nap.TsdFrame(t=time, d=data)
assert isinstance(ts.d, np.ndarray)
if nap.nap_config.backend == "numba":
assert isinstance(ts.d, np.ndarray)

@pytest.mark.filterwarnings("ignore")
@pytest.mark.parametrize(
Expand Down Expand Up @@ -202,8 +205,9 @@ def test_tsdframe_type_t(self, time, data, expectation):
)
def test_tsdframe_warn(self, data, expectation):
"""Check for warnings when the data attribute 'd' is automatically converted to numpy.ndarray."""
with expectation:
nap.TsdFrame(t=np.array(data), d=data)
if nap.nap_config.backend == "numba":
with expectation:
nap.TsdFrame(t=np.array(data), d=data)


class TestTsdTensorArray:
Expand Down Expand Up @@ -245,7 +249,8 @@ def test_tsdtensor_type_d(self, time, data, expectation):
"""Verify that the data attribute 'd' of a TsdTensor object is stored as a numpy.ndarray."""
with expectation:
ts = nap.TsdTensor(t=time, d=data)
assert isinstance(ts.d, np.ndarray)
if nap.nap_config.backend == "numba":
assert isinstance(ts.d, np.ndarray)

@pytest.mark.filterwarnings("ignore")
@pytest.mark.parametrize(
Expand Down Expand Up @@ -278,6 +283,7 @@ def test_tsdtensor_type_t(self, time, data, expectation):
)
def test_tsdtensor_warn(self, data, expectation):
"""Check for warnings when the data attribute 'd' is automatically converted to numpy.ndarray."""
with expectation:
nap.TsdTensor(t=np.ravel(np.array(data)), d=data)
if nap.nap_config.backend == "numba":
with expectation:
nap.TsdTensor(t=np.ravel(np.array(data)), d=data)

0 comments on commit bf866f0

Please sign in to comment.