-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Ingester & Retriever
: Add support for Weaviate (#64)
- Loading branch information
Showing
14 changed files
with
347 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
WEAVIATE_HOST= | ||
WEAVIATE_PORT= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
from datetime import datetime | ||
from typing import Optional | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class LectureUnitDTO(BaseModel): | ||
id: int | ||
to_update: bool = Field(alias="toUpdate") | ||
pdf_file_base64: str = Field(alias="pdfFile") | ||
lecture_unit_id: int = Field(alias="lectureUnitId") | ||
lecture_unit_name: str = Field(alias="lectureUnitName") | ||
lecture_id: int = Field(alias="lectureId") | ||
release_date: Optional[datetime] = Field(alias="releaseDate", default=None) | ||
name: Optional[str] = None | ||
attachment_version: int = Field(alias="attachmentVersion") | ||
lecture_name: str = Field(alias="lectureName") | ||
course_id: int = Field(alias="courseId") | ||
course_name: str = Field(alias="courseName") | ||
course_description: str = Field(alias="courseDescription") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List, Dict | ||
|
||
|
||
class AbstractIngestion(ABC): | ||
""" | ||
Abstract class for ingesting repositories into a database. | ||
""" | ||
|
||
@abstractmethod | ||
def chunk_data(self, path: str) -> List[Dict[str, str]]: | ||
""" | ||
Abstract method to chunk code files in the root directory. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def ingest(self, path: str) -> bool: | ||
""" | ||
Abstract method to ingest repositories into the database. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def update(self, path: str): | ||
""" | ||
Abstract method to update a repository in the database. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
|
||
class AbstractRetrieval(ABC): | ||
""" | ||
Abstract class for retrieving data from a database. | ||
""" | ||
|
||
@abstractmethod | ||
def retrieve(self, path: str, hybrid_factor: float, result_limit: int) -> List[str]: | ||
""" | ||
Abstract method to retrieve data from the database. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from abc import ABC | ||
from typing import List | ||
|
||
from weaviate import WeaviateClient | ||
from weaviate.classes.query import Filter | ||
|
||
from app.retrieval.abstract_retrieval import AbstractRetrieval | ||
from app.vector_database.lecture_schema import init_lecture_schema, LectureSchema | ||
|
||
|
||
class LectureRetrieval(AbstractRetrieval, ABC): | ||
""" | ||
Class for retrieving lecture data from the database. | ||
""" | ||
|
||
def __init__(self, client: WeaviateClient): | ||
self.collection = init_lecture_schema(client) | ||
|
||
def retrieve( | ||
self, | ||
user_message: str, | ||
hybrid_factor: float, | ||
result_limit: int, | ||
lecture_id: int = None, | ||
message_vector: [float] = None, | ||
) -> List[str]: | ||
response = self.collection.query.hybrid( | ||
query=user_message, | ||
filters=( | ||
Filter.by_property(LectureSchema.LECTURE_ID.value).equal(lecture_id) | ||
if lecture_id | ||
else None | ||
), | ||
alpha=hybrid_factor, | ||
vector=message_vector, | ||
return_properties=[ | ||
LectureSchema.PAGE_TEXT_CONTENT.value, | ||
LectureSchema.PAGE_IMAGE_DESCRIPTION.value, | ||
LectureSchema.COURSE_NAME.value, | ||
], | ||
limit=result_limit, | ||
) | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from typing import List | ||
|
||
from weaviate import WeaviateClient | ||
from weaviate.classes.query import Filter | ||
|
||
from app.retrieval.abstract_retrieval import AbstractRetrieval | ||
from app.vector_database.repository_schema import ( | ||
init_repository_schema, | ||
RepositorySchema, | ||
) | ||
|
||
|
||
class RepositoryRetrieval(AbstractRetrieval): | ||
""" | ||
Class for Retrieving repository code for from the vector database. | ||
""" | ||
|
||
def __init__(self, client: WeaviateClient): | ||
self.collection = init_repository_schema(client) | ||
|
||
def retrieve( | ||
self, | ||
user_message: str, | ||
result_limit: int, | ||
repository_id: int = None, | ||
) -> List[str]: | ||
response = self.collection.query.near_text( | ||
near_text=user_message, | ||
filters=( | ||
Filter.by_property(RepositorySchema.REPOSITORY_ID.value).equal( | ||
repository_id | ||
) | ||
if repository_id | ||
else None | ||
), | ||
return_properties=[ | ||
RepositorySchema.REPOSITORY_ID.value, | ||
RepositorySchema.COURSE_ID.value, | ||
RepositorySchema.CONTENT.value, | ||
RepositorySchema.EXERCISE_ID.value, | ||
RepositorySchema.FILEPATH.value, | ||
], | ||
limit=result_limit, | ||
) | ||
return response |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import logging | ||
import os | ||
import weaviate | ||
from .lecture_schema import init_lecture_schema | ||
from .repository_schema import init_repository_schema | ||
import weaviate.classes as wvc | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class VectorDatabase: | ||
""" | ||
Class to interact with the Weaviate vector database | ||
""" | ||
|
||
def __init__(self): | ||
self.client = weaviate.connect_to_wcs( | ||
cluster_url=os.getenv("WEAVIATE_CLUSTER_URL"), | ||
auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_AUTH_KEY")), | ||
) | ||
self.repositories = init_repository_schema(self.client) | ||
self.lectures = init_lecture_schema(self.client) | ||
|
||
def __del__(self): | ||
self.client.close() | ||
|
||
def delete_collection(self, collection_name): | ||
""" | ||
Delete a collection from the database | ||
""" | ||
if self.client.collections.exists(collection_name): | ||
if self.client.collections.delete(collection_name): | ||
logger.info(f"Collection {collection_name} deleted") | ||
else: | ||
logger.error(f"Collection {collection_name} failed to delete") | ||
|
||
def delete_object(self, collection_name, property_name, object_property): | ||
""" | ||
Delete an object from the collection inside the databse | ||
""" | ||
collection = self.client.collections.get(collection_name) | ||
collection.data.delete_many( | ||
where=wvc.query.Filter.by_property(property_name).equal(object_property) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from enum import Enum | ||
|
||
from weaviate.classes.config import Property | ||
from weaviate import WeaviateClient | ||
from weaviate.collections import Collection | ||
from weaviate.collections.classes.config import Configure, VectorDistances, DataType | ||
|
||
|
||
class LectureSchema(Enum): | ||
""" | ||
Schema for the lecture slides | ||
""" | ||
|
||
COLLECTION_NAME = "LectureSlides" | ||
COURSE_NAME = "course_name" | ||
COURSE_DESCRIPTION = "course_description" | ||
COURSE_ID = "course_id" | ||
LECTURE_ID = "lecture_id" | ||
LECTURE_NAME = "lecture_name" | ||
LECTURE_UNIT_ID = "lecture_unit_id" | ||
LECTURE_UNIT_NAME = "lecture_unit_name" | ||
PAGE_TEXT_CONTENT = "page_text_content" | ||
PAGE_IMAGE_DESCRIPTION = "page_image_explanation" | ||
PAGE_BASE64 = "page_base64" | ||
PAGE_NUMBER = "page_number" | ||
|
||
|
||
def init_lecture_schema(client: WeaviateClient) -> Collection: | ||
""" | ||
Initialize the schema for the lecture slides | ||
""" | ||
if client.collections.exists(LectureSchema.COLLECTION_NAME.value): | ||
return client.collections.get(LectureSchema.COLLECTION_NAME.value) | ||
return client.collections.create( | ||
name=LectureSchema.COLLECTION_NAME.value, | ||
vectorizer_config=Configure.Vectorizer.none(), | ||
vector_index_config=Configure.VectorIndex.hnsw( | ||
distance_metric=VectorDistances.COSINE | ||
), | ||
properties=[ | ||
Property( | ||
name=LectureSchema.COURSE_ID.value, | ||
description="The ID of the course", | ||
data_type=DataType.INT, | ||
), | ||
Property( | ||
name=LectureSchema.COURSE_NAME.value, | ||
description="The name of the course", | ||
data_type=DataType.TEXT, | ||
), | ||
Property( | ||
name=LectureSchema.COURSE_DESCRIPTION.value, | ||
description="The description of the COURSE", | ||
data_type=DataType.TEXT, | ||
), | ||
Property( | ||
name=LectureSchema.LECTURE_ID.value, | ||
description="The ID of the lecture", | ||
data_type=DataType.INT, | ||
), | ||
Property( | ||
name=LectureSchema.LECTURE_NAME.value, | ||
description="The name of the lecture", | ||
data_type=DataType.TEXT, | ||
), | ||
Property( | ||
name=LectureSchema.LECTURE_UNIT_ID.value, | ||
description="The ID of the lecture unit", | ||
data_type=DataType.INT, | ||
), | ||
Property( | ||
name=LectureSchema.LECTURE_UNIT_NAME.value, | ||
description="The name of the lecture unit", | ||
data_type=DataType.TEXT, | ||
), | ||
Property( | ||
name=LectureSchema.PAGE_TEXT_CONTENT.value, | ||
description="The original text content from the slide", | ||
data_type=DataType.TEXT, | ||
), | ||
Property( | ||
name=LectureSchema.PAGE_IMAGE_DESCRIPTION.value, | ||
description="The description of the slide if the slide contains an image", | ||
data_type=DataType.TEXT, | ||
), | ||
Property( | ||
name=LectureSchema.PAGE_BASE64.value, | ||
description="The base64 encoded image of the slide if the slide contains an image", | ||
data_type=DataType.TEXT, | ||
), | ||
Property( | ||
name=LectureSchema.PAGE_NUMBER.value, | ||
description="The page number of the slide", | ||
data_type=DataType.INT, | ||
), | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from enum import Enum | ||
from weaviate.classes.config import Property | ||
from weaviate import WeaviateClient | ||
from weaviate.collections import Collection | ||
from weaviate.collections.classes.config import Configure, VectorDistances, DataType | ||
|
||
|
||
class RepositorySchema(Enum): | ||
""" | ||
Schema for the student repository | ||
""" | ||
|
||
COLLECTION_NAME = "StudentRepository" | ||
CONTENT = "content" | ||
COURSE_ID = "course_id" | ||
EXERCISE_ID = "exercise_id" | ||
REPOSITORY_ID = "repository_id" | ||
FILEPATH = "filepath" | ||
|
||
|
||
def init_repository_schema(client: WeaviateClient) -> Collection: | ||
""" | ||
Initialize the schema for the student repository | ||
""" | ||
if client.collections.exists(RepositorySchema.COLLECTION_NAME.value): | ||
return client.collections.get(RepositorySchema.COLLECTION_NAME.value) | ||
return client.collections.create( | ||
name=RepositorySchema.COLLECTION_NAME.value, | ||
vectorizer_config=Configure.Vectorizer.none(), | ||
vector_index_config=Configure.VectorIndex.hnsw( | ||
distance_metric=VectorDistances.COSINE | ||
), | ||
properties=[ | ||
Property( | ||
name=RepositorySchema.CONTENT.value, | ||
description="The content of this chunk of code", | ||
data_type=DataType.TEXT, | ||
), | ||
Property( | ||
name=RepositorySchema.COURSE_ID.value, | ||
description="The ID of the course", | ||
data_type=DataType.INT, | ||
), | ||
Property( | ||
name=RepositorySchema.EXERCISE_ID.value, | ||
description="The ID of the exercise", | ||
data_type=DataType.INT, | ||
), | ||
Property( | ||
name=RepositorySchema.REPOSITORY_ID.value, | ||
description="The ID of the repository", | ||
data_type=DataType.INT, | ||
), | ||
Property( | ||
name=RepositorySchema.FILEPATH.value, | ||
description="The filepath of the code", | ||
data_type=DataType.TEXT, | ||
), | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters