Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch dev #1396

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pipeline {
sh 'mamba env create -q -f environment.yml -p $CONDA_ENV'
sh '''#!/bin/bash -ex
source activate $CONDA_ENV
export KERAS_BACKEND=tensorflow
export KERAS_BACKEND=torch
pip install .
TEMPDIR=$(mktemp -d)
export CAIMAN_DATA=$TEMPDIR/caiman_data
Expand Down
3 changes: 3 additions & 0 deletions caiman/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env python

from caiman.base.timeseries import timeseries
9 changes: 6 additions & 3 deletions caiman/base/movies.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@
import caiman.utils.sbx_utils
import caiman.utils.visualization

from caiman.base.timeseries import timeseries
from caiman.base.traces import trace

try:
cv2.setNumThreads(0)
except:
pass

class movie(caiman.base.timeseries.timeseries):
class movie(timeseries):
"""
Class representing a movie. This class subclasses timeseries,
that in turn subclasses ndarray
Expand Down Expand Up @@ -895,7 +898,7 @@ def partition_FOV_KMeans(self,
fovs = cv2.resize(np.uint8(fovs), (w1, h1), 1. / fx, 1. / fy, interpolation=cv2.INTER_NEAREST)
return np.uint8(fovs), mcoef, distanceMatrix

def extract_traces_from_masks(self, masks: np.ndarray) -> caiman.base.traces.trace:
def extract_traces_from_masks(self, masks: np.ndarray) -> trace:
"""
Args:
masks: array, 3D with each 2D slice bein a mask (integer or fractional)
Expand All @@ -914,7 +917,7 @@ def extract_traces_from_masks(self, masks: np.ndarray) -> caiman.base.traces.tra

pixelsA = np.sum(A, axis=1)
A = A / pixelsA[:, None] # obtain average over ROI
traces = caiman.base.traces.trace(np.dot(A, np.transpose(Y)).T, **self.__dict__)
traces = trace(np.dot(A, np.transpose(Y)).T, **self.__dict__)
return traces

def resize(self, fx=1, fy=1, fz=1, interpolation=cv2.INTER_AREA):
Expand Down
4 changes: 2 additions & 2 deletions caiman/base/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
pass


class timeseries(np.ndarray):
class timeseries(np.ndarray):
"""
Class representing a time series.
"""
Expand Down Expand Up @@ -88,7 +88,7 @@ def __array_prepare__(self, out_arr, context=None):
if context is not None:
inputs = context[1]
for inp in inputs:
if isinstance(inp, timeseries):
if isinstance(inp, timeseries):
if frRef is None:
frRef = inp.fr
else:
Expand Down
3 changes: 2 additions & 1 deletion caiman/base/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
plt.ion()

import caiman.base.timeseries
from caiman.base.timeseries import timeseries

try:
cv2.setNumThreads(0)
Expand All @@ -18,7 +19,7 @@
# This holds the trace class, which is a specialised Caiman timeseries class.


class trace(caiman.base.timeseries.timeseries):
class trace(timeseries):
"""
Class representing a trace.

Expand Down
77 changes: 36 additions & 41 deletions caiman/components_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import os
import peakutils
import tensorflow as tf
import torch
import scipy
from scipy.sparse import csc_matrix
from scipy.stats import norm
Expand Down Expand Up @@ -273,42 +273,37 @@ def evaluate_components_CNN(A,
if not isGPU and 'CAIMAN_ALLOW_GPU' not in os.environ:
print("GPU run not requested, disabling use of GPUs")
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
try:
os.environ["KERAS_BACKEND"] = "tensorflow"
from tensorflow.keras.models import model_from_json
use_keras = True
logger.info('Using Keras')
except (ModuleNotFoundError):
use_keras = False
logger.info('Using Tensorflow')
# try:
# os.environ["KERAS_BACKEND"] = "torch"
# from keras.models import model_load
# use_keras = True
# logging.info('Using Keras')
# except (ModuleNotFoundError):
# use_keras = False
logging.info('Using Torch')

if loaded_model is None:
if use_keras:
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".json")):
model_file = os.path.join(caiman_datadir(), model_name + ".json")
model_weights = os.path.join(caiman_datadir(), model_name + ".h5")
elif os.path.isfile(model_name + ".json"):
model_file = model_name + ".json"
model_weights = model_name + ".h5"
else:
raise FileNotFoundError(f"File for requested model {model_name} not found")
with open(model_file, 'r') as json_file:
print(f"USING MODEL (keras API): {model_file}")
loaded_model_json = json_file.read()

loaded_model = model_from_json(loaded_model_json)
loaded_model.load_weights(model_name + '.h5')
# if use_keras:
# if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".keras")):
# model_file = os.path.join(caiman_datadir(), model_name + ".keras")
# elif os.path.isfile(model_name + ".keras"):
# model_file = model_name + ".keras"
# else:
# raise FileNotFoundError(f"File for requested model {model_name} not found")
#
# print(f"USING MODEL (keras API): {model_file}")
# loaded_model = model_load(model_file)
#else:
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".pt")):
model_file = os.path.join(caiman_datadir(), model_name + ".pt")
elif os.path.isfile(model_name + ".pt"):
model_file = model_name + ".pt"
else:
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".h5.pb")):
model_file = os.path.join(caiman_datadir(), model_name + ".h5.pb")
elif os.path.isfile(model_name + ".h5.pb"):
model_file = model_name + ".h5.pb"
else:
raise FileNotFoundError(f"File for requested model {model_name} not found")
print(f"USING MODEL (tensorflow API): {model_file}")
loaded_model = caiman.utils.utils.load_graph(model_file)
raise FileNotFoundError(f"File for requested model {model_name} not found")
print(f"USING MODEL (PyTorch API): {model_file}")
loaded_model = torch.load(model_file)

logger.debug("Loaded model from disk")
logging.debug("Loaded model from disk")

half_crop = np.minimum(gSig[0] * 4 + 1, patch_size), np.minimum(gSig[1] * 4 + 1, patch_size)
dims = np.array(dims)
Expand All @@ -320,14 +315,14 @@ def evaluate_components_CNN(A,
half_crop[1]:com[1] + half_crop[1]] for mm, com in zip(A.tocsc().T, coms)
]
final_crops = np.array([cv2.resize(im / np.linalg.norm(im), (patch_size, patch_size)) for im in crop_imgs])
if use_keras:
predictions = loaded_model.predict(final_crops[:, :, :, np.newaxis], batch_size=32, verbose=1)
else:
tf_in = loaded_model.get_tensor_by_name('prefix/conv2d_20_input:0')
tf_out = loaded_model.get_tensor_by_name('prefix/output_node0:0')
with tf.Session(graph=loaded_model) as sess:
predictions = sess.run(tf_out, feed_dict={tf_in: final_crops[:, :, :, np.newaxis]})
sess.close()
# if use_keras:
# predictions = loaded_model.predict(final_crops[:, :, :, np.newaxis], batch_size=32, verbose=1)
# else:
final_crops = torch.tensor(final_crops, dtype=torch.float32)
final_crops = torch.reshape(final_crops, (-1, final_crops.shape[-1],
final_crops.shape[1], final_crops.shape[2]))
with torch.no_grad():
predictions = loaded_model(final_crops[:, np.newaxis, :, :])

return predictions, final_crops

Expand Down
82 changes: 43 additions & 39 deletions caiman/source_extraction/cnmf/online_cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
imaging data in real time. In Advances in Neural Information Processing Systems
(pp. 2381-2391).
@url http://papers.nips.cc/paper/6832-onacid-online-analysis-of-calcium-imaging-data-in-real-time

Implemented in PyTorch
Date: January 7th, 2025
"""

import cv2
Expand All @@ -26,7 +29,7 @@
from scipy.stats import norm
from sklearn.decomposition import NMF
from sklearn.preprocessing import normalize
import tensorflow as tf
import torch
from time import time

import caiman
Expand Down Expand Up @@ -320,34 +323,27 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
if self.params.get('online', 'path_to_model') is None or self.params.get('online', 'sniper_mode') is False:
loaded_model = None
self.params.set('online', {'sniper_mode': False})
self.tf_in = None
self.tf_out = None
# self.tf_in = None
# self.tf_out = None
else:
try:
from tensorflow.keras.models import model_from_json
logger.info('Using Keras')
use_keras = True
except(ModuleNotFoundError):
use_keras = False
logger.info('Using Tensorflow')
if use_keras:
path = self.params.get('online', 'path_to_model').split(".")[:-1]
json_path = ".".join(path + ["json"])
model_path = ".".join(path + ["h5"])
json_file = open(json_path, 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)
loaded_model.load_weights(model_path)
self.tf_in = None
self.tf_out = None
else:
path = self.params.get('online', 'path_to_model').split(".")[:-1]
model_path = '.'.join(path + ['h5', 'pb'])
loaded_model = load_graph(model_path)
self.tf_in = loaded_model.get_tensor_by_name('prefix/conv2d_1_input:0')
self.tf_out = loaded_model.get_tensor_by_name('prefix/output_node0:0')
loaded_model = tf.Session(graph=loaded_model)
# try:
# from keras.models import load_model
# use_keras = True
# logging.info('Using Keras')
# use_keras = True
# except(ModuleNotFoundError):
# use_keras = False
logging.info('Using Torch')

path = self.params.get('online', 'path_to_model').split(".")[:-1]
# if use_keras:
# model_path = ".".join(path + ["keras"])
# loaded_model = model_load(model_path)

model_path = '.'.join(path + ['pt'])
loaded_model = load_graph(model_path)
# loaded_model = torch.load(model_file)

self.loaded_model = loaded_model

if self.is1p:
Expand Down Expand Up @@ -548,7 +544,7 @@ def fit_next(self, t, frame_in, num_iters_hals=3):
sniper_mode=self.params.get('online', 'sniper_mode'),
use_peak_max=self.params.get('online', 'use_peak_max'),
mean_buff=self.estimates.mean_buff,
tf_in=self.tf_in, tf_out=self.tf_out,
# tf_in=self.tf_in, tf_out=self.tf_out,
ssub_B=ssub_B, W=self.estimates.W if self.is1p else None,
b0=self.estimates.b0 if self.is1p else None,
corr_img=self.estimates.corr_img if use_corr else None,
Expand Down Expand Up @@ -2002,8 +1998,9 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
gHalf=(5, 5), sniper_mode=True, rval_thr=0.85,
patch_size=50, loaded_model=None, test_both=False,
thresh_CNN_noisy=0.5, use_peak_max=False,
thresh_std_peak_resid = 1, mean_buff=None,
tf_in=None, tf_out=None):
thresh_std_peak_resid = 1, mean_buff=None #,
): # tf_in=None, tf_out=None):

"""
Extract new candidate components from the residual buffer and test them
using space correlation or the CNN classifier. The function runs the CNN
Expand Down Expand Up @@ -2084,12 +2081,18 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
Ain2 /= np.std(Ain2,axis=1)[:,None]
Ain2 = np.reshape(Ain2,(-1,) + tuple(np.diff(ijSig_cnn).squeeze()),order= 'F')
Ain2 = np.stack([cv2.resize(ain,(patch_size ,patch_size)) for ain in Ain2])
if tf_in is None:
predictions = loaded_model.predict(Ain2[:,:,:,np.newaxis], batch_size=min_num_trial, verbose=0)
else:
predictions = loaded_model.run(tf_out, feed_dict={tf_in: Ain2[:, :, :, np.newaxis]})
keep_cnn = list(np.where(predictions[:, 0] > thresh_CNN_noisy)[0])
cnn_pos = Ain2[keep_cnn]
# if use_torch is None:
# predictions = loaded_model.predict(Ain2[:,:,:,np.newaxis], batch_size=min_num_trial, verbose=0)
# keep_cnn = list(np.where(predictions[:, 0] > thresh_CNN_noisy)[0])
# else:
final_crops = torch.tensor(Ain2, dtype=torch.float32)
final_crops = torch.reshape(Ain2, (-1, Ain2.shape[-1],
Ain2.shape[1], Ain2.shape[2]))
with torch.no_grad():
predictions = loaded_model(Ain2[:, np.newaxis, :, :])
keep_cnn = list(torch.where(predictions[:, 0] > thresh_CNN_noisy)[0])

cnn_pos = Ain2[keep_cnn] #Make sure this works
else:
keep_cnn = [] # list(range(len(Ain_cnn)))

Expand Down Expand Up @@ -2138,7 +2141,8 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
corr_img=None, first_moment=None, second_moment=None,
crosscorr=None, col_ind=None, row_ind=None, corr_img_mode=None,
max_img=None, downscale_matrix=None, upscale_matrix=None,
tf_in=None, tf_out=None):
): # tf_in=None, tf_out=None):

"""
Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests
"""
Expand Down Expand Up @@ -2168,7 +2172,7 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
sniper_mode=sniper_mode, rval_thr=rval_thr, patch_size=50,
loaded_model=loaded_model, thresh_CNN_noisy=thresh_CNN_noisy,
use_peak_max=use_peak_max, test_both=test_both, mean_buff=mean_buff,
tf_in=tf_in, tf_out=tf_out)
) # tf_in=tf_in, tf_out=tf_out)

ind_new_all = ijsig_all

Expand Down
22 changes: 22 additions & 0 deletions caiman/tests/test_mrcnn_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python

import numpy as np
import os
import torch

import caiman as cm
from caiman.paths import caiman_datadir
from caiman.utils.utils import download_model, download_demo
from caiman.source_extraction.volpy.mrcnn import neurons
import caiman.source_extraction.volpy.mrcnn.model as modellib

def mrcnn(img, size_range, weights_path):

return

def test_mrcnn():
weights_path = download_model('mask_rcnn')
summary_images = cm.load(download_demo('demo_voltage_imaging_summary_images.tif'))
ROIs = mrcnn(img=summary_images.transpose([1, 2, 0]), size_range=[5, 22],
weights_path=weights_path)
assert ROIs.shape[0] == 14, 'fail to infer correct number of neurons'
42 changes: 42 additions & 0 deletions caiman/tests/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python

import numpy as np
import os

from caiman.paths import caiman_datadir
from caiman.utils.utils import load_graph

import torch

def test_torch():
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

try:
model_name = os.path.join(caiman_datadir(), 'model', 'cnn_model')
# if use_keras:
# model_file = model_name + ".keras"
# print('USING MODEL:' + model_file)
#
# loaded_model = load_model(model_file)
# loaded_model.compile('sgd', 'mse')
# elif use_keras == True:
model_file = model_name + ".pth"
loaded_model = torch.load(model_file)
except:
raise Exception(f'NN model could not be loaded.') #use_keras = {use_keras}')

A = np.random.randn(10, 50, 50, 1)
try:
# if use_keras == False:
# predictions = loaded_model.predict(A, batch_size=32)
# elif use_keras == True:
A = torch.tensor(A, dtype=torch.float32)
A = torch.reshape(A, (-1, A.shape[-1], A.shape[1], A.shape[2]))
with torch.no_grad():
predictions = loaded_model(A)
# pass
except:
raise Exception('NN model could not be deployed.') #use_keras = + str(use_keras))

if __name__ == "__main__":
test_torch()
Loading