-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathlangchain_bot.py
73 lines (62 loc) · 2.43 KB
/
langchain_bot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from langchain.llms import OpenAI
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.docstore.document import Document
import requests
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS
from langchain.text_splitter import CharacterTextSplitter
import pathlib
import subprocess
import tempfile
from dagster import asset
from dagster import FreshnessPolicy, RetryPolicy
import pickle
def get_github_docs(repo_owner, repo_name):
with tempfile.TemporaryDirectory() as d:
subprocess.check_call(
f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .",
cwd=d,
shell=True,
)
git_sha = (
subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d)
.decode("utf-8")
.strip()
)
repo_path = pathlib.Path(d)
markdown_files = list(repo_path.glob("**/*.md")) + list(
repo_path.glob("**/*.mdx")
)
for markdown_file in markdown_files:
with open(markdown_file, "r") as f:
relative_path = markdown_file.relative_to(repo_path)
github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}"
yield Document(page_content=f.read(), metadata={"source": github_url})
@asset
def source_docs():
return list(get_github_docs("dagster-io", "dagster"))
@asset(
retry_policy=RetryPolicy(max_retries=5, delay=5),
freshness_policy=FreshnessPolicy(maximum_lag_minutes=60 * 24),
)
def search_index(source_docs):
source_chunks = []
splitter = CharacterTextSplitter(separator=" ", chunk_size=1024, chunk_overlap=0)
for source in source_docs:
for chunk in splitter.split_text(source.page_content):
source_chunks.append(Document(page_content=chunk, metadata=source.metadata))
with open("search_index.pickle", "wb") as f:
pickle.dump(FAISS.from_documents(source_chunks, OpenAIEmbeddings()), f)
chain = load_qa_with_sources_chain(OpenAI(temperature=0))
def print_answer(question):
with open("search_index.pickle", "rb") as f:
search_index = pickle.load(f)
print(
chain(
{
"input_documents": search_index.similarity_search(question, k=4),
"question": question,
},
return_only_outputs=True,
)["output_text"]
)