diff --git a/adaptor/objectives/objective_base.py b/adaptor/objectives/objective_base.py index ce5ebcf..a45bfe7 100644 --- a/adaptor/objectives/objective_base.py +++ b/adaptor/objectives/objective_base.py @@ -1,6 +1,7 @@ import abc import itertools import logging +import os.path from functools import partial from typing import List, Union, Optional, Iterable, Tuple, Dict, Sequence, Any, Iterator @@ -21,7 +22,7 @@ class Objective(abc.ABC): """ compatible_head: Head - given_id: Optional[str] + given_id: Optional[str] = "" epoch: int num_steps: int last_input: Optional[Union[BatchEncoding, Dict[str, torch.Tensor]]] @@ -135,26 +136,30 @@ def __init__(self, self.dataset_length["eval"] = self._compute_data_source_length(val_texts_or_path) def _check_supported_data_source_format(self, path: str) -> None: + if not os.path.exists(path): + raise FileNotFoundError("Objective %s: Given path '%s' does not exist" % (self, path)) + # when the passed data source is a file, we check that it is in a supported format: # we support .txt and .tar.gz files supported_file_formats = ['.txt', '.gz'] if not any(path.endswith(suffix) for suffix in supported_file_formats): - raise ValueError("Objective %s's given {val_}texts_or_path `%s` is not a List " - "and does not end with one of supported suffixes: ['.txt', '.tar.gz']." - "If you want to use a file data source, please pass it in a supported format." - % (self, path)) + logger.warning("Objective %s's given {val_}texts_or_path `%s` is not a List " + "and does not end with one of supported suffixes: ['.txt', '.gz']." + "We'll assume that the file is a line-separated plaintext file." % (self, path)) def _compute_data_source_length(self, texts_or_path: Union[str, List[str]]) -> int: if isinstance(texts_or_path, str): - if texts_or_path.endswith('.txt'): - with open(self.texts_path, "rb") as f: - return sum(1 for _ in f) # more efficient line count + if texts_or_path.endswith('.gz'): import io import gzip with io.TextIOWrapper(io.BufferedReader(gzip.open(texts_or_path))) as f: return sum(1 for _ in f) # more efficient line count + else: + with open(self.texts_path, "rb") as f: + return sum(1 for _ in f) # more efficient line count + elif isinstance(texts_or_path, list): return len(texts_or_path) else: @@ -169,6 +174,9 @@ def per_objective_log(self, split: str) -> Dict[str, float]: :return: Dict of the format {split + objective_name + evaluator_name: evaluator_value} """ out_logs = {} + if split == "eval" and self.val_texts is None and self.val_texts_path is None: + logger.warning("Skipping evaluation for %s" % self) + return out_logs # aggregate per-progress_bar-steps, or per-evaluation-steps, keep the results of unprocessed evaluations loss_history = self.loss_history[split][-self.max_samples_per_log[split]:] mean_loss = sum(loss_history) / len(loss_history) if len(loss_history) else float("inf") diff --git a/adaptor/utils.py b/adaptor/utils.py index 0049ff7..3e767bb 100644 --- a/adaptor/utils.py +++ b/adaptor/utils.py @@ -48,16 +48,17 @@ def iter_text_file_per_line(path: str) -> Iterable[str]: At this point, `path` is checked to be of a supported format. :param path: file path """ - if path.endswith(".txt"): - with open(path) as f: - for line in f: - yield line.strip() - elif path.endswith(".gz"): + if path.endswith(".gz"): import gzip import io with io.TextIOWrapper(io.BufferedReader(gzip.open(path))) as file: for line in file: yield line.strip() + else: + # assumes plain, newline-separated text file + with open(path) as f: + for line in f: + yield line.strip() class TransformerAdaptationDataset(AdaptationDataset):