Releases: google-deepmind/kfac-jax
Releases · google-deepmind/kfac-jax
v0.0.6
What's Changed
- Adding logging for the number of parameters and optimizer state. by @copybara-service in #125
- Adding automatic cross-device averaging of auxiliary loss/models stats to optimizer. by @copybara-service in #139
- Add
rel_grad_norm
andrel_update_norm
stats logging by @copybara-service in #147 - Fixing bug that would sometimes cause an exception for networks with scalar-valued parameters. by @copybara-service in #151
- [JAX] Migrate XlaBuilder users to emit direct stablehlo MLIR lowerings. by @copybara-service in #161
- Still fixing docs requirements dependencies. by @copybara-service in #166
- Still fixing docs requirements dependencies. by @copybara-service in #169
- Still fixing docs requirements dependencies. by @copybara-service in #171
- Still fixing docs requirements dependencies. by @copybara-service in #173
- Adding capability pass custom arguments to the registration functions, and call them in a custom module, for standard losses in the example code. by @copybara-service in #175
- Fix or ignore some pytype errors. by @copybara-service in #177
- [LSC] Ignore incorrect type annotations related to jax.numpy APIs by @copybara-service in #176
-
- Adding a
sum_of_objects
. by @copybara-service in #190
- Adding a
-
- Adding Polyak averaging feature to example experiments codebase. by @copybara-service in #195
- Adding precon_damping_mult feature to optimizer. by @copybara-service in #196
- Reland jax-ml/jax#10573. by @copybara-service in #199
-
- minor refactoring by @copybara-service in #201
- Fixing issue where loss_registered_reldiff was not computed properly in multi-device settings. by @copybara-service in #202
- Adding a new schedule and applying some fixes to existing ones in the examples codebase. by @copybara-service in #204
- Remove gradient normalization from the preconditioning function by @copybara-service in #206
Full Changelog: v0.0.5...v0.0.6
kfac-jax 0.0.5.
What's Changed
-
- Modifying the structure of the experiments class by separating the Supervised experiment from the Jaxline experiment inheritance. by @copybara-service in #122
-
- Introducing a more general KroneckerFactored block class. by @copybara-service in #123
-
- Bumping up the package version due to changes in the curvature block classes. by @copybara-service in #124
Full Changelog: v0.0.4...v0.0.5
kfac_jax 0.0.4
What's Changed
-
- Adding
use_exact_inverses
argument to optimizer by @copybara-service in #60
- Adding
- Better PSM tests. by @copybara-service in #66
-
- Minor fix of pi-adjusted-inverse by @copybara-service in #71
- Fixing minor type error. by @copybara-service in #76
-
- Minor fix of how jax scopes are named (':' is not valid). by @copybara-service in #80
-
- Improving docstring for optimizer. In particular regarding the damping parameter and LR/momentum/damping adaptation methods. by @copybara-service in #84
- Changing examples code so that dataset functions directly return iterators instead of TF datasets. by @copybara-service in #86
-
- Removing unused put_stop_grad_on_loss_factor argument in multiply_fisher_factor by @copybara-service in #92
- Use jax.device_put_replicated to broadcast to local devices by @copybara-service in #89
- feat(ci): enable
pip
caching in CI by @SauravMaheshkar in #94 - Updates the code to always create variables and computations of the same dtype as the its inputs. Previously, if float64 was enabled, some of the results would be (potentially incorrectly) promoted to higher precision. by @copybara-service in #82
- Adding data seen to the reported statistics on the evaluator in the examples. by @copybara-service in #101
-
- Modifying examples to only use label smoothing and L2 reg loss when training by @copybara-service in #102
-
- Minor fixes/improvements to optimizer docstrings by @copybara-service in #108
- Replaces references to jax.numpy.DeviceArray with jax.Array. by @copybara-service in #110
- Suppress pytype errors. by @copybara-service in #112
- Reenable some types that were previously disabled due to pytype crashes. by @copybara-service in #116
New Contributors
- @SauravMaheshkar made their first contribution in #94
Full Changelog: v0.0.3...v0.0.4
kfac-jax 0.0.3
What's Changed
- Changing the version in the citation text in the README. by @copybara-service in #29
- Adding attributes for the number of training and evaluation devices. by @copybara-service in #31
- Adding some methods to ImplicitExactCurvature by @copybara-service in #32
- Adding "put_stop_grad_on_loss_factor" argument to 'multiply_fisher_factor'. by @copybara-service in #36
- Making ScaleAndShift blocks begin capable of having parameters that are broadcast by construction, e.g. batch norm with scale parameters [1, 1, 1, d]. by @copybara-service in #33
-
- Changing
jax.tree_map
->jax.tree_util.tree_map
and related due to recent deprecation. by @copybara-service in #37
- Changing
-
- Removed unused precedence argument from GraphPattern. by @copybara-service in #38
- Fix a small bug where we don't check in the jaxpr constvars. by @copybara-service in #39
-
- Adding an
estimator
attribute to the optimizer. by @copybara-service in #34
- Adding an
- Updating the docs to correctly refer to
update_cache
. by @copybara-service in #40 - Compare with slightly less numerical precision. by @copybara-service in #41
-
- Revamping the graph matching code to be able to detect layers and register tag in arbitrary higher-order Jax primitives. by @copybara-service in #42
- Revising docstring for optimizer class. Now contains missing details about value_and_grad_func. by @copybara-service in #43
- Internal change. by @copybara-service in #44
-
- Make LossTag to return only the parameter dependent arrays. by @copybara-service in #46
-
- Improving LossTags to be able to deal correctly with None arguments, by passing in argument names. by @copybara-service in #47
- Minor fix to a bug introduced on previous commit. by @copybara-service in #48
-
- Correcting issues with docstring for optimizer. by @copybara-service in #45
- Fixing a bug in the graph matcher introduced in a recent CL. by @copybara-service in #49
- Removing unneeded jax.jit in get_mean and get_sum. by @copybara-service in #50
-
- Adding per-parameter norm stats to optimizer by @copybara-service in #51
- Allowing the pi-adjusted psd inverse to accept diagonal factors. by @copybara-service in #55
- Fixing wrong type annotation of pmap_axis_name. by @copybara-service in #56
- Adding optional offloading of
eigh
computation to the host because of a bug in CUDA 11.7.0 cuSOLVER library. by @copybara-service in #57
Full Changelog: v0.0.2...v0.0.3
kfac-jax 0.0.2.
What's Changed
- Moving .github to top-level directory for CI. by @copybara-service in #1
-
- Updated documentation for state classes. by @copybara-service in #2
- Changing the name on PyPi to kfac-jax. by @copybara-service in #3
- Making the tracer test in float64. by @copybara-service in #4
-
- Allowing graph patterns with multiple broadcast to be merged without dangling equations. by @copybara-service in #5
-
- Adding README for the examples. by @copybara-service in #7
- Changing deprecated
tree_multimap
totree_map
. by @copybara-service in #8 - Fixing small error introduced due to updates to chex. by @copybara-service in #11
- Fixing typo "drop_reminder" by @copybara-service in #13
-
- Adding an argument to set the reduction ratio thresholds for automatic damping adjustment. by @copybara-service in #12
-
- Adding "modifiable_attribute_exceptions" argument to optimizer by @copybara-service in #14
- Changing Imagenet dataset in examples to use a seed for file shuffling to achieve determinism. by @copybara-service in #17
- Small fix to a doc reference bug. by @copybara-service in #16
- Making WeightedMovingAverage to work with arbitrary structures. by @copybara-service in #19
-
- Minor typos. by @copybara-service in #20
- Correct buffer donation of Optimizer._step. by @copybara-service in #21
- Replacing
yield from
with direct iteration. by @copybara-service in #24 - Adding stepwise schedule option to examples. by @copybara-service in #18
- Publishing a new version to PyPi. by @copybara-service in #28
New Contributors
- @copybara-service made their first contribution in #1
Full Changelog: https://github.com/deepmind/kfac-jax/commits/v0.0.2