Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear assignment layer training #6

Open
PhanTask opened this issue Jul 29, 2022 · 3 comments
Open

Linear assignment layer training #6

PhanTask opened this issue Jul 29, 2022 · 3 comments

Comments

@PhanTask
Copy link

PhanTask commented Jul 29, 2022

Hi, thanks so much for open sourcing this solid work! I try to understand how the linear assignment layer is trained and would appreciate your help.

In the class AttentionalGNN, the linear assignment score is computed and processed as this:

        scores = log_optimal_transport(scores.log_softmax(dim=-2), self.bin_score, iters=5)[:, :-1, :-1].view(unreachable.shape)
        score_min = scores.min() - scores.max()
        scores = scores + (score_min - 40) * invalid.float() + (score_min - 20) * unreachable.float()

        return scores * 15

Could you please explain a bit more about this code? I am not pretty sure how to interpret invalid.float(), unreachable.float() and numbers used here (e.g., 40, 20, 15).

Also I try to use such a differentiable linear assignment layer in my work, where I do the bipartite matching to match the same robots across different frames (t0 <-> t1, t1 <-> t2, t2 <-> t3, ...) to get matching traces, and use a discriminator to evaluate the overall quality of the traces and get a loss. Overall it is more like a GAN structure. How should I use this loss to update the linear assignment layer? Thanks!

@illusive-chase
Copy link
Collaborator

As scores represents "log probability of frontiers being sampled", we use a bias like (score_min - 40) * invalid.float() to make it exp(40)=2E+17 times less likely that these INVALID frontiers will be sampled.

Similarly, we add (score_min - 20) * unreachable.float(). Here we use 20 instead of 40 because we expect that if all frontiers are either INVALID or UNREACHABLE, we would rather sample the UNREACHABLE ones than the INVALID ones.

And we think that the linear assignment layer is non-learning and does not need to be trained. It is considered as a differentiable version of the Hungarian algorithm. We use GNN to evaluate the cost between nodes and use the linear assignment layer to do assignment based on the cost.

Hope that helps :)

@JiayunjieJYJ
Copy link

As scores represents "log probability of frontiers being sampled", we use a bias like (score_min - 40) * invalid.float() to make it exp(40)=2E+17 times less likely that these INVALID frontiers will be sampled.

Similarly, we add (score_min - 20) * unreachable.float(). Here we use 20 instead of 40 because we expect that if all frontiers are either INVALID or UNREACHABLE, we would rather sample the UNREACHABLE ones than the INVALID ones.

And we think that the linear assignment layer is non-learning and does not need to be trained. It is considered as a differentiable version of the Hungarian algorithm. We use GNN to evaluate the cost between nodes and use the linear assignment layer to do assignment based on the cost.

Hope that helps :)

Could you please further explain the meaning of :
return scores * 15

And how this matrix influence the final output action?

Thank you very much!

@illusive-chase
Copy link
Collaborator

Similarly, scores * 15 is to sharpen the distribution of output actions.

This code may help you understand better.

>>> import torch
>>> C = torch.distributions.Categorical
>>> torch.randn(10)
tensor([ 0.2402,  0.9646, -0.0617,  0.5511,  0.1720, -0.5559, -0.3629,  0.3931,
        -0.7448, -0.2066])
>>> a=torch.randn(10)
>>> a
tensor([-0.9821,  2.3382, -0.2364, -1.0046, -1.2925,  0.1279,  0.4886, -1.5084,
        -0.8465,  0.5353])
>>> C(logits=a).probs
tensor([0.0217, 0.5993, 0.0457, 0.0212, 0.0159, 0.0657, 0.0943, 0.0128, 0.0248,
        0.0988])
>>> C(logits=a*15).probs
tensor([2.3463e-22, 1.0000e+00, 1.6899e-17, 1.6732e-22, 2.2271e-24, 3.9920e-15,
        8.9233e-13, 8.7412e-26, 1.7918e-21, 1.8004e-12])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants