Skip to content

Commit

Permalink
Merge branch 'main' into feat/sources-in-react-widget
Browse files Browse the repository at this point in the history
  • Loading branch information
utin-francis-peter committed Nov 17, 2024
2 parents 6f83bd8 + 5971ff8 commit ba59042
Show file tree
Hide file tree
Showing 31 changed files with 604 additions and 101 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ docker compose -f docker-compose-dev.yaml up -d
> Make sure you have Python 3.10 or 3.11 installed.
1. Export required environment variables or prepare a `.env` file in the project folder:
- Copy [.env_sample](https://github.com/arc53/DocsGPT/blob/main/application/.env_sample) and create `.env`.
- Copy [.env-template](https://github.com/arc53/DocsGPT/blob/main/application/.env-template) and create `.env`.

(check out [`application/core/settings.py`](application/core/settings.py) if you want to see more config options.)

Expand Down
1 change: 1 addition & 0 deletions application/api/answer/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def complete_stream(
yield f"data: {data}\n\n"
except Exception as e:
print("\033[91merr", str(e), file=sys.stderr)
traceback.print_exc()
data = json.dumps(
{
"type": "error",
Expand Down
115 changes: 105 additions & 10 deletions application/api/user/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import os
import shutil
import uuid
import math

from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, jsonify, make_response, request
from flask import Blueprint, jsonify, make_response, request, redirect
from flask_restx import inputs, fields, Namespace, Resource
from werkzeug.utils import secure_filename

Expand Down Expand Up @@ -315,14 +316,34 @@ def post(self):
for file in files:
filename = secure_filename(file.filename)
file.save(os.path.join(temp_dir, filename))

print(f"Saved file: {filename}")
zip_path = shutil.make_archive(
base_name=os.path.join(save_dir, job_name),
format="zip",
root_dir=temp_dir,
)
final_filename = os.path.basename(zip_path)
shutil.rmtree(temp_dir)
task = ingest.delay(
settings.UPLOAD_FOLDER,
[
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
],
job_name,
final_filename,
user,
)
else:
file = files[0]
final_filename = secure_filename(file.filename)
Expand All @@ -349,9 +370,10 @@ def post(self):
final_filename,
user,
)

except Exception as err:
print(f"Error: {err}")
return make_response(jsonify({"success": False, "error": str(err)}), 400)

return make_response(jsonify({"success": True, "task_id": task.id}), 200)


Expand Down Expand Up @@ -422,19 +444,82 @@ def get(self):

task = celery.AsyncResult(task_id)
task_meta = task.info
print(f"Task status: {task.status}")
if not isinstance(
task_meta, (dict, list, str, int, float, bool, type(None))
):
task_meta = str(task_meta) # Convert to a string representation
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)

return make_response(jsonify({"status": task.status, "result": task_meta}), 200)


@user_ns.route("/api/combine")
class RedirectToSources(Resource):
@api.doc(
description="Redirects /api/combine to /api/sources for backward compatibility"
)
def get(self):
return redirect("/api/sources", code=301)


@user_ns.route("/api/sources/paginated")
class PaginatedSources(Resource):
@api.doc(description="Get document with pagination, sorting and filtering")
def get(self):
user = "local"
sort_field = request.args.get("sort", "date") # Default to 'date'
sort_order = request.args.get("order", "desc") # Default to 'desc'
page = int(request.args.get("page", 1)) # Default to 1
rows_per_page = int(request.args.get("rows", 10)) # Default to 10

# Prepare
query = {"user": user}
total_documents = sources_collection.count_documents(query)
total_pages = max(1, math.ceil(total_documents / rows_per_page))
sort_order = 1 if sort_order == "asc" else -1
skip = (page - 1) * rows_per_page

try:
documents = (
sources_collection.find(query)
.sort(sort_field, sort_order)
.skip(skip)
.limit(rows_per_page)
)

paginated_docs = []
for doc in documents:
doc_data = {
"id": str(doc["_id"]),
"name": doc.get("name", ""),
"date": doc.get("date", ""),
"model": settings.EMBEDDINGS_NAME,
"location": "local",
"tokens": doc.get("tokens", ""),
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
}
paginated_docs.append(doc_data)

response = {
"total": total_documents,
"totalPages": total_pages,
"currentPage": page,
"paginated": paginated_docs,
}
return make_response(jsonify(response), 200)

except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)


@user_ns.route("/api/sources")
class CombinedJson(Resource):
@api.doc(description="Provide JSON file with combined available indexes")
def get(self):
user = "local"
sort_field = request.args.get('sort', 'date') # Default to 'date'
sort_order = request.args.get('order', "desc") # Default to 'desc'
data = [
{
"name": "default",
Expand All @@ -447,7 +532,7 @@ def get(self):
]

try:
for index in sources_collection.find({"user": user}).sort(sort_field, 1 if sort_order=="asc" else -1):
for index in sources_collection.find({"user": user}).sort("date", -1):
data.append(
{
"id": str(index["_id"]),
Expand Down Expand Up @@ -485,6 +570,7 @@ def get(self):
"retriever": "brave_search",
}
)

except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)

Expand Down Expand Up @@ -1674,7 +1760,9 @@ class TextToSpeech(Resource):
tts_model = api.model(
"TextToSpeechModel",
{
"text": fields.String(required=True, description="Text to be synthesized as audio"),
"text": fields.String(
required=True, description="Text to be synthesized as audio"
),
},
)

Expand All @@ -1686,8 +1774,15 @@ def post(self):
try:
tts_instance = GoogleTTS()
audio_base64, detected_language = tts_instance.text_to_speech(text)
return make_response(jsonify({"success": True,'audio_base64': audio_base64,'lang':detected_language}), 200)
return make_response(
jsonify(
{
"success": True,
"audio_base64": audio_base64,
"lang": detected_language,
}
),
200,
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)


48 changes: 48 additions & 0 deletions application/llm/google_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from application.llm.base import BaseLLM

class GoogleLLM(BaseLLM):

def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):

super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key

def _clean_messages_google(self, messages):
return [
{
"role": "model" if message["role"] == "system" else message["role"],
"parts": [message["content"]],
}
for message in messages[1:]
]

def _raw_gen(
self,
baseself,
model,
messages,
stream=False,
**kwargs
):
import google.generativeai as genai
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(model, system_instruction=messages[0]["content"])
response = model.generate_content(self._clean_messages_google(messages))
return response.text

def _raw_gen_stream(
self,
baseself,
model,
messages,
stream=True,
**kwargs
):
import google.generativeai as genai
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(model, system_instruction=messages[0]["content"])
response = model.generate_content(self._clean_messages_google(messages), stream=True)
for line in response:
if line.text is not None:
yield line.text
4 changes: 3 additions & 1 deletion application/llm/llm_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.premai import PremAILLM
from application.llm.google_ai import GoogleLLM


class LLMCreator:
Expand All @@ -18,7 +19,8 @@ class LLMCreator:
"anthropic": AnthropicLLM,
"docsgpt": DocsGPTAPILLM,
"premai": PremAILLM,
"groq": GroqLLM
"groq": GroqLLM,
"google": GoogleLLM
}

