Skip to content

Commit

Permalink
Update dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelburger committed Nov 6, 2024
1 parent e91c17d commit 729a8d7
Show file tree
Hide file tree
Showing 3 changed files with 494 additions and 28 deletions.
26 changes: 20 additions & 6 deletions petagraph/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
hf_hub_version = None
tf_version = None

from nanotron.data.petagraph_dataset import PetaGraphStreamDataset
from nanotron.data.petagraph_dataset import PetaGraphStreamDataset, PetaGraphStreamDatasetV2

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -126,8 +126,14 @@ def get_dataloader_from_data_stage(
# else:
# raise ValueError("Data path must contain either 'unitig' or 'contig'")

contig_format = "s3://logan-pub/c/{accession}/{accession}.contigs.fa.zst"
unitig_format = "s3://logan-pub/u/{accession}/{accession}.unitigs.fa.zst"
# ----- URL FORMAT -----
# contig_format = "s3://logan-pub/c/{accession}/{accession}.contigs.fa.zst"
# unitig_format = "s3://logan-pub/u/{accession}/{accession}.unitigs.fa.zst"

unitig_format = "https://s3.amazonaws.com/logan-pub/u/{accession}/{accession}.unitigs.fa.zst"
contig_format = "https://s3.amazonaws.com/logan-pub/c/{accession}/{accession}.contigs.fa.zst"
log_rank(f"Contig format: {contig_format}", logger=logger, level=logging.INFO, rank=0)
# ----------------------

assert data.all_sequences_resources_path is not None, "all_sequences_resources_path must be provided"
all_sequences_resources_path = Path(data.all_sequences_resources_path)
Expand Down Expand Up @@ -226,14 +232,13 @@ def get_dataloader_from_data_stage(

else:
global_rank = trainer.parallel_context.world_pg.rank()
train_dataset = PetaGraphStreamDataset(
train_dataset = PetaGraphStreamDatasetV2(
logger=logger,
url_list=train_sequence_files,
vocabulary=VOCABULARY,
from_cloud=True, # not mock_data,
maxlen=trainer.sequence_length + 1,
create_attention_mask=True,
prefetch_sequences=data.prefetch_buffer_seq_size,
log_directory=trainer.config.checkpoints.checkpoints_path,
rank=global_rank,
packed=True
Expand All @@ -253,14 +258,23 @@ def get_dataloader_from_data_stage(

log_rank(f"Using {num_dl_workers} dataloader workers", logger=logger, level=logging.INFO, rank=0)

prefetch_factor = None
worker_init_fn = None
if num_dl_workers > 0:
prefetch_factor = data.prefetch_buffer_seq_size
if isinstance(train_dataset, PetaGraphStreamDatasetV2):
worker_init_fn = train_dataset.worker_init_fn

log_rank(f"Prefetch factor: {prefetch_factor}", logger=logger, level=logging.INFO, rank=0)
return DataLoader(
train_dataset,
batch_size=trainer.micro_batch_size,
collate_fn=data_collator,
drop_last=True,
prefetch_factor=prefetch_factor,
num_workers=num_dl_workers,
pin_memory=True,
worker_init_fn=get_dataloader_worker_init(dp_rank=trainer.parallel_context.dp_pg.rank()),
worker_init_fn=worker_init_fn,
)

else:
Expand Down
Loading

0 comments on commit 729a8d7

Please sign in to comment.