diff --git a/python/text_utils/api/processor.py b/python/text_utils/api/processor.py index d54f857..fc46db7 100644 --- a/python/text_utils/api/processor.py +++ b/python/text_utils/api/processor.py @@ -28,10 +28,6 @@ class TextProcessor: pretrained: bool = False devices: list[torch.device] - @classmethod - def _task_upper(cls) -> str: - return cls.task.upper() - @classmethod def available_models(cls) -> list[ModelInfo]: raise NotImplementedError @@ -52,7 +48,7 @@ def _model_url(cls, model: str) -> str: @classmethod def download_dir(cls) -> str: - task_name = cls._task_upper().replace(" ", "_") + task_name = cls.task.upper().replace(" ", "_") return os.environ.get( f"{task_name}_DOWNLOAD_DIR", os.path.join(os.path.dirname(__file__), ".download", task_name), @@ -60,7 +56,7 @@ def download_dir(cls) -> str: @classmethod def cache_dir(cls) -> str: - task_name = cls._task_upper().replace(" ", "_") + task_name = cls.task.upper().replace(" ", "_") return os.environ.get( f"{task_name}_CACHE_DIR", os.path.join(os.path.dirname(__file__), ".cache", task_name), @@ -86,7 +82,7 @@ def from_pretrained( f"{pprint.pformat(cls.available_models())}" ) - logger = logging.get_logger(f"{cls._task_upper()} DOWNLOAD") + logger = logging.get_logger(f"{cls.task.upper()} DOWNLOAD") model_url = cls._model_url(model) if download_dir is None: download_dir = cls.download_dir() @@ -150,7 +146,7 @@ def __init__( self, model: nn.Module, cfg: dict[str, Any], device: Device = "cuda" ) -> None: self.cfg = cfg - self.logger = logging.get_logger(self._task_upper()) + self.logger = logging.get_logger(self.task.upper()) self.logger.debug(f"Got config:\n{self.cfg}") torch.set_num_threads(len(os.sched_getaffinity(0))) diff --git a/python/text_utils/api/server.py b/python/text_utils/api/server.py index 9b1a3b3..2361128 100644 --- a/python/text_utils/api/server.py +++ b/python/text_utils/api/server.py @@ -73,7 +73,10 @@ def from_config( def __init__(self, config: dict[str, Any], log_level: str | int | None = None): self.config = config - self.logger = get_logger(f"{self.text_processor_cls.task} server", log_level) + self.logger = get_logger( + f"{self.text_processor_cls.task.upper()} SERVER", + log_level, + ) self.logger.info(f"Loaded server config:\n{yaml.dump(config)}") self.port = int(self.config.get("port", 40000)) diff --git a/python/text_utils/inference/__init__.py b/python/text_utils/inference/__init__.py index ac70b42..cfb2158 100644 --- a/python/text_utils/inference/__init__.py +++ b/python/text_utils/inference/__init__.py @@ -194,7 +194,7 @@ def get_outputs(intermediate: bool) -> list[list[Beam]]: torch.arange(b), decoder_lengths_tensor - 1 ] - org_log_probs = torch.log_softmax(decoder_outputs, dim=-1) + raw_log_probs = torch.log_softmax(decoder_outputs, dim=-1) # apply logit functions for logit_fn in logit_fns or []: @@ -202,16 +202,14 @@ def get_outputs(intermediate: bool) -> list[list[Beam]]: log_probs = torch.log_softmax(decoder_outputs, dim=-1) - for idx, (log_probs, org_log_probs) in enumerate( - zip( - torch.split(log_probs, num_beams), torch.split(org_log_probs, num_beams) - ) - ): + raw_log_probs = torch.split(raw_log_probs, num_beams) + log_probs = torch.split(log_probs, num_beams) + for idx, (raw_log_prob, log_prob) in enumerate(zip(raw_log_probs, log_probs)): candidates: list[Beam] = [] for beam_idx, beam in enumerate(current_beams[idx]): - for token_id in sample_fn(log_probs[beam_idx], beam_width).tolist(): + for token_id in sample_fn(log_prob[beam_idx], beam_width).tolist(): candidate = beam.clone() - candidate.add(token_id, org_log_probs[beam_idx, token_id].item()) + candidate.add(token_id, raw_log_prob[beam_idx, token_id].item()) candidates.append(candidate) # reset current beams and fill with best candidates diff --git a/python/text_utils/inference/utils.py b/python/text_utils/inference/utils.py index 9b361c6..2b01df9 100644 --- a/python/text_utils/inference/utils.py +++ b/python/text_utils/inference/utils.py @@ -1,7 +1,6 @@ -from typing import Callable, Any +from typing import Any, Callable, Protocol import torch - from grammar_utils.constrain import Constraint # maps from token ids, length, and other kwargs to distribution over next token id and other info @@ -140,7 +139,8 @@ def __repr__(self) -> str: # takes in a beam (and optional length) and returns a scalar score -ScoreFn = Callable[[Beam, int | None], float] +class ScoreFn(Protocol): + def __call__(self, beam: Beam, length: int | None = None) -> float: ... def log_likelihood_score(normalize: bool = True, alpha: float = 1.0) -> ScoreFn: