Skip to content

Commit

Permalink
Replace references to deprecated device_buffer attributes
Browse files Browse the repository at this point in the history
`jax.Array.device_buffer` and `jax.Array.device_buffers` will be deprecated as of jax version 0.4.22; see jax-ml/jax#18844.

PiperOrigin-RevId: 588553845
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Dec 7, 2023
1 parent d72bd65 commit a90fe64
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions trax/optimizers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,11 @@ def _free_accelerators(self, exceptions=(), keep_constants=True):
logging.info('Deleting %d live buffers.', len(live_buffers))
exceptions_buffers = []
for x in fastmath.tree_flatten(exceptions):
if hasattr(x, 'device_buffer'): # DeviceArray
if hasattr(x, 'addressable_shards'): # Array
exceptions_buffers.extend(shard.data for shard in x.addressable_shards)
elif hasattr(x, 'device_buffer'): # DeviceArray
exceptions_buffers.append(x.device_buffer)
if hasattr(x, 'device_buffers'): # ShardedDeviceArray
elif hasattr(x, 'device_buffers'): # ShardedDeviceArray
exceptions_buffers.extend(x.device_buffers)
for b in live_buffers:
should_delete = True
Expand Down

0 comments on commit a90fe64

Please sign in to comment.