Skip to content

Commit

Permalink
update server and inference
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Dec 20, 2024
1 parent 53409f2 commit 12d5b7d
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 20 deletions.
12 changes: 4 additions & 8 deletions python/text_utils/api/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ class TextProcessor:
pretrained: bool = False
devices: list[torch.device]

@classmethod
def _task_upper(cls) -> str:
return cls.task.upper()

@classmethod
def available_models(cls) -> list[ModelInfo]:
raise NotImplementedError
Expand All @@ -52,15 +48,15 @@ def _model_url(cls, model: str) -> str:

@classmethod
def download_dir(cls) -> str:
task_name = cls._task_upper().replace(" ", "_")
task_name = cls.task.upper().replace(" ", "_")
return os.environ.get(
f"{task_name}_DOWNLOAD_DIR",
os.path.join(os.path.dirname(__file__), ".download", task_name),
)

@classmethod
def cache_dir(cls) -> str:
task_name = cls._task_upper().replace(" ", "_")
task_name = cls.task.upper().replace(" ", "_")
return os.environ.get(
f"{task_name}_CACHE_DIR",
os.path.join(os.path.dirname(__file__), ".cache", task_name),
Expand All @@ -86,7 +82,7 @@ def from_pretrained(
f"{pprint.pformat(cls.available_models())}"
)

logger = logging.get_logger(f"{cls._task_upper()} DOWNLOAD")
logger = logging.get_logger(f"{cls.task.upper()} DOWNLOAD")
model_url = cls._model_url(model)
if download_dir is None:
download_dir = cls.download_dir()
Expand Down Expand Up @@ -150,7 +146,7 @@ def __init__(
self, model: nn.Module, cfg: dict[str, Any], device: Device = "cuda"
) -> None:
self.cfg = cfg
self.logger = logging.get_logger(self._task_upper())
self.logger = logging.get_logger(self.task.upper())
self.logger.debug(f"Got config:\n{self.cfg}")

torch.set_num_threads(len(os.sched_getaffinity(0)))
Expand Down
5 changes: 4 additions & 1 deletion python/text_utils/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def from_config(

def __init__(self, config: dict[str, Any], log_level: str | int | None = None):
self.config = config
self.logger = get_logger(f"{self.text_processor_cls.task} server", log_level)
self.logger = get_logger(
f"{self.text_processor_cls.task.upper()} SERVER",
log_level,
)
self.logger.info(f"Loaded server config:\n{yaml.dump(config)}")
self.port = int(self.config.get("port", 40000))

Expand Down
14 changes: 6 additions & 8 deletions python/text_utils/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,24 +194,22 @@ def get_outputs(intermediate: bool) -> list[list[Beam]]:
torch.arange(b), decoder_lengths_tensor - 1
]

org_log_probs = torch.log_softmax(decoder_outputs, dim=-1)
raw_log_probs = torch.log_softmax(decoder_outputs, dim=-1)

# apply logit functions
for logit_fn in logit_fns or []:
decoder_outputs = logit_fn(decoder_token_ids, decoder_outputs, beams)

log_probs = torch.log_softmax(decoder_outputs, dim=-1)

for idx, (log_probs, org_log_probs) in enumerate(
zip(
torch.split(log_probs, num_beams), torch.split(org_log_probs, num_beams)
)
):
raw_log_probs = torch.split(raw_log_probs, num_beams)
log_probs = torch.split(log_probs, num_beams)
for idx, (raw_log_prob, log_prob) in enumerate(zip(raw_log_probs, log_probs)):
candidates: list[Beam] = []
for beam_idx, beam in enumerate(current_beams[idx]):
for token_id in sample_fn(log_probs[beam_idx], beam_width).tolist():
for token_id in sample_fn(log_prob[beam_idx], beam_width).tolist():
candidate = beam.clone()
candidate.add(token_id, org_log_probs[beam_idx, token_id].item())
candidate.add(token_id, raw_log_prob[beam_idx, token_id].item())
candidates.append(candidate)

# reset current beams and fill with best candidates
Expand Down
6 changes: 3 additions & 3 deletions python/text_utils/inference/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Callable, Any
from typing import Any, Callable, Protocol

import torch

from grammar_utils.constrain import Constraint

# maps from token ids, length, and other kwargs to distribution over next token id and other info
Expand Down Expand Up @@ -140,7 +139,8 @@ def __repr__(self) -> str:


# takes in a beam (and optional length) and returns a scalar score
ScoreFn = Callable[[Beam, int | None], float]
class ScoreFn(Protocol):
def __call__(self, beam: Beam, length: int | None = None) -> float: ...


def log_likelihood_score(normalize: bool = True, alpha: float = 1.0) -> ScoreFn:
Expand Down

0 comments on commit 12d5b7d

Please sign in to comment.