Skip to content

Commit

Permalink
feat(RagScoringPipeline): propagate and expose document metadata to L…
Browse files Browse the repository at this point in the history
…LM context (references etc)
  • Loading branch information
nRamstedt committed Nov 15, 2024
1 parent 31e009d commit 1551c5e
Showing 1 changed file with 13 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from collections.abc import AsyncGenerator
from typing import Any

from langstream import Stream
from langstream import Stream, as_async_generator

from fai_backend.assistant.protocol import IAssistantContextStore, IAssistantPipelineStrategy
from fai_backend.collection.dependencies import get_collection_service
from fai_backend.llm.service import query_vector
from fai_backend.projects.dependencies import get_project_service
from fai_backend.vector.factory import vector_db
from fai_backend.vector.service import VectorService
Expand All @@ -20,25 +19,22 @@ async def create_pipeline(
async def run_rag_stream(query: list[str]):
collection_id = context_store.get_mutable().files_collection_id
vector_service = VectorService(vector_db=vector_db, collection_meta_service=get_collection_service())
vector_db_query_result = await query_vector(
vector_service=vector_service,

result = await vector_service.query_from_collection(
collection_name=collection_id,
query=query[0],
query_texts=[query[0]],
n_results=10,
)

documents: [str] = []

def store_and_return_document(document: str):
documents.append(document)
return document
documents, documents_metadata = result['documents'][0], result['metadatas'][0]

def append_score_to_documents(scores):
z = zip(documents, [s[0] for s in scores])
z = zip(documents, documents_metadata, [s[0] for s in scores])
return z

def sort_and_slice_documents(scored_documents, slice_size: int):
first_element = list(scored_documents)[0]
sorted_scores = sorted(first_element, key=lambda x: x[1], reverse=True)
sorted_scores = sorted(first_element, key=lambda x: x[2], reverse=True)
return sorted_scores[:slice_size]

projects = await get_project_service().read_projects()
Expand All @@ -61,8 +57,10 @@ async def scoring_stream(document: str) -> AsyncGenerator[str, None]:
yield score

full_stream = (
vector_db_query_result
.map(store_and_return_document)
Stream[None, str](
'QueryVectorStream',
lambda _: as_async_generator(*documents)
)
.map(scoring_stream)
.gather()
.and_then(append_score_to_documents)
Expand All @@ -75,7 +73,7 @@ async def scoring_stream(document: str) -> AsyncGenerator[str, None]:

def rag_postprocess(in_data: Any):
results: list[str] = in_data[0]['results']
concatenated = '\n'.join([s for (s, _) in results])
concatenated = '\n\n'.join([(s + '\n' + str(m)) for (s, m, _) in results])
context_store.get_mutable().rag_output = concatenated
return concatenated

Expand Down

0 comments on commit 1551c5e

Please sign in to comment.