Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
mdiazmel authored Jan 23, 2025
2 parents dc25712 + 0ab63d0 commit 17c42e5
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 3 deletions.
215 changes: 214 additions & 1 deletion community_tasks/arabic_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@
"""
import random
import re
from typing import Any, Dict, List, Optional, Union

from lighteval.metrics.metrics import Metrics
from lighteval.metrics.llm_as_judge import JudgeLM
from lighteval.metrics.metrics import Metric, MetricCategory, Metrics
from lighteval.metrics.utils.metric_utils import MetricUseCase
from lighteval.tasks.default_prompts import LETTER_INDICES
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
Expand Down Expand Up @@ -832,6 +835,215 @@ def __init__(
]


class JudgeMetricWrapper(Metric):
"""Wrapper class for LLM-based judge metric implementation."""

def __init__(self, judge: JudgeLM):
"""
Initializes the judge metric wrapper.
Args:
judge (JudgeLM): The LLM judge instance to use for evaluation.
"""
self.judge = judge
self.metric_name = "llm_as_judge"
self.category = MetricCategory.LLM_AS_JUDGE
self.corpus_level_fn = self.aggregate_scores
self.sample_level_fn = self._sample_level_fn
self.higher_is_better = True # Fixed tuple syntax
self.use_case = MetricUseCase.NONE

def compute(self, responses: list[str], formatted_docs: list[Doc], **kwargs) -> dict[str, float]:
"""
Computes evaluation scores using the judge's evaluate_answer method.
Args:
responses (list[str]): The predicted answers
formatted_docs (list[Doc]): Documents containing questions and gold answers
Returns:
dict[str, float]: Dictionary containing evaluation scores
"""
results = []
for i, doc in enumerate(formatted_docs):
question = doc.query
gold = doc.choices[doc.gold_index] if doc.gold_index is not None else None
answer = responses[i][0].result[0]

score, _, _ = self.judge.evaluate_answer(question=question, answer=answer, options=None, gold=gold)
results.append({self.metric_name: score})

return results

def aggregate_scores(self, scores: list[dict]) -> float:
return sum(scores) / len(scores) if scores else 0.0

def _sample_level_fn(self):
return None


def parse_candidates(candidates: Union[List[str], str]) -> List[str]:
"""
Parses and validates candidate answers from either list or string format.
Args:
candidates: Either a list of candidate answers or a newline-separated string
Returns:
List[str]: List of validated candidate answers
Raises:
ValueError: If candidates cannot be parsed or are empty
"""
try:
if isinstance(candidates, list):
parsed_candidates = [str(c).strip() for c in candidates if c]
else:
parsed_candidates = [c.strip() for c in str(candidates).split("\n") if c.strip()]

if not parsed_candidates:
raise ValueError("No valid candidates found after parsing")

return parsed_candidates
except Exception as e:
raise ValueError(f"Failed to parse candidates: {str(e)}")


def qa_prompt_arabic(line: Dict[str, Any], task_name: str = None) -> Doc:
"""
Formats the prompt for Arabic question answering with candidates.
Args:
line: Dictionary containing question and candidate information
task_name: Optional name for the task
Returns:
Doc: Formatted document for evaluation
Raises:
ValueError: If required fields are missing or invalid
"""
try:
# Validates and extracts the question
if not isinstance(line.get("question"), str):
raise ValueError("Question must be a string")
question = line["question"]

# Processes candidate answers
candidates = parse_candidates(line["candidates"])

# Validates gold answer
if "gold_answer" not in line:
raise ValueError("Gold answer is required")
gold_answer = str(line["gold_answer"])

# Constructs the prompt
instruction = "بناءً على السياقات المقترحة التالية، اجب عن السؤال التالي"
query = f"{instruction}\n\nالسؤال:\n{question}\n\nالسياقات المقترحة:\n{', '.join(candidates)}\n"

return Doc(
task_name=task_name or "alrage",
query=query,
instruction=instruction,
choices=[gold_answer], # Gold answer is used as the only valid choice
gold_index=0, # Index of the correct answer in choices
)
except Exception as e:
raise ValueError(f"Failed to create QA prompt: {str(e)}")


def judge_template(question: str, answer: str, gold: str, options: Optional[List[str]] = None) -> List[Dict[str, str]]:
"""
Template for the Arabic judge prompt.
System prompt translation:
You are a neutral expert evaluator. Your tasks are:
1. Evaluate the answer's accuracy compared to the correct answer
2. Verify that the answer is supported by the provided context
3. Evaluate the quality and comprehensiveness of the answer
Rate the answer on a scale from 0 to 10.
Args:
question: The question being evaluated
answer: The provided answer
gold: The correct answer
options: Optional list of answer choices
Returns:
List[Dict[str, str]]: Formatted messages for the judge
"""
messages = [
{
"role": "system",
"content": """أنت مقيّم محايد خبير باللغة العربية. يجب عليك:
1. تقييم دقة الإجابة مقارنة بالإجابة الصحيحة
2. التحقق من أن الإجابة مدعومة بالسياق المقدم
3. تقييم جودة وشمولية الإجابة
مهم جداً: يجب أن يكون ردك رقماً فقط من 0 إلى 10. لا تضف أي نص أو تفسير.""",
},
{
"role": "user",
"content": f"""السؤال: {question}
الإجابة المقدمة: {answer}
الإجابة الصحيحة: {gold}
أعط تقييماً من 0 إلى 10:
0-2: إجابة خاطئة تماماً
3-4: إجابة جزئية مع أخطاء
5-6: إجابة متوسطة
7-8: إجابة جيدة
9-10: إجابة ممتازة
اكتب رقماً فقط من 0 إلى 10 بدون أي نص إضافي:""",
},
]
return messages


def process_judge_response(response) -> float:
"""Process the judge's response to extract the score"""
# If response is a list, extract the content from the user role
if isinstance(response, list):
response_content = " ".join(item["content"] for item in response if item["role"] == "user")
else:
response_content = response # If it's not a list, use it directly

