Skip to content

Commit

Permalink
Conditional import of Prism base
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanik12 committed Apr 8, 2024
1 parent 2eacbbc commit f9253cb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions adaptor/evaluators/generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from transformers import PreTrainedTokenizer, BatchEncoding

from .evaluator_base import EvaluatorBase
from .prism import Prism
from ..utils import Head, AdaptationDataset

logger = logging.getLogger()
Expand Down Expand Up @@ -191,6 +190,7 @@ def __init__(self,
probability: Optional[bool] = False,
model_dir: str = "prism/model_dir",
**kwargs):
from .prism import Prism
# language must be set, see prism.py: MODELS['langs'] for a list of supported langs
super().__init__(**kwargs)
self.probability = probability
Expand Down Expand Up @@ -249,7 +249,7 @@ def JS_divergence(self, probs_real: List[float], probs_model: List[float], base:
probs_joined = [(prob_r + prob_m) / 2 for prob_r, prob_m in zip(probs_real, probs_model)]

return (self.KL_divergence(probs_real, probs_joined) + self.KL_divergence(probs_model, probs_joined)) / \
(2 * np.log2(base))
(2 * np.log2(base))

def evaluate_str(self, expected_list: Sequence[str], actual_list: Sequence[str]) -> float:
# we use PRISM for paraphrase evaluation by default
Expand Down

0 comments on commit f9253cb

Please sign in to comment.