-
Notifications
You must be signed in to change notification settings - Fork 386
PyTorch ver 1.0: keras 3.+ with pytorch backend for nn_models.py and … #1515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
2f2cf7d
1379329
018edb6
a7ee5e5
3f43ff0
636c518
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
#!/usr/bin/env python | ||
""" | ||
Contains the model architecture for cnn_model.pt and cnn_model_online.pt. The files | ||
cnn_model.pt and cnn_model_online.pt contain the model weights. The weight files are | ||
used to load the weights into the model architecture. | ||
""" | ||
|
||
import torch | ||
manuelpaeza marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class PyTorchCNN(nn.Module): | ||
def __init__(self): | ||
super(PyTorchCNN, self).__init__() | ||
# First convolutional block | ||
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3) | ||
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3) | ||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
self.dropout1 = nn.Dropout(p=0.25) | ||
|
||
# Second convolutional block | ||
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding='same') | ||
self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3) | ||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
self.dropout2 = nn.Dropout(p=0.25) | ||
|
||
# Flattening and fully connected layers | ||
self.flatten = nn.Flatten() | ||
self.fc1 = nn.Linear(in_features=6400, out_features=512) | ||
self.dropout3 = nn.Dropout(p=0.5) | ||
self.fc2 = nn.Linear(in_features=512, out_features=2) | ||
|
||
def forward(self, x): | ||
# Convolutional Block 1 | ||
x = F.relu(self.conv1(x)) | ||
x = F.relu(self.conv2(x)) | ||
x = self.pool1(x) | ||
x = self.dropout1(x) | ||
|
||
# Convolutional block 2 | ||
x = F.relu(self.conv3(x)) | ||
x = F.relu(self.conv4(x)) | ||
x = self.pool2(x) | ||
x = self.dropout2(x) | ||
|
||
# Flattening and in_features layers | ||
x = self.flatten(x) | ||
x = F.relu(self.fc1(x)) | ||
x = self.dropout3(x) | ||
x = self.fc2(x) | ||
return F.softmax(x, dim=1) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,8 @@ | |
from sklearn.decomposition import NMF | ||
from skimage.morphology import disk | ||
from sklearn.preprocessing import normalize | ||
import tensorflow as tf | ||
import torch | ||
from torch.utils.data import DataLoader, TensorDataset | ||
from time import time | ||
|
||
import caiman | ||
|
@@ -39,6 +40,7 @@ | |
high_pass_filter_space, sliding_window, | ||
register_translation_3d, apply_shifts_dft) | ||
import caiman.paths | ||
from caiman.pytorch_model_arch import PyTorchCNN | ||
from caiman.source_extraction.cnmf.cnmf import CNMF | ||
from caiman.source_extraction.cnmf.estimates import Estimates | ||
from caiman.source_extraction.cnmf.initialization import imblur, initialize_components, hals, downscale | ||
|
@@ -50,14 +52,13 @@ | |
import caiman.summary_images | ||
from caiman.utils.nn_models import (fit_NL_model, create_LN_model, quantile_loss, rate_scheduler) | ||
from caiman.utils.stats import pd_solve | ||
from caiman.utils.utils import save_dict_to_hdf5, load_dict_from_hdf5, parmap, load_graph | ||
from caiman.utils.utils import save_dict_to_hdf5, load_dict_from_hdf5, parmap | ||
|
||
try: | ||
cv2.setNumThreads(0) | ||
except(): | ||
pass | ||
|
||
#FIXME ??? | ||
try: | ||
profile | ||
except: | ||
|
@@ -357,34 +358,13 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None): | |
if self.params.get('online', 'path_to_model') is None or self.params.get('online', 'sniper_mode') is False: | ||
loaded_model = None | ||
self.params.set('online', {'sniper_mode': False}) | ||
self.tf_in = None | ||
self.tf_out = None | ||
else: | ||
try: | ||
from tensorflow.keras.models import model_from_json | ||
logger.info('Using Keras') | ||
use_keras = True | ||
except(ModuleNotFoundError): | ||
use_keras = False | ||
logger.info('Using Tensorflow') | ||
if use_keras: | ||
path = self.params.get('online', 'path_to_model').split(".")[:-1] | ||
json_path = ".".join(path + ["json"]) | ||
model_path = ".".join(path + ["h5"]) | ||
json_file = open(json_path, 'r') | ||
loaded_model_json = json_file.read() | ||
json_file.close() | ||
loaded_model = model_from_json(loaded_model_json) | ||
loaded_model.load_weights(model_path) | ||
self.tf_in = None | ||
self.tf_out = None | ||
else: | ||
path = self.params.get('online', 'path_to_model').split(".")[:-1] | ||
model_path = '.'.join(path + ['h5', 'pb']) | ||
loaded_model = load_graph(model_path) | ||
self.tf_in = loaded_model.get_tensor_by_name('prefix/conv2d_1_input:0') | ||
self.tf_out = loaded_model.get_tensor_by_name('prefix/output_node0:0') | ||
loaded_model = tf.Session(graph=loaded_model) | ||
logger.info('Using Torch') | ||
path = self.params.get('online', 'path_to_model').split(".")[:-1] | ||
model_path = '.'.join(path + ['pt']) | ||
Comment on lines
+363
to
+364
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a clearer way to write this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does not look like it, and it is cleaner given the tensorflow syntax currently. Maybe removal the logger.info, but that is it. |
||
loaded_model = PyTorchCNN() | ||
loaded_model.load_state_dict(torch.load(model_path)) | ||
|
||
self.loaded_model = loaded_model | ||
|
||
if self.is1p: | ||
|
@@ -585,7 +565,6 @@ def fit_next(self, t, frame_in, num_iters_hals=3): | |
sniper_mode=self.params.get('online', 'sniper_mode'), | ||
use_peak_max=self.params.get('online', 'use_peak_max'), | ||
mean_buff=self.estimates.mean_buff, | ||
tf_in=self.tf_in, tf_out=self.tf_out, | ||
ssub_B=ssub_B, W=self.estimates.W if self.is1p else None, | ||
b0=self.estimates.b0 if self.is1p else None, | ||
corr_img=self.estimates.corr_img if use_corr else None, | ||
|
@@ -1252,6 +1231,13 @@ def fit_online(self, **kwargs): | |
+ str(self.estimates.Ab.shape[-1] - self.params.get('init', 'nb'))) | ||
old_comps = self.N | ||
|
||
if np.isnan(np.sum(frame)): | ||
raise Exception(f'Frame {frame_count} contains NaN') | ||
if t % 500 == 0: | ||
logger.info(f'Epoch: {iter + 1}. {t} frames have been processed.' | ||
f'{self.N - old_comps} new components were added. Total: {self.N}') | ||
old_comps = self.N | ||
|
||
# Downsample and normalize | ||
frame_ = frame.copy().astype(np.float32) | ||
if self.params.get('online', 'ds_factor') > 1: | ||
|
@@ -2040,8 +2026,7 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5), | |
gHalf=(5, 5), sniper_mode=True, rval_thr=0.85, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know it's not necessarily code you added, but as I assume you've had to understand this function to rework it, any docs you could add to code-dense parts of this function that might help with understanding would be really useful There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. I have some notes, so I could add them in. |
||
patch_size=50, loaded_model=None, test_both=False, | ||
thresh_CNN_noisy=0.5, use_peak_max=False, | ||
thresh_std_peak_resid = 1, mean_buff=None, | ||
tf_in=None, tf_out=None): | ||
thresh_std_peak_resid = 1, mean_buff=None): | ||
""" | ||
Extract new candidate components from the residual buffer and test them | ||
using space correlation or the CNN classifier. The function runs the CNN | ||
|
@@ -2122,11 +2107,23 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5), | |
Ain2 /= np.std(Ain2,axis=1)[:,None] | ||
Ain2 = np.reshape(Ain2,(-1,) + tuple(np.diff(ijSig_cnn).squeeze()),order= 'F') | ||
Ain2 = np.stack([cv2.resize(ain,(patch_size ,patch_size)) for ain in Ain2]) | ||
if tf_in is None: | ||
predictions = loaded_model.predict(Ain2[:,:,:,np.newaxis], batch_size=min_num_trial, verbose=0) | ||
else: | ||
predictions = loaded_model.run(tf_out, feed_dict={tf_in: Ain2[:, :, :, np.newaxis]}) | ||
keep_cnn = list(np.where(predictions[:, 0] > thresh_CNN_noisy)[0]) | ||
|
||
final_crops = Ain2[:, :, :, np.newaxis] | ||
final_crops_tensor = torch.tensor(final_crops, dtype=torch.float32).permute(0, 3, 1, 2) | ||
|
||
#Create DataLoader for batching | ||
dataset = TensorDataset(final_crops_tensor) | ||
loader = DataLoader(dataset, batch_size=int(min_num_trial), shuffle=False) | ||
|
||
loaded_model.eval() | ||
all_predictions = [] | ||
with torch.no_grad(): | ||
for batch in loader: | ||
outputs = loaded_model(batch[0]) | ||
all_predictions.append(outputs) | ||
|
||
predictions = torch.cat(all_predictions).cpu().numpy() | ||
keep_cnn = list(np.where(predictions[:,0] > thresh_CNN_noisy)[0]) | ||
cnn_pos = Ain2[keep_cnn] | ||
else: | ||
keep_cnn = [] # list(range(len(Ain_cnn))) | ||
|
@@ -2175,8 +2172,7 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf, | |
mean_buff=None, ssub_B=1, W=None, b0=None, | ||
corr_img=None, first_moment=None, second_moment=None, | ||
crosscorr=None, col_ind=None, row_ind=None, corr_img_mode=None, | ||
max_img=None, downscale_matrix=None, upscale_matrix=None, | ||
tf_in=None, tf_out=None): | ||
max_img=None, downscale_matrix=None, upscale_matrix=None): | ||
""" | ||
Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests | ||
""" | ||
|
@@ -2205,8 +2201,7 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf, | |
min_num_trial=min_num_trial, gSig=gSig, gHalf=gHalf, | ||
sniper_mode=sniper_mode, rval_thr=rval_thr, patch_size=50, | ||
loaded_model=loaded_model, thresh_CNN_noisy=thresh_CNN_noisy, | ||
use_peak_max=use_peak_max, test_both=test_both, mean_buff=mean_buff, | ||
tf_in=tf_in, tf_out=tf_out) | ||
use_peak_max=use_peak_max, test_both=test_both, mean_buff=mean_buff) | ||
|
||
ind_new_all = ijsig_all | ||
|
||
|
@@ -2596,4 +2591,4 @@ def load_OnlineCNMF(filename, dview = None): | |
return new_obj | ||
|
||
def inv_mat_vec(A): | ||
return np.linalg.solve(A[0], A[1]) | ||
return np.linalg.solve(A[0], A[1]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,7 @@ | ||
#!/usr/bin/env python | ||
|
||
from . import config | ||
manuelpaeza marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from . import model | ||
from . import neurons | ||
from . import utils | ||
from . import visualize |
Uh oh!
There was an error while loading. Please reload this page.