Skip to content

PyTorch ver 1.0: keras 3.+ with pytorch backend for nn_models.py and … #1515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions caiman/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from caiman.base.timeseries import concatenate
from caiman.cluster import start_server, stop_server
from caiman.mmapping import load_memmap, save_memmap, save_memmap_each, save_memmap_join
from caiman.pytorch_model_arch import PyTorchCNN
from caiman.summary_images import local_correlations

__version__ = importlib.metadata.version('caiman')
67 changes: 22 additions & 45 deletions caiman/components_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
import numpy as np
import os
import peakutils
import tensorflow as tf
import scipy
from scipy.sparse import csc_matrix
from scipy.stats import norm
import torch
from typing import Any, Union
import warnings

import caiman
from caiman.paths import caiman_datadir
from caiman.pytorch_model_arch import PyTorchCNN
import caiman.utils.stats
import caiman.utils.utils

try:
cv2.setNumThreads(0)
Expand Down Expand Up @@ -270,45 +270,22 @@ def evaluate_components_CNN(A,
then this code will try not to use a GPU. Otherwise it will use one if it finds it.
"""
logger = logging.getLogger("caiman")

# TODO: Find a less ugly way to do this
if not isGPU and 'CAIMAN_ALLOW_GPU' not in os.environ:
print("GPU run not requested, disabling use of GPUs")
logger.info("GPU run not requested, disabling use of GPUs")
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
try:
os.environ["KERAS_BACKEND"] = "tensorflow"
from tensorflow.keras.models import model_from_json
use_keras = True
logger.info('Using Keras')
except (ModuleNotFoundError):
use_keras = False
logger.info('Using Tensorflow')

if loaded_model is None:
if use_keras:
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".json")):
model_file = os.path.join(caiman_datadir(), model_name + ".json")
model_weights = os.path.join(caiman_datadir(), model_name + ".h5")
elif os.path.isfile(model_name + ".json"):
model_file = model_name + ".json"
model_weights = model_name + ".h5"
else:
raise FileNotFoundError(f"File for requested model {model_name} not found")
with open(model_file, 'r') as json_file:
print(f"USING MODEL (keras API): {model_file}")
loaded_model_json = json_file.read()
logger.info('Using Torch')

loaded_model = model_from_json(loaded_model_json)
loaded_model.load_weights(model_name + '.h5')
if loaded_model is None:
if os.path.isfile(os.path.join(caiman_datadir(), 'model', 'pytorch-models', model_name + ".pt")):
model_file = os.path.join(caiman_datadir(), 'model', 'pytorch-models', model_name + ".pt")
elif os.path.isfile(model_name + ".pt"):
model_file = model_name + ".pt"
else:
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".h5.pb")):
model_file = os.path.join(caiman_datadir(), model_name + ".h5.pb")
elif os.path.isfile(model_name + ".h5.pb"):
model_file = model_name + ".h5.pb"
else:
raise FileNotFoundError(f"File for requested model {model_name} not found")
print(f"USING MODEL (tensorflow API): {model_file}")
loaded_model = caiman.utils.utils.load_graph(model_file)
raise FileNotFoundError(f"File for requested model {model_name} not found")
logger.info(f"Using model: {model_file}")
loaded_model = PyTorchCNN()
loaded_model.load_state_dict(torch.load(model_file))

logger.debug("Loaded model from disk")

Expand All @@ -322,16 +299,16 @@ def evaluate_components_CNN(A,
half_crop[1]:com[1] + half_crop[1]] for mm, com in zip(A.tocsc().T, coms)
]
final_crops = np.array([cv2.resize(im / np.linalg.norm(im), (patch_size, patch_size)) for im in crop_imgs])
if use_keras:
predictions = loaded_model.predict(final_crops[:, :, :, np.newaxis], batch_size=32, verbose=1)
else:
tf_in = loaded_model.get_tensor_by_name('prefix/conv2d_20_input:0')
tf_out = loaded_model.get_tensor_by_name('prefix/output_node0:0')
with tf.Session(graph=loaded_model) as sess:
predictions = sess.run(tf_out, feed_dict={tf_in: final_crops[:, :, :, np.newaxis]})
sess.close()

# Numpy to PyTorch and add a channel dimension using unsqueeze
final_crops = torch.tensor(final_crops, dtype=torch.float32).unsqueeze(1)

# Pass the preprocessed image crops through the model to get predictions
with torch.no_grad():
predictions = loaded_model(final_crops)

return predictions, final_crops
predictions_numpy = predictions.cpu().numpy()
return predictions_numpy, final_crops

def evaluate_components(Y: np.ndarray,
traces: np.ndarray,
Expand Down
52 changes: 52 additions & 0 deletions caiman/pytorch_model_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env python
"""
Contains the model architecture for cnn_model.pt and cnn_model_online.pt. The files
cnn_model.pt and cnn_model_online.pt contain the model weights. The weight files are
used to load the weights into the model architecture.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class PyTorchCNN(nn.Module):
def __init__(self):
super(PyTorchCNN, self).__init__()
# First convolutional block
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.dropout1 = nn.Dropout(p=0.25)

# Second convolutional block
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding='same')
self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.dropout2 = nn.Dropout(p=0.25)

