Skip to content

Commit

Permalink
Fix Ingestion Pipeline, ready for review
Browse files Browse the repository at this point in the history
  • Loading branch information
yassinsws committed Apr 22, 2024
1 parent 7a6270b commit bc69e96
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 44 deletions.
12 changes: 2 additions & 10 deletions app/domain/iris_message.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,17 @@
from enum import Enum

from pydantic import BaseModel
from typing import List, Optional
from .pyris_image import PyrisImage


class IrisMessageRole(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"


class IrisMessage(BaseModel):
text: str = ""
role: IrisMessageRole
images: list[PyrisImage] | None

def __init__(
self, role: IrisMessageRole, text: str, images: list[PyrisImage] | None = None
):
super().__init__(role=role, text=text)
self.images = images
images: Optional[List[PyrisImage]] = None

def __str__(self):
return f"{self.role.lower()}: {self.text}"
10 changes: 4 additions & 6 deletions app/domain/pyris_image.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from pydantic import BaseModel
from datetime import datetime

from pydantic import BaseModel
from typing import Optional

class PyrisImage(BaseModel):
prompt: str
base64: str
timestamp: datetime
mime_type: str = "jpeg"
prompt: Optional[str] = None
mime_type: Optional[str] = "jpeg"

class Config:
schema_extra = {
"example": {
"prompt": "Example prompt",
"base64": "base64EncodedString==",
"timestamp": "2023-01-01T12:00:00Z",
"mime_type": "jpeg",
}
}
18 changes: 11 additions & 7 deletions app/llm/external/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,27 @@


def convert_to_open_ai_messages(
messages: list[IrisMessage],
messages: list[IrisMessage],
) -> list[dict[str, Any]]:
"""
Convert IrisMessages to OpenAI messages
"""
openai_messages = []
for message in messages:
if message.images:
content = [{"type": "text", "content": message.text}]
content = [{"type": "text", "text": message.text}]
for image in message.images:
content.append(
{
"type": "image_url",
"image_url": f"data:image/{image.mime_type};base64,{image.base64}",
"detail": "high",
"image_url": {
"url": f"data:image/{image.mime_type};base64,{image.base64}",
"detail": "high",
}
}
)
else:
content = message.text
content = [{"type": "text", "text": message.text}]
openai_message = {"role": message.role.value, "content": content}
openai_messages.append(openai_message)
return openai_messages
Expand All @@ -43,14 +48,13 @@ class OpenAIChatModel(ChatModel):
_client: OpenAI

def chat(
self, messages: list[IrisMessage], arguments: CompletionArguments
self, messages: list[IrisMessage], arguments: CompletionArguments
) -> IrisMessage:
response = self._client.chat.completions.create(
model=self.model,
messages=convert_to_open_ai_messages(messages),
temperature=arguments.temperature,
max_tokens=arguments.max_tokens,
stop=arguments.stop,
)
return convert_to_iris_message(response.choices[0].message)

Expand Down
48 changes: 29 additions & 19 deletions app/pipeline/lecture_ingestion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
super().__init__()
self.collection = init_lecture_schema(client)
self.dto = dto
self.llm_image = BasicRequestHandler("gptvision")
self.llm_vision = BasicRequestHandler("gptvision")
self.llm = BasicRequestHandler("gpt35")
self.llm_embedding = BasicRequestHandler("ada")

Expand All @@ -32,14 +32,14 @@ def __call__(self) -> bool:
self.delete_lecture_unit(
lecture_unit.lecture_id, lecture_unit.lecture_unit_id
)
pdf_path = self.save_pdf(lecture_unit)
pdf_path = self.save_pdf(lecture_unit.pdf_file_base64)
chunks = self.chunk_data(
lecture_path=pdf_path, lecture_unit_dto=lecture_unit
)
with self.collection.batch.dynamic() as batch:
for index, chunk in enumerate(chunks):
# embed the
embed_chunk = self.llm_embedding.embed_query(
embed_chunk = self.llm_embedding.embed(
chunk[LectureSchema.PAGE_TEXT_CONTENT]
+ "\n"
+ chunk[LectureSchema.PAGE_IMAGE_DESCRIPTION]
Expand All @@ -60,8 +60,8 @@ def delete(self):
logger.error(f"Error deleting lecture unit: {e}")
return False

def save_pdf(self, lecture_unit):
binary_data = base64.b64decode(lecture_unit.pdf_file_base64)
def save_pdf(self, pdf_file_base64):
binary_data = base64.b64decode(pdf_file_base64)
fd, temp_pdf_file_path = tempfile.mkstemp(suffix=".pdf")
os.close(fd)
with open(temp_pdf_file_path, "wb") as temp_pdf_file:
Expand Down Expand Up @@ -95,25 +95,35 @@ def chunk_data(
lecture_unit_dto.lecture_name,
)
page_content = page.get_text()
data.append(
{
LectureSchema.LECTURE_ID: lecture_unit_dto.lecture_id,
LectureSchema.LECTURE_UNIT_NAME: lecture_unit_dto.unit_name,
LectureSchema.PAGE_TEXT_CONTENT: page_content,
LectureSchema.PAGE_IMAGE_DESCRIPTION: image_interpretation,
LectureSchema.PAGE_BASE64: img_base64,
LectureSchema.PAGE_NUMBER: page_num + 1,
data.append({
LectureSchema.LECTURE_ID: lecture_unit_dto.lecture_id,
LectureSchema.LECTURE_NAME: lecture_unit_dto.lecture_name,
LectureSchema.LECTURE_UNIT_ID: lecture_unit_dto.lecture_unit_id,
LectureSchema.LECTURE_UNIT_NAME: lecture_unit_dto.lecture_unit_name,
LectureSchema.COURSE_ID: lecture_unit_dto.course_id,
LectureSchema.COURSE_NAME: lecture_unit_dto.course_name,
LectureSchema.COURSE_DESCRIPTION: lecture_unit_dto.course_description,
LectureSchema.PAGE_NUMBER: page_num + 1,
LectureSchema.PAGE_TEXT_CONTENT: page_content,
LectureSchema.PAGE_IMAGE_DESCRIPTION: image_interpretation,
LectureSchema.PAGE_BASE64: img_base64,
}
)

else:
page_content = page.get_text()
data.append(
{
LectureSchema.LECTURE_ID: lecture_unit_dto.lecture_id,
LectureSchema.LECTURE_NAME: lecture_unit_dto.lecture_name,
LectureSchema.LECTURE_UNIT_ID: lecture_unit_dto.lecture_unit_id,
LectureSchema.LECTURE_UNIT_NAME: lecture_unit_dto.lecture_unit_name,
LectureSchema.COURSE_ID: lecture_unit_dto.course_id,
LectureSchema.COURSE_NAME: lecture_unit_dto.course_name,
LectureSchema.COURSE_DESCRIPTION: lecture_unit_dto.course_description,
LectureSchema.PAGE_NUMBER: page_num + 1,
LectureSchema.PAGE_TEXT_CONTENT: page_content,
LectureSchema.PAGE_IMAGE_DESCRIPTION: "",
LectureSchema.PAGE_NUMBER: page_num + 1,
LectureSchema.LECTURE_NAME: lecture_unit_dto.lecture_name,
LectureSchema.PAGE_BASE64: "",
}
)
Expand Down Expand Up @@ -149,11 +159,11 @@ def interpret_image(
f" Here is the content of the page before the one you need to interpret:"
f" {last_page_content}"
)
image = PyrisImage(base64=img_base64)
iris_message = IrisMessage(
role=IrisMessageRole.SYSTEM, text=image_interpretation_prompt
role=IrisMessageRole.SYSTEM, text=image_interpretation_prompt, images=[image]
)
image = PyrisImage(base64=img_base64)
response = self.llm_image.chat(
[iris_message, image], CompletionArguments(temperature=0.2, max_tokens=1000)
response = self.llm_vision.chat(
[iris_message], CompletionArguments(temperature=0.2, max_tokens=1000)
)
return response.text
5 changes: 3 additions & 2 deletions app/vector_database/lectureschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class LectureSchema:

COLLECTION_NAME = "LectureSlides"
COURSE_NAME = "course_name"
COURSE_DESCRIPTION = "course_description"
COURSE_ID = "course_id"
LECTURE_ID = "lecture_id"
LECTURE_NAME = "lecture_name"
Expand Down Expand Up @@ -53,8 +54,8 @@ def init_lecture_schema(client: WeaviateClient) -> Collection:
data_type=wvc.config.DataType.TEXT,
),
wvc.config.Property(
name=LectureSchema.LECTURE_DESCRIPTION,
description="The description of the lecture",
name=LectureSchema.COURSE_DESCRIPTION,
description="The description of the COURSE",
data_type=wvc.config.DataType.TEXT,
),
wvc.config.Property(
Expand Down

0 comments on commit bc69e96

Please sign in to comment.