Skip to content

Commit

Permalink
Data loading: support for arbitrary prefix of data files
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanik12 committed Mar 27, 2024
1 parent eb1f952 commit 7504fb0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
24 changes: 16 additions & 8 deletions adaptor/objectives/objective_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import itertools
import logging
import os.path
from functools import partial
from typing import List, Union, Optional, Iterable, Tuple, Dict, Sequence, Any, Iterator

Expand All @@ -21,7 +22,7 @@ class Objective(abc.ABC):
"""

compatible_head: Head
given_id: Optional[str]
given_id: Optional[str] = ""
epoch: int
num_steps: int
last_input: Optional[Union[BatchEncoding, Dict[str, torch.Tensor]]]
Expand Down Expand Up @@ -135,26 +136,30 @@ def __init__(self,
self.dataset_length["eval"] = self._compute_data_source_length(val_texts_or_path)

def _check_supported_data_source_format(self, path: str) -> None:
if not os.path.exists(path):
raise FileNotFoundError("Objective %s: Given path '%s' does not exist" % (self, path))

# when the passed data source is a file, we check that it is in a supported format:
# we support .txt and .tar.gz files
supported_file_formats = ['.txt', '.gz']

if not any(path.endswith(suffix) for suffix in supported_file_formats):
raise ValueError("Objective %s's given {val_}texts_or_path `%s` is not a List "
"and does not end with one of supported suffixes: ['.txt', '.tar.gz']."
"If you want to use a file data source, please pass it in a supported format."
% (self, path))
logger.warning("Objective %s's given {val_}texts_or_path `%s` is not a List "
"and does not end with one of supported suffixes: ['.txt', '.gz']."
"We'll assume that the file is a line-separated plaintext file." % (self, path))

def _compute_data_source_length(self, texts_or_path: Union[str, List[str]]) -> int:
if isinstance(texts_or_path, str):
if texts_or_path.endswith('.txt'):
with open(self.texts_path, "rb") as f:
return sum(1 for _ in f) # more efficient line count

if texts_or_path.endswith('.gz'):
import io
import gzip
with io.TextIOWrapper(io.BufferedReader(gzip.open(texts_or_path))) as f:
return sum(1 for _ in f) # more efficient line count
else:
with open(self.texts_path, "rb") as f:
return sum(1 for _ in f) # more efficient line count

elif isinstance(texts_or_path, list):
return len(texts_or_path)
else:
Expand All @@ -169,6 +174,9 @@ def per_objective_log(self, split: str) -> Dict[str, float]:
:return: Dict of the format {split + objective_name + evaluator_name: evaluator_value}
"""
out_logs = {}
if split == "eval" and self.val_texts is None and self.val_texts_path is None:
logger.warning("Skipping evaluation for %s" % self)
return out_logs
# aggregate per-progress_bar-steps, or per-evaluation-steps, keep the results of unprocessed evaluations
loss_history = self.loss_history[split][-self.max_samples_per_log[split]:]
mean_loss = sum(loss_history) / len(loss_history) if len(loss_history) else float("inf")
Expand Down
11 changes: 6 additions & 5 deletions adaptor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,17 @@ def iter_text_file_per_line(path: str) -> Iterable[str]:
At this point, `path` is checked to be of a supported format.
:param path: file path
"""
if path.endswith(".txt"):
with open(path) as f:
for line in f:
yield line.strip()
elif path.endswith(".gz"):
if path.endswith(".gz"):
import gzip
import io
with io.TextIOWrapper(io.BufferedReader(gzip.open(path))) as file:
for line in file:
yield line.strip()
else:
# assumes plain, newline-separated text file
with open(path) as f:
for line in f:
yield line.strip()


class TransformerAdaptationDataset(AdaptationDataset):
Expand Down

0 comments on commit 7504fb0

Please sign in to comment.