# Flattening and fully connected layers
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(in_features=6400, out_features=512)
self.dropout3 = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(in_features=512, out_features=2)

def forward(self, x):
# Convolutional Block 1
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.pool1(x)
x = self.dropout1(x)

# Convolutional block 2
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = self.pool2(x)
x = self.dropout2(x)

# Flattening and in_features layers
x = self.flatten(x)
x = F.relu(self.fc1(x))
x = self.dropout3(x)
x = self.fc2(x)
return F.softmax(x, dim=1)

81 changes: 38 additions & 43 deletions caiman/source_extraction/cnmf/online_cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from sklearn.decomposition import NMF
from skimage.morphology import disk
from sklearn.preprocessing import normalize
import tensorflow as tf
import torch
from torch.utils.data import DataLoader, TensorDataset
from time import time

import caiman
Expand All @@ -39,6 +40,7 @@
high_pass_filter_space, sliding_window,
register_translation_3d, apply_shifts_dft)
import caiman.paths
from caiman.pytorch_model_arch import PyTorchCNN
from caiman.source_extraction.cnmf.cnmf import CNMF
from caiman.source_extraction.cnmf.estimates import Estimates
from caiman.source_extraction.cnmf.initialization import imblur, initialize_components, hals, downscale
Expand All @@ -50,14 +52,13 @@
import caiman.summary_images
from caiman.utils.nn_models import (fit_NL_model, create_LN_model, quantile_loss, rate_scheduler)
from caiman.utils.stats import pd_solve
from caiman.utils.utils import save_dict_to_hdf5, load_dict_from_hdf5, parmap, load_graph
from caiman.utils.utils import save_dict_to_hdf5, load_dict_from_hdf5, parmap

try:
cv2.setNumThreads(0)
except():
pass

#FIXME ???
try:
profile
except:
Expand Down Expand Up @@ -357,34 +358,13 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
if self.params.get('online', 'path_to_model') is None or self.params.get('online', 'sniper_mode') is False:
loaded_model = None
self.params.set('online', {'sniper_mode': False})
self.tf_in = None
self.tf_out = None
else:
try:
from tensorflow.keras.models import model_from_json
logger.info('Using Keras')
use_keras = True
except(ModuleNotFoundError):
use_keras = False
logger.info('Using Tensorflow')
if use_keras:
path = self.params.get('online', 'path_to_model').split(".")[:-1]
json_path = ".".join(path + ["json"])
model_path = ".".join(path + ["h5"])
json_file = open(json_path, 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)
loaded_model.load_weights(model_path)
self.tf_in = None
self.tf_out = None
else:
path = self.params.get('online', 'path_to_model').split(".")[:-1]
model_path = '.'.join(path + ['h5', 'pb'])
loaded_model = load_graph(model_path)
self.tf_in = loaded_model.get_tensor_by_name('prefix/conv2d_1_input:0')
self.tf_out = loaded_model.get_tensor_by_name('prefix/output_node0:0')
loaded_model = tf.Session(graph=loaded_model)
logger.info('Using Torch')
path = self.params.get('online', 'path_to_model').split(".")[:-1]
model_path = '.'.join(path + ['pt'])
Comment on lines +363 to +364
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a clearer way to write this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not look like it, and it is cleaner given the tensorflow syntax currently. Maybe removal the logger.info, but that is it.

loaded_model = PyTorchCNN()
loaded_model.load_state_dict(torch.load(model_path))

self.loaded_model = loaded_model

if self.is1p:
Expand Down Expand Up @@ -585,7 +565,6 @@ def fit_next(self, t, frame_in, num_iters_hals=3):
sniper_mode=self.params.get('online', 'sniper_mode'),
use_peak_max=self.params.get('online', 'use_peak_max'),
mean_buff=self.estimates.mean_buff,
tf_in=self.tf_in, tf_out=self.tf_out,
ssub_B=ssub_B, W=self.estimates.W if self.is1p else None,
b0=self.estimates.b0 if self.is1p else None,
corr_img=self.estimates.corr_img if use_corr else None,
Expand Down Expand Up @@ -1252,6 +1231,13 @@ def fit_online(self, **kwargs):
+ str(self.estimates.Ab.shape[-1] - self.params.get('init', 'nb')))
old_comps = self.N

