Skip to content

Commit

Permalink
Ingester & Retriever: Add support for Weaviate (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
yassinsws authored May 8, 2024
1 parent 33fa03b commit 659b64b
Show file tree
Hide file tree
Showing 14 changed files with 347 additions and 9 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
WEAVIATE_HOST=
WEAVIATE_PORT=
15 changes: 8 additions & 7 deletions app/domain/data/lecture_unit_dto.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from datetime import datetime
from typing import Optional

from pydantic import BaseModel, Field


class LectureUnitDTO(BaseModel):
id: int
to_update: bool = Field(alias="toUpdate")
pdf_file_base64: str = Field(alias="pdfFile")
lecture_unit_id: int = Field(alias="lectureUnitId")
lecture_unit_name: str = Field(alias="lectureUnitName")
lecture_id: int = Field(alias="lectureId")
release_date: Optional[datetime] = Field(alias="releaseDate", default=None)
name: Optional[str] = None
attachment_version: int = Field(alias="attachmentVersion")
lecture_name: str = Field(alias="lectureName")
course_id: int = Field(alias="courseId")
course_name: str = Field(alias="courseName")
course_description: str = Field(alias="courseDescription")
Empty file added app/ingestion/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions app/ingestion/abstract_ingestion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from abc import ABC, abstractmethod
from typing import List, Dict


class AbstractIngestion(ABC):
"""
Abstract class for ingesting repositories into a database.
"""

@abstractmethod
def chunk_data(self, path: str) -> List[Dict[str, str]]:
"""
Abstract method to chunk code files in the root directory.
"""
pass

@abstractmethod
def ingest(self, path: str) -> bool:
"""
Abstract method to ingest repositories into the database.
"""
pass

@abstractmethod
def update(self, path: str):
"""
Abstract method to update a repository in the database.
"""
pass
2 changes: 1 addition & 1 deletion app/pipeline/chat/tutor_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _add_student_repository_to_prompt(
for file in selected_files:
if file in student_repository:
self.prompt += SystemMessagePromptTemplate.from_template(
f"For reference, we have access to the student's '{file}' file:"
f"For reference, we have access to the student's '{file}' file: "
)
self.prompt += HumanMessagePromptTemplate.from_template(
student_repository[file].replace("{", "{{").replace("}", "}}")
Expand Down
Empty file added app/retrieval/__init__.py
Empty file.
15 changes: 15 additions & 0 deletions app/retrieval/abstract_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from abc import ABC, abstractmethod
from typing import List


class AbstractRetrieval(ABC):
"""
Abstract class for retrieving data from a database.
"""

@abstractmethod
def retrieve(self, path: str, hybrid_factor: float, result_limit: int) -> List[str]:
"""
Abstract method to retrieve data from the database.
"""
pass
43 changes: 43 additions & 0 deletions app/retrieval/lecture_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from abc import ABC
from typing import List

from weaviate import WeaviateClient
from weaviate.classes.query import Filter

from app.retrieval.abstract_retrieval import AbstractRetrieval
from app.vector_database.lecture_schema import init_lecture_schema, LectureSchema


class LectureRetrieval(AbstractRetrieval, ABC):
"""
Class for retrieving lecture data from the database.
"""

def __init__(self, client: WeaviateClient):
self.collection = init_lecture_schema(client)

def retrieve(
self,
user_message: str,
hybrid_factor: float,
result_limit: int,
lecture_id: int = None,
message_vector: [float] = None,
) -> List[str]:
response = self.collection.query.hybrid(
query=user_message,
filters=(
Filter.by_property(LectureSchema.LECTURE_ID.value).equal(lecture_id)
if lecture_id
else None
),
alpha=hybrid_factor,
vector=message_vector,
return_properties=[
LectureSchema.PAGE_TEXT_CONTENT.value,
LectureSchema.PAGE_IMAGE_DESCRIPTION.value,
LectureSchema.COURSE_NAME.value,
],
limit=result_limit,
)
return response
45 changes: 45 additions & 0 deletions app/retrieval/repositories_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import List

from weaviate import WeaviateClient
from weaviate.classes.query import Filter

from app.retrieval.abstract_retrieval import AbstractRetrieval
from app.vector_database.repository_schema import (
init_repository_schema,
RepositorySchema,
)


class RepositoryRetrieval(AbstractRetrieval):
"""
Class for Retrieving repository code for from the vector database.
"""

def __init__(self, client: WeaviateClient):
self.collection = init_repository_schema(client)

def retrieve(
self,
user_message: str,
result_limit: int,
repository_id: int = None,
) -> List[str]:
response = self.collection.query.near_text(
near_text=user_message,
filters=(
Filter.by_property(RepositorySchema.REPOSITORY_ID.value).equal(
repository_id
)
if repository_id
else None
),
return_properties=[
RepositorySchema.REPOSITORY_ID.value,
RepositorySchema.COURSE_ID.value,
RepositorySchema.CONTENT.value,
RepositorySchema.EXERCISE_ID.value,
RepositorySchema.FILEPATH.value,
],
limit=result_limit,
)
return response
Empty file added app/vector_database/__init__.py
Empty file.
44 changes: 44 additions & 0 deletions app/vector_database/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging
import os
import weaviate
from .lecture_schema import init_lecture_schema
from .repository_schema import init_repository_schema
import weaviate.classes as wvc

logger = logging.getLogger(__name__)


class VectorDatabase:
"""
Class to interact with the Weaviate vector database
"""

def __init__(self):
self.client = weaviate.connect_to_wcs(
cluster_url=os.getenv("WEAVIATE_CLUSTER_URL"),
auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_AUTH_KEY")),
)
self.repositories = init_repository_schema(self.client)
self.lectures = init_lecture_schema(self.client)

def __del__(self):
self.client.close()

def delete_collection(self, collection_name):
"""
Delete a collection from the database
"""
if self.client.collections.exists(collection_name):
if self.client.collections.delete(collection_name):
logger.info(f"Collection {collection_name} deleted")
else:
logger.error(f"Collection {collection_name} failed to delete")

def delete_object(self, collection_name, property_name, object_property):
"""
Delete an object from the collection inside the databse
"""
collection = self.client.collections.get(collection_name)
collection.data.delete_many(
where=wvc.query.Filter.by_property(property_name).equal(object_property)
)
97 changes: 97 additions & 0 deletions app/vector_database/lecture_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from enum import Enum

from weaviate.classes.config import Property
from weaviate import WeaviateClient
from weaviate.collections import Collection
from weaviate.collections.classes.config import Configure, VectorDistances, DataType


class LectureSchema(Enum):
"""
Schema for the lecture slides
"""

COLLECTION_NAME = "LectureSlides"
COURSE_NAME = "course_name"
COURSE_DESCRIPTION = "course_description"
COURSE_ID = "course_id"
LECTURE_ID = "lecture_id"
LECTURE_NAME = "lecture_name"
LECTURE_UNIT_ID = "lecture_unit_id"
LECTURE_UNIT_NAME = "lecture_unit_name"
PAGE_TEXT_CONTENT = "page_text_content"
PAGE_IMAGE_DESCRIPTION = "page_image_explanation"
PAGE_BASE64 = "page_base64"
PAGE_NUMBER = "page_number"


def init_lecture_schema(client: WeaviateClient) -> Collection:
"""
Initialize the schema for the lecture slides
"""
if client.collections.exists(LectureSchema.COLLECTION_NAME.value):
return client.collections.get(LectureSchema.COLLECTION_NAME.value)
return client.collections.create(
name=LectureSchema.COLLECTION_NAME.value,
vectorizer_config=Configure.Vectorizer.none(),
vector_index_config=Configure.VectorIndex.hnsw(
distance_metric=VectorDistances.COSINE
),
properties=[
Property(
name=LectureSchema.COURSE_ID.value,
description="The ID of the course",
data_type=DataType.INT,
),
Property(
name=LectureSchema.COURSE_NAME.value,
description="The name of the course",
data_type=DataType.TEXT,
),
Property(
name=LectureSchema.COURSE_DESCRIPTION.value,
description="The description of the COURSE",
data_type=DataType.TEXT,
),
Property(
name=LectureSchema.LECTURE_ID.value,
description="The ID of the lecture",
data_type=DataType.INT,
),
Property(
name=LectureSchema.LECTURE_NAME.value,
description="The name of the lecture",
data_type=DataType.TEXT,
),
Property(
name=LectureSchema.LECTURE_UNIT_ID.value,
description="The ID of the lecture unit",
data_type=DataType.INT,
),
Property(
name=LectureSchema.LECTURE_UNIT_NAME.value,
description="The name of the lecture unit",
data_type=DataType.TEXT,
),
Property(
name=LectureSchema.PAGE_TEXT_CONTENT.value,
description="The original text content from the slide",
data_type=DataType.TEXT,
),
Property(
name=LectureSchema.PAGE_IMAGE_DESCRIPTION.value,
description="The description of the slide if the slide contains an image",
data_type=DataType.TEXT,
),
Property(
name=LectureSchema.PAGE_BASE64.value,
description="The base64 encoded image of the slide if the slide contains an image",
data_type=DataType.TEXT,
),
Property(
name=LectureSchema.PAGE_NUMBER.value,
description="The page number of the slide",
data_type=DataType.INT,
),
],
)
60 changes: 60 additions & 0 deletions app/vector_database/repository_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from enum import Enum
from weaviate.classes.config import Property
from weaviate import WeaviateClient
from weaviate.collections import Collection
from weaviate.collections.classes.config import Configure, VectorDistances, DataType


class RepositorySchema(Enum):
"""
Schema for the student repository
"""

COLLECTION_NAME = "StudentRepository"
CONTENT = "content"
COURSE_ID = "course_id"
EXERCISE_ID = "exercise_id"
REPOSITORY_ID = "repository_id"
FILEPATH = "filepath"


def init_repository_schema(client: WeaviateClient) -> Collection:
"""
Initialize the schema for the student repository
"""
if client.collections.exists(RepositorySchema.COLLECTION_NAME.value):
return client.collections.get(RepositorySchema.COLLECTION_NAME.value)
return client.collections.create(
name=RepositorySchema.COLLECTION_NAME.value,
vectorizer_config=Configure.Vectorizer.none(),
vector_index_config=Configure.VectorIndex.hnsw(
distance_metric=VectorDistances.COSINE
),
properties=[
Property(
name=RepositorySchema.CONTENT.value,
description="The content of this chunk of code",
data_type=DataType.TEXT,
),
Property(
name=RepositorySchema.COURSE_ID.value,
description="The ID of the course",
data_type=DataType.INT,
),
Property(
name=RepositorySchema.EXERCISE_ID.value,
description="The ID of the exercise",
data_type=DataType.INT,
),
Property(
name=RepositorySchema.REPOSITORY_ID.value,
description="The ID of the repository",
data_type=DataType.INT,
),
Property(
name=RepositorySchema.FILEPATH.value,
description="The filepath of the code",
data_type=DataType.TEXT,
),
],
)
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ ollama==0.1.9
openai==1.25.2
pre-commit==3.7.0
pydantic==2.7.1
PyMuPDF==1.23.22
PyYAML==6.0.1
requests~=2.31.0
uvicorn==0.29.0
requests~=2.31.0
weaviate-client==4.5.4

0 comments on commit 659b64b

Please sign in to comment.