Skip to content

Commit

Permalink
add a restart strategy that restarts iterators forever (#602)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored May 28, 2024
1 parent 81ba8c0 commit 2bb1252
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 4 deletions.
10 changes: 7 additions & 3 deletions src/levanter/data/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class StopStrategy(metaclass=StringHolderEnum):
FIRST_STOP_STRATEGY = "first_exhausted"
ALL_STOP_STRATEGY = "all_exhausted"
RESTART_STRATEGY = "restart"


class MixtureDataset(ShardableDataset[T]):
Expand All @@ -25,23 +26,24 @@ class MixtureDataset(ShardableDataset[T]):
Args:
datasets: A dict of datasets, where the key is the name of the dataset and the value is the dataset itself
weights: weights for each dataset
stop_strategy: strategy for stopping the iteration, by default FIRST_STOP_STRATEGY
stop_strategy: strategy for stopping the iteration, by default RESTART_STRATEGY
- FIRST_STOP_STRATEGY: stop when one dataset has been exhausted
- ALL_STOP_STRATEGY: stop when all datasets have been exhausted
- RESTART_STRATEGY: restart the dataset when it has been exhausted
key: random key for datasets sampling
"""

def __init__(
self,
datasets: Mapping[str, ShardableDataset[T]],
weights: Dict[str, float],
stop_strategy: str = StopStrategy.FIRST_STOP_STRATEGY,
stop_strategy: str = StopStrategy.RESTART_STRATEGY,
key: int | PRNGKeyArray = 0,
):
self.datasets = datasets
self.weights = MixtureDataset._normalize_weights(weights)

if stop_strategy not in [StopStrategy.FIRST_STOP_STRATEGY, StopStrategy.ALL_STOP_STRATEGY]:
if stop_strategy not in StopStrategy: # type: ignore
raise ValueError(f"Stop strategy {stop_strategy} is not supported.")

self.stop_strategy = stop_strategy
Expand Down Expand Up @@ -76,6 +78,8 @@ def __iter__(self) -> Iterator[np.ndarray]:
yield item
except StopIteration:
match self.stop_strategy:
case StopStrategy.RESTART_STRATEGY:
iterators[dataset_name] = iter(self.datasets[dataset_name])
case StopStrategy.FIRST_STOP_STRATEGY:
break
case StopStrategy.ALL_STOP_STRATEGY:
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ class LMMixtureDatasetConfig(LMTaskConfig):
""" configuration of each dataset source (urls, hf dataset id, etc.) """
train_weights: Dict[str, float] = field(default_factory=dict)
""" weights for each dataset source. They will be normalized to sum to 1. """
stop_strategy: str = field(default=StopStrategy.FIRST_STOP_STRATEGY)
stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY)

def __post_init__(self):
if len(self.configs) == 0:
Expand Down
59 changes: 59 additions & 0 deletions tests/test_data_mixture.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import tempfile

import tiny_test_corpus
from levanter.data import Dataset
from levanter.data.mixture import MixtureDataset, StopStrategy
from levanter.data.text import TokenSeqDataset


class ListDataset(Dataset[list]):
def __init__(self, data: list):
self.data = data

def __iter__(self):
return iter(self.data)


def test_stop_strategies():
seq_len = 10

Expand Down Expand Up @@ -59,3 +68,53 @@ def test_stop_strategies():
stop_strategy=StopStrategy.FIRST_STOP_STRATEGY,
)
assert mixture_normalized.weights["1"] == mixture_normalized.weights["2"] == 0.5


def test_restart_strategy_gets_the_right_average():

num_docs_1, num_docs_2 = 10, 20
ds1 = ListDataset([1 for _ in range(num_docs_1)])
ds2 = ListDataset([2 for _ in range(num_docs_2)])

datasets = {"1": ds1, "2": ds2}
mixture_balanced_restart = MixtureDataset(
datasets=datasets, # type: ignore
weights={"1": 0.6, "2": 0.4},
stop_strategy=StopStrategy.RESTART_STRATEGY,
)

# ensure we get the right long run average
NUM_SAMPLES = 2300

# variance of a bernoulli distribution is p(1-p) ≈ 0.24
# to get a 95% confidence interval of 0.02, we need ~2300 samples

# we expect to get roughly 60% 1s and 40% 2s
num_ones = 0
for i, ex in enumerate(mixture_balanced_restart):
if ex == 1:
num_ones += 1
if i >= NUM_SAMPLES:
break

assert 0.58 < num_ones / NUM_SAMPLES < 0.62

# now just to verify, stop_first won't give us the same average

num_total = 0
num_ones = 0

mixture_balanced_first = MixtureDataset(
datasets=datasets, # type: ignore
weights={"1": 0.6, "2": 0.4},
stop_strategy=StopStrategy.FIRST_STOP_STRATEGY,
)

for i, ex in enumerate(mixture_balanced_first):
if ex == 1:
num_ones += 1
num_total += 1

assert num_total < 30
assert num_ones == num_docs_1
assert num_ones / num_total < 0.55 or num_ones / num_total > 0.65

0 comments on commit 2bb1252

Please sign in to comment.