Skip to content

Commit

Permalink
missed key sharding in MixtureDataset (#603)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored May 29, 2024
1 parent ed3c6f1 commit a54ad00
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/levanter/data/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(
self,
datasets: Mapping[str, ShardableDataset[T]],
weights: Dict[str, float],
key: int | PRNGKeyArray,
stop_strategy: str = StopStrategy.RESTART_STRATEGY,
key: int | PRNGKeyArray = 0,
):
self.datasets = datasets
self.weights = MixtureDataset._normalize_weights(weights)
Expand All @@ -64,7 +64,8 @@ def _normalize_weights(weights: Dict[str, float]):
def shard(self, shard_id: int, num_shards: int) -> "MixtureDataset":
"""Return a MixtureDataset with the sharded datasets"""
sharded = {name: dset.shard(shard_id, num_shards) for name, dset in self.datasets.items()}
return MixtureDataset(datasets=sharded, weights=self.weights, stop_strategy=self.stop_strategy)
my_key = int(jax.random.randint(jax.random.PRNGKey(self.key), (num_shards,), 0, 2**20)[shard_id])
return MixtureDataset(datasets=sharded, weights=self.weights, stop_strategy=self.stop_strategy, key=my_key)

def __iter__(self) -> Iterator[np.ndarray]:
iterators = {name: iter(dataset) for name, dataset in self.datasets.items()}
Expand Down
5 changes: 4 additions & 1 deletion src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ class LMMixtureDatasetConfig(LMTaskConfig):
train_weights: Dict[str, float] = field(default_factory=dict)
""" weights for each dataset source. They will be normalized to sum to 1. """
stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY)
seed: int = field(default=0)

def __post_init__(self):
if len(self.configs) == 0:
Expand All @@ -737,7 +738,9 @@ def train_set(
) -> ShardableDataset[np.ndarray]:
doc_caches = self.build_caches("train", monitors=monitors)
token_datasets = {name: TokenSeqDataset(cache, seq_len, stride=None) for name, cache in doc_caches.items()}
return MixtureDataset(datasets=token_datasets, weights=self.train_weights, stop_strategy=self.stop_strategy)
return MixtureDataset(
datasets=token_datasets, weights=self.train_weights, stop_strategy=self.stop_strategy, key=self.seed
)

def training_sets(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
Expand Down
6 changes: 6 additions & 0 deletions tests/test_data_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_stop_strategies():
datasets=datasets,
weights={"1": 1.0, "2": 0.0},
stop_strategy=StopStrategy.FIRST_STOP_STRATEGY,
key=0,
)
counter = 0
for batch in mixture_1_only:
Expand All @@ -50,13 +51,15 @@ def test_stop_strategies():
datasets=datasets,
weights={"1": 0.5, "2": 0.5},
stop_strategy=StopStrategy.FIRST_STOP_STRATEGY,
key=0,
)
counter_first = sum([1 for _ in mixture_balanced_first])

mixture_balanced_all = MixtureDataset(
datasets=datasets,
weights={"1": 0.5, "2": 0.5},
stop_strategy=StopStrategy.ALL_STOP_STRATEGY,
key=0,
)
counter_all = sum([1 for _ in mixture_balanced_all])
assert counter_first < counter_all
Expand All @@ -66,6 +69,7 @@ def test_stop_strategies():
datasets=datasets,
weights={"1": 2.0, "2": 2.0},
stop_strategy=StopStrategy.FIRST_STOP_STRATEGY,
key=0,
)
assert mixture_normalized.weights["1"] == mixture_normalized.weights["2"] == 0.5

Expand All @@ -81,6 +85,7 @@ def test_restart_strategy_gets_the_right_average():
datasets=datasets, # type: ignore
weights={"1": 0.6, "2": 0.4},
stop_strategy=StopStrategy.RESTART_STRATEGY,
key=0,
)

# ensure we get the right long run average
Expand Down Expand Up @@ -108,6 +113,7 @@ def test_restart_strategy_gets_the_right_average():
datasets=datasets, # type: ignore
weights={"1": 0.6, "2": 0.4},
stop_strategy=StopStrategy.FIRST_STOP_STRATEGY,
key=0,
)

for i, ex in enumerate(mixture_balanced_first):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_doremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def init_model():

datasets = {"d1": ds1, "d2": ds2, "d3": ds3}

ref_model, ref_loss = fit_to_dataset(MixtureDataset(datasets, weights={k: 1 / 3.0 for k in datasets.keys()}))
ref_model, ref_loss = fit_to_dataset(
MixtureDataset(datasets, weights={k: 1 / 3.0 for k in datasets.keys()}, key=next(keys))
)

# let's see the loss on each dataset
l1_ref = eval_loss_loop(
Expand Down

0 comments on commit a54ad00

Please sign in to comment.