diff --git a/app/common/PipelineEnum.py b/app/common/PipelineEnum.py index fc439a65..b6f84a80 100644 --- a/app/common/PipelineEnum.py +++ b/app/common/PipelineEnum.py @@ -14,4 +14,5 @@ class PipelineEnum(str, Enum): IRIS_SUMMARY_PIPELINE = "IRIS_SUMMARY_PIPELINE" IRIS_LECTURE_RETRIEVAL_PIPELINE = "IRIS_LECTURE_RETRIEVAL_PIPELINE" IRIS_LECTURE_INGESTION = "IRIS_LECTURE_INGESTION" + IRIS_REWRITING_PIPELINE = "IRIS_REWRITING_PIPELINE" NOT_SET = "NOT_SET" diff --git a/app/domain/rewriting_pipeline_execution_dto.py b/app/domain/rewriting_pipeline_execution_dto.py new file mode 100644 index 00000000..2a25690e --- /dev/null +++ b/app/domain/rewriting_pipeline_execution_dto.py @@ -0,0 +1,7 @@ +from pydantic import Field, BaseModel +from . import PipelineExecutionDTO + + +class RewritingPipelineExecutionDTO(BaseModel): + execution: PipelineExecutionDTO + to_be_rewritten: str = Field(alias="toBeRewritten") diff --git a/app/domain/status/rewriting_status_update_dto.py b/app/domain/status/rewriting_status_update_dto.py new file mode 100644 index 00000000..f4351342 --- /dev/null +++ b/app/domain/status/rewriting_status_update_dto.py @@ -0,0 +1,5 @@ +from app.domain.status.status_update_dto import StatusUpdateDTO + + +class RewritingStatusUpdateDTO(StatusUpdateDTO): + result: str = "" diff --git a/app/pipeline/prompts/faq_rewriting.py b/app/pipeline/prompts/faq_rewriting.py new file mode 100644 index 00000000..da88b427 --- /dev/null +++ b/app/pipeline/prompts/faq_rewriting.py @@ -0,0 +1,32 @@ +system_prompt_faq = """\ +:You are an excellent tutor with expertise in computer science and its practical applications, teaching at a university +level. Your task is to proofread and enhance the given FAQ text. Please follow these guidelines: + +1. Correct all spelling and grammatical errors. +2. Ensure the text is written in simple and clear language, making it easy to understand for students. +3. Preserve the original meaning and intent of the text while maintaining clarity. +4. Ensure that the response is always written in complete sentences. If you are given a list of bullet points, \ +convert them into complete sentences. +5. Make sure to use the original language of the input text. +6. Avoid repeating any information that is already present in the text. +7. Make sure to keep the markdown formatting intact and add formatting for the most important information. +8. If someone does input a very short text, that does not resemble to be an answer to a potential question please make. +sure to respond accordingly. Also, if the input text is too short, please point this out. + +Additionally for Short Inputs: If the input text is too short and does not resemble an answer to a potential question, \ +respond appropriately and point this out. +Your output will be used as an answer to a frequently asked question (FAQ) on the Artemis platform. +Ensure it is clear, concise, and well-structured. + +Exclude the start and end markers from your response and provide only the improved content. + +The markers are defined as following: +Start of the text: ###START### +End of the text: ###END### + +The text that has to be rewritten starts now: + +###START### +{rewritten_text} +###END###\ +""" diff --git a/app/pipeline/rewriting_pipeline.py b/app/pipeline/rewriting_pipeline.py new file mode 100644 index 00000000..07b0d575 --- /dev/null +++ b/app/pipeline/rewriting_pipeline.py @@ -0,0 +1,61 @@ +import logging +from typing import Optional + +from langchain.output_parsers import PydanticOutputParser +from langchain_core.prompts import ( + ChatPromptTemplate, +) + +from app.common.PipelineEnum import PipelineEnum +from app.common.pyris_message import PyrisMessage, IrisMessageRole +from app.domain.data.text_message_content_dto import TextMessageContentDTO +from app.domain.rewriting_pipeline_execution_dto import RewritingPipelineExecutionDTO +from app.llm import CapabilityRequestHandler, RequirementList, CompletionArguments +from app.pipeline import Pipeline +from app.pipeline.prompts.faq_rewriting import system_prompt_faq +from app.web.status.status_update import RewritingCallback + +logger = logging.getLogger(__name__) + + +class RewritingPipeline(Pipeline): + callback: RewritingCallback + request_handler: CapabilityRequestHandler + output_parser: PydanticOutputParser + + def __init__(self, callback: Optional[RewritingCallback] = None): + super().__init__(implementation_id="rewriting_pipeline_reference_impl") + self.callback = callback + self.request_handler = CapabilityRequestHandler( + requirements=RequirementList( + gpt_version_equivalent=4.5, + context_length=16385, + ) + ) + self.tokens = [] + + def __call__( + self, + dto: RewritingPipelineExecutionDTO, + prompt: Optional[ChatPromptTemplate] = None, + **kwargs, + ): + if not dto.to_be_rewritten: + raise ValueError("You need to provide a text to rewrite") + + # + prompt = system_prompt_faq.format( + rewritten_text=dto.to_be_rewritten, + ) + prompt = PyrisMessage( + sender=IrisMessageRole.SYSTEM, + contents=[TextMessageContentDTO(text_content=prompt)], + ) + + response = self.request_handler.chat( + [prompt], CompletionArguments(temperature=0.4) + ) + self._append_tokens(response.token_usage, PipelineEnum.IRIS_REWRITING_PIPELINE) + response = response.contents[0].text_content + final_result = response + self.callback.done(final_result=final_result, tokens=self.tokens) diff --git a/app/web/routers/pipelines.py b/app/web/routers/pipelines.py index 13f78ec6..3431c196 100644 --- a/app/web/routers/pipelines.py +++ b/app/web/routers/pipelines.py @@ -11,17 +11,20 @@ CourseChatPipelineExecutionDTO, CompetencyExtractionPipelineExecutionDTO, ) +from app.domain.rewriting_pipeline_execution_dto import RewritingPipelineExecutionDTO from app.pipeline.chat.exercise_chat_agent_pipeline import ExerciseChatAgentPipeline from app.domain.chat.lecture_chat.lecture_chat_pipeline_execution_dto import ( LectureChatPipelineExecutionDTO, ) from app.pipeline.chat.lecture_chat_pipeline import LectureChatPipeline +from app.pipeline.rewriting_pipeline import RewritingPipeline from app.web.status.status_update import ( ExerciseChatStatusCallback, ChatGPTWrapperStatusCallback, CourseChatStatusCallback, CompetencyExtractionCallback, LectureChatCallback, + RewritingCallback, ) from app.pipeline.chat.course_chat_pipeline import CourseChatPipeline from app.dependencies import TokenValidator @@ -223,6 +226,28 @@ def run_competency_extraction_pipeline_worker( callback.error("Fatal error.", exception=e) +def run_rewriting_pipeline_worker(dto: RewritingPipelineExecutionDTO, _variant: str): + try: + callback = RewritingCallback( + run_id=dto.execution.settings.authentication_token, + base_url=dto.execution.settings.artemis_base_url, + initial_stages=dto.execution.initial_stages, + ) + pipeline = RewritingPipeline(callback=callback) + except Exception as e: + logger.error(f"Error preparing rewriting pipeline: {e}") + logger.error(traceback.format_exc()) + capture_exception(e) + return + + try: + pipeline(dto=dto) + except Exception as e: + logger.error(f"Error running rewriting extraction pipeline: {e}") + logger.error(traceback.format_exc()) + callback.error("Fatal error.", exception=e) + + @router.post( "/competency-extraction/{variant}/run", status_code=status.HTTP_202_ACCEPTED, @@ -237,6 +262,17 @@ def run_competency_extraction_pipeline( thread.start() +@router.post( + "/rewriting/{variant}/run", + status_code=status.HTTP_202_ACCEPTED, + dependencies=[Depends(TokenValidator())], +) +def run_rewriting_pipeline(variant: str, dto: RewritingPipelineExecutionDTO): + logger.info(f"Rewriting pipeline started with variant: {variant} and dto: {dto}") + thread = Thread(target=run_rewriting_pipeline_worker, args=(dto, variant)) + thread.start() + + def run_chatgpt_wrapper_pipeline_worker( dto: ExerciseChatPipelineExecutionDTO, _variant: str ): @@ -323,10 +359,18 @@ def get_pipeline(feature: str): description="Default lecture chat variant.", ) ] + case "REWRITING": + return [ + FeatureDTO( + id="rewriting", + name="Default Variant", + description="Default rewriting variant.", + ) + ] case "CHAT_GPT_WRAPPER": return [ FeatureDTO( - id="default", + id="chat_gpt_wrapper", name="Default Variant", description="Default ChatGPT wrapper variant.", ) diff --git a/app/web/status/status_update.py b/app/web/status/status_update.py index bbddc716..8404c330 100644 --- a/app/web/status/status_update.py +++ b/app/web/status/status_update.py @@ -16,6 +16,7 @@ from app.domain.status.lecture_chat_status_update_dto import ( LectureChatStatusUpdateDTO, ) +from app.domain.status.rewriting_status_update_dto import RewritingStatusUpdateDTO from app.domain.status.stage_state_dto import StageStateEnum from app.domain.status.stage_dto import StageDTO from app.domain.status.text_exercise_chat_status_update_dto import ( @@ -295,6 +296,27 @@ def __init__( super().__init__(url, run_id, status, stage, len(stages) - 1) +class RewritingCallback(StatusCallback): + def __init__( + self, + run_id: str, + base_url: str, + initial_stages: List[StageDTO], + ): + url = f"{base_url}/api/public/pyris/pipelines/rewriting/runs/{run_id}/status" + stages = initial_stages or [] + stages.append( + StageDTO( + weight=10, + state=StageStateEnum.NOT_STARTED, + name="Generating Rewritting", + ) + ) + status = RewritingStatusUpdateDTO(stages=stages) + stage = stages[-1] + super().__init__(url, run_id, status, stage, len(stages) - 1) + + class LectureChatCallback(StatusCallback): def __init__( self, diff --git a/docker/pyris-dev.yml b/docker/pyris-dev.yml index 7d1a956d..cfb995ea 100644 --- a/docker/pyris-dev.yml +++ b/docker/pyris-dev.yml @@ -14,6 +14,8 @@ services: - ../llm_config.local.yml:/config/llm_config.yml:ro networks: - pyris + ports: + - 8000:8000 weaviate: extends: