From 10a95cc0d087fb2b25fc1108e3ba00b450fa722f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 6 Dec 2023 15:01:57 -0800 Subject: [PATCH] Replace references to deprecated jax array attributes device_buffer and device_buffers PiperOrigin-RevId: 588553845 --- trax/optimizers/trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trax/optimizers/trainer.py b/trax/optimizers/trainer.py index 41707701f..4d4c4bd9d 100644 --- a/trax/optimizers/trainer.py +++ b/trax/optimizers/trainer.py @@ -445,9 +445,11 @@ def _free_accelerators(self, exceptions=(), keep_constants=True): logging.info('Deleting %d live buffers.', len(live_buffers)) exceptions_buffers = [] for x in fastmath.tree_flatten(exceptions): - if hasattr(x, 'device_buffer'): # DeviceArray + if hasattr(x, 'addressable_shards'): # Array + exceptions_buffers.extend(shard.data for shard in x.addressable_shards) + elif hasattr(x, 'device_buffer'): # DeviceArray exceptions_buffers.append(x.device_buffer) - if hasattr(x, 'device_buffers'): # ShardedDeviceArray + elif hasattr(x, 'device_buffers'): # ShardedDeviceArray exceptions_buffers.extend(x.device_buffers) for b in live_buffers: should_delete = True