if np.isnan(np.sum(frame)):
raise Exception(f'Frame {frame_count} contains NaN')
if t % 500 == 0:
logger.info(f'Epoch: {iter + 1}. {t} frames have been processed.'
f'{self.N - old_comps} new components were added. Total: {self.N}')
old_comps = self.N

# Downsample and normalize
frame_ = frame.copy().astype(np.float32)
if self.params.get('online', 'ds_factor') > 1:
Expand Down Expand Up @@ -2040,8 +2026,7 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
gHalf=(5, 5), sniper_mode=True, rval_thr=0.85,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's not necessarily code you added, but as I assume you've had to understand this function to rework it, any docs you could add to code-dense parts of this function that might help with understanding would be really useful

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I have some notes, so I could add them in.

patch_size=50, loaded_model=None, test_both=False,
thresh_CNN_noisy=0.5, use_peak_max=False,
thresh_std_peak_resid = 1, mean_buff=None,
tf_in=None, tf_out=None):
thresh_std_peak_resid = 1, mean_buff=None):
"""
Extract new candidate components from the residual buffer and test them
using space correlation or the CNN classifier. The function runs the CNN
Expand Down Expand Up @@ -2122,11 +2107,23 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
Ain2 /= np.std(Ain2,axis=1)[:,None]
Ain2 = np.reshape(Ain2,(-1,) + tuple(np.diff(ijSig_cnn).squeeze()),order= 'F')
Ain2 = np.stack([cv2.resize(ain,(patch_size ,patch_size)) for ain in Ain2])
if tf_in is None:
predictions = loaded_model.predict(Ain2[:,:,:,np.newaxis], batch_size=min_num_trial, verbose=0)
else:
predictions = loaded_model.run(tf_out, feed_dict={tf_in: Ain2[:, :, :, np.newaxis]})
keep_cnn = list(np.where(predictions[:, 0] > thresh_CNN_noisy)[0])

final_crops = Ain2[:, :, :, np.newaxis]
final_crops_tensor = torch.tensor(final_crops, dtype=torch.float32).permute(0, 3, 1, 2)

#Create DataLoader for batching
dataset = TensorDataset(final_crops_tensor)
loader = DataLoader(dataset, batch_size=int(min_num_trial), shuffle=False)

loaded_model.eval()
all_predictions = []
with torch.no_grad():
for batch in loader:
outputs = loaded_model(batch[0])
all_predictions.append(outputs)

predictions = torch.cat(all_predictions).cpu().numpy()
keep_cnn = list(np.where(predictions[:,0] > thresh_CNN_noisy)[0])
cnn_pos = Ain2[keep_cnn]
else:
keep_cnn = [] # list(range(len(Ain_cnn)))
Expand Down Expand Up @@ -2175,8 +2172,7 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
mean_buff=None, ssub_B=1, W=None, b0=None,
corr_img=None, first_moment=None, second_moment=None,
crosscorr=None, col_ind=None, row_ind=None, corr_img_mode=None,
max_img=None, downscale_matrix=None, upscale_matrix=None,
tf_in=None, tf_out=None):
max_img=None, downscale_matrix=None, upscale_matrix=None):
"""
Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests
"""
Expand Down Expand Up @@ -2205,8 +2201,7 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
min_num_trial=min_num_trial, gSig=gSig, gHalf=gHalf,
sniper_mode=sniper_mode, rval_thr=rval_thr, patch_size=50,
loaded_model=loaded_model, thresh_CNN_noisy=thresh_CNN_noisy,
use_peak_max=use_peak_max, test_both=test_both, mean_buff=mean_buff,
tf_in=tf_in, tf_out=tf_out)
use_peak_max=use_peak_max, test_both=test_both, mean_buff=mean_buff)

ind_new_all = ijsig_all

Expand Down Expand Up @@ -2596,4 +2591,4 @@ def load_OnlineCNMF(filename, dview = None):
return new_obj

def inv_mat_vec(A):
return np.linalg.solve(A[0], A[1])
return np.linalg.solve(A[0], A[1])
6 changes: 6 additions & 0 deletions caiman/source_extraction/volpy/mrcnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
#!/usr/bin/env python

from . import config
from . import model
from . import neurons
from . import utils
from . import visualize
Loading