-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Value expectation and 1st order CKY #93
base: master
Are you sure you want to change the base?
Conversation
Hmm, for some reason the tests are not running on this. Trying to figure out why. |
.sum(2) | ||
) | ||
assert torch.isclose( | ||
E_val, log_probs.exp().unsqueeze(-1).mul(struct_vals).sum(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious. Why not just make this the implementation of expected value? It seems just as good and perhaps more efficient.y
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, maybe I'm confused but isn't this enumerating over all possible structures explicitly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry, my comment is confusing.
I think a valid way of computing an expectation over any "part-level value" is to first compute the marginals (.marginals()) and then doing an elementwise mul (.mul) and then summing. Doesn't that give you the same thing as the semiring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh wow, I didn't realize this! I just tested it out and it appears to be more efficient for larger structure sizes. I guess this is due to the fast log semiring implementation? I'll update things to use this approach instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think that is right... I haven't thought about this too much, but my guess is that this is just better on GPU hardware since the expectation is batched at the end. But it seems worth understand when this works. I don't think you can compute Entropy this way? (but I might be wrong)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I also don't think entropy can be done this way -- I just tested it out and the results didn't match the semiring. I will switch to this implementation in the latest commit and get rid of the value semiring.
Fwiw I ran a quick speed comparison you might be interested in:
B, N, C = 4, 200, 10
phis = torch.randn(B,N,C,C).cuda()
vals = torch.randn(B,N,C,C,10).cuda()
Results from running w/ genbmm
%%timeit
LinearChainCRF(phis).expected_value(vals)
>>> 100 loops, best of 3: 6.34 ms per loop
%%timeit
LinearChainCRF(phis).marginals.unsqueeze(-1).mul(vals).reshape(B,-1,vals.shape[-1]).sum(1)
>>> 100 loops, best of 3: 5.64 ms per loop
Results from running w/o genbmm
%%timeit
LinearChainCRF(phis).expected_value(vals)
>>> 100 loops, best of 3: 9.67 ms per loop
%%timeit
LinearChainCRF(phis).marginals.unsqueeze(-1).mul(vals).reshape(B,-1,vals.shape[-1]).sum(1)
>>> 100 loops, best of 3: 8.83 ms per loop
torch_struct/distributions.py
Outdated
""" | ||
Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z. | ||
|
||
Params: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be "Parameters:"
torch_struct/distributions.py
Outdated
Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z. | ||
|
||
Params: | ||
* values (*batch_shape x *event_shape, *value_shape): torch.FloatTensor that assigns a value to each part |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's put the types in the first parens, and use :class:torch.FloatTensor
samples = [] | ||
for k in range(nsamples): | ||
if k % 10 == 0: | ||
if k % batch_size == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, sorry this is my fault. 10 is a global constant. Let's put it on MultiSampledSemiring.
torch_struct/distributions.py
Outdated
Implementation uses width-batched, forward-pass only | ||
|
||
* Parallel Time: :math:`O(N)` parallel merges. | ||
* Forward Memory: :math:`O(N^2)` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can't be right... isn't the event shape O(N^3) alone?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops yeah that's from modifying the CKYCRF class
torch_struct/full_cky_crf.py
Outdated
@@ -0,0 +1,114 @@ | |||
import torch | |||
from .helpers import _Struct, Chart | |||
from tqdm import tqdm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Be sure to run python setup.py style
to run flake8 . It will catch these errors.
torch_struct/helpers.py
Outdated
|
||
Returns: | ||
v (torch.Tensor) : the resulting output of the dynammic program | ||
edges (List[torch.Tensor]): the log edge potentials of the model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changing this to logpotentials
throughout.
torch_struct/helpers.py
Outdated
[scores], as in `Alignment`, `LinearChain`, `SemiMarkov`, `CKY_CRF`. | ||
An exceptional case is the `CKY` struct, which takes log potential parameters from production rules | ||
for a PCFG, which are by definition independent of position in the sequence. | ||
charts: Optional[List[Chart]] = None, the charts used in computing the dp. They are needed if we want to run the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Going to remove this for simplicity.
for k in range(v.shape[0]): | ||
obj = v[k].sum(dim=0) | ||
|
||
with torch.autograd.enable_grad(): # in case input potentials don't have grads enabled. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool
torch_struct/semirings/semirings.py
Outdated
return xs | ||
|
||
|
||
def ValueExpectationSemiring(k): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure we don't have this already? Could have sworn someone added it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure, I looked and hadn't seen it anywhere in master so I went ahead with it. Maybe it's in another branch? There's the entropy semiring which is very similar.
Thanks the PR. Lots of nice stuff in here. |
Quick dev question: when I try running |
Interesting, yeah not sure how to run those automatically, I will look into
it.
|
Changes are: