Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using fcn_resnet50_unet-bcss for transfer learning #884

Closed
Himanshunitrr opened this issue Nov 25, 2024 · 2 comments
Closed

Using fcn_resnet50_unet-bcss for transfer learning #884

Himanshunitrr opened this issue Nov 25, 2024 · 2 comments

Comments

@Himanshunitrr
Copy link

I went through a earlier similar issue. I want to use the semantic segmentor model (fcn_resnet50_unet-bcss) for transfer learning.

The params which I am using for creating an object of UNetModel as mentioned at

fcn_resnet50_unet-bcss:

pretrained_weights = "/data/hmaurya/hmaurya/tissue_mask_model.pth"
model = UNetModel(num_input_channels=3, num_output_channels=5, decoder_block=(3, 3))
saved_state_dict = torch.load(pretrained_weights, map_location="cpu")

print(saved_state_dict.keys())
model.load_state_dict(saved_state_dict, strict=False)

which gives me:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[7], [line 6](vscode-notebook-cell:?execution_count=7&line=6)
      [3](vscode-notebook-cell:?execution_count=7&line=3) saved_state_dict = torch.load(pretrained_weights, map_location="cpu")
      [5](vscode-notebook-cell:?execution_count=7&line=5) print(saved_state_dict.keys())
----> [6](vscode-notebook-cell:?execution_count=7&line=6) model.load_state_dict(saved_state_dict, strict=False)

File /data/hmaurya/hmaurya/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:2152, in Module.load_state_dict(self, state_dict, strict, assign)

RuntimeError: Error(s) in loading state_dict for UNetModel:
	size mismatch for clf.weight: copying a param with shape torch.Size([2, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([5, 64, 1, 1]).
	size mismatch for clf.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([5]).

wanted to know which params am I missing or what am I doing wrong.
Thanks!

@measty
Copy link
Collaborator

measty commented Nov 28, 2024

It looks like you have weights for a tissue mask model (number of output channels=2 i.e tissue or not) but you are creating the unet model with parameters defining output channels=5 which is for a tissue type segmentation task into 5 different tissue types. You need to correctly match your output_channel argument in your model definition, with the weights you want to load.

What is the task you are trying to accomplish? Tissue masking, or tissue type segmentation?

@Himanshunitrr
Copy link
Author

Ok, the ordering is quite confusing: https://github.com/TissueImageAnalytics/tiatoolbox/blob/develop/examples/06-semantic-segmentation.ipynb,

image

I downloaded the correct weights and it works

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants