diff --git a/lotus/models/lm.py b/lotus/models/lm.py index b525a63a..42a1dfb0 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -30,9 +30,7 @@ def __init__( self.stats: LMStats = LMStats() - def __call__( - self, messages: list[dict[str, str]] | list[list[dict[str, str]]], **kwargs: dict[str, Any] - ) -> LMOutput: + def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any]) -> LMOutput: all_kwargs = {**self.kwargs, **kwargs} # Set top_logprobs if logprobs requested