Skip to content

Commit f171401

Browse files
committed
2 parents d0df00c + 79b0fcb commit f171401

File tree

8 files changed

+156
-34
lines changed

8 files changed

+156
-34
lines changed

cellpose/__main__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,6 @@ def main():
9898
logger.warning("the '--invert' flag is deprecated in v4.0.1+ and no longer used")
9999
if args.chan2_restore:
100100
logger.warning("the '--chan2_restore' flag is deprecated in v4.0.1+ and no longer used")
101-
if not args.no_resample:
102-
logger.warning("the '--no_resample' flag is deprecated in v4.0.1+ and no longer used")
103101
if args.diam_mean:
104102
logger.warning("the '--diam_mean' flag is deprecated in v4.0.1+ and no longer used")
105103
if args.train_size:
@@ -228,6 +226,7 @@ def _evaluate_cellposemodel_cli(args, logger, imf, device, pretrained_model, nor
228226
min_size=args.min_size,
229227
batch_size=args.batch_size,
230228
bsize=args.bsize,
229+
resample=not args.no_resample,
231230
normalize=normalize,
232231
channel_axis=channel_axis,
233232
z_axis=z_axis,

cellpose/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ def get_arg_parser():
134134

135135
# TODO: remove deprecated in future version
136136
algorithm_args.add_argument(
137-
"--no_resample", action="store_true", help=
138-
'Deprecated in v4.0.1+, not used. ')
137+
"--no_resample", action="store_true",
138+
help="disables flows/cellprob resampling to original image size before computing masks. Using this flag will make more masks more jagged with larger diameter settings.")
139139
algorithm_args.add_argument(
140140
"--no_interp", action="store_true",
141141
help="do not interpolate when running dynamics (was default)")

cellpose/models.py

Lines changed: 92 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __init__(self, gpu=False, pretrained_model="cpsam", model_type=None,
151151
self.net.load_model(self.pretrained_model, device=self.device)
152152

153153

154-
def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
154+
def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
155155
z_axis=None, normalize=True, invert=False, rescale=None, diameter=None,
156156
flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None,
157157
flow3D_smooth=0, stitch_threshold=0.0,
@@ -165,7 +165,6 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
165165
batch_size (int, optional): number of 256x256 patches to run simultaneously on the GPU
166166
(can make smaller or bigger depending on GPU memory usage). Defaults to 64.
167167
resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries).
168-
deprecated in v4.0.1+, resample is not used
169168
channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
170169
if None, channels dimension is attempted to be automatically determined. Defaults to None.
171170
z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
@@ -327,7 +326,9 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
327326

328327
if resample:
329328
# upsample flows before computing them:
330-
raise NotImplementedError
329+
dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0)
330+
cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
331+
331332

332333
if compute_masks:
333334
niter0 = 200
@@ -343,6 +344,10 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
343344

