Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pos_weight should be a Tensor? #4

Open
shionhonda opened this issue May 27, 2019 · 4 comments
Open

pos_weight should be a Tensor? #4

shionhonda opened this issue May 27, 2019 · 4 comments

Comments

@shionhonda
Copy link

When I train with Cora dataset, I get the following error in binary_cross_entropy_with_logits. Shouldn't pos_weight be a Tensor? Thanks!

Traceback (most recent call last):
  File "train.py", line 83, in <module>
    gae_for(args)
  File "train.py", line 62, in gae_for
    norm=norm, pos_weight=pos_weight)
  File "/gae-pytorch/gae/optimizer.py", line 7, in loss_function
    cost = norm * F.binary_cross_entropy_with_logits(preds, labels, pos_weight=pos_weight)
  File "/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py", line 2077, in binary_cross_entropy_with_logits
    return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
TypeError: binary_cross_entropy_with_logits(): argument 'pos_weight' (position 4) must be Tensor, not numpy.float64
@shionhonda
Copy link
Author

I fix it by replacing this:

pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()

to this:

pos_weight = torch.Tensor([float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()])

@iamadog3333
Copy link

iamadog3333 commented Sep 26, 2019

It works.

but I found that length of pos_weight is 1.
According documents:
pos_weight (Tensor, optional) – a weight of positive examples. Must be a vector with length equal to the number of classes.
pos_weight 's length should be 2.

@iamadog3333
Copy link

iamadog3333 commented Sep 26, 2019

Doc says: where c is the class number (c>1 for multi-label binary classification, c=1 for single-label binary classification)
That is right.

@Dzhilin
Copy link

Dzhilin commented Jun 18, 2021

Doc says: where c is the class number (c>1 for multi-label binary classification, c=1 for single-label binary classification)
Hi, I have the same problem. Has your problem been solved?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants