Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tile mode makes inference crash everytime I use it for 3D volumes #107

Open
QuentinRapilly opened this issue Oct 30, 2024 · 3 comments
Open

Comments

@QuentinRapilly
Copy link

I want to use omnipose to segment 3D volumes.
I trained my own model and I had no issues for this part.
But for the inference, everytime I try to segment an image that is "too big" (volume of size ~200x500x300), the code crashes with the error printed below. When I try to activate the "tile" option, it still crashes on this images, but it does it also on smaller images (~100x500x300) that it use to process correctly without the "tile" mode on.
Does anyone have the same issue using the "tile" option with 3D volumes or knows how to solve it?

2024-10-30 15:08:04,062	[INFO]	core    _use...torch()	 line 74	** TORCH GPU version installed and working. **
>>> GPU activated? 1
2024-10-30 15:08:04,064	[INFO]	                    	 line 74	** TORCH GPU version installed and working. **
2024-10-30 15:08:04,064	[INFO]	        assi...evice()	 line 85	>>>> using GPU
2024-10-30 15:08:06,879	[INFO]	models  eval........()	 line 709	Evaluating with flow_threshold 0.00, mask_threshold -4.00
2024-10-30 15:08:06,880	[INFO]	                    	 line 711	using omni model, cluster False
2024-10-30 15:08:06,880	[INFO]	                    	 line 1095	using dataparallel
2024-10-30 15:08:06,880	[INFO]	                    	 line 1107	network initialized.
2024-10-30 15:08:06,942	[INFO]	                    	 line 1114	shape before transforms.convert_image(): (84, 509, 319)
multi-stack tiff read in as having 84 planes 1 channels
2024-10-30 15:08:06,942	[INFO]	models  eval........()	 line 1122	shape after transforms.convert_image(): (84, 509, 319, 1)
2024-10-30 15:08:06,942	[INFO]	                    	 line 1128	shape now (1, 84, 509, 319, 1)
Running on tiles. Now normalizing each tile separately.
Traceback (most recent call last):
  File "/home/qr211/rds/code/omnipose/inference.py", line 62, in <module>
    masks_om, flows_om, _ = model.eval(img, channels=chans, rescale=rescale,
  File "/root/omnipose/cellpose_omni/models.py", line 1134, in eval
    masks, styles, dP, cellprob, p, bd, tr, affinity, bounds  = self._run_cp(x,
  File "/root/omnipose/cellpose_omni/models.py", line 1262, in _run_cp
    yf, style = self._run_nets(img, net_avg=net_avg,
  File "/root/omnipose/cellpose_omni/core.py", line 419, in _run_nets
    y, style = self._run_net(img, augment=augment, tile=tile, tile_overlap=tile_overlap,
  File "/root/omnipose/cellpose_omni/core.py", line 514, in _run_net
    y, style = self._run_tiled(imgs, augment=augment, bsize=bsize,
  File "/root/omnipose/cellpose_omni/core.py", line 624, in _run_tiled
    y0, style = self.network(IMG[irange], return_conv=return_conv)
  File "/root/omnipose/cellpose_omni/core.py", line 367, in network
    y, style = self.net(X)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 183, in forward
    return self.module(*inputs[0], **module_kwargs[0])
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/omnipose/cellpose_omni/resnet_torch.py", line 275, in forward
    T0 = self.upsample(style, T0, self.mkldnn)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/omnipose/cellpose_omni/resnet_torch.py", line 219, in forward
    x = cp.checkpoint(self.upsampling,x) if self.checkpoint else self.upsampling(x) # doesn't do much
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/upsampling.py", line 156, in forward
    return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners,
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 3985, in interpolate
    return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
RuntimeError: Expected output.numel() <= std::numeric_limits<int32_t>::max() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
@kevinjohncutler
Copy link
Owner

@QuentinRapilly sorry I didn't see this. What is your GPU? Pretty sure this is limited VRAM, but you can rerun on much smaller tiles to get around that. I'll try to monitor this thread but please feel free to email me as well.

@QuentinRapilly
Copy link
Author

Hi @kevinjohncutler, I am running my code on a Nvidia A100 with 80Gb, that's why i could process the major part of my images without using the tile option. What is really strange is that once I activate the tile option the code crashes with this error everytime, no matter the size of the image I process, even on the smaller that used to be processed without the option.
And I did not find where to select the size of the tiles for evaluation:

def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
             invert=False, normalize=True, diameter=30., do_3D=False, anisotropy=None,
             net_avg=True, augment=False, tile=True, tile_overlap=0.1, resample=True, 
             interp=True, cluster=False, boundary_seg=False, affinity_seg=False, despur=True,
             flow_threshold=0.4, mask_threshold=0.0, 
             cellprob_threshold=None, dist_threshold=None, diam_threshold=12., min_size=15, max_size=None,
             stitch_threshold=0.0, rescale=None, progress=None, omni=False, verbose=False,
             transparency=False, model_loaded=False)

The two parameters refering to tiles are for overlapping and the boolean telling if you wanna use the tile option or not. But I might have missed something?
Thanks for your help.

@QuentinRapilly
Copy link
Author

QuentinRapilly commented Dec 18, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants