diff --git a/src/nanotron/data/petagraph_dataset.py b/src/nanotron/data/petagraph_dataset.py index 4af72fe9..4ffaef62 100644 --- a/src/nanotron/data/petagraph_dataset.py +++ b/src/nanotron/data/petagraph_dataset.py @@ -6,7 +6,7 @@ # # ============================================================================= from torchdata.datapipes.iter import IterableWrapper, S3FileLoader, \ - FileOpener, Mapper, StreamReader, FSSpecFileOpener, Prefetcher + FileOpener, Mapper, StreamReader, FSSpecFileOpener, Prefetcher, Shuffler from functools import partial import logging @@ -37,10 +37,24 @@ def cyclic_iter(iter): yield x +# ============================================================================= +# Constants +# ============================================================================= +KMER_LENGTH = 31 # overlap + # ============================================================================= # Dataset class # ============================================================================= class PetaGraphStreamDataset(torch.utils.data.IterableDataset): + """Training dataset to stream from Logan + + Parameters + ---------- + sampling_seq_len_inflection : int + The sequence length at which to switch from sampling to keeping the sequence + below the inflection point we only keep the sequence with a probability pr + to its length. Above the inflection point we always keep the sequence. + """ def __init__(self, logger, @@ -56,13 +70,15 @@ def __init__(self, prefetch_fasta_parsing: int = 10, log_directory: Path = None, rank: int = 0, - packed: bool = False + packed: bool = False, + sampling_seq_len_inflection: int = 1024 ): self.samples_per_epoch = samples_per_epoch self.maxlen = maxlen self.create_attention_mask = create_attention_mask self.debug = debug + self.sampling_seq_len_inflection = sampling_seq_len_inflection self.logger = logger self.logging_func = partial(log_rank, logger=logger, level=logging.INFO, rank=0) @@ -168,8 +184,12 @@ def __init__(self, sequences_unbatched = sequences_batched.unbatch() self.prefetch_sequences = prefetch_sequences if self.prefetch_sequences > 0: - self.logging_func(f"Prefetching {self.prefetch_sequences} unbatched sequences") - sequences_unbatched = sequences_unbatched.prefetch(self.prefetch_sequences) + + # self.logging_func(f"Prefetching {self.prefetch_sequences} unbatched sequences") + # sequences_unbatched = sequences_unbatched.prefetch(self.prefetch_sequences) + + self.logging_func(f"Prefetching and shuffling {self.prefetch_sequences} unbatched sequences") + sequences_unbatched = Shuffler(sequences_unbatched, buffer_size=self.prefetch_sequences) # sequences_crop = Mapper(sequences_unbatched, self.crop_maxlen) # sequences_tokenized = Mapper(sequences_crop, self.tokenize_and_pad) @@ -221,7 +241,14 @@ def load_restart_consumed_files(restart_file: Path): return latest_epoch, latest_files - def decompression_func(self, input_data): + def decompression_func(self, input_data: Tuple[str, bytes]): + """Decompress the data + + Parameters + ---------- + input_data : Tuple[str, bytes] + The path and the data to decompress + """ path, data = input_data try: dctx = zstandard.ZstdDecompressor() @@ -231,8 +258,49 @@ def decompression_func(self, input_data): return path, None return path, decompressed_data + + def chop_at_first_repeated_kmer(sequence: str, k: int): + """Chop the sequence at the first repeated kmer + + Python implementation of: + https://gitlab.pasteur.fr/rchikhi_pasteur/logan-circles/-/blob/master/fix_repeated_31kmers.cpp?ref_type=heads - def fasta_parsing_func(self, input_data): + Parameters + ---------- + sequence : str + The sequence to chop + k : int + The kmer length + """ + kmers = set() + for i in range(len(sequence) - k + 1): + kmer = sequence[i:i+k] + if kmer in kmers: + return sequence[:i + k - 1] + kmers.add(kmer) + return sequence + + + def length_sampling_filter(self, sequence: str) -> bool: + seq_len = len(sequence) + if seq_len >= self.sampling_seq_len_inflection: + return True + else: + prob = np.random.rand() + if prob < (seq_len / self.sampling_seq_len_inflection): + True + + return False + + + def fasta_parsing_func(self, input_data: Tuple[str, bytes]): + """Parse the fasta data and return the sequences + + Parameters + ---------- + input_data : Tuple[str, bytes] + The path and the data to parse + """ path, data = input_data if data is None: return [[]] @@ -241,7 +309,16 @@ def fasta_parsing_func(self, input_data): decoded_lines = data.decode() sequences = [(path, str(s.seq)) for s in SeqIO.parse(StringIO(decoded_lines), "fasta")] - return sequences + # Chop sequences in preparation for graph traversal + # TODO + + # Construct sequence graph and perform random walks + # TODO + + # Sample sequences for training + keep_sequences = list(filter(self.length_sampling_filter, sequences)) + + return keep_sequences def crop_maxlen(self, input_sequence: str, maxlen: int = None): # path, input_sequence = input_data