{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DATA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from plotting_utils import *\n",
    "from plotting_ACs import *\n",
    "from clustering_utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Amacrine cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_pickle('AC_data.pkl') \n",
    "nb_clusters = 25\n",
    "cluster_id= df['cluster ID (diag)']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bipolar cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_bc = pd.read_pickle('BC_data.pkl') \n",
    "bc_nb_clusters = 24\n",
    "bc_cluster_id= df_bc['cluster ID (full)']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SPECTRAL CONTRAST"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Amacrine cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spectral_contrast = np.zeros((df['roi'].shape[0], 3))\n",
    "\n",
    "for i in range(df['roi'].shape[0]):\n",
    "    spectral_contrast[i,0] = (abs(df['green_center amplitude'][i]) - abs(df['uv_center amplitude'][i])) / (abs(df['green_center amplitude'][i]) + abs(df['uv_center amplitude'][i]))\n",
    "    spectral_contrast[i,1] = (abs(df['green_ring amplitude'][i]) - abs(df['uv_ring amplitude'][i])) / (abs(df['green_ring amplitude'][i]) + abs(df['uv_ring amplitude'][i])) \n",
    "    spectral_contrast[i,2] = (abs(df['green_surround amplitude'][i]) - abs(df['uv_surround amplitude'][i])) / (abs(df['green_surround amplitude'][i]) + abs(df['uv_surround amplitude'][i])) \n",
    "\n",
    "df['spectral_contrast_center'] = spectral_contrast[:,0]\n",
    "df['spectral_contrast_ring'] = spectral_contrast[:,1]\n",
    "df['spectral_contrast_surround'] = spectral_contrast[:,2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bipolar cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bc_spectral_contrast = np.zeros((df_bc['roi'].shape[0], 2))\n",
    "\n",
    "for i in range(df_bc['roi'].shape[0]):\n",
    "    bc_spectral_contrast[i,0] = (abs(df_bc['green_center amplitude'][i]) - abs(df_bc['uv_center amplitude'][i])) / (abs(df_bc['green_center amplitude'][i]) + abs(df_bc['uv_center amplitude'][i]))\n",
    "    bc_spectral_contrast[i,1] = (abs(df_bc['green_surround amplitude'][i]) - abs(df_bc['uv_surround amplitude'][i])) / (abs(df_bc['green_surround amplitude'][i]) + abs(df_bc['uv_surround amplitude'][i])) \n",
    "\n",
    "df_bc['spectral_contrast_center'] = bc_spectral_contrast[:,0]\n",
    "df_bc['spectral_contrast_surround'] = bc_spectral_contrast[:,1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ON-OFF INDEX"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Amacrine cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "line_duration= 1.6 #in ms\n",
    "baseline_time_ms = int(1000/line_duration)\n",
    "on_time_ms = int(1000/line_duration) # in 1 s how many data points we have \n",
    "off_time_ms = int(1000/line_duration)\n",
    "\n",
    "a=np.zeros(baseline_time_ms)\n",
    "b=np.ones(on_time_ms)\n",
    "c=np.zeros(off_time_ms)\n",
    "\n",
    "stimulus= np.concatenate((a,b,c))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "convolved_response_c = np.zeros((df['roi'].shape[0],stimulus.shape[0]))\n",
    "convolved_response_r = np.zeros((df['roi'].shape[0],stimulus.shape[0]))\n",
    "convolved_response_s = np.zeros((df['roi'].shape[0],stimulus.shape[0]))\n",
    "\n",
    "line_duration_s = 0.0016\n",
    "kernel_length_line = np.int(np.floor(2/line_duration_s))\n",
    "offset_after = np.int(np.floor(kernel_length_line*.25)) #lines to include into the future (using 1/4 of kernel length)\n",
    "offset_before = kernel_length_line-offset_after\n",
    "kernel_past = offset_before - 1\n",
    "kernel_future = offset_after\n",
    "\n",
    "for i in range(df['roi'].shape[0]):\n",
    "    convolved_response_uv_c = np.convolve(stimulus, np.flip(df['uv_center'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    convolved_response_green_c = np.convolve(stimulus, np.flip(df['green_center'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    convolved_response_uv_r = np.convolve(stimulus, np.flip(df['uv_ring'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    convolved_response_green_r = np.convolve(stimulus, np.flip(df['green_ring'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    convolved_response_uv_s = np.convolve(stimulus, np.flip(df['uv_surround'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    convolved_response_green_s = np.convolve(stimulus, np.flip(df['green_surround'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    \n",
    "    convolved_response_c_avg = (convolved_response_uv_c + convolved_response_green_c)/2\n",
    "    convolved_response_r_avg = (convolved_response_uv_r + convolved_response_green_r)/2\n",
    "    convolved_response_s_avg = (convolved_response_uv_s + convolved_response_green_s)/2\n",
    "    \n",
    "    convolved_response_c[i,:] = convolved_response_c_avg-convolved_response_c_avg.min()\n",
    "    convolved_response_r[i,:] = convolved_response_r_avg-convolved_response_r_avg.min()\n",
    "    convolved_response_s[i,:] = convolved_response_s_avg-convolved_response_s_avg.min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "window_on_start = 625\n",
    "window_on_end = 1250\n",
    "window_off_start = 1250\n",
    "window_off_end = 1875\n",
    "\n",
    "ooindex = np.zeros((df['roi'].shape[0],3))\n",
    "\n",
    "for i in range(df['roi'].shape[0]):\n",
    "    ooindex[i,0] = (convolved_response_c[i,window_on_start:window_on_end].mean() - convolved_response_c[i,window_off_start:window_off_end].mean())/(convolved_response_c[i,window_on_start:window_on_end].mean() + convolved_response_c[i,window_off_start:window_off_end].mean())\n",
    "    ooindex[i,1] = (convolved_response_r[i,window_on_start:window_on_end].mean() - convolved_response_r[i,window_off_start:window_off_end].mean())/(convolved_response_r[i,window_on_start:window_on_end].mean() + convolved_response_r[i,window_off_start:window_off_end].mean())\n",
    "    ooindex[i,2] = (convolved_response_s[i,window_on_start:window_on_end].mean() - convolved_response_s[i,window_off_start:window_off_end].mean())/(convolved_response_s[i,window_on_start:window_on_end].mean() + convolved_response_s[i,window_off_start:window_off_end].mean())\n",
    "\n",
    "df['OOi_center'] = ooindex[:,0]\n",
    "df['OOi_ring'] = ooindex[:,1]\n",
    "df['OOi_surround'] = ooindex[:,2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bipolar cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "line_duration= 1.6 #in ms\n",
    "baseline_time_ms = int(1000/line_duration)\n",
    "on_time_ms = int(1000/line_duration) # in 1 s how many data points we have\n",
    "off_time_ms = int(1000/line_duration)\n",
    "\n",
    "a=np.zeros(baseline_time_ms)\n",
    "b=np.ones(on_time_ms)\n",
    "c=np.zeros(off_time_ms)\n",
    "\n",
    "stimulus= np.concatenate((a,b,c))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bc_convolved_response_c = np.zeros((df_bc['roi'].shape[0],stimulus.shape[0]))\n",
    "bc_convolved_response_s = np.zeros((df_bc['roi'].shape[0],stimulus.shape[0]))\n",
    "\n",
    "line_duration_s = 0.0016\n",
    "kernel_length_line = np.int(np.floor(2/line_duration_s))\n",
    "offset_after = np.int(np.floor(kernel_length_line*.25)) #lines to include into the future (using 1/4 of kernel length)\n",
    "offset_before = kernel_length_line-offset_after\n",
    "kernel_past = offset_before - 1\n",
    "kernel_future = offset_after\n",
    "\n",
    "for i in range(df_bc['roi'].shape[0]):\n",
    "    bc_convolved_response_uv_c = np.convolve(stimulus, np.flip(df_bc['uv_center'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    bc_convolved_response_green_c = np.convolve(stimulus, np.flip(df_bc['green_center'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    bc_convolved_response_uv_s = np.convolve(stimulus, np.flip(df_bc['uv_surround'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    bc_convolved_response_green_s = np.convolve(stimulus, np.flip(df_bc['green_surround'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    \n",
    "    bc_convolved_response_c_avg = (bc_convolved_response_uv_c + bc_convolved_response_green_c)/2\n",
    "    bc_convolved_response_s_avg = (bc_convolved_response_uv_s + bc_convolved_response_green_s)/2\n",
    "    \n",
    "    bc_convolved_response_c[i,:] = bc_convolved_response_c_avg - bc_convolved_response_c_avg.min()\n",
    "    bc_convolved_response_s[i,:] = bc_convolved_response_s_avg - bc_convolved_response_s_avg.min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "window_on_start = 625\n",
    "window_on_end = 1250\n",
    "\n",
    "window_off_start = 1250\n",
    "window_off_end = 1875\n",
    "\n",
    "bc_ooindex = np.zeros((df_bc['roi'].shape[0],2))\n",
    "\n",
    "for i in range(df_bc['roi'].shape[0]):\n",
    "    bc_ooindex[i,0] = (bc_convolved_response_c[i,window_on_start:window_on_end].mean() - bc_convolved_response_c[i,window_off_start:window_off_end].mean())/(bc_convolved_response_c[i,window_on_start:window_on_end].mean() + bc_convolved_response_c[i,window_off_start:window_off_end].mean())\n",
    "    bc_ooindex[i,1] = (bc_convolved_response_s[i,window_on_start:window_on_end].mean() - bc_convolved_response_s[i,window_off_start:window_off_end].mean())/(bc_convolved_response_s[i,window_on_start:window_on_end].mean() + bc_convolved_response_s[i,window_off_start:window_off_end].mean())\n",
    "\n",
    "df_bc['OOi_center'] = bc_ooindex[:,0]\n",
    "df_bc['OOi_surround'] = bc_ooindex[:,1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example fields - Figure 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#field 1 (ventral)\n",
    "import datetime\n",
    "one_date = datetime.date(2022, 4, 13) \n",
    "one_exp = 1\n",
    "one_field = 1\n",
    "           \n",
    "first_field = df[(df['date'] == one_date) & \n",
    "                (df['exp_num'] == one_exp) &\n",
    "                (df['field_id'] == one_field)]\n",
    "\n",
    "first_field = first_field.reset_index(drop = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#field 2 (dorsal)\n",
    "two_date = datetime.date(2022, 5, 18) \n",
    "two_exp = 4\n",
    "two_field = 1\n",
    "           \n",
    "second_field = df[(df['date'] == two_date) & \n",
    "                (df['exp_num'] == two_exp) &\n",
    "                (df['field_id'] == two_field)]\n",
    "\n",
    "second_field = second_field.reset_index(drop = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "# Calculate amplitudes for plotting\n",
    "nb_rois= 6\n",
    "roi_contrast_center= np.zeros(((nb_rois),2))\n",
    "roi_contrast_ring= np.zeros(((nb_rois),2))\n",
    "roi_contrast_surround= np.zeros(((nb_rois),2))\n",
    "\n",
    "roi_contrast_center[0,0]= first_field['green_center amplitude'].loc[first_field['roi'] == 14].values[0]\n",
    "roi_contrast_center[0,1]= first_field['uv_center amplitude'].loc[first_field['roi'] == 14].values[0]\n",
    "roi_contrast_ring[0,0]= first_field['green_ring amplitude'].loc[first_field['roi'] == 14].values[0]\n",
    "roi_contrast_ring[0,1]= first_field['uv_ring amplitude'].loc[first_field['roi'] == 14].values[0]\n",
    "roi_contrast_surround[0,0]= first_field['green_surround amplitude'].loc[first_field['roi'] == 14].values[0]\n",
    "roi_contrast_surround[0,1]= first_field['uv_surround amplitude'].loc[first_field['roi'] == 14].values[0]\n",
    "\n",
    "roi_contrast_center[1,0]= first_field['green_center amplitude'].loc[first_field['roi'] == 23].values[0]\n",
    "roi_contrast_center[1,1]= first_field['uv_center amplitude'].loc[first_field['roi'] == 23].values[0]\n",
    "roi_contrast_ring[1,0]= first_field['green_ring amplitude'].loc[first_field['roi'] == 23].values[0]\n",
    "roi_contrast_ring[1,1]= first_field['uv_ring amplitude'].loc[first_field['roi'] == 23].values[0]\n",
    "roi_contrast_surround[1,0]= first_field['green_surround amplitude'].loc[first_field['roi'] == 23].values[0]\n",
    "roi_contrast_surround[1,1]= first_field['uv_surround amplitude'].loc[first_field['roi'] == 23].values[0]\n",
    "\n",
    "roi_contrast_center[2,0]= first_field['green_center amplitude'].loc[first_field['roi'] == 13].values[0]\n",
    "roi_contrast_center[2,1]= first_field['uv_center amplitude'].loc[first_field['roi'] == 13].values[0]\n",
    "roi_contrast_ring[2,0]= first_field['green_ring amplitude'].loc[first_field['roi'] == 13].values[0]\n",
    "roi_contrast_ring[2,1]= first_field['uv_ring amplitude'].loc[first_field['roi'] == 13].values[0]\n",
    "roi_contrast_surround[2,0]= first_field['green_surround amplitude'].loc[first_field['roi'] == 13].values[0]\n",
    "roi_contrast_surround[2,1]= first_field['uv_surround amplitude'].loc[first_field['roi'] == 13].values[0]\n",
    "\n",
    "roi_contrast_center[3,0]= second_field['green_center amplitude'].loc[second_field['roi'] == 8].values[0] \n",
    "roi_contrast_center[3,1]= second_field['uv_center amplitude'].loc[second_field['roi'] == 8].values[0]\n",
    "roi_contrast_ring[3,0]= second_field['green_ring amplitude'].loc[second_field['roi'] == 8].values[0]\n",
    "roi_contrast_ring[3,1]= second_field['uv_ring amplitude'].loc[second_field['roi'] == 8].values[0]\n",
    "roi_contrast_surround[3,0]= second_field['green_surround amplitude'].loc[second_field['roi'] == 8].values[0]\n",
    "roi_contrast_surround[3,1]= second_field['uv_surround amplitude'].loc[second_field['roi'] == 8].values[0]\n",
    "\n",
    "roi_contrast_center[4,0]= second_field['green_center amplitude'].loc[second_field['roi'] == 5].values[0] \n",
    "roi_contrast_center[4,1]= second_field['uv_center amplitude'].loc[second_field['roi'] == 5].values[0]\n",
    "roi_contrast_ring[4,0]= second_field['green_ring amplitude'].loc[second_field['roi'] == 5].values[0]\n",
    "roi_contrast_ring[4,1]= second_field['uv_ring amplitude'].loc[second_field['roi'] == 5].values[0]\n",
    "roi_contrast_surround[4,0]= second_field['green_surround amplitude'].loc[second_field['roi'] == 5].values[0]\n",
    "roi_contrast_surround[4,1]= second_field['uv_surround amplitude'].loc[second_field['roi'] == 5].values[0]\n",
    "\n",
    "roi_contrast_center[5,0]= second_field['green_center amplitude'].loc[second_field['roi'] == 17].values[0] \n",
    "roi_contrast_center[5,1]= second_field['uv_center amplitude'].loc[second_field['roi'] == 17].values[0]\n",
    "roi_contrast_ring[5,0]= second_field['green_ring amplitude'].loc[second_field['roi'] == 17].values[0]\n",
    "roi_contrast_ring[5,1]= second_field['uv_ring amplitude'].loc[second_field['roi'] == 17].values[0]\n",
    "roi_contrast_surround[5,0]= second_field['green_surround amplitude'].loc[second_field['roi'] == 17].values[0]\n",
    "roi_contrast_surround[5,1]= second_field['uv_surround amplitude'].loc[second_field['roi'] == 17].values[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "# PLOT ROI KERNELS\n",
    "roi_1 = first_field[first_field['roi'] == 14]\n",
    "roi_2 = first_field[first_field['roi'] == 23] \n",
    "roi_3 = first_field[first_field['roi'] == 13]\n",
    "roi_4 = second_field[second_field['roi'] == 8]\n",
    "roi_5 = second_field[second_field['roi'] == 5]\n",
    "roi_6 = second_field[second_field['roi'] == 17]\n",
    "\n",
    "fig1, ax = plt.subplots(6,6, figsize=(15,15),sharey = 'row')\n",
    "bar_colors = ['purple', 'darkgreen']\n",
    "\n",
    "ax[0,0].plot(roi_1['uv_center'].values[0], 'purple')\n",
    "ax[0,0].plot(roi_1['green_center'].values[0], 'darkgreen') \n",
    "ax[0,1].bar([0,1],[roi_contrast_center[0,1], roi_contrast_center[0,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[0,2].plot(roi_1['uv_ring'].values[0], 'purple') \n",
    "ax[0,2].plot(roi_1['green_ring'].values[0], 'darkgreen') \n",
    "ax[0,3].bar([0,1],[roi_contrast_ring[0,1], roi_contrast_ring[0,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[0,4].plot(roi_1['uv_surround'].values[0],'purple') \n",
    "ax[0,4].plot(roi_1['green_surround'].values[0], 'darkgreen') \n",
    "ax[0,5].bar([0,1],[roi_contrast_surround[0,1], roi_contrast_surround[0,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[0,0].vlines(x=[0], ymin=[0], ymax=[0.1], lw=1) #scale bar 0.1\n",
    "\n",
    "ax[1,0].plot(roi_2['uv_center'].values[0], 'purple')\n",
    "ax[1,0].plot(roi_2['green_center'].values[0], 'darkgreen') \n",
    "ax[1,1].bar([0,1],[roi_contrast_center[1,1], roi_contrast_center[1,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[1,2].plot(roi_2['uv_ring'].values[0], 'purple') \n",
    "ax[1,2].plot(roi_2['green_ring'].values[0], 'darkgreen') \n",
    "ax[1,3].bar([0,1],[roi_contrast_ring[1,1], roi_contrast_ring[1,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[1,4].plot(roi_2['uv_surround'].values[0],'purple') \n",
    "ax[1,4].plot(roi_2['green_surround'].values[0], 'darkgreen') \n",
    "ax[1,5].bar([0,1],[roi_contrast_surround[1,1], roi_contrast_surround[1,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[1,0].vlines(x=[0], ymin=[0], ymax=[0.1], lw=1) #scale bar 0.1\n",
    "\n",
    "ax[2,0].plot(roi_3['uv_center'].values[0], 'purple')\n",
    "ax[2,0].plot(roi_3['green_center'].values[0], 'darkgreen') \n",
    "ax[2,1].bar([0,1],[roi_contrast_center[2,1], roi_contrast_center[2,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[2,2].plot(roi_3['uv_ring'].values[0], 'purple') \n",
    "ax[2,2].plot(roi_3['green_ring'].values[0], 'darkgreen') \n",
    "ax[2,3].bar([0,1],[roi_contrast_ring[2,1], roi_contrast_ring[2,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[2,4].plot(roi_3['uv_surround'].values[0],'purple') \n",
    "ax[2,4].plot(roi_3['green_surround'].values[0], 'darkgreen') \n",
    "ax[2,5].bar([0,1],[roi_contrast_surround[2,1], roi_contrast_surround[2,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[2,0].vlines(x=[0], ymin=[0], ymax=[0.1], lw=1) #scale bar 0.1\n",
    "\n",
    "ax[3,0].plot(roi_4['uv_center'].values[0], 'purple')\n",
    "ax[3,0].plot(roi_4['green_center'].values[0], 'darkgreen') \n",
    "ax[3,1].bar([0,1],[roi_contrast_center[3,1], roi_contrast_center[3,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[3,2].plot(roi_4['uv_ring'].values[0], 'purple') \n",
    "ax[3,2].plot(roi_4['green_ring'].values[0], 'darkgreen') \n",
    "ax[3,3].bar([0,1],[roi_contrast_ring[3,1], roi_contrast_ring[3,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[3,4].plot(roi_4['uv_surround'].values[0],'purple') \n",
    "ax[3,4].plot(roi_4['green_surround'].values[0], 'darkgreen') \n",
    "ax[3,5].bar([0,1],[roi_contrast_surround[3,1], roi_contrast_surround[3,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[3,0].vlines(x=[0], ymin=[0], ymax=[0.1], lw=1) #scale bar 0.1\n",
    "\n",
    "ax[4,0].plot(roi_5['uv_center'].values[0], 'purple')\n",
    "ax[4,0].plot(roi_5['green_center'].values[0], 'darkgreen') \n",
    "ax[4,1].bar([0,1],[roi_contrast_center[4,1], roi_contrast_center[4,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[4,2].plot(roi_5['uv_ring'].values[0], 'purple') \n",
    "ax[4,2].plot(roi_5['green_ring'].values[0], 'darkgreen') \n",
    "ax[4,3].bar([0,1],[roi_contrast_ring[4,1], roi_contrast_ring[4,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[4,4].plot(roi_5['uv_surround'].values[0],'purple') \n",
    "ax[4,4].plot(roi_5['green_surround'].values[0], 'darkgreen') \n",
    "ax[4,5].bar([0,1],[roi_contrast_surround[4,1], roi_contrast_surround[4,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[4,0].vlines(x=[0], ymin=[0], ymax=[0.1], lw=1) #scale bar 0.1\n",
    "\n",
    "ax[5,0].plot(roi_6['uv_center'].values[0], 'purple')\n",
    "ax[5,0].plot(roi_6['green_center'].values[0], 'darkgreen') \n",
    "ax[5,1].bar([0,1],[roi_contrast_center[5,1], roi_contrast_center[5,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[5,2].plot(roi_6['uv_ring'].values[0], 'purple') \n",
    "ax[5,2].plot(roi_6['green_ring'].values[0], 'darkgreen') \n",
    "ax[5,3].bar([0,1],[roi_contrast_ring[5,1], roi_contrast_ring[5,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[5,4].plot(roi_6['uv_surround'].values[0],'purple') \n",
    "ax[5,4].plot(roi_6['green_surround'].values[0], 'darkgreen') \n",
    "ax[5,5].bar([0,1],[roi_contrast_surround[5,1], roi_contrast_surround[5,0]],align='center', capsize=2, width=1, color=bar_colors)\n",
    "ax[5,0].vlines(x=[0], ymin=[0], ymax=[0.1], lw=1) #scale bar 0.1\n",
    "\n",
    "ax[0,0].axvline(x= 938, color='grey', linestyle='--')\n",
    "ax[0,2].axvline(x= 938, color='grey', linestyle='--')\n",
    "ax[0,4].axvline(x= 938, color='grey', linestyle='--')\n",
    "ax[0,0].axhline(y= 0, color='grey', linestyle='--')\n",
    "ax[1,0].axhline(y= 0, color='grey', linestyle='--')\n",
    "ax[2,0].axhline(y= 0, color='grey', linestyle='--')\n",
    "ax[3,0].axhline(y= 0, color='grey', linestyle='--')\n",
    "ax[4,0].axhline(y= 0, color='grey', linestyle='--')\n",
    "ax[5,0].axhline(y= 0, color='grey', linestyle='--')\n",
    "\n",
    "sns.despine()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "# SC CIRCLES\n",
    "nb_rois = 6\n",
    "roi_spectral_contrast = np.zeros(((nb_rois),3))\n",
    "\n",
    "roi_spectral_contrast[0,0] = first_field['spectral_contrast_surround'].loc[first_field['roi'] == 14].values[0] \n",
    "roi_spectral_contrast[0,1] = first_field['spectral_contrast_ring'].loc[first_field['roi'] == 14].values[0]\n",
    "roi_spectral_contrast[0,2] = first_field['spectral_contrast_center'].loc[first_field['roi'] == 14].values[0]\n",
    "\n",
    "roi_spectral_contrast[1,0] = first_field['spectral_contrast_surround'].loc[first_field['roi'] == 23].values[0] \n",
    "roi_spectral_contrast[1,1] = first_field['spectral_contrast_ring'].loc[first_field['roi'] == 23].values[0]\n",
    "roi_spectral_contrast[1,2] = first_field['spectral_contrast_center'].loc[first_field['roi'] == 23].values[0]\n",
    "\n",
    "roi_spectral_contrast[2,0] = first_field['spectral_contrast_surround'].loc[first_field['roi'] == 13].values[0]\n",
    "roi_spectral_contrast[2,1] = first_field['spectral_contrast_ring'].loc[first_field['roi'] == 13].values[0]\n",
    "roi_spectral_contrast[2,2] = first_field['spectral_contrast_center'].loc[first_field['roi'] == 13].values[0]\n",
    "\n",
    "roi_spectral_contrast[3,0] = second_field['spectral_contrast_surround'].loc[second_field['roi'] == 8].values[0]\n",
    "roi_spectral_contrast[3,1] = second_field['spectral_contrast_ring'].loc[second_field['roi'] == 8].values[0]\n",
    "roi_spectral_contrast[3,2] = second_field['spectral_contrast_center'].loc[second_field['roi'] == 8].values[0]\n",
    "\n",
    "roi_spectral_contrast[4,0] = second_field['spectral_contrast_surround'].loc[second_field['roi'] == 5].values[0]\n",
    "roi_spectral_contrast[4,1] = second_field['spectral_contrast_ring'].loc[second_field['roi'] == 5].values[0]\n",
    "roi_spectral_contrast[4,2] = second_field['spectral_contrast_center'].loc[second_field['roi'] == 5].values[0]\n",
    "\n",
    "roi_spectral_contrast[5,0] = second_field['spectral_contrast_surround'].loc[second_field['roi'] == 17].values[0]\n",
    "roi_spectral_contrast[5,1] = second_field['spectral_contrast_ring'].loc[second_field['roi'] == 17].values[0]\n",
    "roi_spectral_contrast[5,2] = second_field['spectral_contrast_center'].loc[second_field['roi'] == 17].values[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import cm\n",
    "from matplotlib.colors import ListedColormap, LinearSegmentedColormap\n",
    "PiYG = cm.get_cmap('PiYG', 256)\n",
    "new_colors = PiYG(np.linspace(0, 1, 256))\n",
    "white = np.array([1, 1, 1, 1])\n",
    "\n",
    "# center1 = np.amax((first_field['spectral_contrast_center'].max(), np.abs(first_field['spectral_contrast_center'].min())))\n",
    "# ring1 = np.amax((first_field['spectral_contrast_ring'].max(), np.abs(first_field['spectral_contrast_ring'].min())))\n",
    "# surround1 = np.amax((first_field['spectral_contrast_surround'].max(), np.abs(first_field['spectral_contrast_surround'].min())))\n",
    "# center2 = np.amax((second_field['spectral_contrast_center'].max(), np.abs(second_field['spectral_contrast_center'].min())))\n",
    "# ring2 = np.amax((second_field['spectral_contrast_ring'].max(), np.abs(second_field['spectral_contrast_ring'].min())))\n",
    "# surround2 = np.amax((second_field['spectral_contrast_surround'].max(), np.abs(second_field['spectral_contrast_surround'].min())))\n",
    "\n",
    "limit = 0.5 #np.amax((center1, ring1, surround1, center2, ring2, surround2))\n",
    "\n",
    "vmax = limit \n",
    "vmin = -limit\n",
    "white_start = 1/4* limit  \n",
    "white_stop = -1/4* limit \n",
    "\n",
    "total_range = vmax-vmin\n",
    "len_color = new_colors.shape[0]\n",
    "white_stop_percentage = np.abs(vmin-white_stop)*100/total_range\n",
    "white_stop_pixel = int(white_stop_percentage*len_color*0.01)\n",
    "white_start_percentage = np.abs(vmax-white_start)*100/total_range\n",
    "white_start_pixel = int(white_start_percentage*len_color*0.01)\n",
    "\n",
    "new_colors[white_stop_pixel:-white_start_pixel,:] = white\n",
    "newcmp = ListedColormap(new_colors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.collections import PatchCollection\n",
    "from matplotlib.patches import Circle\n",
    "\n",
    "def plot_colorCircles(color_pref, cmap, limit = None, colorbar = False, savefig = False, path = None):\n",
    "    center = [0.5, 0.5]\n",
    "    radii = [0.3, 0.2, 0.1]\n",
    "    nb_rois = 6\n",
    "    if limit is None:\n",
    "        limit = np.amax((np.amax(color_pref), np.abs(np.amin(color_pref))))\n",
    "    fig, ax = plt.subplots(nb_rois, 1, figsize=(1.3,nb_rois*1.5))\n",
    "    for k in range(nb_rois):\n",
    "        patches = []\n",
    "        colors = color_pref[k,:]\n",
    "        for r in radii: \n",
    "            circle = Circle((center), r)\n",
    "            patches.append(circle)\n",
    "        p = PatchCollection(patches, cmap = cmap, edgecolor = 'black', linewidth = 0.5)\n",
    "        p.set_array(colors)\n",
    "        p.set_clim([-limit, limit])\n",
    "        ax[k].set_axis_off()\n",
    "        ax[k].add_collection(p)\n",
    "    if colorbar:\n",
    "        fig.colorbar(p, ax=ax[-1])\n",
    "    if savefig:\n",
    "        plt.savefig(path, dpi = 600)\n",
    "    plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_colorCircles(roi_spectral_contrast, newcmp, limit=0.5, colorbar = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# OOi CIRCLES\n",
    "nb_rois = 6\n",
    "roi_ooindex = np.zeros(((nb_rois),3))\n",
    "\n",
    "roi_ooindex[0,0] = first_field['OOi_surround'].loc[first_field['roi'] == 14].values[0]\n",
    "roi_ooindex[0,1] = first_field['OOi_ring'].loc[first_field['roi'] == 14].values[0]\n",
    "roi_ooindex[0,2] = first_field['OOi_center'].loc[first_field['roi'] == 14].values[0]\n",
    "\n",
    "roi_ooindex[1,0] = first_field['OOi_surround'].loc[first_field['roi'] == 23].values[0]\n",
    "roi_ooindex[1,1] = first_field['OOi_ring'].loc[first_field['roi'] == 23].values[0]\n",
    "roi_ooindex[1,2] = first_field['OOi_center'].loc[first_field['roi'] == 23].values[0]\n",
    "\n",
    "roi_ooindex[2,0] = first_field['OOi_surround'].loc[first_field['roi'] == 13].values[0]\n",
    "roi_ooindex[2,1] = first_field['OOi_ring'].loc[first_field['roi'] == 13].values[0]\n",
    "roi_ooindex[2,2] = first_field['OOi_center'].loc[first_field['roi'] == 13].values[0]\n",
    "\n",
    "roi_ooindex[3,0] = second_field['OOi_surround'].loc[second_field['roi'] == 8].values[0]\n",
    "roi_ooindex[3,1] = second_field['OOi_ring'].loc[second_field['roi'] == 8].values[0]\n",
    "roi_ooindex[3,2] = second_field['OOi_center'].loc[second_field['roi'] == 8].values[0]\n",
    "\n",
    "roi_ooindex[4,0] = second_field['OOi_surround'].loc[second_field['roi'] == 5].values[0]\n",
    "roi_ooindex[4,1] = second_field['OOi_ring'].loc[second_field['roi'] == 5].values[0]\n",
    "roi_ooindex[4,2] = second_field['OOi_center'].loc[second_field['roi'] == 5].values[0]\n",
    "\n",
    "roi_ooindex[5,0] = second_field['OOi_surround'].loc[second_field['roi'] == 17].values[0]\n",
    "roi_ooindex[5,1] = second_field['OOi_ring'].loc[second_field['roi'] == 17].values[0]\n",
    "roi_ooindex[5,2] = second_field['OOi_center'].loc[second_field['roi'] == 17].values[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import cm\n",
    "from matplotlib.colors import ListedColormap, LinearSegmentedColormap\n",
    "PiYG = cm.get_cmap('Greys_r', 256)\n",
    "new_colors = PiYG(np.linspace(0, 1, 256))\n",
    "white = np.array([0.5, 0.5, 0.5, 0.5])\n",
    "\n",
    "center1 = np.amax((first_field['OOi_center'].max(), np.abs(first_field['OOi_center'].min())))\n",
    "ring1 = np.amax((first_field['OOi_ring'].max(), np.abs(first_field['OOi_ring'].min())))\n",
    "surround1 = np.amax((first_field['OOi_surround'].max(), np.abs(first_field['OOi_surround'].min())))\n",
    "center2 = np.amax((second_field['OOi_center'].max(), np.abs(second_field['OOi_center'].min())))\n",
    "ring2 = np.amax((second_field['OOi_ring'].max(), np.abs(second_field['OOi_ring'].min())))\n",
    "surround2 = np.amax((second_field['OOi_surround'].max(), np.abs(second_field['OOi_surround'].min())))\n",
    "\n",
    "limit = np.amax((center1, ring1, surround1, center2, ring2, surround2))\n",
    "\n",
    "vmax = limit \n",
    "vmin = -limit\n",
    "white_start = 1/4* limit  \n",
    "white_stop = -1/4* limit \n",
    "\n",
    "total_range = vmax-vmin\n",
    "len_color = new_colors.shape[0]\n",
    "white_stop_percentage = np.abs(vmin-white_stop)*100/total_range\n",
    "white_stop_pixel = int(white_stop_percentage*len_color*0.01)\n",
    "white_start_percentage = np.abs(vmax-white_start)*100/total_range\n",
    "white_start_pixel = int(white_start_percentage*len_color*0.01)\n",
    "\n",
    "new_colors[white_stop_pixel:-white_start_pixel,:] = white\n",
    "newcmp = ListedColormap(new_colors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.collections import PatchCollection\n",
    "from matplotlib.patches import Circle\n",
    "\n",
    "def plot_polarityCircles(color_pref, cmap, limit = None, colorbar = False, savefig = False, path = None):\n",
    "    center = [0.5, 0.5]\n",
    "    radii = [0.3, 0.2, 0.1]\n",
    "    nb_rois = 6\n",
    "    if limit is None:\n",
    "        limit = np.amax((np.amax(color_pref), np.abs(np.amin(color_pref))))\n",
    "    fig, ax = plt.subplots(nb_rois, 1, figsize=(1.3,nb_rois*1.5))\n",
    "    for k in range(nb_rois):\n",
    "        patches = []\n",
    "        colors = color_pref[k,:]\n",
    "        for r in radii: \n",
    "            circle = Circle((center), r)\n",
    "            patches.append(circle)\n",
    "        p = PatchCollection(patches, cmap = cmap, edgecolor = 'black', linewidth = 0.5) #'Greys_r'\n",
    "        p.set_array(colors)\n",
    "        p.set_clim([-limit, limit])\n",
    "        ax[k].set_axis_off()\n",
    "        ax[k].add_collection(p)\n",
    "    if colorbar:\n",
    "        fig.colorbar(p, ax=ax[-1])\n",
    "    if savefig:\n",
    "        plt.savefig(path, dpi = 600)\n",
    "    plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_polarityCircles(roi_ooindex, newcmp, limit=limit, colorbar = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "# FOR ONE FIELD: PLOT SC ON THE FIELD'S ROI MASK \n",
    "\n",
    "import matplotlib.cm as cm\n",
    "cmap = cm.get_cmap('PiYG')\n",
    "cmap.set_bad(color='grey')\n",
    "\n",
    "ipl_mask = (Field.RoiMask() & first_field).fetch1('roi_mask')\n",
    "\n",
    "SCmap_center = ipl_mask*(-1)\n",
    "SCmap_ring = ipl_mask*(-1)\n",
    "SCmap_surround = ipl_mask*(-1)\n",
    "\n",
    "xlength = SCmap_center.shape[0]\n",
    "zlength = SCmap_center.shape[1]\n",
    "\n",
    "roi_list = first_field['roi'].values\n",
    "\n",
    "for x in range(xlength):\n",
    "    for z in range(zlength):  \n",
    "        if SCmap_center[x,z]== 0:\n",
    "            SCmap_center[x,z] = np.nan\n",
    "        elif SCmap_center[x,z] in roi_list:\n",
    "            roi_index = np.int(SCmap_center[x, z])\n",
    "            SCmap_center[x,z] = first_field['spectral_contrast_center'].loc[first_field['roi'] == roi_index].values[0]\n",
    "        else:\n",
    "            SCmap_center[x,z] = np.nan\n",
    "            \n",
    "        if SCmap_ring[x,z]== 0:\n",
    "            SCmap_ring[x,z] = np.nan\n",
    "        elif SCmap_ring[x,z] in roi_list:\n",
    "            roi_index = np.int(SCmap_ring[x, z])\n",
    "            SCmap_ring[x,z] = first_field['spectral_contrast_ring'].loc[first_field['roi'] == roi_index].values[0]\n",
    "        else:\n",
    "            SCmap_ring[x,z] = np.nan\n",
    "            \n",
    "        if SCmap_surround[x,z]== 0:\n",
    "            SCmap_surround[x,z] = np.nan\n",
    "        elif SCmap_surround[x,z] in roi_list:\n",
    "            roi_index = np.int(SCmap_surround[x, z])\n",
    "            SCmap_surround[x,z] = first_field['spectral_contrast_surround'].loc[first_field['roi'] == roi_index].values[0]\n",
    "        else:\n",
    "            SCmap_surround[x,z] = np.nan\n",
    "            \n",
    "# masked_array_center = np.ma.array (SCmap_center, mask=np.isnan(SCmap_center))\n",
    "# masked_array_ring = np.ma.array (SCmap_ring, mask=np.isnan(SCmap_ring))\n",
    "# masked_array_surround = np.ma.array (SCmap_surround, mask=np.isnan(SCmap_surround))\n",
    "\n",
    "fig, ax = plt.subplots(1,3, figsize=(15,15))\n",
    "\n",
    "img= ax[0].imshow(SCmap_center.T, cmap=newcmp, vmin= -limit, vmax= limit, origin='lower') \n",
    "ax[1].imshow(SCmap_ring.T, cmap=newcmp, vmin= -limit, vmax= limit, origin='lower') \n",
    "ax[2].imshow(SCmap_surround.T, cmap=newcmp, vmin= -limit, vmax= limit, origin='lower') "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Chromatic preference - Figure 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the x retinal position of each ROI\n",
    "key = {\n",
    "                'experimenter': 'Korympidou',\n",
    "        }\n",
    "x = pd.DataFrame.from_dict(((RelativeFieldLocation()) &\n",
    "\n",
    "                                    key).fetch(as_dict=True))\n",
    "\n",
    "df['Rel. field location x'] = ''\n",
    "for index, row in df.iterrows():\n",
    "    \n",
    "    df.at[index, 'Rel. field location x'] = x['relx'][(x['date'] == row['date']) & (x['exp_num'] == row['exp_num']) & (x['field_id'] == row['field_id'])].values[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# discard ROIs that have wrong y location assignment\n",
    "df= df.loc[df['y_retinal_location'] != 'no_info']\n",
    "df = df.reset_index(drop = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "location_y = df['y_retinal_location']\n",
    "location_x = df['Rel. field location x']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set color map\n",
    "from matplotlib import cm\n",
    "from matplotlib.colors import ListedColormap, LinearSegmentedColormap\n",
    "PiYG = cm.get_cmap('PiYG', 256)\n",
    "new_colors = PiYG(np.linspace(0, 1, 256))\n",
    "white = np.array([1, 1, 1, 1])\n",
    "\n",
    "limit = 0.75 \n",
    "\n",
    "vmax = limit \n",
    "vmin = -limit\n",
    "white_start = 1/4* limit \n",
    "white_stop = -1/4* limit \n",
    "\n",
    "total_range = vmax-vmin\n",
    "len_color = new_colors.shape[0]\n",
    "white_stop_percentage = np.abs(vmin-white_stop)*100/total_range\n",
    "white_stop_pixel = int(white_stop_percentage*len_color*0.01)\n",
    "white_start_percentage = np.abs(vmax-white_start)*100/total_range\n",
    "white_start_pixel = int(white_start_percentage*len_color*0.01)\n",
    "\n",
    "new_colors[white_stop_pixel:-white_start_pixel,:] = white\n",
    "newcmp = ListedColormap(new_colors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add a random value in x & y for each ROI to distance them from one another in their field region\n",
    "# value = 700 \n",
    "# random_addition_1 = np.random.uniform(low=-value, high=value, size=len(df))\n",
    "# random_addition_2 = np.random.uniform(low=-value, high=value, size=len(df))\n",
    "\n",
    "random_addition_1 = np.load('random_array_xposition.npy')\n",
    "random_addition_2 = np.load('random_array_yposition.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['Rel. field location x + random'] = df['Rel. field location x'] + random_addition_1/1.7 \n",
    "df['Rel. field location y + random'] = df['Rel. field location y'] + random_addition_2/1.7 \n",
    "\n",
    "fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(20,6))\n",
    "sc=ax1.scatter(x=df['Rel. field location x + random'].values, \n",
    "                y= df['Rel. field location y + random'].values, \n",
    "                c= df['spectral_contrast_center'].values,  \n",
    "                cmap=newcmp, vmax = limit, vmin= -limit, s=15) \n",
    "\n",
    "ax2.scatter(x=df['Rel. field location x + random'].values, \n",
    "                y= df['Rel. field location y + random'].values, \n",
    "                c= df['spectral_contrast_ring'].values, \n",
    "                cmap=newcmp, vmax = limit, vmin= -limit, s=15)\n",
    "\n",
    "ax3.scatter(x=df['Rel. field location x + random'].values, \n",
    "                y= df['Rel. field location y + random'].values, \n",
    "                c= df['spectral_contrast_surround'].loc[df['spectral_contrast_surround']<-0.2].values, \n",
    "                cmap=newcmp, vmax=limit, vmin= -limit, s=15)\n",
    "# plt.colorbar(sc)\n",
    "lims = 3000\n",
    "circle1=plt.Circle((0,0),lims, color='black', fill=False)\n",
    "circle2=plt.Circle((0,0),lims, color='black', fill=False)\n",
    "circle3=plt.Circle((0,0),lims, color='black', fill=False)\n",
    "circle4=plt.Circle((0,0),50, color='black', fill=True)\n",
    "circle5=plt.Circle((0,0),50, color='black', fill=True)\n",
    "circle6=plt.Circle((0,0),50, color='black', fill=True)\n",
    "ax1.set_xlim(-lims,lims)\n",
    "ax1.set_ylim(-lims,lims)\n",
    "ax2.set_xlim(-lims,lims)\n",
    "ax2.set_ylim(-lims,lims)\n",
    "ax3.set_xlim(-lims,lims)\n",
    "ax3.set_ylim(-lims,lims)\n",
    "ax1.axvline(x=0, ymin=-lims, ymax=lims)\n",
    "ax2.axvline(x=0, ymin=-lims, ymax=lims)\n",
    "ax3.axvline(x=0, ymin=-lims, ymax=lims)\n",
    "ax1.axhline(y=0, xmin=-lims, xmax=lims)\n",
    "ax2.axhline(y=0, xmin=-lims, xmax=lims)\n",
    "ax3.axhline(y=0, xmin=-lims, xmax=lims)\n",
    "ax1.add_patch(circle1)\n",
    "ax2.add_patch(circle2)\n",
    "ax3.add_patch(circle3)\n",
    "ax1.add_patch(circle4)\n",
    "ax2.add_patch(circle5)\n",
    "ax3.add_patch(circle6)\n",
    "ax1.axis('off')\n",
    "ax2.axis('off')\n",
    "ax3.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make bins\n",
    "# y_new = (df['Rel. field location y'].unique())\n",
    "# y_new.sort()\n",
    "y_new= df['Rel. field location y']\n",
    "bins = np.array([2192,1347.1,829.6,195.28,-633.5,-1058.4,-1340.9,-1829])\n",
    "x=np.digitize(y_new,bins)\n",
    "df['bin_id'] = x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make bins - BCs\n",
    "\n",
    "# BC_new = (df_bc['Rel. field location y'].unique())\n",
    "# BC_new.sort()\n",
    "BC_new= df_bc['Rel. field location y']\n",
    "BC_bins = np.array([1857,1262.,785.18,344.,-1438.6,-1690.8,-1797.62,-1890])\n",
    "a=np.digitize(BC_new,BC_bins)\n",
    "df_bc['bin_id'] = a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_bins = 7\n",
    "fig, ax = plt.subplots(nb_bins,2, figsize=(5,10), sharex='all')\n",
    "\n",
    "for i in range(0,nb_bins):\n",
    "    sns.kdeplot(x='spectral_contrast_ring',\n",
    "            data= df.loc[df['bin_id'] == i+1],linestyle='dashed', color= 'black', ax=ax[i,0])\n",
    "    sns.kdeplot(x='spectral_contrast_center',\n",
    "                data= df.loc[df['bin_id'] == i+1],linestyle= 'solid', color= 'black', ax=ax[i,0])\n",
    "    sns.kdeplot(x='spectral_contrast_center',\n",
    "                data= df_bc.loc[df_bc['bin_id'] == i+1],linestyle= 'solid', color= 'blue', ax=ax[i,0])\n",
    "    \n",
    "    sns.kdeplot(x='spectral_contrast_surround',\n",
    "                data= df.loc[df['bin_id'] == i+1],linestyle='solid', color= 'black', ax=ax[i,1])\n",
    "    sns.kdeplot(x='spectral_contrast_surround',\n",
    "                data= df_bc.loc[df_bc['bin_id'] == i+1],linestyle='solid', color= 'blue', ax=ax[i,1])\n",
    "    ax[6,0].axvline(0, color = 'black', linestyle= '--')\n",
    "    ax[6,1].axvline(0, color = 'black', linestyle= '--')\n",
    "    plt.xlim(-1.5,1.5)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Amacrine cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Heatmap of average SC per bin\n",
    "\n",
    "bin_id= df['bin_id']\n",
    "nb_bins = 7\n",
    "\n",
    "mean_bin_contrast= np.zeros(((nb_bins),3))\n",
    "\n",
    "for i in range(0,nb_bins):\n",
    "    \n",
    "    mean_bin_contrast[i,0]= df['spectral_contrast_center'][np.where(bin_id == i+1)[0]].mean()\n",
    "    mean_bin_contrast[i,1]= df['spectral_contrast_ring'][np.where(bin_id == i+1)[0]].mean()\n",
    "    mean_bin_contrast[i,2]= df['spectral_contrast_surround'][np.where(bin_id == i+1)[0]].mean()\n",
    "sns.heatmap(mean_bin_contrast,cmap=newcmp,vmin=-limit,vmax=limit)\n",
    "\n",
    "plt.xticks([0.5, 1.5, 2.5], ['Center', 'Ring', 'Surround'])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set color map\n",
    "from matplotlib import cm\n",
    "from matplotlib.colors import ListedColormap, LinearSegmentedColormap\n",
    "PiYG = cm.get_cmap('PiYG', 256)\n",
    "new_colors = PiYG(np.linspace(0, 1, 256))\n",
    "white = np.array([1, 1, 1, 1])\n",
    "\n",
    "limit = np.amax((mean_bin_contrast.max(), np.abs(mean_bin_contrast.min())))\n",
    "\n",
    "vmax = limit # These values change if you work w drug data\n",
    "vmin = -limit\n",
    "white_start = 1/4* limit  #0.09 # Choose yourself as you like and keep stable for all further analysis\n",
    "white_stop = -1/4* limit \n",
    "\n",
    "# To explain white borders take 1/4 of the max SC value across clusters and spatial conditions\n",
    "# np.amax((np.abs(mean_cluster_spectral_contrast.min()), np.abs(mean_cluster_spectral_contrast.max())))\n",
    "\n",
    "total_range = vmax-vmin\n",
    "len_color = new_colors.shape[0]\n",
    "white_stop_percentage = np.abs(vmin-white_stop)*100/total_range\n",
    "white_stop_pixel = int(white_stop_percentage*len_color*0.01)\n",
    "white_start_percentage = np.abs(vmax-white_start)*100/total_range#*2\n",
    "white_start_pixel = int(white_start_percentage*len_color*0.01)\n",
    "\n",
    "new_colors[white_stop_pixel:-white_start_pixel,:] = white\n",
    "newcmp = ListedColormap(new_colors)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bipolar cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Heatmap of average SC per bin - BIPOLAR CELLS\n",
    "\n",
    "bin_id= df_bc['bin_id']\n",
    "nb_bins = 7\n",
    "\n",
    "bc_mean_bin_contrast= np.zeros(((nb_bins),2))\n",
    "\n",
    "for i in range(0,nb_bins):\n",
    "    \n",
    "    bc_mean_bin_contrast[i,0]= df_bc['spectral_contrast_center'][np.where(bin_id == i+1)[0]].mean()\n",
    "    bc_mean_bin_contrast[i,1]= df_bc['spectral_contrast_surround'][np.where(bin_id == i+1)[0]].mean()\n",
    "\n",
    "sns.heatmap(bc_mean_bin_contrast,cmap=newcmp,vmin=-limit,vmax=limit)\n",
    "\n",
    "plt.xticks([0.5, 1.5], ['Center', 'Surround'])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make bins across the IPL\n",
    "\n",
    "depth_new= df['ipl_depth']\n",
    "bins = np.array([0,0.2,0.4,0.6,0.8,1.01])\n",
    "# bins = np.array([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.01])\n",
    "x=np.digitize(depth_new,bins)\n",
    "df['depth_id'] = x\n",
    "\n",
    "bc_depth_new= df_bc['ipl_depth']\n",
    "bc_bins = np.array([0,0.2,0.4,0.6,0.8,1.01])\n",
    "a=np.digitize(bc_depth_new,bc_bins)\n",
    "df_bc['depth_id'] = a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_dorsal= df.loc[df['y_retinal_location'] == 'dorsal']\n",
    "df_ventral= df.loc[df['y_retinal_location'] == 'ventral']\n",
    "\n",
    "bc_df_dorsal= df_bc.loc[df_bc['Rel. field location y'] > 0]\n",
    "bc_df_ventral= df_bc.loc[df_bc['Rel. field location y'] < 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_bins = 5\n",
    "\n",
    "mean_bin_contrast_d= np.zeros(((nb_bins),3))\n",
    "mean_bin_contrast_v= np.zeros(((nb_bins),3))\n",
    "bc_mean_bin_contrast_d= np.zeros(((nb_bins),2))\n",
    "bc_mean_bin_contrast_v= np.zeros(((nb_bins),2))\n",
    "\n",
    "for i in range(0,nb_bins):\n",
    "    \n",
    "    mean_bin_contrast_d[i,0]= df_dorsal['spectral_contrast_center'].loc[df_dorsal['depth_id'] == i+1].values.mean()\n",
    "    mean_bin_contrast_d[i,1]= df_dorsal['spectral_contrast_ring'].loc[df_dorsal['depth_id'] == i+1].values.mean()\n",
    "    mean_bin_contrast_d[i,2]= df_dorsal['spectral_contrast_surround'].loc[df_dorsal['depth_id'] == i+1].values.mean()\n",
    "\n",
    "    bc_mean_bin_contrast_d[i,0]= bc_df_dorsal['spectral_contrast_center'].loc[bc_df_dorsal['depth_id'] == i+1].values.mean()\n",
    "    bc_mean_bin_contrast_d[i,1]= bc_df_dorsal['spectral_contrast_surround'].loc[bc_df_dorsal['depth_id'] == i+1].values.mean()\n",
    "    \n",
    "    mean_bin_contrast_v[i,0]= df_ventral['spectral_contrast_center'].loc[df_ventral['depth_id'] == i+1].values.mean()\n",
    "    mean_bin_contrast_v[i,1]= df_ventral['spectral_contrast_ring'].loc[df_ventral['depth_id'] == i+1].values.mean()\n",
    "    mean_bin_contrast_v[i,2]= df_ventral['spectral_contrast_surround'].loc[df_ventral['depth_id'] == i+1].values.mean()\n",
    "\n",
    "    bc_mean_bin_contrast_v[i,0]= bc_df_ventral['spectral_contrast_center'].loc[bc_df_ventral['depth_id'] == i+1].values.mean()\n",
    "    bc_mean_bin_contrast_v[i,1]= bc_df_ventral['spectral_contrast_surround'].loc[bc_df_ventral['depth_id'] == i+1].values.mean() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(3,2, figsize=(8,15),sharex='all',sharey='all')\n",
    "\n",
    "sns.histplot(data=df_ventral, \n",
    "                x='spectral_contrast_center', \n",
    "                y= 'ipl_depth',ax=ax[0,0],bins=40,cmap='binary',vmax=5)\n",
    "sns.histplot(data=df_ventral, \n",
    "                x='spectral_contrast_ring', \n",
    "                y= 'ipl_depth',ax=ax[1,0],bins=40,cmap='binary',vmax=5)\n",
    "sns.histplot(data=df_ventral, \n",
    "                x='spectral_contrast_surround', \n",
    "                y= 'ipl_depth',ax=ax[2,0],bins=40,cmap='binary',vmax=5)\n",
    "ax[0,0].plot(mean_bin_contrast_v[:,0],[0.0,0.25,0.5,0.75,1.0],'-',linewidth=4,color= 'white')\n",
    "ax[1,0].plot(mean_bin_contrast_v[:,1],[0.0,0.25,0.5,0.75,1.0],'-',linewidth=4,color= 'white')\n",
    "ax[2,0].plot(mean_bin_contrast_v[:,2],[0.0,0.25,0.5,0.75,1.0],'-',linewidth=4,color= 'white')\n",
    "ax[0,0].plot(bc_mean_bin_contrast_v[:,0],[0.0,0.25,0.5,0.75,1.0],'-',linewidth=4,color= 'blue')\n",
    "ax[2,0].plot(bc_mean_bin_contrast_v[:,1],[0.0,0.25,0.5,0.75,1.0],'-',linewidth=4,color= 'blue')\n",
    "\n",
    "sns.histplot(data=df_dorsal, \n",
    "                x='spectral_contrast_center', \n",
    "                y= 'ipl_depth',ax=ax[0,1],bins=40,cmap='binary',vmax=5)\n",
    "sns.histplot(data=df_dorsal, \n",
    "                x='spectral_contrast_ring', \n",
    "                y= 'ipl_depth',ax=ax[1,1],bins=40,cmap='binary',vmax=5)\n",
    "sns.histplot(data=df_dorsal, \n",
    "                x='spectral_contrast_surround', \n",
    "                y= 'ipl_depth',ax=ax[2,1],bins=40,cmap='binary',vmax=5)\n",
    "\n",
    "ax[0,1].plot(mean_bin_contrast_d[:,0],[0.0,0.25,0.5,0.75,1.0],'-',linewidth=4,color= 'white')\n",
    "ax[1,1].plot(mean_bin_contrast_d[:,1],[0.0,0.25,0.5,0.75,1.0],'-',linewidth=4,color= 'white')\n",
    "ax[2,1].plot(mean_bin_contrast_d[:,2],[0.0,0.25,0.5,0.75,1.0],'-',linewidth=4,color= 'white')\n",
    "ax[0,1].plot(bc_mean_bin_contrast_d[:,0],[0.0,0.25,0.5,0.75,1.0],'-',linewidth=4,color= 'blue')\n",
    "ax[2,1].plot(bc_mean_bin_contrast_d[:,1],[0.0,0.25,0.5,0.75,1.0],'-',linewidth=4,color= 'blue')\n",
    "\n",
    "ax[0,0].set_title('Ventral')\n",
    "ax[0,1].set_title('Dorsal')\n",
    "ax[2,0].axvline(0, color = 'red', linestyle='dashed')\n",
    "ax[2,1].axvline(0, color = 'red', linestyle='dashed')\n",
    "\n",
    "plt.xlim(-1.2,1.2)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Amacrine cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1,2, figsize=(12,5))\n",
    "\n",
    "sns.heatmap(mean_bin_contrast_v,cmap=newcmp,vmin=-limit,vmax=limit,ax=ax[0])\n",
    "sns.heatmap(mean_bin_contrast_d,cmap=newcmp,vmin=-limit,vmax=limit,ax=ax[1])\n",
    "ax[0].set_title('Ventral')\n",
    "ax[1].set_title('Dorsal')\n",
    "ax[0].set_xticks([0.5, 1.5, 2.5])\n",
    "ax[0].set_xticklabels(['Center', 'Ring', 'Surround'])\n",
    "ax[1].set_xticks([0.5, 1.5, 2.5])\n",
    "ax[1].set_xticklabels(['Center', 'Ring', 'Surround'])\n",
    "ax[0].invert_yaxis()\n",
    "ax[1].invert_yaxis()\n",
    "\n",
    "# plt.savefig('SC_ipldistribution_mean_per_bin.eps', format='eps')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bipolar cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1,2, figsize=(12,5))\n",
    "\n",
    "sns.heatmap(bc_mean_bin_contrast_v,cmap=newcmp,vmin=-limit,vmax=limit,ax=ax[0])\n",
    "sns.heatmap(bc_mean_bin_contrast_d,cmap=newcmp,vmin=-limit,vmax=limit,ax=ax[1])\n",
    "ax[0].set_title('Ventral')\n",
    "ax[1].set_title('Dorsal')\n",
    "ax[0].set_xticks([0.5, 1.5])\n",
    "ax[0].set_xticklabels(['Center', 'Surround'])\n",
    "ax[1].set_xticks([0.5, 1.5])\n",
    "ax[1].set_xticklabels(['Center', 'Surround'])\n",
    "ax[0].invert_yaxis()\n",
    "ax[1].invert_yaxis()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clustering amacrine cells - Figure 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot kernel heatmap\n",
    "name = 'cluster ID (diag)'\n",
    "plot_Kernels(df, name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot location-ipldepth histos\n",
    "df_part= df.loc[df['y_retinal_location'] != 'no_info']\n",
    "df_part = df_part.reset_index(drop = True)\n",
    "name = 'cluster ID (diag)'\n",
    "plot_histogram(df_part, name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "# Plot kernel averages\n",
    "def plot_Kernels_MK(dataframe, clustering_name,  nb_clusters = None, savefig = False, path_means = None):\n",
    "    size_kernel = len(dataframe['uv_center'].iloc[0])\n",
    "    assert size_kernel == 1250\n",
    "    labels = ['UV-c', 'UV-r', 'UV-s', 'G-c', 'G-r', 'G-s']\n",
    "    # colors_uv = ['purple', 'purple', 'purple', 'darkgreen', 'limegreen', 'lime']\n",
    "    cluster_IDs = dataframe[clustering_name].to_numpy()\n",
    "    assert np.amin(cluster_IDs) == 0\n",
    "    if nb_clusters is None:\n",
    "        nb_clusters = np.unique(cluster_IDs).shape[0]\n",
    "    orderROIs = np.argsort(cluster_IDs)\n",
    "    cluster_sizes = np.zeros(nb_clusters)\n",
    "    for index in range(nb_clusters):\n",
    "        cluster_sizes[index] = np.where(cluster_IDs == index)[0].shape[0]\n",
    "    assert cluster_sizes.sum() == len(dataframe)\n",
    "    data_uv = np.concatenate((np.vstack(dataframe['uv_center'].to_numpy()),\n",
    "                              np.vstack(dataframe['uv_ring'].to_numpy()),\n",
    "                              np.vstack(dataframe['uv_surround'].to_numpy())), axis = 1)\n",
    "    data_green = np.concatenate((np.vstack(dataframe['green_center'].to_numpy()),\n",
    "                                 np.vstack(dataframe['green_ring'].to_numpy()),\n",
    "                                 np.vstack(dataframe['green_surround'].to_numpy())), axis = 1)\n",
    "\n",
    "    uv_center_amplitude = dataframe['uv_center amplitude']\n",
    "    uv_ring_amplitude = dataframe['uv_ring amplitude']\n",
    "    uv_surround_amplitude = dataframe['uv_surround amplitude']\n",
    "    green_center_amplitude = dataframe['green_center amplitude']\n",
    "    green_ring_amplitude = dataframe['green_ring amplitude']\n",
    "    green_surround_amplitude = dataframe['green_surround amplitude']\n",
    "    \n",
    "    cluster_means_uv = np.zeros((nb_clusters,size_kernel*3))\n",
    "    cluster_means_green = np.zeros((nb_clusters,size_kernel*3))\n",
    "    cluster_amplitude_uv = np.zeros((nb_clusters,3))\n",
    "    cluster_amplitude_green = np.zeros((nb_clusters,3))\n",
    "    std_amplitude_uv = np.zeros((nb_clusters,3))\n",
    "    std_amplitude_green = np.zeros((nb_clusters,3))\n",
    "\n",
    "    for current_cluster_ID in range(nb_clusters): \n",
    "        cluster_mask = np.where(cluster_IDs == current_cluster_ID)[0]\n",
    "        \n",
    "        current_data_uv = data_uv[cluster_mask,:]\n",
    "        current_data_green = data_green[cluster_mask,:]\n",
    "        cluster_means_uv[current_cluster_ID,:]  = np.mean(current_data_uv, axis = 0)\n",
    "        cluster_means_green[current_cluster_ID,:]  = np.mean(current_data_green, axis = 0)\n",
    "        \n",
    "        cluster_amplitude_uv[current_cluster_ID,0] = uv_center_amplitude[cluster_mask].mean()\n",
    "        cluster_amplitude_uv[current_cluster_ID,1] = uv_ring_amplitude[cluster_mask].mean()\n",
    "        cluster_amplitude_uv[current_cluster_ID,2] = uv_surround_amplitude[cluster_mask].mean()\n",
    "        cluster_amplitude_green[current_cluster_ID,0] = green_center_amplitude[cluster_mask].mean()\n",
    "        cluster_amplitude_green[current_cluster_ID,1] = green_ring_amplitude[cluster_mask].mean()\n",
    "        cluster_amplitude_green[current_cluster_ID,2] = green_surround_amplitude[cluster_mask].mean()\n",
    "        \n",
    "        std_amplitude_uv[current_cluster_ID,0] = np.std(uv_center_amplitude[cluster_mask])\n",
    "        std_amplitude_uv[current_cluster_ID,1] = np.std(uv_ring_amplitude[cluster_mask])\n",
    "        std_amplitude_uv[current_cluster_ID,2] = np.std(uv_surround_amplitude[cluster_mask])\n",
    "        std_amplitude_green[current_cluster_ID,0] = np.std(green_center_amplitude[cluster_mask])\n",
    "        std_amplitude_green[current_cluster_ID,1] = np.std(green_ring_amplitude[cluster_mask])\n",
    "        std_amplitude_green[current_cluster_ID,2] = np.std(green_surround_amplitude[cluster_mask])\n",
    "   \n",
    "    fig, ax = plt.subplots(nb_clusters, 6, sharey='all', figsize=(5,20))  \n",
    "\n",
    "    bar_colors = ['purple', 'darkgreen']\n",
    "\n",
    "    for current_cluster_ID in range(nb_clusters):\n",
    "        for current_kernel_ID in range(3):\n",
    "            #plotting kernels\n",
    "            start_kernel = current_kernel_ID*size_kernel\n",
    "            my_ax = ax[current_cluster_ID, current_kernel_ID*2] \n",
    "\n",
    "            my_ax.axis('off')\n",
    "            my_ax.axhline(0, color = 'black', linestyle = 'dashed', linewidth = 0.75)\n",
    "            my_ax.axvline(size_kernel*3/4, color = 'black', linestyle = 'dashed', linewidth = 0.75) \n",
    "            my_ax.plot(cluster_means_uv[current_cluster_ID,start_kernel:start_kernel+size_kernel],\n",
    "                       color = 'purple', linewidth = 1.0)\n",
    "            my_ax.plot(cluster_means_green[current_cluster_ID,start_kernel:start_kernel+size_kernel],\n",
    "                       color = 'darkgreen', linewidth = 1.0)\n",
    "            if current_cluster_ID == 0:\n",
    "                my_ax.set_title(labels[current_kernel_ID])\n",
    "            if current_kernel_ID == 0:\n",
    "                my_ax.vlines(x=[0], ymin=[0], ymax=[0.1], lw=1) #scale bar 0.1\n",
    "                \n",
    "            #plotting bars\n",
    "            second_ax = ax[current_cluster_ID, current_kernel_ID*2+1] \n",
    "            second_ax.bar([0,1],[cluster_amplitude_uv[current_cluster_ID,current_kernel_ID], \n",
    "                                 cluster_amplitude_green[current_cluster_ID,current_kernel_ID]],\n",
    "                          yerr=[std_amplitude_uv[current_cluster_ID,current_kernel_ID], \n",
    "                                 std_amplitude_green[current_cluster_ID,current_kernel_ID]],\n",
    "                          align='center', capsize=2, width=1, color=bar_colors) \n",
    "            second_ax.axhline(linestyle='dotted',c='k')\n",
    "            second_ax.axis('off')  \n",
    "\n",
    "    if savefig:\n",
    "        plt.savefig(path_means, dpi = 600, transparent=True, bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot kernel averages\n",
    "name = 'cluster ID (diag)'\n",
    "plot_Kernels_MK(df, name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot SC rings\n",
    "mean_cluster_spectral_contrast = np.zeros(((nb_clusters),3))\n",
    "\n",
    "for i in range(nb_clusters):\n",
    "\n",
    "    mean_cluster_spectral_contrast[i,0]= df['spectral_contrast_surround'].loc[df['cluster ID (diag)'] == i].mean() \n",
    "    mean_cluster_spectral_contrast[i,1]= df['spectral_contrast_ring'].loc[df['cluster ID (diag)'] == i].mean()\n",
    "    mean_cluster_spectral_contrast[i,2]= df['spectral_contrast_center'].loc[df['cluster ID (diag)'] == i].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import cm\n",
    "from matplotlib.colors import ListedColormap, LinearSegmentedColormap\n",
    "PiYG = cm.get_cmap('PiYG', 256)\n",
    "new_colors = PiYG(np.linspace(0, 1, 256))\n",
    "white = np.array([1, 1, 1, 1])\n",
    "\n",
    "limit = np.amax((np.amax(mean_cluster_spectral_contrast), np.abs(np.amin(mean_cluster_spectral_contrast))))\n",
    "\n",
    "vmax = limit \n",
    "vmin = -limit\n",
    "\n",
    "# To explain white borders take 1/4 of the max SC value across clusters and spatial conditions\n",
    "white_start = 1/4* limit \n",
    "white_stop = -1/4* limit \n",
    "\n",
    "total_range = vmax-vmin\n",
    "len_color = new_colors.shape[0]\n",
    "white_stop_percentage = np.abs(vmin-white_stop)*100/total_range\n",
    "white_stop_pixel = int(white_stop_percentage*len_color*0.01)\n",
    "white_start_percentage = np.abs(vmax-white_start)*100/total_range#*2\n",
    "white_start_pixel = int(white_start_percentage*len_color*0.01)\n",
    "\n",
    "new_colors[white_stop_pixel:-white_start_pixel,:] = white\n",
    "newcmp = ListedColormap(new_colors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.collections import PatchCollection\n",
    "from matplotlib.patches import Circle\n",
    "\n",
    "def plot_colorCircles(color_pref, cmap, limit = None, colorbar = False, savefig = False, path = None):\n",
    "    center = [0.5, 0.5]\n",
    "    radii = [0.3, 0.2, 0.1]\n",
    "    nb_clusters = color_pref.shape[0]\n",
    "    if limit is None:\n",
    "        limit = np.amax((np.amax(color_pref), np.abs(np.amin(color_pref))))\n",
    "    fig, ax = plt.subplots(nb_clusters, 1, figsize=(1.3,nb_clusters*1.5))\n",
    "    for k in range(nb_clusters):\n",
    "        patches = []\n",
    "        colors = color_pref[k,:]\n",
    "        for r in radii: \n",
    "            circle = Circle((center), r)\n",
    "            patches.append(circle)\n",
    "        p = PatchCollection(patches, cmap = cmap, edgecolor = 'black', linewidth = 0.5)\n",
    "        p.set_array(colors)\n",
    "        p.set_clim([-limit, limit])\n",
    "        ax[k].set_axis_off()\n",
    "        ax[k].add_collection(p)\n",
    "    if colorbar:\n",
    "        fig.colorbar(p, ax=ax[-1])\n",
    "    if savefig:\n",
    "        plt.savefig(path, dpi = 600)\n",
    "    plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_colorCircles(mean_cluster_spectral_contrast, newcmp, limit=None, colorbar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot OOi rings\n",
    "mean_ooindex= np.zeros(((nb_clusters),3))\n",
    "\n",
    "for i in range(nb_clusters):\n",
    "    mean_ooindex[i,0] = ooindex[np.where(df['cluster ID (diag)'] == i), 2].mean() #mean surround\n",
    "    mean_ooindex[i,1] = ooindex[np.where(df['cluster ID (diag)'] == i), 1].mean()\n",
    "    mean_ooindex[i,2] = ooindex[np.where(df['cluster ID (diag)'] == i), 0].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import cm\n",
    "from matplotlib.colors import ListedColormap, LinearSegmentedColormap\n",
    "PiYG = cm.get_cmap('Greys_r', 256)\n",
    "new_colors = PiYG(np.linspace(0, 1, 256))\n",
    "white = np.array([0.5, 0.5, 0.5, 0.5])\n",
    "\n",
    "limit = np.amax((np.amax(mean_ooindex), np.abs(np.amin(mean_ooindex))))\n",
    "\n",
    "vmax = limit \n",
    "vmin = -limit\n",
    "\n",
    "white_start = 1/4* limit\n",
    "white_stop = -1/4* limit \n",
    "\n",
    "total_range = vmax-vmin\n",
    "len_color = new_colors.shape[0]\n",
    "white_stop_percentage = np.abs(vmin-white_stop)*100/total_range\n",
    "white_stop_pixel = int(white_stop_percentage*len_color*0.01)\n",
    "white_start_percentage = np.abs(vmax-white_start)*100/total_range\n",
    "white_start_pixel = int(white_start_percentage*len_color*0.01)\n",
    "\n",
    "new_colors[white_stop_pixel:-white_start_pixel,:] = white\n",
    "newcmp = ListedColormap(new_colors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.collections import PatchCollection\n",
    "from matplotlib.patches import Circle\n",
    "\n",
    "def plot_polarityCircles(color_pref, cmap, limit = None, colorbar = False, savefig = False, path = None):\n",
    "    center = [0.5, 0.5]\n",
    "    radii = [0.3, 0.2, 0.1]\n",
    "    nb_clusters = color_pref.shape[0]\n",
    "    if limit is None:\n",
    "        limit = np.amax((np.amax(color_pref), np.abs(np.amin(color_pref))))\n",
    "    fig, ax = plt.subplots(nb_clusters, 1, figsize=(1.3,nb_clusters*1.5))\n",
    "    for k in range(nb_clusters):\n",
    "        patches = []\n",
    "        colors = color_pref[k,:]\n",
    "        for r in radii: \n",
    "            circle = Circle((center), r)\n",
    "            patches.append(circle)\n",
    "        p = PatchCollection(patches, cmap = cmap, edgecolor = 'black', linewidth = 0.5) #'Greys_r'\n",
    "        p.set_array(colors)\n",
    "        p.set_clim([-limit, limit])\n",
    "        ax[k].set_axis_off()\n",
    "        ax[k].add_collection(p)\n",
    "    if colorbar:\n",
    "        fig.colorbar(p, ax=ax[-1])\n",
    "    if savefig:\n",
    "        plt.savefig(path, dpi = 600)\n",
    "    plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_polarityCircles(mean_ooindex, newcmp, limit=None, colorbar = True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Chromatic tuning organizing principles - Figure 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# use chromatic kernels and group clusters together\n",
    "\n",
    "mean_cluster_uv_center= np.zeros(((nb_clusters),1250))\n",
    "mean_cluster_uv_ring= np.zeros(((nb_clusters),1250))\n",
    "mean_cluster_uv_surround= np.zeros(((nb_clusters),1250))\n",
    "mean_cluster_green_center= np.zeros(((nb_clusters),1250))\n",
    "mean_cluster_green_ring= np.zeros(((nb_clusters),1250))\n",
    "mean_cluster_green_surround= np.zeros(((nb_clusters),1250))\n",
    "\n",
    "for i in range(nb_clusters):\n",
    "    mean_cluster_uv_center[i,:]= np.stack(df['uv_center'].loc[df['cluster ID (diag)'] == i]).mean(axis=0)\n",
    "    mean_cluster_uv_ring[i,:]= np.stack(df['uv_ring'].loc[df['cluster ID (diag)'] == i]).mean(axis=0)\n",
    "    mean_cluster_uv_surround[i,:]= np.stack(df['uv_surround'].loc[df['cluster ID (diag)'] == i]).mean(axis=0)\n",
    "    mean_cluster_green_center[i,:]= np.stack(df['green_center'].loc[df['cluster ID (diag)'] == i]).mean(axis=0)\n",
    "    mean_cluster_green_ring[i,:]= np.stack(df['green_ring'].loc[df['cluster ID (diag)'] == i]).mean(axis=0)\n",
    "    mean_cluster_green_surround[i,:]= np.stack(df['green_surround'].loc[df['cluster ID (diag)'] == i]).mean(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "array = np.concatenate((mean_cluster_uv_center,mean_cluster_uv_ring,mean_cluster_uv_surround,\n",
    "                       mean_cluster_green_center,mean_cluster_green_ring,mean_cluster_green_surround),axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "array_norm = np.zeros((nb_clusters,array.shape[1]))\n",
    "\n",
    "for i in range(nb_clusters):\n",
    "    array_norm[i,:] = array[i,:]/(max(np.amax(array[i,:]),np.abs(np.amin(array[i,:]))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "corr_dist_features = np.zeros((nb_clusters,nb_clusters))\n",
    "\n",
    "for i in range(nb_clusters):\n",
    "    for j in range(nb_clusters):\n",
    "        corr_dist_features[i,j] = mean_squared_error(array_norm[i,:],array_norm[j,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# distance dendrograms or ACs & BCs\n",
    "from scipy.spatial.distance import squareform\n",
    "from scipy.cluster.hierarchy import dendrogram, linkage\n",
    "\n",
    "dists = squareform(corr_dist_features, checks=False)\n",
    "\n",
    "fig = plt.figure(figsize=(15,10))\n",
    "linkage_matrix = linkage(dists, \"weighted\",optimal_ordering=True)\n",
    "\n",
    "a=dendrogram(linkage_matrix,orientation='top')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clustering bipolar cells - Supplementary Figure 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "# Plot kernel heatmap\n",
    "def updated_plot_heatmap_BIPOLAR_cells(dataframe, clustering_name,  nb_clusters = None,\n",
    "                                       savefig = False, path_ROIs = None, path_means = None):\n",
    "    size_kernel = len(dataframe['uv_center'].iloc[0])\n",
    "    assert size_kernel == 1250\n",
    "    labels = ['UV-c', 'UV-s', 'G-c', 'G-s']\n",
    "    cluster_IDs_original = dataframe[clustering_name].to_numpy() # Based on IPL depth\n",
    "    cluster_IDs = np.zeros(len(cluster_IDs_original))\n",
    "    # Manually defined based on how we want to display the clusters\n",
    "    # Based on 1) dorso-ventral location and then on 2) IPL depth\n",
    "    order = np.array([2,5,7,10,12,14,16,21,22,23,3,4,6,8,9,11,13,15,17,18,19,20,24,1]) - 1\n",
    "    #2,5,7,10,12,14,16,21,22,23,3,6,8,9,11,15,18,19,20,1,4,13,17,24\n",
    "    for new_id, old_id in enumerate(order):\n",
    "        cluster_IDs[np.where(cluster_IDs_original == int(old_id))[0]] = int(new_id)\n",
    "    assert np.amin(cluster_IDs) == 0\n",
    "    if nb_clusters is None:\n",
    "        nb_clusters = np.unique(cluster_IDs).shape[0]\n",
    "    orderROIs = np.argsort(cluster_IDs)\n",
    "    cluster_sizes = np.zeros(nb_clusters)\n",
    "    for index in range(nb_clusters):\n",
    "        cluster_sizes[index] = np.where(cluster_IDs == index)[0].shape[0]\n",
    "    assert cluster_sizes.sum() == len(dataframe)\n",
    "    data = np.concatenate((np.vstack(dataframe['uv_center'].to_numpy()),\n",
    "                           np.vstack(dataframe['uv_surround'].to_numpy()),\n",
    "                           np.vstack(dataframe['green_center'].to_numpy()),\n",
    "                           np.vstack(dataframe['green_surround'].to_numpy())), axis = 1)\n",
    "    plt.figure(figsize=(8,15))\n",
    "    plt.imshow(data[orderROIs,:], aspect = 'auto', cmap = 'binary_r', interpolation = 'None')\n",
    "    plt.xticks(np.array([(size_kernel/2)+i*size_kernel for i in range(4)]) - 0.5, labels)\n",
    "    plt.yticks(np.array([cluster_sizes[i]/2 + cluster_sizes[:i].sum() for i in range(nb_clusters)]) - 0.5,\n",
    "               ['C$_{' +  str(i) + \"}$\" for i in range(nb_clusters)])\n",
    "    plt.tick_params(length = 0)\n",
    "    for index in range(1, 4):\n",
    "        plt.axvline((index*size_kernel)-0.5, color = 'white', linewidth = 0.75)\n",
    "    current_line_location = -0.5\n",
    "    for cluster_id in range(nb_clusters-1):\n",
    "        current_line_location = current_line_location + cluster_sizes[cluster_id]\n",
    "        plt.axhline(current_line_location, color = 'white', linewidth = 0.75)\n",
    "    if savefig:\n",
    "        plt.savefig(path_ROIs, dpi = 600, transparent=True, bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot kernel heatmap\n",
    "name = 'cluster ID (full)'\n",
    "updated_plot_heatmap_BIPOLAR_cells(df_bc, name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "# Plot location-ipldepth histos\n",
    "def updated_plot_histogram_BIPOLAR_cells(dataframe, clustering_name, savefig = False, path = None):   \n",
    "    cluster_IDs_original = dataframe[clustering_name].to_numpy() # Based on IPL depth\n",
    "    cluster_IDs = np.zeros(len(cluster_IDs_original))\n",
    "    # Manually defined based on how we want to display the clusters\n",
    "    # Based on 1) dorso-ventral location and then on 2) IPL depth\n",
    "    order = np.array([2,5,7,10,12,14,16,21,22,23,3,6,8,9,11,15,18,19,20,1,4,13,17,24]) - 1\n",
    "    for new_id, old_id in enumerate(order):\n",
    "        cluster_IDs[np.where(cluster_IDs_original == int(old_id))[0]] = int(new_id)\n",
    "    \n",
    "    nb_clusters = np.unique(cluster_IDs).shape[0]\n",
    "    y_location = dataframe['Rel. field location y'].to_numpy()\n",
    "    IPL_depth = dataframe['ipl_depth'].to_numpy()\n",
    "\n",
    "    # Hist for V-D location\n",
    "    max_y_location = np.amax(y_location)\n",
    "    min_y_location = np.amin(y_location)\n",
    "    total_max_y = np.amax((max_y_location, np.abs(min_y_location)))\n",
    "    binwidth_y_location = 200\n",
    "    bins_y = np.arange(min_y_location, max_y_location + binwidth_y_location, binwidth_y_location)\n",
    "\n",
    "    # Hist for IPL depth\n",
    "    max_IPL = np.amax(IPL_depth)\n",
    "    min_IPL = np.amin(IPL_depth)\n",
    "    binwidth_IPL = 0.05\n",
    "    bins_IPL = np.arange(min_IPL, max_IPL + binwidth_IPL, binwidth_IPL)\n",
    "\n",
    "    fig, ax = plt.subplots(nb_clusters, 2, sharex='col', sharey='col', figsize=(4,15))\n",
    "    for current_cluster_ID in range(nb_clusters):\n",
    "\n",
    "        cluster_mask = np.where(cluster_IDs == current_cluster_ID)[0]\n",
    "        current_y = y_location[cluster_mask]\n",
    "        current_IPL = IPL_depth[cluster_mask]\n",
    "\n",
    "        # V-D \n",
    "        my_ax = ax[current_cluster_ID, 0]\n",
    "        my_ax.axvline(0, color = 'black', linewidth = 0.5)\n",
    "        my_ax.hist(current_y, bins = bins_y, color = 'black')\n",
    "        my_ax.set_xlim([-total_max_y-300,total_max_y+300])\n",
    "\n",
    "        if current_cluster_ID == 0:\n",
    "            my_ax.set_title('V-D location')\n",
    "        if current_cluster_ID == (nb_clusters - 1):\n",
    "            my_ax.set_ylabel('Nb of Rois')\n",
    "            my_ax.set_xlabel('Ventral - dorsal \\n(Rel. y location)')\n",
    "        else:\n",
    "            my_ax.tick_params(axis='x', which='both', bottom=False)\n",
    "            my_ax.tick_params(axis='y', which='both', left=False, labelsize=0.0, labelcolor = 'white')\n",
    "\n",
    "        #IPL\n",
    "        my_ax = ax[current_cluster_ID, 1]\n",
    "        my_ax.axvline(0, color = 'black', linewidth = 0.5)\n",
    "        my_ax.axvline(1, color = 'black', linewidth = 0.5)\n",
    "        my_ax.hist(current_IPL, bins = bins_IPL, color = 'black')\n",
    "\n",
    "        if current_cluster_ID == 0:\n",
    "            my_ax.set_title('IPL depth')\n",
    "        if current_cluster_ID == (nb_clusters - 1):\n",
    "            my_ax.set_xlabel('IPL depth')\n",
    "            my_ax.set_xticks([0,1])    \n",
    "            my_ax.set_xticklabels(['GCL', 'INL'])\n",
    "        else:\n",
    "            my_ax.tick_params(axis='x', which='both', bottom=False)            \n",
    "            my_ax.tick_params(axis='y', which='both', left=False, labelsize=0.0, labelcolor = 'white')\n",
    "\n",
    "    if savefig:\n",
    "        plt.savefig(path, dpi = 600, transparent=True, bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot location-ipldepth histos\n",
    "name = 'cluster ID (full)'\n",
    "updated_plot_histogram_BIPOLAR_cells(df_bc, name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "# Plot kernel averages\n",
    "def updated_plot_Kernels_BIPOLAR_cells(dataframe, clustering_name,  nb_clusters = None, savefig = False, path_means = None):\n",
    "    size_kernel = len(dataframe['uv_center'].iloc[0])\n",
    "    assert size_kernel == 1250\n",
    "    labels = ['UV-c', 'UV-s', 'G-c', 'G-s']\n",
    "    cluster_IDs_original = dataframe[clustering_name].to_numpy() # Based on IPL depth\n",
    "    cluster_IDs = np.zeros(len(cluster_IDs_original))\n",
    "    # Manually defined based on how we want to display the clusters\n",
    "    # Based on 1) dorso-ventral location and then on 2) IPL depth\n",
    "    order = np.array([2,5,7,10,12,14,16,21,22,23,3,6,8,9,11,15,18,19,20,1,4,13,17,24]) - 1\n",
    "    for new_id, old_id in enumerate(order):\n",
    "        cluster_IDs[np.where(cluster_IDs_original == int(old_id))[0]] = int(new_id)\n",
    "    assert np.amin(cluster_IDs) == 0\n",
    "    if nb_clusters is None:\n",
    "        nb_clusters = np.unique(cluster_IDs).shape[0]\n",
    "    orderROIs = np.argsort(cluster_IDs)\n",
    "    cluster_sizes = np.zeros(nb_clusters)\n",
    "    for index in range(nb_clusters):\n",
    "        cluster_sizes[index] = np.where(cluster_IDs == index)[0].shape[0]\n",
    "    assert cluster_sizes.sum() == len(dataframe)\n",
    "    data_uv = np.concatenate((np.vstack(dataframe['uv_center'].to_numpy()),\n",
    "                              np.vstack(dataframe['uv_surround'].to_numpy())), axis = 1)\n",
    "    data_green = np.concatenate((np.vstack(dataframe['green_center'].to_numpy()),\n",
    "                                 np.vstack(dataframe['green_surround'].to_numpy())), axis = 1)\n",
    "    uv_center_amplitude = dataframe['uv_center amplitude']\n",
    "    uv_surround_amplitude = dataframe['uv_surround amplitude']\n",
    "    green_center_amplitude = dataframe['green_center amplitude']\n",
    "    green_surround_amplitude = dataframe['green_surround amplitude']\n",
    "    cluster_means_uv = np.zeros((nb_clusters,size_kernel*2))\n",
    "    cluster_means_green = np.zeros((nb_clusters,size_kernel*2))\n",
    "    cluster_amplitude_uv = np.zeros((nb_clusters,2))\n",
    "    cluster_amplitude_green = np.zeros((nb_clusters,2))\n",
    "    std_amplitude_uv = np.zeros((nb_clusters,2))\n",
    "    std_amplitude_green = np.zeros((nb_clusters,2))\n",
    "    for current_cluster_ID in range(nb_clusters): # Calculate cluster means\n",
    "        cluster_mask = np.where(cluster_IDs == current_cluster_ID)[0]\n",
    "        current_data_uv = data_uv[cluster_mask,:]\n",
    "        current_data_green = data_green[cluster_mask,:]\n",
    "        cluster_means_uv[current_cluster_ID,:]  = np.mean(current_data_uv, axis = 0)\n",
    "        cluster_means_green[current_cluster_ID,:]  = np.mean(current_data_green, axis = 0)\n",
    "        cluster_amplitude_uv[current_cluster_ID,0] = uv_center_amplitude[cluster_mask].mean()\n",
    "        cluster_amplitude_uv[current_cluster_ID,1] = uv_surround_amplitude[cluster_mask].mean()\n",
    "        cluster_amplitude_green[current_cluster_ID,0] = green_center_amplitude[cluster_mask].mean()\n",
    "        cluster_amplitude_green[current_cluster_ID,1] = green_surround_amplitude[cluster_mask].mean()\n",
    "        std_amplitude_uv[current_cluster_ID,0] = np.std(uv_center_amplitude[cluster_mask])\n",
    "        std_amplitude_uv[current_cluster_ID,1] = np.std(uv_surround_amplitude[cluster_mask])\n",
    "        std_amplitude_green[current_cluster_ID,1] = np.std(green_center_amplitude[cluster_mask])\n",
    "        std_amplitude_green[current_cluster_ID,1] = np.std(green_surround_amplitude[cluster_mask])\n",
    "#     fig, ax = plt.subplots(nb_clusters, 2, sharex='all', sharey='row', figsize=(5,20))\n",
    "    fig, ax = plt.subplots(nb_clusters, 4, sharey='row', figsize=(5,20))\n",
    "    bar_colors = ['purple', 'darkgreen']\n",
    "    for current_cluster_ID in range(nb_clusters):\n",
    "        for current_kernel_ID in range(2):\n",
    "            #plotting kernels\n",
    "            start_kernel = current_kernel_ID*size_kernel\n",
    "#             my_ax = ax[current_cluster_ID, current_kernel_ID]\n",
    "            my_ax = ax[current_cluster_ID, current_kernel_ID*2]\n",
    "            my_ax.axis('off')\n",
    "            my_ax.axhline(0, color = 'black', linestyle = 'dashed', linewidth = 0.75)\n",
    "            my_ax.axvline(size_kernel*3/4, color = 'black', linestyle = 'dashed', linewidth = 0.75)\n",
    "            my_ax.plot(cluster_means_uv[current_cluster_ID,start_kernel:start_kernel+size_kernel],\n",
    "                       color = 'purple', linewidth = 1.0)\n",
    "            my_ax.plot(cluster_means_green[current_cluster_ID,start_kernel:start_kernel+size_kernel],\n",
    "                       color = 'darkgreen', linewidth = 1.0)\n",
    "            if current_cluster_ID == 0:\n",
    "                my_ax.set_title(labels[current_kernel_ID])\n",
    "            if current_kernel_ID == 0:\n",
    "                my_ax.vlines(x=[0], ymin=[0], ymax=[0.2], lw=1) #scale bar 0.2\n",
    "            #plotting bars\n",
    "            second_ax = ax[current_cluster_ID, current_kernel_ID*2+1]\n",
    "            second_ax.bar([0,1],[cluster_amplitude_uv[current_cluster_ID,current_kernel_ID],\n",
    "                                 cluster_amplitude_green[current_cluster_ID,current_kernel_ID]],\n",
    "                          yerr=[std_amplitude_uv[current_cluster_ID,current_kernel_ID],\n",
    "                                 std_amplitude_green[current_cluster_ID,current_kernel_ID]],\n",
    "                          align='center', capsize=2, width=1, color=bar_colors)\n",
    "            second_ax.axhline(linestyle='dotted',c='k')\n",
    "            second_ax.axis('off')\n",
    "    if savefig:\n",
    "        plt.savefig(path_means, dpi = 600, transparent=True, bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot kernel averages\n",
    "name = 'cluster ID (full)'\n",
    "updated_plot_Kernels_BIPOLAR_cells(df_bc, name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot SC rings\n",
    "bc_mean_cluster_spectral_contrast = np.zeros(((bc_nb_clusters),2))\n",
    "for i in range(bc_nb_clusters):\n",
    "    bc_mean_cluster_spectral_contrast[i,0]= df_bc['spectral_contrast_surround'][np.where(bc_cluster_id == i)[0]].mean() # surround first for plot with circles\n",
    "    bc_mean_cluster_spectral_contrast[i,1]= df_bc['spectral_contrast_center'][np.where(bc_cluster_id == i)[0]].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import cm\n",
    "from matplotlib.colors import ListedColormap, LinearSegmentedColormap\n",
    "PiYG = cm.get_cmap('PiYG', 256)\n",
    "new_colors = PiYG(np.linspace(0, 1, 256))\n",
    "white = np.array([1, 1, 1, 1])\n",
    "\n",
    "limit = np.amax((np.amax(mean_cluster_spectral_contrast), np.abs(np.amin(mean_cluster_spectral_contrast))))\n",
    "\n",
    "vmax = limit \n",
    "vmin = -limit\n",
    "\n",
    "white_start = 1/4* limit \n",
    "white_stop = -1/4* limit \n",
    "\n",
    "total_range = vmax-vmin\n",
    "len_color = new_colors.shape[0]\n",
    "white_stop_percentage = np.abs(vmin-white_stop)*100/total_range\n",
    "white_stop_pixel = int(white_stop_percentage*len_color*0.01)\n",
    "white_start_percentage = np.abs(vmax-white_start)*100/total_range\n",
    "white_start_pixel = int(white_start_percentage*len_color*0.01)\n",
    "\n",
    "new_colors[white_stop_pixel:-white_start_pixel,:] = white\n",
    "newcmp = ListedColormap(new_colors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.collections import PatchCollection\n",
    "from matplotlib.patches import Circle\n",
    "\n",
    "def plot_colorCircles_BIPOLAR_cells(color_pref,  cmap, limit = None, colorbar = False, savefig = False, path = None):\n",
    "    center = [0.5, 0.5]\n",
    "    radii = [0.3, 0.1]\n",
    "    bc_nb_clusters = color_pref.shape[0]\n",
    "    if limit is None:\n",
    "        limit = np.amax((np.amax(color_pref), np.abs(np.amin(color_pref))))\n",
    "    fig, ax = plt.subplots(bc_nb_clusters, 1, figsize=(1.3,bc_nb_clusters*1.5))\n",
    "    for k in range(bc_nb_clusters):\n",
    "        patches = []\n",
    "        colors = color_pref[k,:]\n",
    "        for r in radii: \n",
    "            circle = Circle((center), r)\n",
    "            patches.append(circle)\n",
    "        p = PatchCollection(patches, cmap = cmap, edgecolor = 'black', linewidth = 0.5)\n",
    "        p.set_array(colors)\n",
    "        p.set_clim([-limit, limit])\n",
    "        ax[k].set_axis_off()\n",
    "        ax[k].add_collection(p)\n",
    "    if colorbar:\n",
    "        fig.colorbar(p, ax=ax[-1])\n",
    "    if savefig:\n",
    "        plt.savefig(path, dpi = 600)\n",
    "    plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_colorCircles_BIPOLAR_cells(bc_mean_cluster_spectral_contrast, newcmp, limit=limit, colorbar=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot OOi rings\n",
    "bc_mean_ooindex= np.zeros(((bc_nb_clusters),3))\n",
    "\n",
    "for i in range(bc_nb_clusters):\n",
    "    bc_mean_ooindex[i,0] = bc_ooindex[np.where(df_bc['cluster ID (full)'] == i), 1].mean() #mean surround\n",
    "    bc_mean_ooindex[i,1] = bc_ooindex[np.where(df_bc['cluster ID (full)'] == i), 0].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import cm\n",
    "from matplotlib.colors import ListedColormap, LinearSegmentedColormap\n",
    "PiYG = cm.get_cmap('Greys_r', 256)\n",
    "new_colors = PiYG(np.linspace(0, 1, 256))\n",
    "white = np.array([0.5, 0.5, 0.5, 0.5])\n",
    "\n",
    "limit = np.amax((np.amax(mean_ooindex), np.abs(np.amin(mean_ooindex))))\n",
    "\n",
    "vmax = limit \n",
    "vmin = -limit\n",
    "\n",
    "white_start = 1/4* limit further analysis\n",
    "white_stop = -1/4* limit \n",
    "\n",
    "total_range = vmax-vmin\n",
    "len_color = new_colors.shape[0]\n",
    "white_stop_percentage = np.abs(vmin-white_stop)*100/total_range\n",
    "white_stop_pixel = int(white_stop_percentage*len_color*0.01)\n",
    "white_start_percentage = np.abs(vmax-white_start)*100/total_range\n",
    "white_start_pixel = int(white_start_percentage*len_color*0.01)\n",
    "\n",
    "new_colors[white_stop_pixel:-white_start_pixel,:] = white\n",
    "newcmp = ListedColormap(new_colors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.collections import PatchCollection\n",
    "from matplotlib.patches import Circle\n",
    "\n",
    "def BIPOLAR_plot_polarityCircles(color_pref, cmap, limit = None, colorbar = False, savefig = False, path = None):\n",
    "    center = [0.5, 0.5]\n",
    "    radii = [0.3, 0.1]\n",
    "    nb_clusters = color_pref.shape[0]\n",
    "    if limit is None:\n",
    "        limit = np.amax((np.amax(color_pref), np.abs(np.amin(color_pref))))\n",
    "    fig, ax = plt.subplots(nb_clusters, 1, figsize=(1.3,nb_clusters*1.5))\n",
    "    for k in range(nb_clusters):\n",
    "        patches = []\n",
    "        colors = color_pref[k,:]\n",
    "        for r in radii: \n",
    "            circle = Circle((center), r)\n",
    "            patches.append(circle)\n",
    "        p = PatchCollection(patches, cmap = cmap, edgecolor = 'black', linewidth = 0.5) #'Greys_r'\n",
    "        p.set_array(colors)\n",
    "        p.set_clim([-limit, limit])\n",
    "        ax[k].set_axis_off()\n",
    "        ax[k].add_collection(p)\n",
    "    if colorbar:\n",
    "        fig.colorbar(p, ax=ax[-1])\n",
    "    if savefig:\n",
    "        plt.savefig(path, dpi = 600)\n",
    "    plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BIPOLAR_plot_polarityCircles(bc_mean_ooindex, newcmp, limit=limit, colorbar = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Chirp plotting - Supplementary Figure 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "# Plot chirp averages\n",
    "def plot_Chirp(dataframe, clustering_name, nb_clusters = None, savefig = False, path_ROIs = None, path_means = None):\n",
    "    size_chirp = len(dataframe['global_chirp_snippets'].iloc[0])\n",
    "    assert size_chirp == len(dataframe['local_chirp_snippets'].iloc[0])\n",
    "    labels = ['global_chirp', 'local_chirp']\n",
    "    colors = ['black', 'black'] #'#fdbb84', '#e34a33'\n",
    "    cluster_IDs = dataframe[clustering_name].to_numpy()\n",
    "    assert np.amin(cluster_IDs) == 0\n",
    "    if nb_clusters is None:\n",
    "        nb_clusters = np.unique(cluster_IDs).shape[0]\n",
    "    orderROIs = np.argsort(cluster_IDs)\n",
    "    cluster_sizes = np.zeros(nb_clusters)\n",
    "    for index in range(nb_clusters):\n",
    "        cluster_sizes[index] = np.where(cluster_IDs == index)[0].shape[0]    \n",
    "    assert cluster_sizes.sum() == len(dataframe)\n",
    "    data = np.concatenate((np.vstack(dataframe['global_chirp_snippets'].to_numpy()),\n",
    "                           np.vstack(dataframe['local_chirp_snippets'].to_numpy())), axis = 1)\n",
    "    \n",
    "#     plt.figure(figsize=(8,8))\n",
    "#     plt.imshow(data[orderROIs,:], aspect = 'auto', cmap = 'binary_r', interpolation = 'None', vmin=-2,vmax=4)\n",
    "#     plt.xticks(np.array([(size_chirp/2)+i*size_chirp for i in range(2)]) - 0.5, labels)\n",
    "#     plt.yticks(np.array([cluster_sizes[i]/2 + cluster_sizes[:i].sum() for i in range(nb_clusters)]) - 0.5,\n",
    "#                ['C$_{' +  str(i) + \"}$\" for i in range(nb_clusters)])\n",
    "#     plt.tick_params(length = 0)\n",
    "#     plt.axvline(size_chirp-0.5, color = 'white', linewidth = 0.75)\n",
    "#     current_line_location = -0.5\n",
    "#     for cluster_id in range(nb_clusters-1):\n",
    "#         current_line_location = current_line_location + cluster_sizes[cluster_id]\n",
    "#         plt.axhline(current_line_location, color = 'white', linewidth = 0.75)\n",
    "#     if savefig:\n",
    "#         plt.savefig(path_ROIs, dpi = 600, transparent=True, bbox_inches='tight')\n",
    "#     plt.show()\n",
    "#     print(data.min(), data.max())\n",
    "    \n",
    "    cluster_means = np.zeros((nb_clusters,size_chirp*2))\n",
    "    cluster_stds = np.zeros((nb_clusters,size_chirp*2))\n",
    "    for current_cluster_ID in range(nb_clusters): # Calculate cluster means\n",
    "        cluster_mask = np.where(cluster_IDs == current_cluster_ID)[0]\n",
    "        current_data = data[cluster_mask,:]\n",
    "        cluster_means[current_cluster_ID,:] = np.mean(current_data, axis = 0)\n",
    "        cluster_stds[current_cluster_ID,:]  = np.std(current_data, axis = 0)\n",
    "        \n",
    "    fig, ax = plt.subplots(nb_clusters, 2, sharey='all',figsize=(8,25)) #sharey='all'\n",
    "    for current_cluster_ID in range(nb_clusters):\n",
    "        for current_kernel_ID in range(2):\n",
    "            start_kernel = current_kernel_ID*size_chirp\n",
    "            my_ax = ax[current_cluster_ID, current_kernel_ID]\n",
    "            my_ax.axis('off')\n",
    "            my_ax.axhline(0, color = 'black', linestyle = 'dashed', linewidth = 0.75)\n",
    "            my_ax.plot(cluster_means[current_cluster_ID,start_kernel:start_kernel+size_chirp],\n",
    "                       color = colors[current_kernel_ID], linewidth = 1.5)\n",
    "            my_ax.fill_between(np.arange(size_chirp),cluster_means[current_cluster_ID,start_kernel:start_kernel+size_chirp]\n",
    "                               -cluster_stds[current_cluster_ID,start_kernel:start_kernel+size_chirp],\n",
    "                               cluster_means[current_cluster_ID,start_kernel:start_kernel+size_chirp]\n",
    "                               +cluster_stds[current_cluster_ID,start_kernel:start_kernel+size_chirp],\n",
    "                               alpha=0.1, color=colors[current_kernel_ID],zorder=5)  \n",
    "            if current_cluster_ID == 0:\n",
    "                my_ax.set_title(labels[current_kernel_ID])\n",
    "            if current_kernel_ID == 0:\n",
    "                my_ax.vlines(x=[-1], ymin=[0], ymax=[1], lw=1) #scale bar 0.3\n",
    "    if savefig:\n",
    "        plt.savefig(path_means, dpi = 600, transparent=True, bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot chirp averages\n",
    "df_chirp = pd.read_pickle('AC_data_unnorm.pkl')  # to include the correct chirp figures\n",
    "name = 'cluster ID (diag)'\n",
    "plot_Chirp(df_chirp, name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pharmacology - Figure 5, Supplementary Figures 3 & 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_pickle('control_drug_data.pkl')\n",
    "df = df.loc[(df['uv_center quality'] == 1) | (df['green_center quality'] == 1) | (df['uv_center_drug quality'] == 1) | (df['green_center_drug quality'] == 1)]\n",
    "df= df.loc[df['y_retinal_location'] != 'no_info']\n",
    "\n",
    "df = df.loc[df['treatment']=='TPMPA;gabazine']\n",
    "# df = df.loc[df['treatment']=='strychnine']\n",
    "# df = df.loc[df['treatment']=='LAP4']\n",
    "# df = df.loc[df['treatment']=='UBP310']\n",
    "\n",
    "df = df.reset_index(drop = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Calculate ROIs' SC\n",
    "\n",
    "spectral_contrast = np.zeros((df['roi'].shape[0], 3))\n",
    "drug_spectral_contrast = np.zeros((df['roi'].shape[0], 3))\n",
    "\n",
    "for i in range(df['roi'].shape[0]):\n",
    "    spectral_contrast[i,0] = (abs(df['green_center amplitude'][i]) - abs(df['uv_center amplitude'][i])) / (abs(df['green_center amplitude'][i]) + abs(df['uv_center amplitude'][i]))\n",
    "    spectral_contrast[i,1] = (abs(df['green_ring amplitude'][i]) - abs(df['uv_ring amplitude'][i])) / (abs(df['green_ring amplitude'][i]) + abs(df['uv_ring amplitude'][i])) \n",
    "    spectral_contrast[i,2] = (abs(df['green_surround amplitude'][i]) - abs(df['uv_surround amplitude'][i])) / (abs(df['green_surround amplitude'][i]) + abs(df['uv_surround amplitude'][i])) \n",
    "    \n",
    "    drug_spectral_contrast[i,0] = (abs(df['green_center_drug amplitude'][i]) - abs(df['uv_center_drug amplitude'][i])) / (abs(df['green_center_drug amplitude'][i]) + abs(df['uv_center_drug amplitude'][i]))\n",
    "    drug_spectral_contrast[i,1] = (abs(df['green_ring_drug amplitude'][i]) - abs(df['uv_ring_drug amplitude'][i])) / (abs(df['green_ring_drug amplitude'][i]) + abs(df['uv_ring_drug amplitude'][i])) \n",
    "    drug_spectral_contrast[i,2] = (abs(df['green_surround_drug amplitude'][i]) - abs(df['uv_surround_drug amplitude'][i])) / (abs(df['green_surround_drug amplitude'][i]) + abs(df['uv_surround_drug amplitude'][i])) \n",
    "\n",
    "df['SC_center_ctrl'] = spectral_contrast[:,0]\n",
    "df['SC_ring_ctrl'] = spectral_contrast[:,1]\n",
    "df['SC_surround_ctrl'] = spectral_contrast[:,2]\n",
    "df['SC_center_drug'] = drug_spectral_contrast[:,0]\n",
    "df['SC_ring_drug'] = drug_spectral_contrast[:,1]\n",
    "df['SC_surround_drug'] = drug_spectral_contrast[:,2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Calculate ROIs' OOi\n",
    "\n",
    "line_duration = 1.6 #in ms\n",
    "baseline_time_ms = int(1000/line_duration)\n",
    "on_time_ms = int(1000/line_duration)\n",
    "off_time_ms = int(1000/line_duration)\n",
    "\n",
    "a=np.zeros(baseline_time_ms)\n",
    "b=np.ones(on_time_ms)\n",
    "c=np.zeros(off_time_ms)\n",
    "\n",
    "stimulus= np.concatenate((a,b,c))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "convolved_response_c = np.zeros((df['roi'].shape[0],stimulus.shape[0]))\n",
    "convolved_response_r = np.zeros((df['roi'].shape[0],stimulus.shape[0]))\n",
    "convolved_response_s = np.zeros((df['roi'].shape[0],stimulus.shape[0]))\n",
    "drug_convolved_response_c = np.zeros((df['roi'].shape[0],stimulus.shape[0]))\n",
    "drug_convolved_response_r = np.zeros((df['roi'].shape[0],stimulus.shape[0]))\n",
    "drug_convolved_response_s = np.zeros((df['roi'].shape[0],stimulus.shape[0]))\n",
    "\n",
    "line_duration_s = 0.0016\n",
    "kernel_length_line = np.int(np.floor(2/line_duration_s))\n",
    "offset_after = np.int(np.floor(kernel_length_line*.25)) #lines to include into the future (using 1/4 of kernel length)\n",
    "offset_before = kernel_length_line-offset_after\n",
    "kernel_past = offset_before - 1\n",
    "kernel_future = offset_after\n",
    "\n",
    "for i in range(df['roi'].shape[0]):\n",
    "    convolved_response_uv_c = np.convolve(stimulus, np.flip(df['uv_center'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    convolved_response_green_c = np.convolve(stimulus, np.flip(df['green_center'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    convolved_response_uv_r = np.convolve(stimulus, np.flip(df['uv_ring'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    convolved_response_green_r = np.convolve(stimulus, np.flip(df['green_ring'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    convolved_response_uv_s = np.convolve(stimulus, np.flip(df['uv_surround'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    convolved_response_green_s = np.convolve(stimulus, np.flip(df['green_surround'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    \n",
    "    convolved_response_c_avg = (convolved_response_uv_c + convolved_response_green_c)/2\n",
    "    convolved_response_r_avg = (convolved_response_uv_r + convolved_response_green_r)/2\n",
    "    convolved_response_s_avg = (convolved_response_uv_s + convolved_response_green_s)/2\n",
    "    \n",
    "    convolved_response_c[i,:] = convolved_response_c_avg-convolved_response_c_avg.min()\n",
    "    convolved_response_r[i,:] = convolved_response_r_avg-convolved_response_r_avg.min()\n",
    "    convolved_response_s[i,:] = convolved_response_s_avg-convolved_response_s_avg.min()\n",
    "    \n",
    "    drug_convolved_response_uv_c = np.convolve(stimulus, np.flip(df['uv_center_drug'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    drug_convolved_response_green_c = np.convolve(stimulus, np.flip(df['green_center_drug'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    drug_convolved_response_uv_r = np.convolve(stimulus, np.flip(df['uv_ring_drug'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    drug_convolved_response_green_r = np.convolve(stimulus, np.flip(df['green_ring_drug'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    drug_convolved_response_uv_s = np.convolve(stimulus, np.flip(df['uv_surround_drug'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    drug_convolved_response_green_s = np.convolve(stimulus, np.flip(df['green_surround_drug'][i]), mode='full')[kernel_future:-kernel_past]\n",
    "    \n",
    "    drug_convolved_response_c_avg = (drug_convolved_response_uv_c + drug_convolved_response_green_c)/2\n",
    "    drug_convolved_response_r_avg = (drug_convolved_response_uv_r + drug_convolved_response_green_r)/2\n",
    "    drug_convolved_response_s_avg = (drug_convolved_response_uv_s + drug_convolved_response_green_s)/2\n",
    "    \n",
    "    drug_convolved_response_c[i,:] = drug_convolved_response_c_avg-drug_convolved_response_c_avg.min()\n",
    "    drug_convolved_response_r[i,:] = drug_convolved_response_r_avg-drug_convolved_response_r_avg.min()\n",
    "    drug_convolved_response_s[i,:] = drug_convolved_response_s_avg-drug_convolved_response_s_avg.min()    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "window_on_start = 625\n",
    "window_on_end = 1250\n",
    "\n",
    "window_off_start = 1250\n",
    "window_off_end = 1875\n",
    "\n",
    "ooindex = np.zeros((df['roi'].shape[0],3))\n",
    "drug_ooindex = np.zeros((df['roi'].shape[0],3))\n",
    "\n",
    "for i in range(df['roi'].shape[0]):\n",
    "    ooindex[i,0] = (convolved_response_c[i,window_on_start:window_on_end].mean() - convolved_response_c[i,window_off_start:window_off_end].mean())/(convolved_response_c[i,window_on_start:window_on_end].mean() + convolved_response_c[i,window_off_start:window_off_end].mean())\n",
    "    ooindex[i,1] = (convolved_response_r[i,window_on_start:window_on_end].mean() - convolved_response_r[i,window_off_start:window_off_end].mean())/(convolved_response_r[i,window_on_start:window_on_end].mean() + convolved_response_r[i,window_off_start:window_off_end].mean())\n",
    "    ooindex[i,2] = (convolved_response_s[i,window_on_start:window_on_end].mean() - convolved_response_s[i,window_off_start:window_off_end].mean())/(convolved_response_s[i,window_on_start:window_on_end].mean() + convolved_response_s[i,window_off_start:window_off_end].mean())\n",
    "    \n",
    "    drug_ooindex[i,0] = (drug_convolved_response_c[i,window_on_start:window_on_end].mean() - drug_convolved_response_c[i,window_off_start:window_off_end].mean())/(drug_convolved_response_c[i,window_on_start:window_on_end].mean() + drug_convolved_response_c[i,window_off_start:window_off_end].mean())\n",
    "    drug_ooindex[i,1] = (drug_convolved_response_r[i,window_on_start:window_on_end].mean() - drug_convolved_response_r[i,window_off_start:window_off_end].mean())/(drug_convolved_response_r[i,window_on_start:window_on_end].mean() + drug_convolved_response_r[i,window_off_start:window_off_end].mean())\n",
    "    drug_ooindex[i,2] = (drug_convolved_response_s[i,window_on_start:window_on_end].mean() - drug_convolved_response_s[i,window_off_start:window_off_end].mean())/(drug_convolved_response_s[i,window_on_start:window_on_end].mean() + drug_convolved_response_s[i,window_off_start:window_off_end].mean())\n",
    "\n",
    "df['OOi_center_ctrl'] = ooindex[:,0]\n",
    "df['OOi_ring_ctrl'] = ooindex[:,1]\n",
    "df['OOi_surround_ctrl'] = ooindex[:,2]\n",
    "df['OOi_center_drug'] = drug_ooindex[:,0]\n",
    "df['OOi_ring_drug'] = drug_ooindex[:,1]\n",
    "df['OOi_surround_drug'] = drug_ooindex[:,2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calclate Delta SC\n",
    "df['Delta SC Center'] = df['SC_center_drug'] - df['SC_center_ctrl']\n",
    "df['Delta SC Ring'] = df['SC_ring_drug'] - df['SC_ring_ctrl']\n",
    "df['Delta SC Surround'] = df['SC_surround_drug'] - df['SC_surround_ctrl']\n",
    "\n",
    "# Calclate Delta OOi\n",
    "df['Delta OOi Center'] = df['OOi_center_drug'] - df['OOi_center_ctrl']\n",
    "df['Delta OOi Ring'] = df['OOi_ring_drug'] - df['OOi_ring_ctrl']\n",
    "df['Delta OOi Surround'] = df['OOi_surround_drug'] - df['OOi_surround_ctrl']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# On/Off\n",
    "df_on = df.loc[(df['OOi_center_ctrl'] > 0)]\n",
    "df_on = df_on.reset_index(drop = True)\n",
    "df_off = df.loc[df['OOi_center_ctrl'] < 0]\n",
    "df_off = df_off.reset_index(drop = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# kdeplot On/Off \n",
    "fig, ax = plt.subplots(3,2, figsize=(10,7), sharex='all')\n",
    "\n",
    "#center\n",
    "sns.kdeplot(x='Delta OOi Center',\n",
    "        data= df_on, color='darkred', ax=ax[0,0])\n",
    "sns.kdeplot(x='Delta OOi Center',\n",
    "        data= df_off, color='darkblue', ax=ax[0,0])\n",
    "sns.kdeplot(x='Delta SC Center',\n",
    "        data= df_on, color='darkred', ax=ax[0,1],label='On')\n",
    "sns.kdeplot(x='Delta SC Center',\n",
    "        data= df_off, color='darkblue', ax=ax[0,1],label='Off')\n",
    "\n",
    "#ring\n",
    "sns.kdeplot(x='Delta OOi Ring',\n",
    "        data= df_on, color='darkred', ax=ax[1,0])\n",
    "sns.kdeplot(x='Delta OOi Ring',\n",
    "        data= df_off, color='darkblue', ax=ax[1,0])\n",
    "sns.kdeplot(x='Delta SC Ring',\n",
    "        data= df_on, color='darkred', ax=ax[1,1])\n",
    "sns.kdeplot(x='Delta SC Ring',\n",
    "        data= df_off, color='darkblue', ax=ax[1,1])\n",
    "\n",
    "#surround\n",
    "sns.kdeplot(x='Delta OOi Surround',\n",
    "        data= df_on, color='darkred', ax=ax[2,0])\n",
    "sns.kdeplot(x='Delta OOi Surround',\n",
    "        data= df_off, color='darkblue', ax=ax[2,0])\n",
    "sns.kdeplot(x='Delta SC Surround',\n",
    "        data= df_on, color='darkred', ax=ax[2,1])\n",
    "sns.kdeplot(x='Delta SC Surround',\n",
    "        data= df_off, color='darkblue', ax=ax[2,1])\n",
    "\n",
    "for current_row in range(3):\n",
    "    for current_column in range(2):\n",
    "        my_ax = ax[current_row,current_column] \n",
    "        my_ax.axvline(0, color='gray', linestyle='dashed')\n",
    "        my_ax.set_xlim(-2,2)\n",
    "        my_ax.get_yaxis().set_visible(False)\n",
    "        \n",
    "ax[0,0].set_title('On Off Index')\n",
    "ax[0,1].set_title('Spectral Cotrast')\n",
    "ax[2,0].set_xlabel('ΔSC')\n",
    "ax[2,1].set_xlabel('ΔOOi')\n",
    "ax[0,0].text(-2.2,0.8,'Center', size=12,verticalalignment='center',rotation='90')\n",
    "ax[1,0].text(-2.2,1.2,'Ring', size=12,verticalalignment='center',rotation='90')\n",
    "ax[2,0].text(-2.2,0.4,'Surround', size=12,verticalalignment='center',rotation='90')\n",
    "\n",
    "ax[0,1].legend()\n",
    "sns.despine()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dorsal/Ventral\n",
    "df_dorsal= df.loc[df['y_retinal_location'] == 'dorsal']\n",
    "df_dorsal = df_dorsal.reset_index(drop = True)\n",
    "\n",
    "df_ventral= df.loc[df['y_retinal_location'] == 'ventral']\n",
    "df_ventral = df_ventral.reset_index(drop = True)\n",
    "\n",
    "# On/Off & Dorsal/Ventral\n",
    "\n",
    "df_dorsal_on = df_dorsal.loc[df_dorsal['OOi_center_ctrl'] > 0]\n",
    "df_dorsal_on = df_dorsal_on.reset_index(drop = True)\n",
    "df_dorsal_off = df_dorsal.loc[df_dorsal['OOi_center_ctrl'] < 0]\n",
    "df_dorsal_off = df_dorsal_off.reset_index(drop = True)\n",
    "\n",
    "df_ventral_on = df_ventral.loc[df_ventral['OOi_center_ctrl'] > 0]\n",
    "df_ventral_on = df_ventral_on.reset_index(drop = True)\n",
    "df_ventral_off = df_ventral.loc[df_ventral['OOi_center_ctrl'] < 0]\n",
    "df_ventral_off = df_ventral_off.reset_index(drop = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "# lineplots all\n",
    "fig, ax = plt.subplots(6,4, figsize=(10,15), sharex='all',sharey='all')\n",
    "pos=[1,2]\n",
    "label=['Ctrl','Drug']\n",
    "\n",
    "#OOi_dorsal_on\n",
    "for i in range(df_dorsal_on.shape[0]):\n",
    "    ax[0,0].scatter(pos,[df_dorsal_on['OOi_center_ctrl'].values[i],df_dorsal_on['OOi_center_drug'].values[i]], color='darkred',alpha=0.05)\n",
    "    ax[0,0].plot(pos,[df_dorsal_on['OOi_center_ctrl'].values[i],df_dorsal_on['OOi_center_drug'].values[i]], '-', color='darkred',alpha=0.05)\n",
    "    ax[2,0].scatter(pos,[df_dorsal_on['OOi_ring_ctrl'].values[i],df_dorsal_on['OOi_ring_drug'].values[i]], color='darkred',alpha=0.05)\n",
    "    ax[2,0].plot(pos,[df_dorsal_on['OOi_ring_ctrl'].values[i],df_dorsal_on['OOi_ring_drug'].values[i]], '-', color='darkred',alpha=0.05)\n",
    "    ax[4,0].scatter(pos,[df_dorsal_on['OOi_surround_ctrl'].values[i],df_dorsal_on['OOi_surround_drug'].values[i]], color='darkred',alpha=0.05)\n",
    "    ax[4,0].plot(pos,[df_dorsal_on['OOi_surround_ctrl'].values[i],df_dorsal_on['OOi_surround_drug'].values[i]], '-', color='darkred',alpha=0.05)\n",
    "ax[0,0].scatter(pos,[np.mean(df_dorsal_on['OOi_center_ctrl'].values),np.mean(df_dorsal_on['OOi_center_drug'].values)], color='k')\n",
    "ax[0,0].plot(pos,[np.mean(df_dorsal_on['OOi_center_ctrl'].values),np.mean(df_dorsal_on['OOi_center_drug'].values)], '-', color='k')\n",
    "ax[2,0].scatter(pos,[np.mean(df_dorsal_on['OOi_ring_ctrl'].values),np.mean(df_dorsal_on['OOi_ring_drug'].values)], color='k')\n",
    "ax[2,0].plot(pos,[np.mean(df_dorsal_on['OOi_ring_ctrl'].values),np.mean(df_dorsal_on['OOi_ring_drug'].values)], '-', color='k')\n",
    "ax[4,0].scatter(pos,[np.mean(df_dorsal_on['OOi_surround_ctrl'].values),np.mean(df_dorsal_on['OOi_surround_drug'].values)], color='k')\n",
    "ax[4,0].plot(pos,[np.mean(df_dorsal_on['OOi_surround_ctrl'].values),np.mean(df_dorsal_on['OOi_surround_drug'].values)], '-', color='k')\n",
    "#OOi_ventral_on\n",
    "for i in range(df_ventral_on.shape[0]):\n",
    "    ax[1,0].scatter(pos,[df_ventral_on['OOi_center_ctrl'].values[i],df_ventral_on['OOi_center_drug'].values[i]], color='darkred',alpha=0.05)\n",
    "    ax[1,0].plot(pos,[df_ventral_on['OOi_center_ctrl'].values[i],df_ventral_on['OOi_center_drug'].values[i]], '-', color='darkred',alpha=0.05)\n",
    "    ax[3,0].scatter(pos,[df_ventral_on['OOi_ring_ctrl'].values[i],df_ventral_on['OOi_ring_drug'].values[i]], color='darkred',alpha=0.05)\n",
    "    ax[3,0].plot(pos,[df_ventral_on['OOi_ring_ctrl'].values[i],df_ventral_on['OOi_ring_drug'].values[i]], '-', color='darkred',alpha=0.05)\n",
    "    ax[5,0].scatter(pos,[df_ventral_on['OOi_surround_ctrl'].values[i],df_ventral_on['OOi_surround_drug'].values[i]], color='darkred',alpha=0.05)\n",
    "    ax[5,0].plot(pos,[df_ventral_on['OOi_surround_ctrl'].values[i],df_ventral_on['OOi_surround_drug'].values[i]], '-', color='darkred',alpha=0.05)\n",
    "ax[1,0].scatter(pos,[np.mean(df_ventral_on['OOi_center_ctrl'].values),np.mean(df_ventral_on['OOi_center_drug'].values)], color='k')\n",
    "ax[1,0].plot(pos,[np.mean(df_ventral_on['OOi_center_ctrl'].values),np.mean(df_ventral_on['OOi_center_drug'].values)], '-', color='k')\n",
    "ax[3,0].scatter(pos,[np.mean(df_ventral_on['OOi_ring_ctrl'].values),np.mean(df_ventral_on['OOi_ring_drug'].values)], color='k')\n",
    "ax[3,0].plot(pos,[np.mean(df_ventral_on['OOi_ring_ctrl'].values),np.mean(df_ventral_on['OOi_ring_drug'].values)], '-', color='k')\n",
    "ax[5,0].scatter(pos,[np.mean(df_ventral_on['OOi_surround_ctrl'].values),np.mean(df_ventral_on['OOi_surround_drug'].values)], color='k')\n",
    "ax[5,0].plot(pos,[np.mean(df_ventral_on['OOi_surround_ctrl'].values),np.mean(df_ventral_on['OOi_surround_drug'].values)], '-', color='k')\n",
    "# OOi_dorsal_off\n",
    "for i in range(df_dorsal_off.shape[0]):\n",
    "    ax[0,1].scatter(pos,[df_dorsal_off['OOi_center_ctrl'].values[i],df_dorsal_off['OOi_center_drug'].values[i]], color='darkblue',alpha=0.05)\n",
    "    ax[0,1].plot(pos,[df_dorsal_off['OOi_center_ctrl'].values[i],df_dorsal_off['OOi_center_drug'].values[i]], '-', color='darkblue',alpha=0.05)\n",
    "    ax[2,1].scatter(pos,[df_dorsal_off['OOi_ring_ctrl'].values[i],df_dorsal_off['OOi_ring_drug'].values[i]], color='darkblue',alpha=0.05)\n",
    "    ax[2,1].plot(pos,[df_dorsal_off['OOi_ring_ctrl'].values[i],df_dorsal_off['OOi_ring_drug'].values[i]], '-', color='darkblue',alpha=0.05)\n",
    "    ax[4,1].scatter(pos,[df_dorsal_off['OOi_surround_ctrl'].values[i],df_dorsal_off['OOi_surround_drug'].values[i]], color='darkblue',alpha=0.05)\n",
    "    ax[4,1].plot(pos,[df_dorsal_off['OOi_surround_ctrl'].values[i],df_dorsal_off['OOi_surround_drug'].values[i]], '-', color='darkblue',alpha=0.05)\n",
    "ax[0,1].scatter(pos,[np.mean(df_dorsal_off['OOi_center_ctrl'].values),np.mean(df_dorsal_off['OOi_center_drug'].values)], color='k')\n",
    "ax[0,1].plot(pos,[np.mean(df_dorsal_off['OOi_center_ctrl'].values),np.mean(df_dorsal_off['OOi_center_drug'].values)], '-', color='k')\n",
    "ax[2,1].scatter(pos,[np.mean(df_dorsal_off['OOi_ring_ctrl'].values),np.mean(df_dorsal_off['OOi_ring_drug'].values)], color='k')\n",
    "ax[2,1].plot(pos,[np.mean(df_dorsal_off['OOi_ring_ctrl'].values),np.mean(df_dorsal_off['OOi_ring_drug'].values)], '-', color='k')\n",
    "ax[4,1].scatter(pos,[np.mean(df_dorsal_off['OOi_surround_ctrl'].values),np.mean(df_dorsal_off['OOi_surround_drug'].values)], color='k')\n",
    "ax[4,1].plot(pos,[np.mean(df_dorsal_off['OOi_surround_ctrl'].values),np.mean(df_dorsal_off['OOi_surround_drug'].values)], '-', color='k')\n",
    "# OOi_ventral_off\n",
    "for i in range(df_ventral_off.shape[0]):\n",
    "    ax[1,1].scatter(pos,[df_ventral_off['OOi_center_ctrl'].values[i],df_ventral_off['OOi_center_drug'].values[i]], color='darkblue',alpha=0.05)\n",
    "    ax[1,1].plot(pos,[df_ventral_off['OOi_center_ctrl'].values[i],df_ventral_off['OOi_center_drug'].values[i]], '-', color='darkblue',alpha=0.05)\n",
    "    ax[3,1].scatter(pos,[df_ventral_off['OOi_ring_ctrl'].values[i],df_ventral_off['OOi_ring_drug'].values[i]], color='darkblue',alpha=0.05)\n",
    "    ax[3,1].plot(pos,[df_ventral_off['OOi_ring_ctrl'].values[i],df_ventral_off['OOi_ring_drug'].values[i]], '-', color='darkblue',alpha=0.05)\n",
    "    ax[5,1].scatter(pos,[df_ventral_off['OOi_surround_ctrl'].values[i],df_ventral_off['OOi_surround_drug'].values[i]], color='darkblue',alpha=0.05)\n",
    "    ax[5,1].plot(pos,[df_ventral_off['OOi_surround_ctrl'].values[i],df_ventral_off['OOi_surround_drug'].values[i]], '-', color='darkblue',alpha=0.05)\n",
    "ax[1,1].scatter(pos,[np.mean(df_ventral_off['OOi_center_ctrl'].values),np.mean(df_ventral_off['OOi_center_drug'].values)], color='k')\n",
    "ax[1,1].plot(pos,[np.mean(df_ventral_off['OOi_center_ctrl'].values),np.mean(df_ventral_off['OOi_center_drug'].values)], '-', color='k')\n",
    "ax[3,1].scatter(pos,[np.mean(df_ventral_off['OOi_ring_ctrl'].values),np.mean(df_ventral_off['OOi_ring_drug'].values)], color='k')\n",
    "ax[3,1].plot(pos,[np.mean(df_ventral_off['OOi_ring_ctrl'].values),np.mean(df_ventral_off['OOi_ring_drug'].values)], '-', color='k')\n",
    "ax[5,1].scatter(pos,[np.mean(df_ventral_off['OOi_surround_ctrl'].values),np.mean(df_ventral_off['OOi_surround_drug'].values)], color='k')\n",
    "ax[5,1].plot(pos,[np.mean(df_ventral_off['OOi_surround_ctrl'].values),np.mean(df_ventral_off['OOi_surround_drug'].values)], '-', color='k')\n",
    "\n",
    "\n",
    "#SC_dorsal_on\n",
    "for i in range(df_dorsal_on.shape[0]):\n",
    "    ax[0,2].scatter(pos,[df_dorsal_on['SC_center_ctrl'].values[i],df_dorsal_on['SC_center_drug'].values[i]], color='darkred',alpha=0.05)\n",
    "    ax[0,2].plot(pos,[df_dorsal_on['SC_center_ctrl'].values[i],df_dorsal_on['SC_center_drug'].values[i]], '-', color='darkred',alpha=0.05)\n",
    "    ax[2,2].scatter(pos,[df_dorsal_on['SC_ring_ctrl'].values[i],df_dorsal_on['SC_ring_drug'].values[i]], color='darkred',alpha=0.05)\n",
    "    ax[2,2].plot(pos,[df_dorsal_on['SC_ring_ctrl'].values[i],df_dorsal_on['SC_ring_drug'].values[i]], '-', color='darkred',alpha=0.05)\n",
    "    ax[4,2].scatter(pos,[df_dorsal_on['SC_surround_ctrl'].values[i],df_dorsal_on['SC_surround_drug'].values[i]], color='darkred',alpha=0.05)\n",
    "    ax[4,2].plot(pos,[df_dorsal_on['SC_surround_ctrl'].values[i],df_dorsal_on['SC_surround_drug'].values[i]], '-', color='darkred',alpha=0.05)\n",
    "ax[0,2].scatter(pos,[np.mean(df_dorsal_on['SC_center_ctrl'].values),np.mean(df_dorsal_on['SC_center_drug'].values)], color='k')\n",
    "ax[0,2].plot(pos,[np.mean(df_dorsal_on['SC_center_ctrl'].values),np.mean(df_dorsal_on['SC_center_drug'].values)], '-', color='k')\n",
    "ax[2,2].scatter(pos,[np.mean(df_dorsal_on['SC_ring_ctrl'].values),np.mean(df_dorsal_on['SC_ring_drug'].values)], color='k')\n",
    "ax[2,2].plot(pos,[np.mean(df_dorsal_on['SC_ring_ctrl'].values),np.mean(df_dorsal_on['SC_ring_drug'].values)], '-', color='k')\n",
    "ax[4,2].scatter(pos,[np.mean(df_dorsal_on['SC_surround_ctrl'].values),np.mean(df_dorsal_on['SC_surround_drug'].values)], color='k')\n",
    "ax[4,2].plot(pos,[np.mean(df_dorsal_on['SC_surround_ctrl'].values),np.mean(df_dorsal_on['SC_surround_drug'].values)], '-', color='k')\n",
    "#SC_ventral_on\n",
    "for i in range(df_ventral_on.shape[0]):\n",
    "    ax[1,2].scatter(pos,[df_ventral_on['SC_center_ctrl'].values[i],df_ventral_on['SC_center_drug'].values[i]], color='darkred',alpha=0.05)\n",
    "    ax[1,2].plot(pos,[df_ventral_on['SC_center_ctrl'].values[i],df_ventral_on['SC_center_drug'].values[i]], '-', color='darkred',alpha=0.05)\n",
    "    ax[3,2].scatter(pos,[df_ventral_on['SC_ring_ctrl'].values[i],df_ventral_on['SC_ring_drug'].values[i]], color='darkred',alpha=0.05)\n",
    "    ax[3,2].plot(pos,[df_ventral_on['SC_ring_ctrl'].values[i],df_ventral_on['SC_ring_drug'].values[i]], '-', color='darkred',alpha=0.05)\n",
    "    ax[5,2].scatter(pos,[df_ventral_on['SC_surround_ctrl'].values[i],df_ventral_on['SC_surround_drug'].values[i]], color='darkred',alpha=0.05)\n",
    "    ax[5,2].plot(pos,[df_ventral_on['SC_surround_ctrl'].values[i],df_ventral_on['SC_surround_drug'].values[i]], '-', color='darkred',alpha=0.05)\n",
    "ax[1,2].scatter(pos,[np.mean(df_ventral_on['SC_center_ctrl'].values),np.mean(df_ventral_on['SC_center_drug'].values)], color='k')\n",
    "ax[1,2].plot(pos,[np.mean(df_ventral_on['SC_center_ctrl'].values),np.mean(df_ventral_on['SC_center_drug'].values)], '-', color='k')\n",
    "ax[3,2].scatter(pos,[np.mean(df_ventral_on['SC_ring_ctrl'].values),np.mean(df_ventral_on['SC_ring_drug'].values)], color='k')\n",
    "ax[3,2].plot(pos,[np.mean(df_ventral_on['SC_ring_ctrl'].values),np.mean(df_ventral_on['SC_ring_drug'].values)], '-', color='k',alpha=1)\n",
    "ax[5,2].scatter(pos,[np.mean(df_ventral_on['SC_surround_ctrl'].values),np.mean(df_ventral_on['SC_surround_drug'].values)], color='k')\n",
    "ax[5,2].plot(pos,[np.mean(df_ventral_on['SC_surround_ctrl'].values),np.mean(df_ventral_on['SC_surround_drug'].values)], '-', color='k')\n",
    "# SC_dorsal_off\n",
    "for i in range(df_dorsal_off.shape[0]):\n",
    "    ax[0,3].scatter(pos,[df_dorsal_off['SC_center_ctrl'].values[i],df_dorsal_off['SC_center_drug'].values[i]], color='darkblue',alpha=0.05)\n",
    "    ax[0,3].plot(pos,[df_dorsal_off['SC_center_ctrl'].values[i],df_dorsal_off['SC_center_drug'].values[i]], '-', color='darkblue',alpha=0.05)\n",
    "    ax[2,3].scatter(pos,[df_dorsal_off['SC_ring_ctrl'].values[i],df_dorsal_off['SC_ring_drug'].values[i]], color='darkblue',alpha=0.05)\n",
    "    ax[2,3].plot(pos,[df_dorsal_off['SC_ring_ctrl'].values[i],df_dorsal_off['SC_ring_drug'].values[i]], '-', color='darkblue',alpha=0.05)\n",
    "    ax[4,3].scatter(pos,[df_dorsal_off['SC_surround_ctrl'].values[i],df_dorsal_off['SC_surround_drug'].values[i]], color='darkblue',alpha=0.05)\n",
    "    ax[4,3].plot(pos,[df_dorsal_off['SC_surround_ctrl'].values[i],df_dorsal_off['SC_surround_drug'].values[i]], '-', color='darkblue',alpha=0.05)\n",
    "ax[0,3].scatter(pos,[np.mean(df_dorsal_off['SC_center_ctrl'].values),np.mean(df_dorsal_off['SC_center_drug'].values)], color='k')\n",
    "ax[0,3].plot(pos,[np.mean(df_dorsal_off['SC_center_ctrl'].values),np.mean(df_dorsal_off['SC_center_drug'].values)], '-', color='k')\n",
    "ax[2,3].scatter(pos,[np.mean(df_dorsal_off['SC_ring_ctrl'].values),np.mean(df_dorsal_off['SC_ring_drug'].values)], color='k')\n",
    "ax[2,3].plot(pos,[np.mean(df_dorsal_off['SC_ring_ctrl'].values),np.mean(df_dorsal_off['SC_ring_drug'].values)], '-', color='k')\n",
    "ax[4,3].scatter(pos,[np.mean(df_dorsal_off['SC_surround_ctrl'].values),np.mean(df_dorsal_off['SC_surround_drug'].values)], color='k')\n",
    "ax[4,3].plot(pos,[np.mean(df_dorsal_off['SC_surround_ctrl'].values),np.mean(df_dorsal_off['SC_surround_drug'].values)], '-', color='k')\n",
    "#SC_ventral_off\n",
    "for i in range(df_ventral_off.shape[0]):\n",
    "    ax[1,3].scatter(pos,[df_ventral_off['SC_center_ctrl'].values[i],df_ventral_off['SC_center_drug'].values[i]], color='darkblue',alpha=0.05)\n",
    "    ax[1,3].plot(pos,[df_ventral_off['SC_center_ctrl'].values[i],df_ventral_off['SC_center_drug'].values[i]], '-', color='darkblue',alpha=0.05)\n",
    "    ax[3,3].scatter(pos,[df_ventral_off['SC_ring_ctrl'].values[i],df_ventral_off['SC_ring_drug'].values[i]], color='darkblue',alpha=0.05)\n",
    "    ax[3,3].plot(pos,[df_ventral_off['SC_ring_ctrl'].values[i],df_ventral_off['SC_ring_drug'].values[i]], '-', color='darkblue',alpha=0.05)\n",
    "    ax[5,3].scatter(pos,[df_ventral_off['SC_surround_ctrl'].values[i],df_ventral_off['SC_surround_drug'].values[i]], color='darkblue',alpha=0.05)\n",
    "    ax[5,3].plot(pos,[df_ventral_off['SC_surround_ctrl'].values[i],df_ventral_off['SC_surround_drug'].values[i]], '-', color='darkblue',alpha=0.05)\n",
    "ax[1,3].scatter(pos,[np.mean(df_ventral_off['SC_center_ctrl'].values),np.mean(df_ventral_off['SC_center_drug'].values)], color='k')\n",
    "ax[1,3].plot(pos,[np.mean(df_ventral_off['SC_center_ctrl'].values),np.mean(df_ventral_off['SC_center_drug'].values)], '-', color='k')\n",
    "ax[3,3].scatter(pos,[np.mean(df_ventral_off['SC_ring_ctrl'].values),np.mean(df_ventral_off['SC_ring_drug'].values)], color='k')\n",
    "ax[3,3].plot(pos,[np.mean(df_ventral_off['SC_ring_ctrl'].values),np.mean(df_ventral_off['SC_ring_drug'].values)], '-', color='k')\n",
    "ax[5,3].scatter(pos,[np.mean(df_ventral_off['SC_surround_ctrl'].values),np.mean(df_ventral_off['SC_surround_drug'].values)], color='k')\n",
    "ax[5,3].plot(pos,[np.mean(df_ventral_off['SC_surround_ctrl'].values),np.mean(df_ventral_off['SC_surround_drug'].values)], '-', color='k')\n",
    "\n",
    "for current_row in range(6):\n",
    "    for current_column in range(4):\n",
    "        my_ax = ax[current_row,current_column] \n",
    "        my_ax.get_xaxis().set_tick_params(direction='out')\n",
    "        my_ax.xaxis.set_ticks_position('bottom')\n",
    "        my_ax.set_xticks(np.arange(1, len(label) + 1))\n",
    "        my_ax.set_xticklabels(label,fontsize=12)\n",
    "        my_ax.set_xlim(0.8, len(label) + 0.2)\n",
    "        my_ax.axhline(0, color='gray', linestyle='dashed')\n",
    "        my_ax.set_ylim(-1,1)\n",
    "        \n",
    "ax[0,0].set_title('On')\n",
    "ax[0,1].set_title('Off')\n",
    "ax[0,2].set_title('On')\n",
    "ax[0,3].set_title('Off')\n",
    "ax[5,0].set_xlabel('OOi')\n",
    "ax[5,1].set_xlabel('OOi')\n",
    "ax[5,2].set_xlabel('SC')\n",
    "ax[5,3].set_xlabel('SC')\n",
    "ax[0,0].text(0.2,0,'Center DORSAL', size=12,verticalalignment='center',rotation='90')\n",
    "ax[1,0].text(0.2,0,'Center VENTRAL', size=12,verticalalignment='center',rotation='90')\n",
    "\n",
    "ax[2,0].text(0.2,0,'Ring DORSAL', size=12,verticalalignment='center',rotation='90')\n",
    "ax[3,0].text(0.2,0,'Ring VENTRAL', size=12,verticalalignment='center',rotation='90')\n",
    "\n",
    "ax[4,0].text(0.2,0,'Surround DORSAL', size=12,verticalalignment='center',rotation='90')\n",
    "ax[5,0].text(0.2,0,'Surround VENTRAL', size=12,verticalalignment='center',rotation='90')\n",
    "\n",
    "sns.despine()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "#make uncorrected p-values array\n",
    "from scipy import stats\n",
    "p_vals = []\n",
    "\n",
    "#SC\n",
    "#TPMPA\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['SC_center_ctrl'].loc[df_dorsal_on['treatment']=='TPMPA;gabazine'], df_dorsal_on['SC_center_drug'].loc[df_dorsal_on['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['SC_center_ctrl'].loc[df_dorsal_off['treatment']=='TPMPA;gabazine'], df_dorsal_off['SC_center_drug'].loc[df_dorsal_off['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['SC_center_ctrl'].loc[df_ventral_on['treatment']=='TPMPA;gabazine'], df_ventral_on['SC_center_drug'].loc[df_ventral_on['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['SC_center_ctrl'].loc[df_ventral_off['treatment']=='TPMPA;gabazine'], df_ventral_off['SC_center_drug'].loc[df_ventral_off['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['SC_ring_ctrl'].loc[df_dorsal_on['treatment']=='TPMPA;gabazine'], df_dorsal_on['SC_ring_drug'].loc[df_dorsal_on['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['SC_ring_ctrl'].loc[df_dorsal_off['treatment']=='TPMPA;gabazine'], df_dorsal_off['SC_ring_drug'].loc[df_dorsal_off['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['SC_ring_ctrl'].loc[df_ventral_on['treatment']=='TPMPA;gabazine'], df_ventral_on['SC_ring_drug'].loc[df_ventral_on['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['SC_ring_ctrl'].loc[df_ventral_off['treatment']=='TPMPA;gabazine'], df_ventral_off['SC_ring_drug'].loc[df_ventral_off['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['SC_surround_ctrl'].loc[df_dorsal_on['treatment']=='TPMPA;gabazine'], df_dorsal_on['SC_surround_drug'].loc[df_dorsal_on['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['SC_surround_ctrl'].loc[df_dorsal_off['treatment']=='TPMPA;gabazine'], df_dorsal_off['SC_surround_drug'].loc[df_dorsal_off['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['SC_surround_ctrl'].loc[df_ventral_on['treatment']=='TPMPA;gabazine'], df_ventral_on['SC_surround_drug'].loc[df_ventral_on['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['SC_surround_ctrl'].loc[df_ventral_off['treatment']=='TPMPA;gabazine'], df_ventral_off['SC_surround_drug'].loc[df_ventral_off['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "#strychnine\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['SC_center_ctrl'].loc[df_dorsal_on['treatment']=='strychnine'], df_dorsal_on['SC_center_drug'].loc[df_dorsal_on['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['SC_center_ctrl'].loc[df_dorsal_off['treatment']=='strychnine'], df_dorsal_off['SC_center_drug'].loc[df_dorsal_off['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['SC_center_ctrl'].loc[df_ventral_on['treatment']=='strychnine'], df_ventral_on['SC_center_drug'].loc[df_ventral_on['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['SC_center_ctrl'].loc[df_ventral_off['treatment']=='strychnine'], df_ventral_off['SC_center_drug'].loc[df_ventral_off['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['SC_ring_ctrl'].loc[df_dorsal_on['treatment']=='strychnine'], df_dorsal_on['SC_ring_drug'].loc[df_dorsal_on['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['SC_ring_ctrl'].loc[df_dorsal_off['treatment']=='strychnine'], df_dorsal_off['SC_ring_drug'].loc[df_dorsal_off['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['SC_ring_ctrl'].loc[df_ventral_on['treatment']=='strychnine'], df_ventral_on['SC_ring_drug'].loc[df_ventral_on['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['SC_ring_ctrl'].loc[df_ventral_off['treatment']=='strychnine'], df_ventral_off['SC_ring_drug'].loc[df_ventral_off['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['SC_surround_ctrl'].loc[df_dorsal_on['treatment']=='strychnine'], df_dorsal_on['SC_surround_drug'].loc[df_dorsal_on['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['SC_surround_ctrl'].loc[df_dorsal_off['treatment']=='strychnine'], df_dorsal_off['SC_surround_drug'].loc[df_dorsal_off['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['SC_surround_ctrl'].loc[df_ventral_on['treatment']=='strychnine'], df_ventral_on['SC_surround_drug'].loc[df_ventral_on['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['SC_surround_ctrl'].loc[df_ventral_off['treatment']=='strychnine'], df_ventral_off['SC_surround_drug'].loc[df_ventral_off['treatment']=='strychnine']).pvalue)\n",
    "#L-AP4\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['SC_center_ctrl'].loc[df_dorsal_on['treatment']=='LAP4'], df_dorsal_on['SC_center_drug'].loc[df_dorsal_on['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['SC_center_ctrl'].loc[df_dorsal_off['treatment']=='LAP4'], df_dorsal_off['SC_center_drug'].loc[df_dorsal_off['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['SC_center_ctrl'].loc[df_ventral_on['treatment']=='LAP4'], df_ventral_on['SC_center_drug'].loc[df_ventral_on['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['SC_center_ctrl'].loc[df_ventral_off['treatment']=='LAP4'], df_ventral_off['SC_center_drug'].loc[df_ventral_off['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['SC_ring_ctrl'].loc[df_dorsal_on['treatment']=='LAP4'], df_dorsal_on['SC_ring_drug'].loc[df_dorsal_on['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['SC_ring_ctrl'].loc[df_dorsal_off['treatment']=='LAP4'], df_dorsal_off['SC_ring_drug'].loc[df_dorsal_off['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['SC_ring_ctrl'].loc[df_ventral_on['treatment']=='LAP4'], df_ventral_on['SC_ring_drug'].loc[df_ventral_on['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['SC_ring_ctrl'].loc[df_ventral_off['treatment']=='LAP4'], df_ventral_off['SC_ring_drug'].loc[df_ventral_off['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['SC_surround_ctrl'].loc[df_dorsal_on['treatment']=='LAP4'], df_dorsal_on['SC_surround_drug'].loc[df_dorsal_on['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['SC_surround_ctrl'].loc[df_dorsal_off['treatment']=='LAP4'], df_dorsal_off['SC_surround_drug'].loc[df_dorsal_off['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['SC_surround_ctrl'].loc[df_ventral_on['treatment']=='LAP4'], df_ventral_on['SC_surround_drug'].loc[df_ventral_on['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['SC_surround_ctrl'].loc[df_ventral_off['treatment']=='LAP4'], df_ventral_off['SC_surround_drug'].loc[df_ventral_off['treatment']=='LAP4']).pvalue)\n",
    "#UBP 310\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['SC_center_ctrl'].loc[df_dorsal_on['treatment']=='UBP310'], df_dorsal_on['SC_center_drug'].loc[df_dorsal_on['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['SC_center_ctrl'].loc[df_dorsal_off['treatment']=='UBP310'], df_dorsal_off['SC_center_drug'].loc[df_dorsal_off['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['SC_center_ctrl'].loc[df_ventral_on['treatment']=='UBP310'], df_ventral_on['SC_center_drug'].loc[df_ventral_on['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['SC_center_ctrl'].loc[df_ventral_off['treatment']=='UBP310'], df_ventral_off['SC_center_drug'].loc[df_ventral_off['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['SC_ring_ctrl'].loc[df_dorsal_on['treatment']=='UBP310'], df_dorsal_on['SC_ring_drug'].loc[df_dorsal_on['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['SC_ring_ctrl'].loc[df_dorsal_off['treatment']=='UBP310'], df_dorsal_off['SC_ring_drug'].loc[df_dorsal_off['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['SC_ring_ctrl'].loc[df_ventral_on['treatment']=='UBP310'], df_ventral_on['SC_ring_drug'].loc[df_ventral_on['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['SC_ring_ctrl'].loc[df_ventral_off['treatment']=='UBP310'], df_ventral_off['SC_ring_drug'].loc[df_ventral_off['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['SC_surround_ctrl'].loc[df_dorsal_on['treatment']=='UBP310'], df_dorsal_on['SC_surround_drug'].loc[df_dorsal_on['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['SC_surround_ctrl'].loc[df_dorsal_off['treatment']=='UBP310'], df_dorsal_off['SC_surround_drug'].loc[df_dorsal_off['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['SC_surround_ctrl'].loc[df_ventral_on['treatment']=='UBP310'], df_ventral_on['SC_surround_drug'].loc[df_ventral_on['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['SC_surround_ctrl'].loc[df_ventral_off['treatment']=='UBP310'], df_ventral_off['SC_surround_drug'].loc[df_ventral_off['treatment']=='UBP310']).pvalue)\n",
    "\n",
    "#OOi\n",
    "#TPMPA\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['OOi_center_ctrl'].loc[df_dorsal_on['treatment']=='TPMPA;gabazine'], df_dorsal_on['OOi_center_drug'].loc[df_dorsal_on['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['OOi_center_ctrl'].loc[df_dorsal_off['treatment']=='TPMPA;gabazine'], df_dorsal_off['OOi_center_drug'].loc[df_dorsal_off['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['OOi_center_ctrl'].loc[df_ventral_on['treatment']=='TPMPA;gabazine'], df_ventral_on['OOi_center_drug'].loc[df_ventral_on['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['OOi_center_ctrl'].loc[df_ventral_off['treatment']=='TPMPA;gabazine'], df_ventral_off['OOi_center_drug'].loc[df_ventral_off['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['OOi_ring_ctrl'].loc[df_dorsal_on['treatment']=='TPMPA;gabazine'], df_dorsal_on['OOi_ring_drug'].loc[df_dorsal_on['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['OOi_ring_ctrl'].loc[df_dorsal_off['treatment']=='TPMPA;gabazine'], df_dorsal_off['OOi_ring_drug'].loc[df_dorsal_off['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['OOi_ring_ctrl'].loc[df_ventral_on['treatment']=='TPMPA;gabazine'], df_ventral_on['OOi_ring_drug'].loc[df_ventral_on['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['OOi_ring_ctrl'].loc[df_ventral_off['treatment']=='TPMPA;gabazine'], df_ventral_off['OOi_ring_drug'].loc[df_ventral_off['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['OOi_surround_ctrl'].loc[df_dorsal_on['treatment']=='TPMPA;gabazine'], df_dorsal_on['OOi_surround_drug'].loc[df_dorsal_on['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['OOi_surround_ctrl'].loc[df_dorsal_off['treatment']=='TPMPA;gabazine'], df_dorsal_off['OOi_surround_drug'].loc[df_dorsal_off['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['OOi_surround_ctrl'].loc[df_ventral_on['treatment']=='TPMPA;gabazine'], df_ventral_on['OOi_surround_drug'].loc[df_ventral_on['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['OOi_surround_ctrl'].loc[df_ventral_off['treatment']=='TPMPA;gabazine'], df_ventral_off['OOi_surround_drug'].loc[df_ventral_off['treatment']=='TPMPA;gabazine']).pvalue)\n",
    "\n",
    "#strychnine\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['OOi_center_ctrl'].loc[df_dorsal_on['treatment']=='strychnine'], df_dorsal_on['OOi_center_drug'].loc[df_dorsal_on['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['OOi_center_ctrl'].loc[df_dorsal_off['treatment']=='strychnine'], df_dorsal_off['OOi_center_drug'].loc[df_dorsal_off['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['OOi_center_ctrl'].loc[df_ventral_on['treatment']=='strychnine'], df_ventral_on['OOi_center_drug'].loc[df_ventral_on['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['OOi_center_ctrl'].loc[df_ventral_off['treatment']=='strychnine'], df_ventral_off['OOi_center_drug'].loc[df_ventral_off['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['OOi_ring_ctrl'].loc[df_dorsal_on['treatment']=='strychnine'], df_dorsal_on['OOi_ring_drug'].loc[df_dorsal_on['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['OOi_ring_ctrl'].loc[df_dorsal_off['treatment']=='strychnine'], df_dorsal_off['OOi_ring_drug'].loc[df_dorsal_off['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['OOi_ring_ctrl'].loc[df_ventral_on['treatment']=='strychnine'], df_ventral_on['OOi_ring_drug'].loc[df_ventral_on['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['OOi_ring_ctrl'].loc[df_ventral_off['treatment']=='strychnine'], df_ventral_off['OOi_ring_drug'].loc[df_ventral_off['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['OOi_surround_ctrl'].loc[df_dorsal_on['treatment']=='strychnine'], df_dorsal_on['OOi_surround_drug'].loc[df_dorsal_on['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['OOi_surround_ctrl'].loc[df_dorsal_off['treatment']=='strychnine'], df_dorsal_off['OOi_surround_drug'].loc[df_dorsal_off['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['OOi_surround_ctrl'].loc[df_ventral_on['treatment']=='strychnine'], df_ventral_on['OOi_surround_drug'].loc[df_ventral_on['treatment']=='strychnine']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['OOi_surround_ctrl'].loc[df_ventral_off['treatment']=='strychnine'], df_ventral_off['OOi_surround_drug'].loc[df_ventral_off['treatment']=='strychnine']).pvalue)\n",
    "\n",
    "#L-AP4\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['OOi_center_ctrl'].loc[df_dorsal_on['treatment']=='LAP4'], df_dorsal_on['OOi_center_drug'].loc[df_dorsal_on['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['OOi_center_ctrl'].loc[df_dorsal_off['treatment']=='LAP4'], df_dorsal_off['OOi_center_drug'].loc[df_dorsal_off['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['OOi_center_ctrl'].loc[df_ventral_on['treatment']=='LAP4'], df_ventral_on['OOi_center_drug'].loc[df_ventral_on['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['OOi_center_ctrl'].loc[df_ventral_off['treatment']=='LAP4'], df_ventral_off['OOi_center_drug'].loc[df_ventral_off['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['OOi_ring_ctrl'].loc[df_dorsal_on['treatment']=='LAP4'], df_dorsal_on['OOi_ring_drug'].loc[df_dorsal_on['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['OOi_ring_ctrl'].loc[df_dorsal_off['treatment']=='LAP4'], df_dorsal_off['OOi_ring_drug'].loc[df_dorsal_off['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['OOi_ring_ctrl'].loc[df_ventral_on['treatment']=='LAP4'], df_ventral_on['OOi_ring_drug'].loc[df_ventral_on['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['OOi_ring_ctrl'].loc[df_ventral_off['treatment']=='LAP4'], df_ventral_off['OOi_ring_drug'].loc[df_ventral_off['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['OOi_surround_ctrl'].loc[df_dorsal_on['treatment']=='LAP4'], df_dorsal_on['OOi_surround_drug'].loc[df_dorsal_on['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['OOi_surround_ctrl'].loc[df_dorsal_off['treatment']=='LAP4'], df_dorsal_off['OOi_surround_drug'].loc[df_dorsal_off['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['OOi_surround_ctrl'].loc[df_ventral_on['treatment']=='LAP4'], df_ventral_on['OOi_surround_drug'].loc[df_ventral_on['treatment']=='LAP4']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['OOi_surround_ctrl'].loc[df_ventral_off['treatment']=='LAP4'], df_ventral_off['OOi_surround_drug'].loc[df_ventral_off['treatment']=='LAP4']).pvalue)\n",
    "\n",
    "#UBP 310\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['OOi_center_ctrl'].loc[df_dorsal_on['treatment']=='UBP310'], df_dorsal_on['OOi_center_drug'].loc[df_dorsal_on['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['OOi_center_ctrl'].loc[df_dorsal_off['treatment']=='UBP310'], df_dorsal_off['OOi_center_drug'].loc[df_dorsal_off['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['OOi_center_ctrl'].loc[df_ventral_on['treatment']=='UBP310'], df_ventral_on['OOi_center_drug'].loc[df_ventral_on['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['OOi_center_ctrl'].loc[df_ventral_off['treatment']=='UBP310'], df_ventral_off['OOi_center_drug'].loc[df_ventral_off['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['OOi_ring_ctrl'].loc[df_dorsal_on['treatment']=='UBP310'], df_dorsal_on['OOi_ring_drug'].loc[df_dorsal_on['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['OOi_ring_ctrl'].loc[df_dorsal_off['treatment']=='UBP310'], df_dorsal_off['OOi_ring_drug'].loc[df_dorsal_off['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['OOi_ring_ctrl'].loc[df_ventral_on['treatment']=='UBP310'], df_ventral_on['OOi_ring_drug'].loc[df_ventral_on['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['OOi_ring_ctrl'].loc[df_ventral_off['treatment']=='UBP310'], df_ventral_off['OOi_ring_drug'].loc[df_ventral_off['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_on['OOi_surround_ctrl'].loc[df_dorsal_on['treatment']=='UBP310'], df_dorsal_on['OOi_surround_drug'].loc[df_dorsal_on['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_dorsal_off['OOi_surround_ctrl'].loc[df_dorsal_off['treatment']=='UBP310'], df_dorsal_off['OOi_surround_drug'].loc[df_dorsal_off['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_on['OOi_surround_ctrl'].loc[df_ventral_on['treatment']=='UBP310'], df_ventral_on['OOi_surround_drug'].loc[df_ventral_on['treatment']=='UBP310']).pvalue)\n",
    "p_vals.append(stats.ttest_rel(df_ventral_off['OOi_surround_ctrl'].loc[df_ventral_off['treatment']=='UBP310'], df_ventral_off['OOi_surround_drug'].loc[df_ventral_off['treatment']=='UBP310']).pvalue)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import statsmodels.stats.multitest\n",
    "\n",
    "reject, p_corr, _, alpha_corr = statsmodels.stats.multitest.multipletests(p_vals, alpha=0.05, method='bonferroni')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# list\n",
    "drug = ['TPMPA;','strychnine;','LAP4;','UBP310;']\n",
    "polarity = ['ON', 'OFF']\n",
    "space = ['Centre;', 'Ring;', 'Surround;']\n",
    "retina = ['Dorsal;', 'Ventral;']\n",
    "order_tests = []\n",
    "\n",
    "# SC\n",
    "for drug_i in drug:\n",
    "    for space_i in space:\n",
    "        for retina_i in retina:\n",
    "            for polarity_i in polarity:\n",
    "                order_tests.append('SC: ' + drug_i +' ' + space_i +' ' + retina_i +' ' + polarity_i)\n",
    "# OOi\n",
    "for drug_i in drug:\n",
    "    for space_i in space:\n",
    "        for retina_i in retina:\n",
    "            for polarity_i in polarity:\n",
    "                order_tests.append('OOi: ' + drug_i +' ' + space_i +' ' + retina_i +' ' + polarity_i)\n",
    "\n",
    "\n",
    "np.array(order_tests)[(np.where(reject)[0])]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}