@@ -151,7 +151,7 @@ def __init__(self, gpu=False, pretrained_model="cpsam", model_type=None,
151
151
self .net .load_model (self .pretrained_model , device = self .device )
152
152
153
153
154
- def eval (self , x , batch_size = 8 , resample = None , channels = None , channel_axis = None ,
154
+ def eval (self , x , batch_size = 8 , resample = True , channels = None , channel_axis = None ,
155
155
z_axis = None , normalize = True , invert = False , rescale = None , diameter = None ,
156
156
flow_threshold = 0.4 , cellprob_threshold = 0.0 , do_3D = False , anisotropy = None ,
157
157
flow3D_smooth = 0 , stitch_threshold = 0.0 ,
@@ -165,7 +165,6 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
165
165
batch_size (int, optional): number of 256x256 patches to run simultaneously on the GPU
166
166
(can make smaller or bigger depending on GPU memory usage). Defaults to 64.
167
167
resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries).
168
- deprecated in v4.0.1+, resample is not used
169
168
channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
170
169
if None, channels dimension is attempted to be automatically determined. Defaults to None.
171
170
z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
@@ -327,7 +326,9 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
327
326
328
327
if resample :
329
328
# upsample flows before computing them:
330
- raise NotImplementedError
329
+ dP = self ._resize_gradients (dP , to_y_size = Ly_0 , to_x_size = Lx_0 , to_z_size = Lz_0 )
330
+ cellprob = self ._resize_cellprob (cellprob , to_x_size = Lx_0 , to_y_size = Ly_0 , to_z_size = Lz_0 )
331
+
331
332
332
333
if compute_masks :
333
334
niter0 = 200
@@ -343,6 +344,10 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
343
344
344
345
# undo resizing:
345
346
if image_scaling is not None or anisotropy is not None :
347
+
348
+ dP = self ._resize_gradients (dP , to_y_size = Ly_0 , to_x_size = Lx_0 , to_z_size = Lz_0 ) # works for 2 or 3D:
349
+ cellprob = self ._resize_cellprob (cellprob , to_x_size = Lx_0 , to_y_size = Ly_0 , to_z_size = Lz_0 )
350
+
346
351
if do_3D :
347
352
if compute_masks :
348
353
# Rescale xy then xz:
@@ -351,29 +356,96 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
351
356
masks = transforms .resize_image (masks , Ly = Lz_0 , Lx = Lx_0 , no_channels = True , interpolation = cv2 .INTER_NEAREST )
352
357
masks = masks .transpose (1 , 0 , 2 )
353
358
354
- # cellprob is the same
355
- cellprob = transforms .resize_image (cellprob , Ly = Ly_0 , Lx = Lx_0 , no_channels = True )
356
- cellprob = cellprob .transpose (1 , 0 , 2 )
357
- cellprob = transforms .resize_image (cellprob , Ly = Lz_0 , Lx = Lx_0 , no_channels = True )
358
- cellprob = cellprob .transpose (1 , 0 , 2 )
359
-
360
- # dP has gradients that can be treated as channels:
361
- dP = dP .transpose (1 , 2 , 3 , 0 ) # move gradients last:
362
- dP = transforms .resize_image (dP , Ly = Ly_0 , Lx = Lx_0 , no_channels = False )
363
- dP = dP .transpose (1 , 0 , 2 , 3 ) # switch axes to resize again
364
- dP = transforms .resize_image (dP , Ly = Lz_0 , Lx = Lx_0 , no_channels = False )
365
- dP = dP .transpose (3 , 1 , 0 , 2 ) # undo transposition
366
-
367
359
else :
368
360
# 2D or 3D stitching case:
369
361
if compute_masks :
370
362
masks = transforms .resize_image (masks , Ly = Ly_0 , Lx = Lx_0 , no_channels = True , interpolation = cv2 .INTER_NEAREST )
371
- cellprob = transforms .resize_image (cellprob , Ly = Ly_0 , Lx = Lx_0 , no_channels = True )
372
- dP = np .moveaxis (dP , 0 , - 1 ) # Put gradients last
373
- dP = transforms .resize_image (dP , Ly = Ly_0 , Lx = Lx_0 , no_channels = False )
374
- dP = np .moveaxis (dP , - 1 , 0 ) # Put gradients first
375
363
376
364
return masks , [plot .dx_to_circ (dP ), dP , cellprob ], styles
365
+
366
+
367
+ def _resize_cellprob (self , prob : np .ndarray , to_y_size : int , to_x_size : int , to_z_size : int = None ) -> np .ndarray :
368
+ """
369
+ Resize cellprob array to specified dimensions for either 2D or 3D.
370
+
371
+ Parameters:
372
+ prob (numpy.ndarray): The cellprobs to resize, either in 2D or 3D. Returns the same ndim as provided.
373
+ to_y_size (int): The target size along the Y-axis.
374
+ to_x_size (int): The target size along the X-axis.
375
+ to_z_size (int, optional): The target size along the Z-axis. Required
376
+ for 3D cellprobs.
377
+
378
+ Returns:
379
+ numpy.ndarray: The resized cellprobs array with the same number of dimensions
380
+ as the input.
381
+
382
+ Raises:
383
+ ValueError: If the input cellprobs array does not have 3 or 4 dimensions.
384
+ """
385
+ prob_shape = prob .shape
386
+ prob = prob .squeeze ()
387
+ squeeze_happened = prob .shape != prob_shape
388
+ prob_shape = np .array (prob_shape )
389
+
390
+ if prob .ndim == 2 :
391
+ # 2D case:
392
+ prob = transforms .resize_image (prob , Ly = to_y_size , Lx = to_x_size , no_channels = True )
393
+ if squeeze_happened :
394
+ prob = np .expand_dims (prob , int (np .argwhere (prob_shape == 1 ))) # add back empty axis for compatibility
395
+ elif prob .ndim == 3 :
396
+ # 3D case:
397
+ prob = transforms .resize_image (prob , Ly = to_y_size , Lx = to_x_size , no_channels = True )
398
+ prob = prob .transpose (1 , 0 , 2 )
399
+ prob = transforms .resize_image (prob , Ly = to_z_size , Lx = to_x_size , no_channels = True )
400
+ prob = prob .transpose (1 , 0 , 2 )
401
+ else :
402
+ raise ValueError (f'gradients have incorrect dimension after squeezing. Should be 2 or 3, prob shape: { prob .shape } ' )
403
+
404
+ return prob
405
+
406
+
407
+ def _resize_gradients (self , grads : np .ndarray , to_y_size : int , to_x_size : int , to_z_size : int = None ) -> np .ndarray :
408
+ """
409
+ Resize gradient arrays to specified dimensions for either 2D or 3D gradients.
410
+
411
+ Parameters:
412
+ grads (np.ndarray): The gradients to resize, either in 2D or 3D. Returns the same ndim as provided.
413
+ to_y_size (int): The target size along the Y-axis.
414
+ to_x_size (int): The target size along the X-axis.
415
+ to_z_size (int, optional): The target size along the Z-axis. Required
416
+ for 3D gradients.
417
+
418
+ Returns:
419
+ numpy.ndarray: The resized gradient array with the same number of dimensions
420
+ as the input.
421
+
422
+ Raises:
423
+ ValueError: If the input gradient array does not have 3 or 4 dimensions.
424
+ """
425
+ grads_shape = grads .shape
426
+ grads = grads .squeeze ()
427
+ squeeze_happened = grads .shape != grads_shape
428
+ grads_shape = np .array (grads_shape )
429
+
430
+ if grads .ndim == 3 :
431
+ # 2D case, with XY flows in 2 channels:
432
+ grads = np .moveaxis (grads , 0 , - 1 ) # Put gradients last
433
+ grads = transforms .resize_image (grads , Ly = to_y_size , Lx = to_x_size , no_channels = False )
434
+ grads = np .moveaxis (grads , - 1 , 0 ) # Put gradients first
435
+
436
+ if squeeze_happened :
437
+ grads = np .expand_dims (grads , int (np .argwhere (grads_shape == 1 ))) # add back empty axis for compatibility
438
+ elif grads .ndim == 4 :
439
+ # dP has gradients that can be treated as channels:
440
+ grads = grads .transpose (1 , 2 , 3 , 0 ) # move gradients last:
441
+ grads = transforms .resize_image (grads , Ly = to_y_size , Lx = to_x_size , no_channels = False )
442
+ grads = grads .transpose (1 , 0 , 2 , 3 ) # switch axes to resize again
443
+ grads = transforms .resize_image (grads , Ly = to_z_size , Lx = to_x_size , no_channels = False )
444
+ grads = grads .transpose (3 , 1 , 0 , 2 ) # undo transposition
445
+ else :
446
+ raise ValueError (f'gradients have incorrect dimension after squeezing. Should be 3 or 4, grads shape: { grads .shape } ' )
447
+
448
+ return grads
377
449
378
450
379
451
def _run_net (self , x ,
0 commit comments