From 82e1cce4700cfa092de939ed1d33f00e2b5ac659 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 14 Oct 2024 09:42:00 -0700 Subject: [PATCH] what --- src/levanter/data/loader.py | 1 + src/levanter/utils/background_iterable.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index c93043dec..ca5793d12 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -264,6 +264,7 @@ def __init__(self, producer_fn: Callable[[], Iterator[Ex] | AsyncIterator[Ex]], def _fill_queue_with_batches(self): with local_cpu_mesh(): + print("sub", jax.devices()) super()._fill_queue_with_batches() diff --git a/src/levanter/utils/background_iterable.py b/src/levanter/utils/background_iterable.py index 11a80f8ec..593cc40fb 100644 --- a/src/levanter/utils/background_iterable.py +++ b/src/levanter/utils/background_iterable.py @@ -4,6 +4,7 @@ import threading from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, TypeVar, Union +import jax import tblib from levanter.utils.thread_utils import AsyncIteratorWrapper @@ -92,6 +93,7 @@ def _fill_queue_with_batches(self): if isinstance(iterator, Iterator): self._produce_batches_sync(iterator) else: + print("asyncio", jax.devices()) asyncio.run(self._produce_batches_async(iterator)) except Exception: self.q.put(_ExceptionWrapper(sys.exc_info())) @@ -121,6 +123,7 @@ def _produce_batches_sync(self, iterator): async def _produce_batches_async(self, iterator): try: async for batch in iterator: + print(jax.devices()) while not self._stop_event.is_set(): try: self.q.put(batch, block=True, timeout=1)