From bc69e969f7d1aa81120c4e758b35d121ed2d333f Mon Sep 17 00:00:00 2001 From: Yassine Souissi Date: Tue, 23 Apr 2024 00:36:52 +0200 Subject: [PATCH] Fix Ingestion Pipeline, ready for review --- app/domain/iris_message.py | 12 +----- app/domain/pyris_image.py | 10 ++--- app/llm/external/openai_chat.py | 18 ++++---- app/pipeline/lecture_ingestion_pipeline.py | 48 +++++++++++++--------- app/vector_database/lectureschema.py | 5 ++- 5 files changed, 49 insertions(+), 44 deletions(-) diff --git a/app/domain/iris_message.py b/app/domain/iris_message.py index a7468f7a..d4add334 100644 --- a/app/domain/iris_message.py +++ b/app/domain/iris_message.py @@ -1,25 +1,17 @@ from enum import Enum - from pydantic import BaseModel +from typing import List, Optional from .pyris_image import PyrisImage - class IrisMessageRole(str, Enum): USER = "user" ASSISTANT = "assistant" SYSTEM = "system" - class IrisMessage(BaseModel): text: str = "" role: IrisMessageRole - images: list[PyrisImage] | None - - def __init__( - self, role: IrisMessageRole, text: str, images: list[PyrisImage] | None = None - ): - super().__init__(role=role, text=text) - self.images = images + images: Optional[List[PyrisImage]] = None def __str__(self): return f"{self.role.lower()}: {self.text}" diff --git a/app/domain/pyris_image.py b/app/domain/pyris_image.py index 4f292ba9..2555a22c 100644 --- a/app/domain/pyris_image.py +++ b/app/domain/pyris_image.py @@ -1,19 +1,17 @@ -from pydantic import BaseModel from datetime import datetime - +from pydantic import BaseModel +from typing import Optional class PyrisImage(BaseModel): - prompt: str base64: str - timestamp: datetime - mime_type: str = "jpeg" + prompt: Optional[str] = None + mime_type: Optional[str] = "jpeg" class Config: schema_extra = { "example": { "prompt": "Example prompt", "base64": "base64EncodedString==", - "timestamp": "2023-01-01T12:00:00Z", "mime_type": "jpeg", } } diff --git a/app/llm/external/openai_chat.py b/app/llm/external/openai_chat.py index 351caf72..bff72a00 100644 --- a/app/llm/external/openai_chat.py +++ b/app/llm/external/openai_chat.py @@ -10,22 +10,27 @@ def convert_to_open_ai_messages( - messages: list[IrisMessage], + messages: list[IrisMessage], ) -> list[dict[str, Any]]: + """ + Convert IrisMessages to OpenAI messages + """ openai_messages = [] for message in messages: if message.images: - content = [{"type": "text", "content": message.text}] + content = [{"type": "text", "text": message.text}] for image in message.images: content.append( { "type": "image_url", - "image_url": f"data:image/{image.mime_type};base64,{image.base64}", - "detail": "high", + "image_url": { + "url": f"data:image/{image.mime_type};base64,{image.base64}", + "detail": "high", + } } ) else: - content = message.text + content = [{"type": "text", "text": message.text}] openai_message = {"role": message.role.value, "content": content} openai_messages.append(openai_message) return openai_messages @@ -43,14 +48,13 @@ class OpenAIChatModel(ChatModel): _client: OpenAI def chat( - self, messages: list[IrisMessage], arguments: CompletionArguments + self, messages: list[IrisMessage], arguments: CompletionArguments ) -> IrisMessage: response = self._client.chat.completions.create( model=self.model, messages=convert_to_open_ai_messages(messages), temperature=arguments.temperature, max_tokens=arguments.max_tokens, - stop=arguments.stop, ) return convert_to_iris_message(response.choices[0].message) diff --git a/app/pipeline/lecture_ingestion_pipeline.py b/app/pipeline/lecture_ingestion_pipeline.py index a9b44d52..88753962 100644 --- a/app/pipeline/lecture_ingestion_pipeline.py +++ b/app/pipeline/lecture_ingestion_pipeline.py @@ -22,7 +22,7 @@ def __init__( super().__init__() self.collection = init_lecture_schema(client) self.dto = dto - self.llm_image = BasicRequestHandler("gptvision") + self.llm_vision = BasicRequestHandler("gptvision") self.llm = BasicRequestHandler("gpt35") self.llm_embedding = BasicRequestHandler("ada") @@ -32,14 +32,14 @@ def __call__(self) -> bool: self.delete_lecture_unit( lecture_unit.lecture_id, lecture_unit.lecture_unit_id ) - pdf_path = self.save_pdf(lecture_unit) + pdf_path = self.save_pdf(lecture_unit.pdf_file_base64) chunks = self.chunk_data( lecture_path=pdf_path, lecture_unit_dto=lecture_unit ) with self.collection.batch.dynamic() as batch: for index, chunk in enumerate(chunks): # embed the - embed_chunk = self.llm_embedding.embed_query( + embed_chunk = self.llm_embedding.embed( chunk[LectureSchema.PAGE_TEXT_CONTENT] + "\n" + chunk[LectureSchema.PAGE_IMAGE_DESCRIPTION] @@ -60,8 +60,8 @@ def delete(self): logger.error(f"Error deleting lecture unit: {e}") return False - def save_pdf(self, lecture_unit): - binary_data = base64.b64decode(lecture_unit.pdf_file_base64) + def save_pdf(self, pdf_file_base64): + binary_data = base64.b64decode(pdf_file_base64) fd, temp_pdf_file_path = tempfile.mkstemp(suffix=".pdf") os.close(fd) with open(temp_pdf_file_path, "wb") as temp_pdf_file: @@ -95,14 +95,18 @@ def chunk_data( lecture_unit_dto.lecture_name, ) page_content = page.get_text() - data.append( - { - LectureSchema.LECTURE_ID: lecture_unit_dto.lecture_id, - LectureSchema.LECTURE_UNIT_NAME: lecture_unit_dto.unit_name, - LectureSchema.PAGE_TEXT_CONTENT: page_content, - LectureSchema.PAGE_IMAGE_DESCRIPTION: image_interpretation, - LectureSchema.PAGE_BASE64: img_base64, - LectureSchema.PAGE_NUMBER: page_num + 1, + data.append({ + LectureSchema.LECTURE_ID: lecture_unit_dto.lecture_id, + LectureSchema.LECTURE_NAME: lecture_unit_dto.lecture_name, + LectureSchema.LECTURE_UNIT_ID: lecture_unit_dto.lecture_unit_id, + LectureSchema.LECTURE_UNIT_NAME: lecture_unit_dto.lecture_unit_name, + LectureSchema.COURSE_ID: lecture_unit_dto.course_id, + LectureSchema.COURSE_NAME: lecture_unit_dto.course_name, + LectureSchema.COURSE_DESCRIPTION: lecture_unit_dto.course_description, + LectureSchema.PAGE_NUMBER: page_num + 1, + LectureSchema.PAGE_TEXT_CONTENT: page_content, + LectureSchema.PAGE_IMAGE_DESCRIPTION: image_interpretation, + LectureSchema.PAGE_BASE64: img_base64, } ) @@ -110,10 +114,16 @@ def chunk_data( page_content = page.get_text() data.append( { + LectureSchema.LECTURE_ID: lecture_unit_dto.lecture_id, + LectureSchema.LECTURE_NAME: lecture_unit_dto.lecture_name, + LectureSchema.LECTURE_UNIT_ID: lecture_unit_dto.lecture_unit_id, + LectureSchema.LECTURE_UNIT_NAME: lecture_unit_dto.lecture_unit_name, + LectureSchema.COURSE_ID: lecture_unit_dto.course_id, + LectureSchema.COURSE_NAME: lecture_unit_dto.course_name, + LectureSchema.COURSE_DESCRIPTION: lecture_unit_dto.course_description, + LectureSchema.PAGE_NUMBER: page_num + 1, LectureSchema.PAGE_TEXT_CONTENT: page_content, LectureSchema.PAGE_IMAGE_DESCRIPTION: "", - LectureSchema.PAGE_NUMBER: page_num + 1, - LectureSchema.LECTURE_NAME: lecture_unit_dto.lecture_name, LectureSchema.PAGE_BASE64: "", } ) @@ -149,11 +159,11 @@ def interpret_image( f" Here is the content of the page before the one you need to interpret:" f" {last_page_content}" ) + image = PyrisImage(base64=img_base64) iris_message = IrisMessage( - role=IrisMessageRole.SYSTEM, text=image_interpretation_prompt + role=IrisMessageRole.SYSTEM, text=image_interpretation_prompt, images=[image] ) - image = PyrisImage(base64=img_base64) - response = self.llm_image.chat( - [iris_message, image], CompletionArguments(temperature=0.2, max_tokens=1000) + response = self.llm_vision.chat( + [iris_message], CompletionArguments(temperature=0.2, max_tokens=1000) ) return response.text diff --git a/app/vector_database/lectureschema.py b/app/vector_database/lectureschema.py index fda74e96..6d76ee63 100644 --- a/app/vector_database/lectureschema.py +++ b/app/vector_database/lectureschema.py @@ -15,6 +15,7 @@ class LectureSchema: COLLECTION_NAME = "LectureSlides" COURSE_NAME = "course_name" + COURSE_DESCRIPTION = "course_description" COURSE_ID = "course_id" LECTURE_ID = "lecture_id" LECTURE_NAME = "lecture_name" @@ -53,8 +54,8 @@ def init_lecture_schema(client: WeaviateClient) -> Collection: data_type=wvc.config.DataType.TEXT, ), wvc.config.Property( - name=LectureSchema.LECTURE_DESCRIPTION, - description="The description of the lecture", + name=LectureSchema.COURSE_DESCRIPTION, + description="The description of the COURSE", data_type=wvc.config.DataType.TEXT, ), wvc.config.Property(