344345
# undo resizing:
345346
if image_scaling is not None or anisotropy is not None:
347+
348+
dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0) # works for 2 or 3D:
349+
cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
350+
346351
if do_3D:
347352
if compute_masks:
348353
# Rescale xy then xz:
@@ -351,29 +356,96 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
351356
masks = transforms.resize_image(masks, Ly=Lz_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
352357
masks = masks.transpose(1, 0, 2)
353358

354-
# cellprob is the same
355-
cellprob = transforms.resize_image(cellprob, Ly=Ly_0, Lx=Lx_0, no_channels=True)
356-
cellprob = cellprob.transpose(1, 0, 2)
357-
cellprob = transforms.resize_image(cellprob, Ly=Lz_0, Lx=Lx_0, no_channels=True)
358-
cellprob = cellprob.transpose(1, 0, 2)
359-
360-
# dP has gradients that can be treated as channels:
361-
dP = dP.transpose(1, 2, 3, 0) # move gradients last:
362-
dP = transforms.resize_image(dP, Ly=Ly_0, Lx=Lx_0, no_channels=False)
363-
dP = dP.transpose(1, 0, 2, 3) # switch axes to resize again
364-
dP = transforms.resize_image(dP, Ly=Lz_0, Lx=Lx_0, no_channels=False)
365-
dP = dP.transpose(3, 1, 0, 2) # undo transposition
366-
367359
else:
368360
# 2D or 3D stitching case:
369361
if compute_masks:
370362
masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
371-
cellprob = transforms.resize_image(cellprob, Ly=Ly_0, Lx=Lx_0, no_channels=True)
372-
dP = np.moveaxis(dP, 0, -1) # Put gradients last
373-
dP = transforms.resize_image(dP, Ly=Ly_0, Lx=Lx_0, no_channels=False)
374-
dP = np.moveaxis(dP, -1, 0) # Put gradients first
375363

376364
return masks, [plot.dx_to_circ(dP), dP, cellprob], styles
365+
366+
367+
def _resize_cellprob(self, prob: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
368+
"""
369+
Resize cellprob array to specified dimensions for either 2D or 3D.
370+
371+
Parameters:
372+
prob (numpy.ndarray): The cellprobs to resize, either in 2D or 3D. Returns the same ndim as provided.
373+
to_y_size (int): The target size along the Y-axis.
374+
to_x_size (int): The target size along the X-axis.
375+
to_z_size (int, optional): The target size along the Z-axis. Required
376+
for 3D cellprobs.
377+
378+
Returns:
379+
numpy.ndarray: The resized cellprobs array with the same number of dimensions
380+
as the input.
381+
382+
Raises:
383+
ValueError: If the input cellprobs array does not have 3 or 4 dimensions.
384+
"""
385+
prob_shape = prob.shape
386+
prob = prob.squeeze()
387+
squeeze_happened = prob.shape != prob_shape
388+
prob_shape = np.array(prob_shape)
389+
390+
if prob.ndim == 2:
391+
# 2D case:
392+
prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True)
393+
if squeeze_happened:
394+
prob = np.expand_dims(prob, int(np.argwhere(prob_shape == 1))) # add back empty axis for compatibility
395+
elif prob.ndim == 3:
396+
# 3D case:
397+
prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True)
398+
prob = prob.transpose(1, 0, 2)
399+
prob = transforms.resize_image(prob, Ly=to_z_size, Lx=to_x_size, no_channels=True)
400+
prob = prob.transpose(1, 0, 2)
401+
else:
402+
raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 2 or 3, prob shape: {prob.shape}')
403+
404+
return prob
405+
406+
407+
def _resize_gradients(self, grads: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
408+
"""
409+
Resize gradient arrays to specified dimensions for either 2D or 3D gradients.
410+
411+
Parameters:
412+
grads (np.ndarray): The gradients to resize, either in 2D or 3D. Returns the same ndim as provided.
413+
to_y_size (int): The target size along the Y-axis.
414+
to_x_size (int): The target size along the X-axis.
415+
to_z_size (int, optional): The target size along the Z-axis. Required
416+
for 3D gradients.
417+
418+
Returns:
419+
numpy.ndarray: The resized gradient array with the same number of dimensions
420+
as the input.
421+
422+
Raises:
423+
ValueError: If the input gradient array does not have 3 or 4 dimensions.
424+
"""
425+
grads_shape = grads.shape
426+
grads = grads.squeeze()
427+
squeeze_happened = grads.shape != grads_shape
428+
grads_shape = np.array(grads_shape)
429+
430+
if grads.ndim == 3:
431+
# 2D case, with XY flows in 2 channels:
432+
grads = np.moveaxis(grads, 0, -1) # Put gradients last
433+
grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False)
434+
grads = np.moveaxis(grads, -1, 0) # Put gradients first
435+
436+
if squeeze_happened:
437+
grads = np.expand_dims(grads, int(np.argwhere(grads_shape == 1))) # add back empty axis for compatibility
438+
elif grads.ndim == 4:
439+
# dP has gradients that can be treated as channels:
440+
grads = grads.transpose(1, 2, 3, 0) # move gradients last:
441+
grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False)
442+
grads = grads.transpose(1, 0, 2, 3) # switch axes to resize again
443+
grads = transforms.resize_image(grads, Ly=to_z_size, Lx=to_x_size, no_channels=False)
444+
grads = grads.transpose(3, 1, 0, 2) # undo transposition
445+
else:
446+
raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 3 or 4, grads shape: {grads.shape}')
447+
448+
return grads
377449

378450

