-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
134 lines (121 loc) · 6.52 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import asyncio
from functools import lru_cache
from typing import List
from openai import OpenAI
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from config import (
key, CHUNK_SIZE, CHUNK_OVERLAP, EMBEDDING_MODEL,
GPT_MODEL, TEMPERATURE, MAX_TOKENS, DEFAULT_SIMILAR_DOCS_COUNT
)
class VineyardAssistant:
"""
Класс для обработки запросов о виноградарстве с использованием векторного поиска и GPT.
"""
def __init__(self):
"""Инициализация ассистента с настройкой всех необходимых компонентов."""
self.client = OpenAI(api_key=key)
self.embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
encode_kwargs={'normalize_embeddings': True}
)
self.vector_store = None
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
length_function=len,
separators=["\n\n", "\n", " ", ""],
)
self.data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
self.initialize_vector_store()
@staticmethod
def clean_text(text: str) -> str:
"""Очистка текста от лишних пробелов и пустых строк."""
if not isinstance(text, str):
raise ValueError("Input must be a string")
text = '\n'.join(line for line in text.splitlines() if line.strip())
text = ' '.join(text.split())
return text.strip()
def initialize_vector_store(self):
"""Инициализация векторного хранилища."""
try:
index_path = os.path.join(self.data_dir, "faiss_index")
if os.path.exists(index_path):
self.vector_store = FAISS.load_local(
index_path,
self.embeddings,
allow_dangerous_deserialization=True
)
else:
documents = self.load_training_data()
self.vector_store = FAISS.from_documents(documents, self.embeddings)
self.vector_store.save_local(index_path)
except Exception as e:
print(f"Error initializing vector store: {e}")
raise
def load_training_data(self) -> List[Document]:
"""Загрузка обучающих данных из текстовых файлов."""
documents = []
for filename in os.listdir(self.data_dir):
if filename.endswith(".txt"):
try:
with open(os.path.join(self.data_dir, filename), 'r', encoding='utf-8') as f:
text = f.read()
cleaned_text = self.clean_text(text)
documents.extend(self.text_splitter.create_documents([cleaned_text]))
except Exception as e:
print(f"Error loading file {filename}: {str(e)}")
return documents
@staticmethod
def preprocess_query(query: str) -> str:
"""Предобработка запроса пользователя."""
if not query or not isinstance(query, str):
raise ValueError("Query must be a non-empty string")
query = query.strip().lower()
return query
@lru_cache(maxsize=100)
def get_similar_documents(self, query: str, k: int = DEFAULT_SIMILAR_DOCS_COUNT) -> List[Document]:
"""Получение похожих документов из векторного хранилища."""
retriever = self.vector_store.as_retriever(search_kwargs={"k": k})
similar_docs = retriever.get_relevant_documents(query)
cleaned_docs = [
Document(page_content=self.clean_text(doc.page_content))
for doc in similar_docs
]
return cleaned_docs
async def process_query(self, user_query: str) -> str:
"""Асинхронная обработка запроса пользователя."""
try:
if not user_query or not isinstance(user_query, str):
raise ValueError("Query must be a non-empty string")
processed_query = self.preprocess_query(user_query)
similar_docs = self.get_similar_documents(processed_query)
context = "\n\n".join(doc.page_content for doc in similar_docs)
response = await asyncio.to_thread(
self.client.chat.completions.create,
model=GPT_MODEL,
messages=[
{"role": "system", "content": """Вы - эксперт по виноградарству компании Ceres Pro.
Ваша задача - отвечать на вопросы пользователя.
Если в предоставленном контексте есть информация, используйте её для ответа.
Если информации недостаточно, опирайтесь на собственные знания.
Пишите ответ так, чтобы он был полезен и информативен.
В конце каждого ответа спрашивайте, остались ли у пользователя ещё вопросы."""},
{"role": "user", "content": f"Контекст:\n\n{context}\n\nВопрос: {user_query}\n\n"}
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
)
answer = response.choices[0].message.content
return answer
except Exception as e:
error_msg = f"Error processing query: {str(e)}"
print(error_msg)
return "Произошла ошибка при обработке запроса. Пожалуйста, попробуйте еще раз или переформулируйте вопрос."
async def process_queries_batch(self, queries: List[str]) -> List[str]:
"""Пакетная обработка нескольких запросов."""
return await asyncio.gather(*(self.process_query(query) for query in queries))