Skip to content

Commit

Permalink
Add shuffling and length sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelburger committed Nov 5, 2024
1 parent 814e21a commit 402162d
Showing 1 changed file with 84 additions and 7 deletions.
91 changes: 84 additions & 7 deletions src/nanotron/data/petagraph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
# =============================================================================
from torchdata.datapipes.iter import IterableWrapper, S3FileLoader, \
FileOpener, Mapper, StreamReader, FSSpecFileOpener, Prefetcher
FileOpener, Mapper, StreamReader, FSSpecFileOpener, Prefetcher, Shuffler

from functools import partial
import logging
Expand Down Expand Up @@ -37,10 +37,24 @@ def cyclic_iter(iter):
yield x


# =============================================================================
# Constants
# =============================================================================
KMER_LENGTH = 31 # overlap

# =============================================================================
# Dataset class
# =============================================================================
class PetaGraphStreamDataset(torch.utils.data.IterableDataset):
"""Training dataset to stream from Logan
Parameters
----------
sampling_seq_len_inflection : int
The sequence length at which to switch from sampling to keeping the sequence
below the inflection point we only keep the sequence with a probability pr
to its length. Above the inflection point we always keep the sequence.
"""

def __init__(self,
logger,
Expand All @@ -56,13 +70,15 @@ def __init__(self,
prefetch_fasta_parsing: int = 10,
log_directory: Path = None,
rank: int = 0,
packed: bool = False
packed: bool = False,
sampling_seq_len_inflection: int = 1024
):

self.samples_per_epoch = samples_per_epoch
self.maxlen = maxlen
self.create_attention_mask = create_attention_mask
self.debug = debug
self.sampling_seq_len_inflection = sampling_seq_len_inflection

self.logger = logger
self.logging_func = partial(log_rank, logger=logger, level=logging.INFO, rank=0)
Expand Down Expand Up @@ -168,8 +184,12 @@ def __init__(self,
sequences_unbatched = sequences_batched.unbatch()
self.prefetch_sequences = prefetch_sequences
if self.prefetch_sequences > 0:
self.logging_func(f"Prefetching {self.prefetch_sequences} unbatched sequences")
sequences_unbatched = sequences_unbatched.prefetch(self.prefetch_sequences)

# self.logging_func(f"Prefetching {self.prefetch_sequences} unbatched sequences")
# sequences_unbatched = sequences_unbatched.prefetch(self.prefetch_sequences)

self.logging_func(f"Prefetching and shuffling {self.prefetch_sequences} unbatched sequences")
sequences_unbatched = Shuffler(sequences_unbatched, buffer_size=self.prefetch_sequences)

# sequences_crop = Mapper(sequences_unbatched, self.crop_maxlen)
# sequences_tokenized = Mapper(sequences_crop, self.tokenize_and_pad)
Expand Down Expand Up @@ -221,7 +241,14 @@ def load_restart_consumed_files(restart_file: Path):

return latest_epoch, latest_files

def decompression_func(self, input_data):
def decompression_func(self, input_data: Tuple[str, bytes]):
"""Decompress the data
Parameters
----------
input_data : Tuple[str, bytes]
The path and the data to decompress
"""
path, data = input_data
try:
dctx = zstandard.ZstdDecompressor()
Expand All @@ -231,8 +258,49 @@ def decompression_func(self, input_data):
return path, None

return path, decompressed_data

def chop_at_first_repeated_kmer(sequence: str, k: int):
"""Chop the sequence at the first repeated kmer
Python implementation of:
https://gitlab.pasteur.fr/rchikhi_pasteur/logan-circles/-/blob/master/fix_repeated_31kmers.cpp?ref_type=heads
def fasta_parsing_func(self, input_data):
Parameters
----------
sequence : str
The sequence to chop
k : int
The kmer length
"""
kmers = set()
for i in range(len(sequence) - k + 1):
kmer = sequence[i:i+k]
if kmer in kmers:
return sequence[:i + k - 1]
kmers.add(kmer)
return sequence


def length_sampling_filter(self, sequence: str) -> bool:
seq_len = len(sequence)
if seq_len >= self.sampling_seq_len_inflection:
return True
else:
prob = np.random.rand()
if prob < (seq_len / self.sampling_seq_len_inflection):
True

return False


def fasta_parsing_func(self, input_data: Tuple[str, bytes]):
"""Parse the fasta data and return the sequences
Parameters
----------
input_data : Tuple[str, bytes]
The path and the data to parse
"""
path, data = input_data
if data is None:
return [[]]
Expand All @@ -241,7 +309,16 @@ def fasta_parsing_func(self, input_data):
decoded_lines = data.decode()
sequences = [(path, str(s.seq)) for s in SeqIO.parse(StringIO(decoded_lines), "fasta")]

return sequences
# Chop sequences in preparation for graph traversal
# TODO

# Construct sequence graph and perform random walks
# TODO

# Sample sequences for training
keep_sequences = list(filter(self.length_sampling_filter, sequences))

return keep_sequences

def crop_maxlen(self, input_sequence: str, maxlen: int = None):
# path, input_sequence = input_data
Expand Down

0 comments on commit 402162d

Please sign in to comment.