-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
= Enea_Gore
committed
Nov 17, 2024
1 parent
571a0e1
commit 07e9717
Showing
6 changed files
with
182 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
11 changes: 11 additions & 0 deletions
11
modules/text/module_text_llm/module_text_llm/retrieval_augmented_generation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from module_text_llm.approach_config import ApproachConfig | ||
from pydantic import Field | ||
from typing import Literal | ||
|
||
|
||
from module_text_llm.retrieval_augmented_generation.prompt_generate_suggestions import GenerateSuggestionsPrompt | ||
|
||
class RAGApproachConfig(ApproachConfig): | ||
type: Literal['rag'] = 'rag' | ||
generate_suggestions_prompt: GenerateSuggestionsPrompt = Field(default=GenerateSuggestionsPrompt()) | ||
|
97 changes: 97 additions & 0 deletions
97
...xt/module_text_llm/module_text_llm/retrieval_augmented_generation/generate_suggestions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from typing import List | ||
|
||
from athena import emit_meta | ||
from athena.text import Exercise, Submission, Feedback | ||
from athena.logger import logger | ||
from llm_core.utils.llm_utils import ( | ||
get_chat_prompt_with_formatting_instructions, | ||
check_prompt_length_and_omit_features_if_necessary, | ||
num_tokens_from_prompt, | ||
) | ||
from athena.text import Exercise, Submission, Feedback | ||
from llm_core.utils.predict_and_parse import predict_and_parse | ||
|
||
from module_text_llm.config import BasicApproachConfig | ||
from module_text_llm.helpers.utils import add_sentence_numbers, get_index_range_from_line_range, format_grading_instructions | ||
from module_text_llm.basic_approach.prompt_generate_suggestions import AssessmentModel | ||
|
||
async def generate_suggestions(exercise: Exercise, submission: Submission, config: BasicApproachConfig, debug: bool) -> List[Feedback]: | ||
model = config.model.get_model() # type: ignore[attr-defined] | ||
prompt_input = { | ||
"max_points": exercise.max_points, | ||
"bonus_points": exercise.bonus_points, | ||
"grading_instructions": format_grading_instructions(exercise.grading_instructions, exercise.grading_criteria), | ||
"problem_statement": exercise.problem_statement or "No problem statement.", | ||
"example_solution": exercise.example_solution, | ||
"submission": add_sentence_numbers(submission.text) | ||
} | ||
|
||
chat_prompt = get_chat_prompt_with_formatting_instructions( | ||
model=model, | ||
system_message=config.generate_suggestions_prompt.system_message, | ||
human_message=config.generate_suggestions_prompt.human_message, | ||
pydantic_object=AssessmentModel | ||
) | ||
|
||
# Check if the prompt is too long and omit features if necessary (in order of importance) | ||
omittable_features = ["example_solution", "problem_statement", "grading_instructions"] | ||
prompt_input, should_run = check_prompt_length_and_omit_features_if_necessary( | ||
prompt=chat_prompt, | ||
prompt_input= prompt_input, | ||
max_input_tokens=config.max_input_tokens, | ||
omittable_features=omittable_features, | ||
debug=debug | ||
) | ||
|
||
# Skip if the prompt is too long | ||
if not should_run: | ||
logger.warning("Input too long. Skipping.") | ||
if debug: | ||
emit_meta("prompt", chat_prompt.format(**prompt_input)) | ||
emit_meta("error", f"Input too long {num_tokens_from_prompt(chat_prompt, prompt_input)} > {config.max_input_tokens}") | ||
return [] | ||
|
||
result = await predict_and_parse( | ||
model=model, | ||
chat_prompt=chat_prompt, | ||
prompt_input=prompt_input, | ||
pydantic_object=AssessmentModel, | ||
tags=[ | ||
f"exercise-{exercise.id}", | ||
f"submission-{submission.id}", | ||
], | ||
use_function_calling=True | ||
) | ||
|
||
if debug: | ||
emit_meta("generate_suggestions", { | ||
"prompt": chat_prompt.format(**prompt_input), | ||
"result": result.dict() if result is not None else None | ||
}) | ||
|
||
if result is None: | ||
return [] | ||
|
||
grading_instruction_ids = set( | ||
grading_instruction.id | ||
for criterion in exercise.grading_criteria or [] | ||
for grading_instruction in criterion.structured_grading_instructions | ||
) | ||
|
||
feedbacks = [] | ||
for feedback in result.feedbacks: | ||
index_start, index_end = get_index_range_from_line_range(feedback.line_start, feedback.line_end, submission.text) | ||
grading_instruction_id = feedback.grading_instruction_id if feedback.grading_instruction_id in grading_instruction_ids else None | ||
feedbacks.append(Feedback( | ||
exercise_id=exercise.id, | ||
submission_id=submission.id, | ||
title=feedback.title, | ||
description=feedback.description, | ||
index_start=index_start, | ||
index_end=index_end, | ||
credits=feedback.credits, | ||
structured_grading_instruction_id=grading_instruction_id, | ||
meta={} | ||
)) | ||
|
||
return feedbacks |
65 changes: 65 additions & 0 deletions
65
...le_text_llm/module_text_llm/retrieval_augmented_generation/prompt_generate_suggestions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from pydantic import Field, BaseModel | ||
from typing import List, Optional | ||
from pydantic import BaseModel, Field | ||
|
||
system_message = """\ | ||
You are an AI tutor for text assessment at a prestigious university. | ||
# Task | ||
Create graded feedback suggestions for a student\'s text submission that a human tutor would accept. \ | ||
Meaning, the feedback you provide should be applicable to the submission with little to no modification. | ||
# Style | ||
1. Constructive, 2. Specific, 3. Balanced, 4. Clear and Concise, 5. Actionable, 6. Educational, 7. Contextual | ||
# Problem statement | ||
{problem_statement} | ||
# Example solution | ||
{example_solution} | ||
# Grading instructions | ||
{grading_instructions} | ||
Max points: {max_points}, bonus points: {bonus_points}\ | ||
Respond in json. | ||
""" | ||
|
||
human_message = """\ | ||
Student\'s submission to grade (with sentence numbers <number>: <sentence>): | ||
Respond in json. | ||
\"\"\" | ||
{submission} | ||
\"\"\"\ | ||
""" | ||
|
||
# Input Prompt | ||
class GenerateSuggestionsPrompt(BaseModel): | ||
"""\ | ||
Features available: **{problem_statement}**, **{example_solution}**, **{grading_instructions}**, **{max_points}**, **{bonus_points}**, **{submission}** | ||
_Note: **{problem_statement}**, **{example_solution}**, or **{grading_instructions}** might be omitted if the input is too long._\ | ||
""" | ||
system_message: str = Field(default=system_message, | ||
description="Message for priming AI behavior and instructing it what to do.") | ||
human_message: str = Field(default=human_message, | ||
description="Message from a human. The input on which the AI is supposed to act.") | ||
# Output Object | ||
class FeedbackModel(BaseModel): | ||
title: str = Field(description="Very short title, i.e. feedback category or similar", example="Logic Error") | ||
description: str = Field(description="Feedback description") | ||
line_start: Optional[int] = Field(description="Referenced line number start, or empty if unreferenced") | ||
line_end: Optional[int] = Field(description="Referenced line number end, or empty if unreferenced") | ||
credits: float = Field(0.0, description="Number of points received/deducted") | ||
grading_instruction_id: Optional[int] = Field( | ||
description="ID of the grading instruction that was used to generate this feedback, or empty if no grading instruction was used" | ||
) | ||
|
||
|
||
class AssessmentModel(BaseModel): | ||
"""Collection of feedbacks making up an assessment""" | ||
|
||
feedbacks: List[FeedbackModel] = Field(description="Assessment feedbacks") | ||
|