diff --git a/llments/eval/factscore/abstain_detection.py b/llments/eval/factscore/abstain_detection.py index b9cb8ac..ab1e174 100644 --- a/llments/eval/factscore/abstain_detection.py +++ b/llments/eval/factscore/abstain_detection.py @@ -1,7 +1,7 @@ """Abstain Detection Module.""" import numpy as np import re -from typing import List +from typing import List, cast invalid_ppl_mentions: List[str] = [ "I could not find any information", @@ -40,7 +40,7 @@ def is_invalid_ppl(text: str) -> bool: Returns: bool: True if the text starts with any invalid phrase, False otherwise. """ - return np.any([text.lower().startswith(mention.lower()) for mention in invalid_ppl_mentions]) + return cast(bool, np.any([text.lower().startswith(mention.lower()) for mention in invalid_ppl_mentions])) def is_invalid_paragraph_ppl(text: str) -> bool: """Determine if a paragraph is invalid based on its content. diff --git a/llments/eval/factscore/clm.py b/llments/eval/factscore/clm.py index 74a5d8a..eb022e6 100644 --- a/llments/eval/factscore/clm.py +++ b/llments/eval/factscore/clm.py @@ -12,8 +12,8 @@ from transformers import AutoModelForCausalLM from transformers import LlamaTokenizer -from factscore.utils import convert_model_to_int8_on_gpu -from factscore.lm import LM +from llments.eval.factscore.utils import convert_model_to_int8_on_gpu +from llments.eval.factscore.lm import LM class CLM(LM): """CLM (Causal Language Model) Class. @@ -101,7 +101,7 @@ def _generate( if verbose: input_ids = tqdm(input_ids) - generations = [] + generations: List[] = [] scores = [] for curr_input_ids in input_ids: if len(curr_input_ids) > max_sequence_length - max_output_length: diff --git a/llments/eval/factscore/download_data.py b/llments/eval/factscore/download_data.py index 650c81b..099d30f 100644 --- a/llments/eval/factscore/download_data.py +++ b/llments/eval/factscore/download_data.py @@ -5,7 +5,7 @@ import torch import tqdm import transformers -from typing import Tuple +from typing import Tuple, Dict def download_file(_id: str, dest: str, cache_dir: str) -> None: """Download a file from a given URL or Google Drive ID to the specified destination. @@ -59,7 +59,7 @@ def download_file(_id: str, dest: str, cache_dir: str) -> None: print("Unzip {} ... [Success]".format(dest)) def smart_tokenizer_and_embedding_resize( - special_tokens_dict: dict, + special_tokens_dict: Dict[str, str], tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ) -> None: diff --git a/llments/eval/factscore/lm.py b/llments/eval/factscore/lm.py index 76dc114..312d167 100644 --- a/llments/eval/factscore/lm.py +++ b/llments/eval/factscore/lm.py @@ -36,7 +36,7 @@ def load_model(self) -> None: """ raise NotImplementedError() - def generate( + def _generate( self, prompt: str, sample_idx: int = 0, @@ -84,7 +84,7 @@ def save_cache(self) -> None: with open(self.cache_file, "wb") as f: pickle.dump(self.cache_dict, f) - def load_cache(self, allow_retry: bool = True) -> Dict[str, Any]: + def load_cache(self, allow_retry: bool = True) -> Any: """Load the cache from the cache file. Args: @@ -92,7 +92,7 @@ def load_cache(self, allow_retry: bool = True) -> Dict[str, Any]: Defaults to True. Returns: - Dict[str, Any]: The loaded cache dictionary. + Any: The loaded cache dictionary. Raises: Exception: Propagates the exception if `allow_retry` is False and loading fails. diff --git a/llments/eval/factscore/openai_lm.py b/llments/eval/factscore/openai_lm.py index 3deb944..75e2b12 100644 --- a/llments/eval/factscore/openai_lm.py +++ b/llments/eval/factscore/openai_lm.py @@ -1,5 +1,5 @@ """OpenAI Model Module.""" -from factscore.lm import LM +from llments.eval.factscore.lm import LM import openai import sys import time @@ -108,7 +108,7 @@ def call_ChatGPT( max_len: int = 1024, temp: float = 0.7, verbose: bool = False -) -> Dict[str, Any]: +) -> Any | None: """Call the OpenAI ChatCompletion API to generate a response based on the input message. Args: @@ -119,7 +119,7 @@ def call_ChatGPT( verbose (bool, optional): If True, print detailed error information. Defaults to False. Returns: - Dict[str, Any]: The raw response from the OpenAI ChatCompletion API. + Any: The raw response from the OpenAI ChatCompletion API. Raises: AssertionError: If an InvalidRequestError occurs, such as when the prompt is too long. @@ -156,7 +156,7 @@ def call_GPT3( num_log_probs: int = 0, echo: bool = False, verbose: bool = False -) -> Dict[str, Any]: +) -> Any | None: """Call the OpenAI GPT-3 API to generate a response based on the input prompt. This function handles API rate limits by implementing an exponential backoff retry mechanism. @@ -172,7 +172,7 @@ def call_GPT3( verbose (bool, optional): If True, print detailed error information. Defaults to False. Returns: - Dict[str, Any]: The raw response from the OpenAI GPT-3 API. + Any: The raw response from the OpenAI GPT-3 API. Raises: AssertionError: If an InvalidRequestError occurs, such as when the prompt is too long. diff --git a/llments/eval/factscore/utils.py b/llments/eval/factscore/utils.py index 7618c35..ac6c7da 100644 --- a/llments/eval/factscore/utils.py +++ b/llments/eval/factscore/utils.py @@ -148,16 +148,16 @@ def convert_model_to_int8_on_gpu(model: nn.Module, device: str) -> nn.Module: model.half() - memory_before_quantization = get_memory_footprint(model) # without lm_head + memory_before_quantization: float = get_memory_footprint(model) # without lm_head ـreplace_linear_with_int8linear(model) # replace `Linear` with `QuantizedLinearInt8` model.to(device=device) - memory_after_quantization = get_memory_footprint(model) # without lm_head + memory_after_quantization: float = get_memory_footprint(model) # without lm_head saving = round(100 * memory_after_quantization/memory_before_quantization) - memory_before_quantization: float = round(memory_before_quantization / 2**30, 2) # rounding for printing - memory_after_quantization: float = round(memory_after_quantization / 2**30, 2) # rounding for printing + memory_before_quantization = round(memory_before_quantization / 2**30, 2) # rounding for printing + memory_after_quantization = round(memory_after_quantization / 2**30, 2) # rounding for printing print(f'Quantization memory - before: {memory_before_quantization} GB, after: {memory_after_quantization} GB ({saving}% of the size before)') return model \ No newline at end of file