From fef836a091750369afddfe626c89d28668482240 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 14 Oct 2024 00:11:33 -0700 Subject: [PATCH 1/6] fix crash in data loader caused by using stale array --- src/levanter/data/loader.py | 201 +++++++++++++++++++----------------- 1 file changed, 106 insertions(+), 95 deletions(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index ab97e0827..fdecfa245 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -21,7 +21,8 @@ from levanter.data.utils import batched from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape from levanter.utils.background_iterable import BackgroundIterable -from levanter.utils.thread_utils import blocking_wait +from levanter.utils.jax_utils import local_cpu_mesh +from levanter.utils.thread_utils import AsyncIteratorWrapper, blocking_wait Ex = TypeVar("Ex") @@ -98,6 +99,8 @@ def __iter__(self): return self.iter_from_step(None) def iter_from_step(self, start_from_batch: Optional[int] = None): + # sometimes we pass in an array for the start_from_batch, so we need to check for that + start_from_batch = int(start_from_batch) if start_from_batch is not None else None return DataLoaderIterator(self, start_from_batch=start_from_batch) @@ -109,115 +112,131 @@ def __init__(self, data_loader: DataLoader, start_from_batch: Optional[int] = No if self.mapping is None: self.mapping = hax.partitioning.current_thread_local_mapping() - # TODO: bring back non-prefetching version buffered_batches = self.dl.max_buffered_batches - self._batches = iter(BackgroundIterable(self._produce_batches, max_capacity=buffered_batches)) + if buffered_batches == 0: + self._batches = AsyncIteratorWrapper(self._produce_batches()) + else: + self._batches = iter(BackgroundIterable(self._produce_batches, max_capacity=buffered_batches)) def __next__(self): time_start = time.time() - out = next(self._batches) + individual_data_batch = next(self._batches) + data_for_this_batch = {index: datum for index, datum in zip(self.dl._local_indices, individual_data_batch)} + batch = self._batchify_local_data(data_for_this_batch) + time_end = time.time() if (time_end - time_start) > 0.5: logger.info(f"Prefetch wasn't fast enough: {time_end - time_start:.3f}") - return out + return batch async def _produce_batches(self): batch_number = self._start_from_batch or 0 - total_ex_loaded = 0 done = False while not done: - next_batch_numbers = [] - for i in range(self.dl.prefetch_size): - if self.dl.data_store.is_finite(): - next_end = (batch_number + 1) * self.dl.batch_size - available_len = await self.dl.data_store.wait_until_len_at_least(next_end) - if available_len < next_end: - done = True - break - - next_batch_numbers.append(batch_number) - batch_number += 1 + target_next_batch_number = batch_number + self.dl.prefetch_size + max_achievable_batch_number = await self._dataset_get_available_batch_number(target_next_batch_number) + if max_achievable_batch_number < target_next_batch_number: + done = True + + next_batch_numbers = list(range(batch_number, min(target_next_batch_number, max_achievable_batch_number))) + + if len(next_batch_numbers) == 0: + break + + batch_number = next_batch_numbers[-1] + 1 async for batch in self._retrieve_batches(next_batch_numbers): yield batch - total_ex_loaded += self.dl.batch_size * len(next_batch_numbers) + async def _dataset_get_available_batch_number(self, target_max_batch_number: int) -> int: + if self.dl.data_store.is_finite(): + next_end = (target_max_batch_number + 1) * self.dl.batch_size + available_len = await self.dl.data_store.wait_until_len_at_least(next_end) + max_achievable_batch_number = available_len // self.dl.batch_size - async def _retrieve_batches(self, batch_numbers: list[int]): - with hax.axis_mapping(self.mapping), self.dl.mesh: - indices_for_this_batch_of_batches: list[int] = [] - for bn in batch_numbers: - indices_this_batch = range(bn * self.dl.batch_size, (bn + 1) * self.dl.batch_size, 1) - indices_this_batch_this_process = [indices_this_batch[i] for i in self.dl._local_indices] - indices_for_this_batch_of_batches.extend(indices_this_batch_this_process) + return max_achievable_batch_number + + return target_max_batch_number + async def _retrieve_batches(self, batch_numbers: list[int]): + with local_cpu_mesh(): time_start = time.time() - individual_datums = await self.dl.data_store.get_batch(indices_for_this_batch_of_batches) + individual_datums_for_each_batch = await self._do_retrieve_batch_of_batches(batch_numbers) + # reshape to be per batch time_end = time.time() logger.debug(f"Time to get {len(batch_numbers)} batches: {time_end - time_start:.3f}") - time_start = time.time() - # reshape to be per batch - individual_datums = list(batched(individual_datums, len(self.dl._local_indices))) - - # below we're gonna get the indices relative to this batch (i.e. 0 to batch_size) - index_to_datum = [ - {index: datum for index, datum in zip(self.dl._local_indices, individual_data_batch)} - for individual_data_batch in individual_datums - ] - - def get_local_batch(bn: int, begin: int, end: int) -> list: - # TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example - # which will require support from the datastore (i.e. tensorstore) - device_batch = _stack_tree(self.dl.Batch.name, [index_to_datum[bn][i] for i in range(begin, end)]) - batch_leaves = hax.tree_util.tree_leaves(device_batch) - return batch_leaves - - def get_local_data_for_leaf(bn, indices: _TensorSliceIndex, leaf_index: int) -> Array: - batch_slice = indices[0] - begin, end, stride = batch_slice.indices(self.dl.batch_size) - if stride != 1: - raise ValueError("Stride must be 1") - - leaf_data = (get_local_batch(bn, begin, end))[leaf_index] - - if isinstance(leaf_data, hax.NamedArray): - # select out the batch axis - batch_index = index_where(lambda ax: ax.name == self.dl.Batch.name, leaf_data.axes) - new_indices = list(indices) - new_indices[batch_index] = slice(None) - return leaf_data.array[tuple(new_indices)] + for data in individual_datums_for_each_batch: + yield data + + def _batchify_local_data(self, data_for_this_batch: dict[int, Array]): + cache: dict[tuple[int, int], list[Array | hax.NamedArray]] = {} + + def get_local_batch(begin: int, end: int) -> list: + if (begin, end) in cache: + return cache[(begin, end)] + + # TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example + # which will require support from the datastore (i.e. tensorstore) + device_batch = _stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)]) + batch_leaves = hax.tree_util.tree_leaves(device_batch) + + cache[(begin, end)] = batch_leaves + + return batch_leaves + + def get_local_data_for_leaf(indices: _TensorSliceIndex, leaf_index: int) -> Array: + batch_slice = indices[0] + begin, end, stride = batch_slice.indices(self.dl.batch_size) + if stride != 1: + raise ValueError("Stride must be 1") + + leaf_data = get_local_batch(begin, end)[leaf_index] + + if isinstance(leaf_data, hax.NamedArray): + # select out the batch axis + batch_index = index_where(lambda ax: ax.name == self.dl.Batch.name, leaf_data.axes) + new_indices = list(indices) + new_indices[batch_index] = slice(None) + return leaf_data.array[tuple(new_indices)] + else: + other_indices = indices[1:] + if all(idx == slice(None) for idx in other_indices): + return leaf_data else: - other_indices = indices[1:] - if all(idx == slice(None) for idx in other_indices): - return leaf_data - else: - # TODO: this doesn't work with named axes - return leaf_data[(..., *other_indices)] - - for batch_offset, bn in enumerate(batch_numbers): - - def make_global_array_for_leaf(leaf_index, item_leaf_shape: ShapeSpec | NamedShapeSpec): - def get_data(indices): - return get_local_data_for_leaf(batch_offset, indices, leaf_index) - - raw_array = jax.make_array_from_callback( - to_raw_shape(item_leaf_shape), - jax.sharding.NamedSharding(self.dl.mesh, self._pspec_for(item_leaf_shape)), - get_data, - ) - if isinstance(item_leaf_shape, NamedShapeSpec): - return hax.NamedArray(raw_array, item_leaf_shape.shape) - else: - return raw_array - - gda_leaves = [ - make_global_array_for_leaf(leaf_index, _batchified_shape(self.dl.Batch, item_leaf)) - for leaf_index, item_leaf in enumerate(self.dl._ex_leaves) - ] - - gda_tree = jax.tree.unflatten(self.dl._ex_structure, gda_leaves) - yield gda_tree + # TODO: this doesn't work with named axes + return leaf_data[(..., *other_indices)] + + def make_global_array_for_leaf(leaf_index, item_leaf_shape: ShapeSpec | NamedShapeSpec): + def get_data(indices): + return get_local_data_for_leaf(indices, leaf_index) + + raw_array = jax.make_array_from_callback( + to_raw_shape(item_leaf_shape), + jax.sharding.NamedSharding(self.dl.mesh, self._pspec_for(item_leaf_shape)), + get_data, + ) + if isinstance(item_leaf_shape, NamedShapeSpec): + return hax.NamedArray(raw_array, item_leaf_shape.shape) + else: + return raw_array + + gda_leaves = [ + make_global_array_for_leaf(leaf_index, _batchified_shape(self.dl.Batch, item_leaf)) + for leaf_index, item_leaf in enumerate(self.dl._ex_leaves) + ] + gda_tree = jax.tree.unflatten(self.dl._ex_structure, gda_leaves) + return gda_tree + + async def _do_retrieve_batch_of_batches(self, batch_numbers): + indices_for_this_batch_of_batches: list[int] = [] + for bn in batch_numbers: + indices_this_batch = range(bn * self.dl.batch_size, (bn + 1) * self.dl.batch_size, 1) + indices_this_batch_this_process = [indices_this_batch[i] for i in self.dl._local_indices] + indices_for_this_batch_of_batches.extend(indices_this_batch_this_process) + individual_datums = await self.dl.data_store.get_batch(indices_for_this_batch_of_batches) + individual_datums_for_each_batch = list(batched(individual_datums, len(self.dl._local_indices))) + return individual_datums_for_each_batch def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec: if isinstance(shape_spec, ShapeSpec): # type: ignore @@ -246,14 +265,6 @@ def _batchified_shape(Batch, leaf: hax.NamedArray | Array) -> ShapeSpec | NamedS return ShapeSpec((Batch.size,) + leaf.shape, leaf.dtype) -def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec: - if isinstance(shape_spec, ShapeSpec): # type: ignore - batch_name = hax.partitioning.physical_axis_name(self.Batch, self.axis_resources) - return PartitionSpec(batch_name, *((None,) * (len(shape_spec.shape) - 1))) - else: - return hax.partitioning.pspec_for_axis(shape_spec.shape, self.axis_resources) # type: ignore - - @functools.partial(jax.jit, static_argnums=(0,)) def _stack_tree(batch_name, individual_datums): def _stack_leaves_unchecked(*leaves): From 7e278b5f250abf861c2338b49d1bcd42f6e82e76 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 14 Oct 2024 08:59:36 -0700 Subject: [PATCH 2/6] try this --- src/levanter/data/loader.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index fdecfa245..c93043dec 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -2,7 +2,7 @@ import logging import time from collections import defaultdict -from typing import Iterable, Iterator, Optional, Tuple, TypeVar +from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, Tuple, TypeVar import jax from jax import Array @@ -20,7 +20,7 @@ from levanter.data.dataset import AsyncDataset from levanter.data.utils import batched from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape -from levanter.utils.background_iterable import BackgroundIterable +from levanter.utils.background_iterable import BackgroundIterator from levanter.utils.jax_utils import local_cpu_mesh from levanter.utils.thread_utils import AsyncIteratorWrapper, blocking_wait @@ -113,10 +113,11 @@ def __init__(self, data_loader: DataLoader, start_from_batch: Optional[int] = No self.mapping = hax.partitioning.current_thread_local_mapping() buffered_batches = self.dl.max_buffered_batches + self._batches: Iterator[Ex] if buffered_batches == 0: self._batches = AsyncIteratorWrapper(self._produce_batches()) else: - self._batches = iter(BackgroundIterable(self._produce_batches, max_capacity=buffered_batches)) + self._batches = _JaxCpuBackgroundIterator(self._produce_batches, max_capacity=buffered_batches) def __next__(self): time_start = time.time() @@ -246,18 +247,6 @@ def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec: return hax.partitioning.pspec_for_axis(shape_spec.shape, self.dl.axis_resources) # type: ignore -def _abstractify(x): - def _abstractify_array(x): - if isinstance(x, jax.numpy.ndarray): - return ShapeSpec(x.shape, x.dtype) - elif isinstance(x, hax.NamedArray): - return NamedShapeSpec(x.axes, x.dtype) - - return x - - return hax.tree_util.tree_map(_abstractify_array, x) - - def _batchified_shape(Batch, leaf: hax.NamedArray | Array) -> ShapeSpec | NamedShapeSpec: if is_named_array(leaf): return NamedShapeSpec((Batch,) + leaf.axes, leaf.dtype) @@ -265,6 +254,19 @@ def _batchified_shape(Batch, leaf: hax.NamedArray | Array) -> ShapeSpec | NamedS return ShapeSpec((Batch.size,) + leaf.shape, leaf.dtype) +class _JaxCpuBackgroundIterator(BackgroundIterator[Ex]): + """ + We want the thread to only use the CPU device. + """ + + def __init__(self, producer_fn: Callable[[], Iterator[Ex] | AsyncIterator[Ex]], max_capacity: Optional[int]): + super().__init__(producer_fn, max_capacity) + + def _fill_queue_with_batches(self): + with local_cpu_mesh(): + super()._fill_queue_with_batches() + + @functools.partial(jax.jit, static_argnums=(0,)) def _stack_tree(batch_name, individual_datums): def _stack_leaves_unchecked(*leaves): From 82e1cce4700cfa092de939ed1d33f00e2b5ac659 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 14 Oct 2024 09:42:00 -0700 Subject: [PATCH 3/6] what --- src/levanter/data/loader.py | 1 + src/levanter/utils/background_iterable.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index c93043dec..ca5793d12 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -264,6 +264,7 @@ def __init__(self, producer_fn: Callable[[], Iterator[Ex] | AsyncIterator[Ex]], def _fill_queue_with_batches(self): with local_cpu_mesh(): + print("sub", jax.devices()) super()._fill_queue_with_batches() diff --git a/src/levanter/utils/background_iterable.py b/src/levanter/utils/background_iterable.py index 11a80f8ec..593cc40fb 100644 --- a/src/levanter/utils/background_iterable.py +++ b/src/levanter/utils/background_iterable.py @@ -4,6 +4,7 @@ import threading from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, TypeVar, Union +import jax import tblib from levanter.utils.thread_utils import AsyncIteratorWrapper @@ -92,6 +93,7 @@ def _fill_queue_with_batches(self): if isinstance(iterator, Iterator): self._produce_batches_sync(iterator) else: + print("asyncio", jax.devices()) asyncio.run(self._produce_batches_async(iterator)) except Exception: self.q.put(_ExceptionWrapper(sys.exc_info())) @@ -121,6 +123,7 @@ def _produce_batches_sync(self, iterator): async def _produce_batches_async(self, iterator): try: async for batch in iterator: + print(jax.devices()) while not self._stop_event.is_set(): try: self.q.put(batch, block=True, timeout=1) From f31d9947b36c5325c17dd9e08095a7937c5ca062 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 14 Oct 2024 11:08:10 -0700 Subject: [PATCH 4/6] ok i feel like we have it? --- src/levanter/data/loader.py | 1 - src/levanter/utils/background_iterable.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index ca5793d12..c93043dec 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -264,7 +264,6 @@ def __init__(self, producer_fn: Callable[[], Iterator[Ex] | AsyncIterator[Ex]], def _fill_queue_with_batches(self): with local_cpu_mesh(): - print("sub", jax.devices()) super()._fill_queue_with_batches() diff --git a/src/levanter/utils/background_iterable.py b/src/levanter/utils/background_iterable.py index 593cc40fb..11a80f8ec 100644 --- a/src/levanter/utils/background_iterable.py +++ b/src/levanter/utils/background_iterable.py @@ -4,7 +4,6 @@ import threading from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, TypeVar, Union -import jax import tblib from levanter.utils.thread_utils import AsyncIteratorWrapper @@ -93,7 +92,6 @@ def _fill_queue_with_batches(self): if isinstance(iterator, Iterator): self._produce_batches_sync(iterator) else: - print("asyncio", jax.devices()) asyncio.run(self._produce_batches_async(iterator)) except Exception: self.q.put(_ExceptionWrapper(sys.exc_info())) @@ -123,7 +121,6 @@ def _produce_batches_sync(self, iterator): async def _produce_batches_async(self, iterator): try: async for batch in iterator: - print(jax.devices()) while not self._stop_event.is_set(): try: self.q.put(batch, block=True, timeout=1) From f46690a991c8960e1c49e36d1ada813055ae6eb0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 14 Oct 2024 12:51:43 -0700 Subject: [PATCH 5/6] pr --- src/levanter/data/loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index c93043dec..70af3b6f4 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -205,7 +205,6 @@ def get_local_data_for_leaf(indices: _TensorSliceIndex, leaf_index: int) -> Arra if all(idx == slice(None) for idx in other_indices): return leaf_data else: - # TODO: this doesn't work with named axes return leaf_data[(..., *other_indices)] def make_global_array_for_leaf(leaf_index, item_leaf_shape: ShapeSpec | NamedShapeSpec): From 45a03f623e502a8c50cb8dce4d11a807a0e8e7fe Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 14 Oct 2024 13:20:10 -0700 Subject: [PATCH 6/6] think i got it --- src/levanter/data/loader.py | 9 +++--- tests/test_doremi.py | 58 ++++++++++++++++++++----------------- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index 70af3b6f4..928c9456c 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -63,10 +63,11 @@ def __init__( self.mesh = mesh self.Batch = Batch - def _exemplar_shape(): - return blocking_wait(self.data_store.getitem_async(0)) - - self._ex_leaves, self._ex_structure = jax.tree_flatten(_exemplar_shape(), is_leaf=is_named_array) + with local_cpu_mesh(): + # It's important that all data loading happens CPU side. We might relax this one day. + self._ex_leaves, self._ex_structure = jax.tree_flatten( + blocking_wait(self.data_store.getitem_async(0)), is_leaf=is_named_array + ) local_device_indices, local_indices = self._compute_local_device_indices() diff --git a/tests/test_doremi.py b/tests/test_doremi.py index 8600c9c8b..d2cf8b590 100644 --- a/tests/test_doremi.py +++ b/tests/test_doremi.py @@ -15,7 +15,7 @@ from levanter.data import AsyncDataset from levanter.data.mixture import MixtureDataset from levanter.trainer import Trainer, TrainerConfig -from levanter.utils.jax_utils import key_iterator +from levanter.utils.jax_utils import key_iterator, local_cpu_mesh from levanter.utils.py_utils import non_caching_cycle @@ -27,6 +27,15 @@ class Example(equinox.Module): Block = hax.Axis("Block", 1024) +def platform_of_array(x): + if isinstance(x, jax.Array): + return set(d.platform for d in x.devices()) + elif isinstance(x, hax.NamedArray): + return platform_of_array(x.array) + else: + return "cpu" + + class LogitDataset(AsyncDataset[Example]): def __init__(self, W, noise, x_mask, x_bias, *, key): self.W = W @@ -52,17 +61,12 @@ def _gen_block_data(block_id): self._gen_block_data = _gen_block_data - def __iter__(self): - key_iter = key_iterator(self.key) - Dim = self.W.axes[0] - while True: - kk = next(key_iter) - this_key_iter = key_iterator(kk) - x_block = hax.random.normal(next(this_key_iter), (Block, Dim)) * self.x_mask + self.x_bias - noise = hax.random.normal(next(this_key_iter), (Block,)) * self.noise - y_block = (hax.nn.sigmoid(hax.dot(x_block, self.W, axis=Dim) + noise) > 0.5).astype(float) - for i in range(Block.size): - yield self._make_example(x_block, y_block, i) + def _make_block(self, Dim, kk): + this_key_iter = key_iterator(kk) + x_block = hax.random.normal(next(this_key_iter), (Block, Dim)) * self.x_mask + self.x_bias + noise = hax.random.normal(next(this_key_iter), (Block,)) * self.noise + y_block = (hax.nn.sigmoid(hax.dot(x_block, self.W, axis=Dim) + noise) > 0.5).astype(float) + return x_block, y_block async def async_len(self) -> int: raise ValueError("Infinitely long dataset") @@ -106,21 +110,21 @@ def test_estimate_mixture_weights(): Dim = hax.Axis("Dim", 5) Batch = hax.Axis("Batch", 32) - keys = key_iterator(0) - - # W = hax.random.normal(next(keys), (Dim,)) - W1 = hax.named([0.0, 0.5, 0.5, 0.0, 0.0], (Dim,)) - x1_mask = hax.named([0.0, 1.0, 1.0, 0.0, 0.0], (Dim,)) - W2 = hax.named([0.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) - x2_mask = hax.named([0.0, 0.0, 0.0, 1.0, 1.0], (Dim,)) - W3 = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) - x3_mask = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) - x3_bias = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) - - # y = sigmoid(Wx + b + N(0, noise^2)) > 0.5 - ds1 = LogitDataset(W1, 0.1, x1_mask, 0.0, key=next(keys)) - ds2 = LogitDataset(W2, 2.0, x2_mask, 0.0, key=next(keys)) - ds3 = LogitDataset(W3, 0.05, x3_mask, x3_bias, key=next(keys)) + # data loading needs to take place on CPU + with local_cpu_mesh(): + keys = key_iterator(0) + W1 = hax.named([0.0, 0.5, 0.5, 0.0, 0.0], (Dim,)) + x1_mask = hax.named([0.0, 1.0, 1.0, 0.0, 0.0], (Dim,)) + W2 = hax.named([0.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) + x2_mask = hax.named([0.0, 0.0, 0.0, 1.0, 1.0], (Dim,)) + W3 = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) + x3_mask = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) + x3_bias = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) + + # y = sigmoid(Wx + b + N(0, noise^2)) > 0.5 + ds1 = LogitDataset(W1, 0.1, x1_mask, 0.0, key=next(keys)) + ds2 = LogitDataset(W2, 2.0, x2_mask, 0.0, key=next(keys)) + ds3 = LogitDataset(W3, 0.05, x3_mask, x3_bias, key=next(keys)) # TODO: remove key as a requirement for models def compute_loss_fn(model, example, reduction=hax.mean, reduction_axis=None, key=None):