379451
def _run_net(self, x,

docs/faq.rst

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,49 @@ FAQ
6666
using `torch.set_num_threads <https://pytorch.org/docs/stable/generated/torch.set_num_threads.html>`_ or through the environment
6767
variables ``OMP_NUM_THREADS`` or ``MKL_NUM_THREADS`` as described
6868
`here <https://pytorch.org/docs/stable/threading_environment_variables.html>`_.
69+
70+
71+
**Q: How does HITL work?**
72+
73+
In cellpose HITL training always starts from a pretrained model but incorporates more training
74+
data with each iteration. To start, only a single image is used as training data.
75+
After an iteration another image is included in the training data. Since there is more
76+
training data, the model should be more accurate on subsequent images.
77+
78+
The goal of HITL training is to produce a model that is finetuned on your data and also generalist
79+
enough to segment new images not in the training set. One of the problems with annotating
80+
images is that it can be time-consuming to annotate your images to produce a finetuned model.
81+
Cellpose also circumvents this tedium by using the already generalist-trained model to predict
82+
your image segmentation. This prediction will be better than nothing, and it will get some
83+
segmentation masks correct. That is helpful becuase you can accept the correct masks, and add
84+
or edit the incorrect ones. Now you have a new image that can be used for training a new finetuned
85+
model. This new finetuned model can then also predict segmentation for an image in your dataset,
86+
and, since it's finetuned on your data, will do somewhat better than the 'base' cellpose model.
87+
You can repeat these steps, (predict using the latest model, annotate the predictions, train,
88+
and predict again) until you have a model that performs well enough on your data.
89+
90+
91+
**Q: What is a 'model'?**
92+
93+
A model is the neural network architecture and parameters (fitted numbers) in that architecture.
94+
The CPSAM model we distribute is a 'model', and you can have another 'model' made from finetuning
95+
on your data. These models are similar becuase they have the same architecture, but distinct
96+
because they have different weights.
97+
98+
99+
**Q: How can I do HITL without the GUI? (I don't have GPU hardware on my machine, but I want to use
100+
colab/a cluster)**
101+
102+
You can do the following steps:
103+
104+
1. Load the images onto the remote machine.
105+
106+
2. Use a script to segment the image using the pretrained model.
107+
108+
3. Download the segmented image masks and annotate it with the cellpose GUI.
109+
110+
4. Load the annotated masks onto the remote machine and train a model with all the images in the folder (only 1 at first)
111+
112+
5. Evaluate the trained model on the next image.
113+
114+
6. Repeat 3-5 until you have a working fine-tuned model.

docs/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ can install it as ``pip install cellpose[gui]``.
1111

1212
- run Cellpose-SAM in the cloud (no install) at `Hugging Face <https://huggingface.co/spaces/mouseland/cellpose>`_.
1313
- `paper <https://www.biorxiv.org/content/10.1101/2025.04.28.651001v1>`_ on biorxiv
14-
- talk
14+
- `talk <https://www.youtube.com/watch?v=KIdYXgQemcI>`_
1515

1616

1717

docs/installation.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ Common issues
6262
If you receive an issue with Qt "xcb", you may need to install xcb libraries, e.g.:
6363

6464
::
65+
6566
sudo apt install libxcb-cursor0
6667
sudo apt install libxcb-xinerama0
6768

@@ -90,6 +91,7 @@ If you are having other issues with the graphical interface and QT, see some adv
9091
If you have errors related to OpenMP and libiomp5, then try
9192

9293
::
94+
9395
conda install nomkl
9496

9597
If you receive an error associated with **matplotlib**, try upgrading

docs/settings.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ See the :ref:`cpmclass` for all run options.
88

99
.. warning::
1010
Cellpose 3 used ``models.Cellpose`` class which has been removed in Cellpose 4. Users should
11-
now only use the ``models.CellposeModel``` class.
11+
now only use the ``models.CellposeModel`` class.
1212

1313
Here is an example of calling the ``CellposeModel`` class and
1414
running a list of images for reference:

tests/test_output.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,25 +38,28 @@ def clear_output(data_dir, image_names):
3838
os.remove(npy_output)
3939

4040

41-
@pytest.mark.parametrize('compute_masks, resample',
41+
@pytest.mark.parametrize('compute_masks, resample, diameter',
4242
[
43-
(True, True),
44-
(False, True),
45-
(False, False,)
43+
(True, True, 40),
44+
(True, True, None),
45+
(False, True, None),
46+
(False, False, None)
4647
]
4748
)
48-
def test_class_2D_one_img(data_dir, image_names, cellposemodel_fixture_24layer, compute_masks, resample):
49+
def test_class_2D_one_img(data_dir, image_names, cellposemodel_fixture_24layer, compute_masks, resample, diameter):
4950
clear_output(data_dir, image_names)
5051

5152
img_file = data_dir / '2D' / image_names[0]
5253

5354
img = io.imread_2D(img_file)
5455
# flowps = io.imread(img_file.parent / (img_file.stem + "_cp4_gt_flowps.tif"))
5556

56-
masks_pred, _, _ = cellposemodel_fixture_24layer.eval(img, normalize=True, compute_masks=compute_masks, resample=resample)
57+
masks_pred, _, _ = cellposemodel_fixture_24layer.eval(img, normalize=True, compute_masks=compute_masks, resample=resample, diameter=diameter)
5758

58-
if not compute_masks:
59-
return # just check that not compute_masks works:
59+
if not compute_masks or diameter:
60+
# not compute_masks won't return masks so can't check
61+
# different diameter will give different masks, so can't check
62+
return
6063

6164
io.imsave(data_dir / '2D' / (img_file.stem + "_cp_masks.png"), masks_pred)
6265
# flowsp_pred = np.concatenate([flows_pred[1], flows_pred[2][None, ...]], axis=0)

0 commit comments

Comments
 (0)