Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanHB committed Dec 18, 2024
1 parent bae4506 commit 6b0cb60
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 17 deletions.
15 changes: 2 additions & 13 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def eval_docs(self) -> list[Doc]:
return self._docs

def construct_requests(
self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str, system_prompt: str
self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str
) -> Dict[RequestType, List[Request]]:
"""
Constructs a list of requests from the task based on the given parameters.
Expand All @@ -365,7 +365,6 @@ def construct_requests(
context=context,
choice=gold,
metric_categories=[MetricCategory.TARGET_PERPLEXITY],
system_prompt=system_prompt,
)
for i, gold in enumerate(golds)
]
Expand All @@ -377,7 +376,6 @@ def construct_requests(
request_index=0,
context=context,
metric_categories=[MetricCategory.PERPLEXITY],
system_prompt=system_prompt,
)
]
if self.has_metric_category[MetricCategory.GENERATIVE_SAMPLING]:
Expand All @@ -397,7 +395,6 @@ def construct_requests(
do_sample=True,
use_logits=False,
metric_categories=[MetricCategory.GENERATIVE_SAMPLING],
system_prompt=system_prompt,
)
]
if (
Expand All @@ -424,7 +421,6 @@ def construct_requests(
]
if self.has_metric_category[c]
],
system_prompt=system_prompt,
)
]
if (
Expand All @@ -443,7 +439,6 @@ def construct_requests(
for c in [MetricCategory.MULTICHOICE, MetricCategory.MULTICHOICE_PMI]
if self.has_metric_category[c]
],
system_prompt=system_prompt,
)
for i, choice in enumerate(formatted_doc.choices)
]
Expand All @@ -460,7 +455,6 @@ def construct_requests(
context=formatted_doc.unconditioned_query,
choice=choice,
metric_categories=[MetricCategory.MULTICHOICE_PMI],
system_prompt=system_prompt,
)
for i, choice in enumerate(formatted_doc.choices)
]
Expand All @@ -473,7 +467,6 @@ def construct_requests(
context=context,
choices=formatted_doc.choices,
metric_categories=[MetricCategory.MULTICHOICE_ONE_TOKEN],
system_prompt=system_prompt,
)
]
if self.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN]:
Expand All @@ -486,7 +479,6 @@ def construct_requests(
stop_sequence=self.stop_sequence,
generation_size=self.generation_size,
metric_categories=[MetricCategory.LLM_AS_JUDGE_MULTI_TURN],
system_prompt=system_prompt,
)
]
if self.has_metric_category[MetricCategory.LLM_AS_JUDGE]:
Expand All @@ -501,7 +493,6 @@ def construct_requests(
generation_grammar=self.generation_grammar,
num_samples=1,
metric_categories=[MetricCategory.LLM_AS_JUDGE],
system_prompt=system_prompt,
)
]

Expand Down Expand Up @@ -661,9 +652,7 @@ def create_requests_from_tasks( # noqa: C901
# Constructing the requests
cur_task_name = f"{task_name}|{num_fewshot}"
docs[SampleUid(cur_task_name, doc_id_seed)] = doc
req_type_reqs_dict = task.construct_requests(
doc, doc.ctx, doc_id_seed, cur_task_name, system_prompt
)
req_type_reqs_dict = task.construct_requests(doc, doc.ctx, doc_id_seed, cur_task_name)
for req_type, reqs in req_type_reqs_dict.items():
requests[req_type].extend(reqs)

Expand Down
9 changes: 6 additions & 3 deletions src/lighteval/tasks/prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,12 @@ def _single_turn_context(
if isinstance(self.model, LiteLLMClient):
return output, num_effective_fewshots

return self.model.tokenizer.apply_chat_template(
output, tokenize=False, add_generation_prompt=True
), num_effective_fewshots
elif use_chat_template:
return self.model.tokenizer.apply_chat_template(
output, tokenize=False, add_generation_prompt=True
), num_effective_fewshots

return output, num_effective_fewshots

def get_examples(
self,
Expand Down
1 change: 0 additions & 1 deletion src/lighteval/tasks/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class Request:
request_index: int
context: str
metric_categories: list["MetricCategory"] # noqa F821
system_prompt: Optional[str]


@dataclass
Expand Down

0 comments on commit 6b0cb60

Please sign in to comment.