try:
# Extract the score from the response content
score = float(next(num for num in response_content.split() if num.replace(".", "", 1).isdigit()))
return min(max(score / 10.0, 0.0), 1.0)
except (StopIteration, ValueError):
return 0.0


judge = JudgeLM(
model="Qwen/Qwen2.5-72B-Instruct",
templates=judge_template,
process_judge_response=process_judge_response,
judge_backend="vllm",
)

wrapped_judge = JudgeMetricWrapper(judge)

# Task configuration
alrage_qa_task = LightevalTaskConfig(
name="alrage_qa",
prompt_function=qa_prompt_arabic,
suite=["community"],
hf_repo="OALL/ALRAGE",
hf_subset=None,
hf_avail_splits=["train"],
evaluation_splits=["train"],
metric=[wrapped_judge],
trust_dataset=True,
generation_size=200,
stop_sequence=[],
version=0,
)

TASKS_TABLE = (
ARABIC_MMLU_TASKS
+ ARABIC_MMLU_HT_TASKS
Expand All @@ -852,4 +1064,5 @@ def __init__(
+ [hellaswag_okapi_ar_task]
+ [toxigen_ar_task]
+ [sciq_ar_task]
+ [alrage_qa_task]
)
1 change: 1 addition & 0 deletions examples/tasks/OALL_v2_tasks.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,4 @@ community|arabic_mmlu_ht:sociology|0|0
community|arabic_mmlu_ht:us_foreign_policy|0|0
community|arabic_mmlu_ht:virology|0|0
community|arabic_mmlu_ht:world_religions|0|0
community|alrage_qa|0|0
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ dependencies = [
"transformers>=4.38.0",
"accelerate",
"huggingface_hub>=0.23.0",
"torch>=2.0,<2.5",
"torch>=2.0,<3.0",
"GitPython>=3.1.41", # for logging
"datasets>=2.14.0",
"numpy<2", # pinned to avoid incompatibilities
Expand Down Expand Up @@ -109,7 +109,7 @@ multilingual = [
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]
math = ["latex2sympy2_extended>=0.9.0"]
math = ["latex2sympy2_extended>=0.9.1"]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
47 changes: 47 additions & 0 deletions src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,28 @@ def __init__(

self.public = public

@property
def results(self):
config_general = asdict(self.general_config_logger)
# We remove the config from logging, which contains context/accelerator objects
config_general.pop("config")
results = {
"config_general": config_general,
"results": self.metrics_logger.metric_aggregated,
"versions": self.versions_logger.versions,
"config_tasks": self.task_config_logger.tasks_configs,
"summary_tasks": self.details_logger.compiled_details,
"summary_general": asdict(self.details_logger.compiled_details_over_all_tasks),
}
return results

@property
def details(self):
return {
task_name: [asdict(detail) for detail in task_details]
for task_name, task_details in self.details_logger.details.items()
}

def save(self) -> None:
"""Saves the experiment information and results to files, and to the hub if requested."""
logger.info("Saving experiment tracker")
Expand Down Expand Up @@ -281,6 +303,31 @@ def push_to_hub(

self.recreate_metadata_card(repo_id)

def push_results_to_hub(self, repo_id: str, path_in_repo: str, private: bool | None = None):
repo_id = repo_id if "/" in repo_id else f"{self.hub_results_org}/{repo_id}"
private = private if private is not None else not self.public
self.api.create_repo(repo_id, private=private, repo_type="dataset", exist_ok=True)
results_json = json.dumps(self.results, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False)
self.api.upload_file(
repo_id=repo_id,
path_or_fileobj=results_json.encode(),
path_in_repo=path_in_repo,
repo_type="dataset",
)

def push_details_to_hub(self, repo_id: str, path_in_repo: str, private: bool | None = None):
repo_id = repo_id if "/" in repo_id else f"{self.hub_results_org}/{repo_id}"
private = private if private is not None else not self.public
self.api.create_repo(repo_id, private=private, repo_type="dataset", exist_ok=True)
for task_name, details in self.details:
details_json = "\n".join([json.dumps(detail) for detail in details])
self.api.upload_file(
repo_id=repo_id,
path_or_fileobj=details_json.encode(),
path_in_repo=path_in_repo.format(task_name=task_name),
repo_type="dataset",
)

def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901
"""Fully updates the details repository metadata card for the currently evaluated model
Expand Down
27 changes: 27 additions & 0 deletions tests/metrics/test_extractive_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,34 @@ def test_math_extraction_edge_cases(gold, pred, expected):
r"To find the product \( ab \) where \( a = 2012_3 \) and \( b = 201_3 \), we first convert these base-three numbers to base ten. For \( a = 2012_3 \): \[ a = 2 \cdot 3^3 + 0 \cdot 3^2 + 1 \cdot 3^1 + 2 \cdot 3^0 = 2 \cdot 27 + 0 \cdot 9 + 1 \cdot 3 + 2 \cdot 1 = 54 + 0 + 3 + 2 = 59_{10} \] For \( b = 201_3 \): \[ b = 2 \cdot 3^2 + 0 \cdot 3^1 + 1 \cdot 3^0 = 2 \cdot 9 + 0 \cdot 3 + 1 \cdot 1 = 18 + 0 + 1 = 19_{10} \] Now, calculate the product in base ten: \[ ab = 59 \times 19 \] Perform the multiplication: \[ 59 \times 19 = 59 \times (20 - 1) = 59 \times 20 - 59 \times 1 = 1180 - 59 = 1121 \] Next, convert \( 1121_{10} \) to base three. We do this by dividing by 3 and recording the remainders: \[ 1121 \div 3 = 373 \quad \text{remainder } 2 \] \[ 373 \div 3 = 124 \quad \text{remainder } 1 \] \[ 124 \div 3 = 41 \quad \text{remainder } 1 \] \[ 41 \div 3 = 13 \quad \text{remainder } 2 \] \[ 13 \div 3 = 4 \quad \text{remainder } 1 \] \[ 4 \div 3 = 1 \quad \text{remainder } 1 \] \[ 1 \div 3 = 0 \quad \text{remainder } 1 \] Reading the remainders from last to first, we find: \[ 1121_{10} = 1112122_3 \] Thus, the product \( ab \) expressed in the base-three number system is \(\boxed{1112122_3}\).",
0,
),
(
r"\(\boxed{\text{C}}\).",
r"$\boxed{\text{(C)}}.$",
1,
),
(
r" So the answer is: \[ \boxed{11111111100} \]",
r"is $\boxed{11,\! 111,\! 111,\! 100}$",
1,
),
(
r" So the answer is: \[ \boxed{32349} \]",
r"is $\boxed{32,\! 349}$",
1,
),
(
r"Thus, the domain of the function \( f(x) \) is: \[ \boxed{(2, 12) \cup (12, 102)} \]",
r"Thus, the answer is $x \in \boxed{(2,12) \cup (12,102)}$",
1,
),
],
)
def test_math_extraction_additional_cases(gold, pred, expected):
assert compare_strings(gold, pred, match_types=["latex", "expr"]) == expected


# text{C} Qwen correct
# 11111111100 Qwen correct
# Interval(2, oo) qwen incorrect
# text{west} qwen incorrect
# 32349, 32,\!348 qwen incorrect

0 comments on commit 17c42e5

Please sign in to comment.