diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index 826367c2e..320e8266d 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -263,14 +263,6 @@ def _batchified_shape(Batch, leaf: hax.NamedArray | Array) -> ShapeSpec | NamedS return ShapeSpec((Batch.size,) + leaf.shape, leaf.dtype) -def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec: - if isinstance(shape_spec, ShapeSpec): # type: ignore - batch_name = hax.partitioning.physical_axis_name(self.Batch, self.axis_resources) - return PartitionSpec(batch_name, *((None,) * (len(shape_spec.shape) - 1))) - else: - return hax.partitioning.pspec_for_axis(shape_spec.shape, self.axis_resources) # type: ignore - - @functools.partial(jax.jit, static_argnums=(0,)) def _stack_tree(batch_name, individual_datums): def _stack_leaves_unchecked(*leaves):