Skip to content

Commit

Permalink
mypy changes
Browse files Browse the repository at this point in the history
  • Loading branch information
AakritiKinra authored Dec 19, 2024
1 parent 5fd9304 commit efefd37
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 19 deletions.
4 changes: 2 additions & 2 deletions llments/eval/factscore/abstain_detection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Abstain Detection Module."""
import numpy as np
import re
from typing import List
from typing import List, cast

invalid_ppl_mentions: List[str] = [
"I could not find any information",
Expand Down Expand Up @@ -40,7 +40,7 @@ def is_invalid_ppl(text: str) -> bool:
Returns:
bool: True if the text starts with any invalid phrase, False otherwise.
"""
return np.any([text.lower().startswith(mention.lower()) for mention in invalid_ppl_mentions])
return cast(bool, np.any([text.lower().startswith(mention.lower()) for mention in invalid_ppl_mentions]))

def is_invalid_paragraph_ppl(text: str) -> bool:
"""Determine if a paragraph is invalid based on its content.
Expand Down
6 changes: 3 additions & 3 deletions llments/eval/factscore/clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from transformers import AutoModelForCausalLM
from transformers import LlamaTokenizer

from factscore.utils import convert_model_to_int8_on_gpu
from factscore.lm import LM
from llments.eval.factscore.utils import convert_model_to_int8_on_gpu
from llments.eval.factscore.lm import LM

class CLM(LM):
"""CLM (Causal Language Model) Class.
Expand Down Expand Up @@ -101,7 +101,7 @@ def _generate(
if verbose:
input_ids = tqdm(input_ids)

generations = []
generations: List[<str>] = []
scores = []
for curr_input_ids in input_ids:
if len(curr_input_ids) > max_sequence_length - max_output_length:
Expand Down
4 changes: 2 additions & 2 deletions llments/eval/factscore/download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import tqdm
import transformers
from typing import Tuple
from typing import Tuple, Dict

def download_file(_id: str, dest: str, cache_dir: str) -> None:
"""Download a file from a given URL or Google Drive ID to the specified destination.
Expand Down Expand Up @@ -59,7 +59,7 @@ def download_file(_id: str, dest: str, cache_dir: str) -> None:
print("Unzip {} ... [Success]".format(dest))

def smart_tokenizer_and_embedding_resize(
special_tokens_dict: dict,
special_tokens_dict: Dict[str, str],
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
) -> None:
Expand Down
6 changes: 3 additions & 3 deletions llments/eval/factscore/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def load_model(self) -> None:
"""
raise NotImplementedError()

def generate(
def _generate(
self,
prompt: str,
sample_idx: int = 0,
Expand Down Expand Up @@ -84,15 +84,15 @@ def save_cache(self) -> None:
with open(self.cache_file, "wb") as f:
pickle.dump(self.cache_dict, f)

def load_cache(self, allow_retry: bool = True) -> Dict[str, Any]:
def load_cache(self, allow_retry: bool = True) -> Any:
"""Load the cache from the cache file.
Args:
allow_retry (bool, optional): Whether to retry loading the cache in case of errors.
Defaults to True.
Returns:
Dict[str, Any]: The loaded cache dictionary.
Any: The loaded cache dictionary.
Raises:
Exception: Propagates the exception if `allow_retry` is False and loading fails.
Expand Down
10 changes: 5 additions & 5 deletions llments/eval/factscore/openai_lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""OpenAI Model Module."""
from factscore.lm import LM
from llments.eval.factscore.lm import LM
import openai
import sys
import time
Expand Down Expand Up @@ -108,7 +108,7 @@ def call_ChatGPT(
max_len: int = 1024,
temp: float = 0.7,
verbose: bool = False
) -> Dict[str, Any]:
) -> Any | None:
"""Call the OpenAI ChatCompletion API to generate a response based on the input message.
Args:
Expand All @@ -119,7 +119,7 @@ def call_ChatGPT(
verbose (bool, optional): If True, print detailed error information. Defaults to False.
Returns:
Dict[str, Any]: The raw response from the OpenAI ChatCompletion API.
Any: The raw response from the OpenAI ChatCompletion API.
Raises:
AssertionError: If an InvalidRequestError occurs, such as when the prompt is too long.
Expand Down Expand Up @@ -156,7 +156,7 @@ def call_GPT3(
num_log_probs: int = 0,
echo: bool = False,
verbose: bool = False
) -> Dict[str, Any]:
) -> Any | None:
"""Call the OpenAI GPT-3 API to generate a response based on the input prompt.
This function handles API rate limits by implementing an exponential backoff retry mechanism.
Expand All @@ -172,7 +172,7 @@ def call_GPT3(
verbose (bool, optional): If True, print detailed error information. Defaults to False.
Returns:
Dict[str, Any]: The raw response from the OpenAI GPT-3 API.
Any: The raw response from the OpenAI GPT-3 API.
Raises:
AssertionError: If an InvalidRequestError occurs, such as when the prompt is too long.
Expand Down
8 changes: 4 additions & 4 deletions llments/eval/factscore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,16 @@ def convert_model_to_int8_on_gpu(model: nn.Module, device: str) -> nn.Module:

model.half()

memory_before_quantization = get_memory_footprint(model) # without lm_head
memory_before_quantization: float = get_memory_footprint(model) # without lm_head

ـreplace_linear_with_int8linear(model) # replace `Linear` with `QuantizedLinearInt8`

model.to(device=device)
memory_after_quantization = get_memory_footprint(model) # without lm_head
memory_after_quantization: float = get_memory_footprint(model) # without lm_head

saving = round(100 * memory_after_quantization/memory_before_quantization)
memory_before_quantization: float = round(memory_before_quantization / 2**30, 2) # rounding for printing
memory_after_quantization: float = round(memory_after_quantization / 2**30, 2) # rounding for printing
memory_before_quantization = round(memory_before_quantization / 2**30, 2) # rounding for printing
memory_after_quantization = round(memory_after_quantization / 2**30, 2) # rounding for printing

print(f'Quantization memory - before: {memory_before_quantization} GB, after: {memory_after_quantization} GB ({saving}% of the size before)')
return model

0 comments on commit efefd37

Please sign in to comment.