Skip to content

Commit

Permalink
what
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Oct 14, 2024
1 parent 7e278b5 commit 82e1cce
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/levanter/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(self, producer_fn: Callable[[], Iterator[Ex] | AsyncIterator[Ex]],

def _fill_queue_with_batches(self):
with local_cpu_mesh():
print("sub", jax.devices())
super()._fill_queue_with_batches()


Expand Down
3 changes: 3 additions & 0 deletions src/levanter/utils/background_iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import threading
from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, TypeVar, Union

import jax
import tblib

from levanter.utils.thread_utils import AsyncIteratorWrapper
Expand Down Expand Up @@ -92,6 +93,7 @@ def _fill_queue_with_batches(self):
if isinstance(iterator, Iterator):
self._produce_batches_sync(iterator)
else:
print("asyncio", jax.devices())
asyncio.run(self._produce_batches_async(iterator))
except Exception:
self.q.put(_ExceptionWrapper(sys.exc_info()))
Expand Down Expand Up @@ -121,6 +123,7 @@ def _produce_batches_sync(self, iterator):
async def _produce_batches_async(self, iterator):
try:
async for batch in iterator:
print(jax.devices())
while not self._stop_event.is_set():
try:
self.q.put(batch, block=True, timeout=1)
Expand Down

0 comments on commit 82e1cce

Please sign in to comment.