Optimize torch.func
full/diag GGN/Fisher-MC
#149
Labels
enhancement
New feature or request
Milestone
torch.func
full/diag GGN/Fisher-MC
#149
Currently, we compute the Jacobians explicitly. We can improve this by using VJPs.
Reference for full GGN: https://github.com/f-dangel/curvlinops/blob/5852711aedf2728bc609fabfa95eac00da1beb63/curvlinops/examples/functorch.py#L72-L138
Not a high priority since KFAC is usually used and (diag/full) EF implementations are already efficient.
The text was updated successfully, but these errors were encountered: