Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: output with shape [1, 256, 192] doesn't match the broadcast shape [3, 256, 192] #81

Open
steven5clu884 opened this issue Nov 17, 2021 · 4 comments

Comments

@steven5clu884
Copy link

I clone the repo, and then create an anaconda environment
I then type in python test.py
I have checkpoints
── GMM
│   └── gmm_final.pth
└── TOM
└── tom_final.pth

My traceback looks like
Traceback (most recent call last):
File "test.py", line 229, in
main()
File "test.py", line 215, in main
test_gmm(opt, test_loader, model, board)
File "test.py", line 86, in test_gmm
for step, inputs in enumerate(test_loader.data_loader):
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in next
data = self._next_data()
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
return self._process_data(data)
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
normalize1: torch.Size([3, 256, 192])
data.reraise()
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/_utils.py", line 434, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/utils/fetch.py", line 49, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/wizzerking/develop/cp-vton-plus/cp_dataset.py", line 150, in getitem
shape_ori = self.transform(parse_shape_ori) # [-1,1]
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 61, in call
img = t(img)
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in call_impl
return forward_call(*input, **kwargs)
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 226, in forward
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 352, in normalize
print("normalize1:",tensor.sub
(mean).div
(std).shape)
RuntimeError: output with shape [1, 256, 192] doesn't match the broadcast shape [3, 256, 192]

Things I have tried
in cp_dataset.py add .convert('RGB')
after each call top Image.open so for instance

   if self.stage == 'GMM':
        c = Image.open(osp.join(self.data_path, 'cloth', c_name))
        c = c.convert('RGB')
    
        cm = Image.open(osp.join(self.data_path, 'cloth-mask', c_name)).convert('L')
        cm = cm.convert('RGB')

This does not help; also adding to the person image

person image

    im = Image.open(osp.join(self.data_path, 'image', im_name))
    im = im.convert('RGB')

does not help

@rocketeerli
Copy link

you need to change you pytorch to 0.4.
Or you can just do as follows:
change the CPDataset class in cp_dataset.py
self.transform = transforms.Compose([ \ transforms.ToTensor(), \ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
to
self.transform = transforms.Compose([ \ transforms.ToTensor(), \ transforms.Normalize((0.5,), (0.5,))])

@slaifan
Copy link

slaifan commented Feb 5, 2022

Hi,

I am having the same problem, I couldn't find torch==0.4. I believe it is deprecated (please correct me if wrong).
Changing the Normalization code didn't work either as I am still getting the same error

Thanks in advance!

@zakmicallef
Copy link

If you updated your torch .view needs to be replaced with .reshape on network.py on line 135.

@niranjanakella
Copy link

In cp_dataset.py, line 30

Replace:

self.transform = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

With this line:

self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))])

And also in network.py, line 35
Replace:

x = x.view(x.size(0), -1)

With this line:

x = x.reshape(x.size(0), -1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants