kfac_jax 0.0.4
What's Changed
-
- Adding
use_exact_inverses
argument to optimizer by @copybara-service in #60
- Adding
- Better PSM tests. by @copybara-service in #66
-
- Minor fix of pi-adjusted-inverse by @copybara-service in #71
- Fixing minor type error. by @copybara-service in #76
-
- Minor fix of how jax scopes are named (':' is not valid). by @copybara-service in #80
-
- Improving docstring for optimizer. In particular regarding the damping parameter and LR/momentum/damping adaptation methods. by @copybara-service in #84
- Changing examples code so that dataset functions directly return iterators instead of TF datasets. by @copybara-service in #86
-
- Removing unused put_stop_grad_on_loss_factor argument in multiply_fisher_factor by @copybara-service in #92
- Use jax.device_put_replicated to broadcast to local devices by @copybara-service in #89
- feat(ci): enable
pip
caching in CI by @SauravMaheshkar in #94 - Updates the code to always create variables and computations of the same dtype as the its inputs. Previously, if float64 was enabled, some of the results would be (potentially incorrectly) promoted to higher precision. by @copybara-service in #82
- Adding data seen to the reported statistics on the evaluator in the examples. by @copybara-service in #101
-
- Modifying examples to only use label smoothing and L2 reg loss when training by @copybara-service in #102
-
- Minor fixes/improvements to optimizer docstrings by @copybara-service in #108
- Replaces references to jax.numpy.DeviceArray with jax.Array. by @copybara-service in #110
- Suppress pytype errors. by @copybara-service in #112
- Reenable some types that were previously disabled due to pytype crashes. by @copybara-service in #116
New Contributors
- @SauravMaheshkar made their first contribution in #94
Full Changelog: v0.0.3...v0.0.4