diff --git a/metaseq/tasks/streaming_language_modeling.py b/metaseq/tasks/streaming_language_modeling.py index f89962015..e21d90840 100644 --- a/metaseq/tasks/streaming_language_modeling.py +++ b/metaseq/tasks/streaming_language_modeling.py @@ -35,6 +35,7 @@ from metaseq.data.document_to_sequence import DocumentToSequenceDataset from metaseq.data.cm3_dataset import CausalMaskedDocumentToSequenceDataset from metaseq.dataclass import ChoiceEnum +import json try: from tokenizers import ByteLevelBPETokenizer, Tokenizer @@ -126,6 +127,12 @@ class StreamingLanguageModelingConfig(MetaseqDataclass): default=DEFAULT_MULTICORPUS_MAX, metadata={"help": "Maximum size for example proportional sampling"}, ) + data_weights: Optional[str] = field( + default=None, + metadata={ + "help": "Proportion of each finetuning benchmark to use. Use only during finetuning" + }, + ) data_subshard_count: int = field( default=1, metadata={ @@ -173,7 +180,8 @@ class StreamingLanguageModelingConfig(MetaseqDataclass): cm3_percent_full_document_rotation: float = field( default=0.0, metadata={ - "help": "What percent of the time to rotate full documents while still abiding by the number of sentinel tokens used." + "help": "What percent of the time to rotate full documents while still abiding" + "by the number of sentinel tokens used." }, ) num_retrieved_doc: int = field( @@ -317,10 +325,10 @@ def tokenize_cm3_v2(self, json): raise ValueError(f"dataset not valid: {json['dataset_name']}") def _tokenize_text_json(self, json): - if 'text' in json: + if "text" in json: text = json["text"] - elif 'content' in json: - text = json['content'] + elif "content" in json: + text = json["content"] else: text = str(json) return torch.LongTensor( @@ -425,7 +433,15 @@ def _alpha_sampling(self, datasets, corpora, epoch=1): dtype=float, ) logger.info(f"loaded total {dataset_lengths.sum()} blocks for all corpora") - sample_probs = self._get_sample_prob(dataset_lengths) + data_weights = json.loads(self.args.data_weights) + for cp in corpora: + if cp not in data_weights: + data_weights[cp] = 1 + dataset_lengths_weighted = np.array( + [len(d) * data_weights[cp] for d, cp in zip(datasets, corpora)], + dtype=float, + ) + sample_probs = self._get_sample_prob(dataset_lengths_weighted) logger.info( "Sample probability by corpus: %s", @@ -595,8 +611,13 @@ def load_dataset(self, split: str, epoch=1, combine=False, **kwargs): corpora.append(os.path.splitext(file)[0]) assert len(datasets) > 0 - if self.args.multicorpus_sampling_alpha != 1: - datasets = self._alpha_sampling(datasets, corpora, epoch) + if split == "train": + # Let's don't change validation data at all + if ( + self.args.data_weights is not None + or self.args.multicorpus_sampling_alpha != 1 + ): + datasets = self._alpha_sampling(datasets, corpora, epoch) dataset = torch.utils.data.ConcatDataset(datasets)