Skip to content

Commit

Permalink
tidy up pupil diameter
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaprins committed Feb 24, 2025
1 parent fd325fb commit f8ca302
Showing 1 changed file with 50 additions and 44 deletions.
94 changes: 50 additions & 44 deletions examples/mouse_eye_movements.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,23 +137,25 @@ def plot_x_y_keypoints(da_dict, ylim=None, **kwargs):
eye_midpoint = da.sel(keypoints=["eye-L", "eye-R"]).mean("keypoints")
position_norm_dict[da_name] = da - eye_midpoint
# %%
# We plot the x and y positions again, but now using the processed data.
# We plot the x and y positions again, but now using the normalised data.
fig = plot_x_y_keypoints(
position_norm_dict, ylim=(-200, 200), time=time_points
)
fig.show()
# %%
# Pupil position over time
# ---------------------------------
# A pupil centroid keypoint can be added to the data array using
# ------------------------
# To look at pupil position, and later also velocity, over time we use the
# pupil centroid (in this case the midpoint between "pupil-L" and "pupil-R").
# The pupil centroid keypoint ("pupil-C") is be added to the data array using
# ``xarray.DataArray.assign_coords`` and ``xarray.concat``.
for da_name, da in position_norm_dict.items():
pupil_centroid = da.sel(keypoints=["pupil-L", "pupil-R"]).mean("keypoints")
pupil_centroid = pupil_centroid.assign_coords({"keypoints": "pupil-C"})
position_norm_dict[da_name] = xr.concat([da, pupil_centroid], "keypoints")

# %%
# Now the position of the pupil centroid ("pupil-C") can be plotted.
# Now the position of the pupil centroid can be plotted.
fig, axes = plt.subplots(2, 1, figsize=(6, 4))
for i, (da_name, da) in enumerate(position_norm_dict.items()):
da = da.sel(keypoints="pupil-C", time=time_points) # select data to plot
Expand All @@ -178,7 +180,7 @@ def plot_x_y_keypoints(da_dict, ylim=None, **kwargs):
# Pupil velocity over time
# ------------------------
# We use ``compute_velocity`` from ``movement``'s ``kinematics`` module to
# calculate the velocity with which the centre of the pupil ("pupil-C") moves.
# calculate the velocity with which the pupil centroid moves.
velocity_dict = {}
for da_name, da in position_norm_dict.items():
velocity_dict[da_name] = kin.compute_velocity(da.sel(keypoints="pupil-C"))
Expand All @@ -204,57 +206,61 @@ def plot_x_y_keypoints(da_dict, ylim=None, **kwargs):
# %%
# The positive peaks correspond to rapid eye movements to the right, the
# negative peaks correspond to rapid eye movements to the left.

# %%
# Pupil diameter
# Pupil diameter
# --------------
# In these datasets, the distance between the two pupil keypoints
# is used to quantify the pupil diameter.
# Here we define the pupil diameter as the distance between the two pupil
# keypoints. We use ``compute_pairwise_distances`` from ``movement.kinematics``
# to calculate the Euclidean distance between "pupil-L" and "pupil-R".
pupil_diameter_dict = {}
for da_name, da in position_norm_dict.items():
left = da.sel(keypoints="pupil-L").squeeze()
right = da.sel(keypoints="pupil-R").squeeze()
dx = right.sel(space="x") - left.sel(space="x")
dy = right.sel(space="y") - left.sel(space="y")
# TODO: can this be done with compute_pairwise_distances from kinematics?
pupil_diameter_dict[da_name] = (dx**2 + dy**2) ** 0.5 # euclidean distance
pupil_diameter_dict[da_name] = kin.compute_pairwise_distances(
da, "keypoints", {"pupil-L": "pupil-R"}
)

fig, ax = plt.subplots(figsize=(5, 3))
for da_name, da in pupil_diameter_dict.items():
ax.plot(da.time, da, label=da_name)
ax.set_xlabel("Time (s)")
ax.set_ylabel("Pupil Diameter (pixels)")
ax.legend()
ax.set_title("Pupil Diameter")
plt.tight_layout()
plt.show()

# %%
# Pupil Diameter after filter
# ---------------------------
# The pupil diameter is plotted using a function so that we can easily compare
# the effect of filtering.
def plot_pupil_diameter(da_dict, **kwargs):
fig, ax = plt.subplots(figsize=(5, 3))
for da_name, da in da_dict.items():
ax.plot(da.sel(**kwargs), label=da_name)
ax.set_xlabel("Time (s)")
ax.set_ylabel("Pupil Diameter (pixels)")
ax.legend()
ax.set_title("Pupil Diameter")
plt.tight_layout()
return fig, ax


fig, ax = plot_pupil_diameter(pupil_diameter_dict)
fig.show()
# %%
# Filtered Pupil Diameter
# -----------------------
# A filter can be used to smooth out pupil size data. Unlike eye movements,
# which can be extremely fast, pupil size is unlikely to change rapidly. A
# Moving Average Filter is used here to smooth the data by averaging a
# specified number of data points (defined by the window size) to reduce noise.

# TODO: Use movement filter!
filter = 150 # number of frames over which the filter is used
fig, ax = plt.subplots(figsize=(5, 3))
filter = np.ones(150)
filter_dict = {}
for da_name, da in pupil_diameter_dict.items():
filtered_data = np.convolve(da, np.ones(filter) / filter, mode="same")
# remove the first and last timepoints that are distorted by the filter
plot_slice = slice(filter // 2, -filter // 2)
ax.plot(
da.coords["time"][plot_slice],
filtered_data[plot_slice],
label=da_name + " (filter)",
)
ax.set(
xlabel="Time (frames)",
ylabel="Pupil Diameter (pixels)",
title="Filtered Pupil Diameter",
)
ax.legend()
plt.tight_layout()
plt.show()
da_filter = da.copy(deep=True)
# Calculate smooth average using numpy.convolve
da_filter.data = np.convolve(da_filter, filter / len(filter), mode="same")
filter_dict[da_name] = da_filter

# %%
# The moving average filter distorts the first few and last frames of the pupil
# diameter data which is why we exclude the first and last number of frames
# corresponding to half the filter length.

plot_slice = slice(len(filter) // 2, -len(filter) // 2)
time_points = filter_dict["black"].time[plot_slice] # same for both datasets
fig, ax = plot_pupil_diameter(filter_dict, time=time_points)
ax.set_title("Filtered Pupil Diameter")
fig.show()
# %%

0 comments on commit f8ca302

Please sign in to comment.