@classmethod
Expand Down
11 changes: 10 additions & 1 deletion application/parser/remote/reddit_loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from application.parser.remote.base import BaseRemote
from langchain_community.document_loaders import RedditPostsLoader
import json


class RedditPostsLoaderRemote(BaseRemote):
def load_data(self, inputs):
data = eval(inputs)
try:
data = json.loads(inputs)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON input: {e}")

required_fields = ["client_id", "client_secret", "user_agent", "search_queries"]
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
client_id = data.get("client_id")
client_secret = data.get("client_secret")
user_agent = data.get("user_agent")
Expand Down
3 changes: 0 additions & 3 deletions application/retriever/classic_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def _get_data(self):
settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
)
docs_temp = docsearch.search(self.question, k=self.chunks)
print(docs_temp)
docs = [
{
"title": i.metadata.get(
Expand All @@ -60,8 +59,6 @@ def _get_data(self):
}
for i in docs_temp
]
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]

return docs

Expand Down
2 changes: 1 addition & 1 deletion docs/pages/_app.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export default function MyApp({ Component, pageProps }) {
return (
<>
<Component {...pageProps} />
<DocsGPTWidget apiKey="d61a020c-ac8f-4f23-bb98-458e4da3c240" theme="dark" />
<DocsGPTWidget apiKey="d61a020c-ac8f-4f23-bb98-458e4da3c240" theme="dark" size="medium" />
</>
)
}
16 changes: 15 additions & 1 deletion frontend/src/Navigation.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import SourceDropdown from './components/SourceDropdown';
import {
setConversation,
updateConversationId,
handleAbort,
} from './conversation/conversationSlice';
import ConversationTile from './conversation/ConversationTile';
import { useDarkTheme, useMediaQuery, useOutsideAlerter } from './hooks';
Expand All @@ -34,10 +35,12 @@ import {
selectSelectedDocs,
selectSelectedDocsStatus,
selectSourceDocs,
selectPaginatedDocuments,
setConversations,
setModalStateDeleteConv,
setSelectedDocs,
setSourceDocs,
setPaginatedDocuments,
} from './preferences/preferenceSlice';
import Spinner from './assets/spinner.svg';
import SpinnerDark from './assets/spinner-dark.svg';
Expand Down Expand Up @@ -72,6 +75,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
const conversations = useSelector(selectConversations);
const modalStateDeleteConv = useSelector(selectModalStateDeleteConv);
const conversationId = useSelector(selectConversationId);
const paginatedDocuments = useSelector(selectPaginatedDocuments);
const [isDeletingConversation, setIsDeletingConversation] = useState(false);

const { isMobile } = useMediaQuery();
Expand Down Expand Up @@ -143,9 +147,18 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
})
.then((updatedDocs) => {
dispatch(setSourceDocs(updatedDocs));
const updatedPaginatedDocs = paginatedDocuments?.filter(
(document) => document.id !== doc.id,
);
dispatch(
setPaginatedDocuments(updatedPaginatedDocs || paginatedDocuments),
);
dispatch(
setSelectedDocs(
updatedDocs?.find((doc) => doc.name.toLowerCase() === 'default'),
Array.isArray(updatedDocs) &&
updatedDocs?.find(
(doc: Doc) => doc.name.toLowerCase() === 'default',
),
),
);
})
Expand All @@ -168,6 +181,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
};

const resetConversation = () => {
handleAbort();
dispatch(setConversation([]));
dispatch(
updateConversationId({
Expand Down
3 changes: 2 additions & 1 deletion frontend/src/api/endpoints.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
const endpoints = {
USER: {
DOCS: '/api/combine',
DOCS: '/api/sources',
DOCS_CHECK: '/api/docs_check',
DOCS_PAGINATED: '/api/sources/paginated',
API_KEYS: '/api/get_api_keys',
CREATE_API_KEY: '/api/create_api_key',
DELETE_API_KEY: '/api/delete_api_key',
Expand Down
Loading

0 comments on commit ba59042

Please sign in to comment.