diff --git a/trax/optimizers/trainer.py b/trax/optimizers/trainer.py index 41707701f..4d4c4bd9d 100644 --- a/trax/optimizers/trainer.py +++ b/trax/optimizers/trainer.py @@ -445,9 +445,11 @@ def _free_accelerators(self, exceptions=(), keep_constants=True): logging.info('Deleting %d live buffers.', len(live_buffers)) exceptions_buffers = [] for x in fastmath.tree_flatten(exceptions): - if hasattr(x, 'device_buffer'): # DeviceArray + if hasattr(x, 'addressable_shards'): # Array + exceptions_buffers.extend(shard.data for shard in x.addressable_shards) + elif hasattr(x, 'device_buffer'): # DeviceArray exceptions_buffers.append(x.device_buffer) - if hasattr(x, 'device_buffers'): # ShardedDeviceArray + elif hasattr(x, 'device_buffers'): # ShardedDeviceArray exceptions_buffers.extend(x.device_buffers) for b in live_buffers: should_delete = True