-
Notifications
You must be signed in to change notification settings - Fork 0
/
extract.py
executable file
·174 lines (131 loc) · 5.3 KB
/
extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import asyncio
import argparse
import csv
from http import HTTPStatus
from typing import Dict, List
from io import StringIO
import requests
import tiktoken
from bs4 import BeautifulSoup
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.schema import ChatResult
from langchain.text_splitter import CharacterTextSplitter
def scrape_hacker_news_comments(post_id: int) -> List[Dict]:
url = f"https://news.ycombinator.com/item?id={post_id}"
response = requests.get(url)
if response.status_code != HTTPStatus.OK:
raise Exception(f"Failed to retrieve webpage ({response.status_code}): {url}")
soup = BeautifulSoup(response.content, "html.parser")
comments = []
for comment in soup.select(".athing.comtr"):
comment_id = comment["id"]
user = comment.select_one(".hnuser")
age = comment.select_one(".age")
text = comment.select_one(".commtext")
if all([user, age, text]):
comments.append(
{
"id": comment_id,
"user": user.text,
"age": age.text,
"text": text.get_text(separator="\n", strip=True),
}
)
return comments
def get_chat_prompt_template() -> ChatPromptTemplate:
system_message_template = """You will be provided a list of user comments.
Extract all books as a CSV, omit quotes and use pipes as a separator.
Return only the CSV.
Input:
===
I really enjoyed reading The Lord of the Rings by J.R.R. Tolkien
I don't read much these days
Zero to One' by Peter Thiel, thank me later.
More of a podcast guy myself
You should check out Black Swan by Nassim Taleb.
===
Output:
title|author
The Lord of the Rings|J.R.R. Tolkien
Zero to One|Peter Thiel
Black Swan|Nassim Taleb
Input:
===
{comments}
===
Output:
"""
system_message_prompt = SystemMessagePromptTemplate.from_template(system_message_template)
prompt = ChatPromptTemplate.from_messages([system_message_prompt])
return prompt
def parse_pipe_delimited_csv(csv_string) -> List[Dict]:
csv_reader = csv.reader(StringIO(csv_string), delimiter="|")
header = next(csv_reader)
return [dict(zip(header, row)) for row in csv_reader]
def process_raw_output(raw_output: ChatResult) -> List[Dict]:
items = []
for result in raw_output.generations:
raw_text = result[0].text
items.extend(parse_pipe_delimited_csv(raw_text))
return items
def get_number_of_tokens(model_name: str, _input: str) -> int:
return len(tiktoken.encoding_for_model(model_name).encode(_input))
def write_csv(items: List[Dict], filename: str) -> None:
if len(items) == 0:
raise Exception("No items to write")
with open(filename, "w") as f:
writer = csv.DictWriter(f, fieldnames=items[0].keys())
writer.writeheader()
writer.writerows(items)
async def main(args: argparse.Namespace):
prompt_template = get_chat_prompt_template()
# This is not quite right, but close enough;
# Langchain appends 'System: ' to the string, which is not part of the API request.
system_prompt_tokens_used = get_number_of_tokens(
model_name=args.model_name,
_input=prompt_template.format_prompt(comments="").to_string(), # format with no input, yielding only the prompt
)
remaining_tokens = args.model_token_limit - system_prompt_tokens_used
input_token_budget = remaining_tokens // 2
comments = scrape_hacker_news_comments(args.post_id)
concatenated_comments = "\n".join([f'"{comment["text"].strip()}"' for comment in comments])
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
separator="\n",
chunk_size=input_token_budget,
chunk_overlap=0,
encoding_name=tiktoken.encoding_for_model(args.model_name).name,
)
text_chunks = text_splitter.split_text(concatenated_comments)
messages = [prompt_template.format_prompt(comments=chunk).to_messages() for chunk in text_chunks]
chat = ChatOpenAI(temperature=0, model_name=args.model_name, request_timeout=args.timeout)
raw_output = await chat.agenerate(messages)
print("Total token usage: ", raw_output.llm_output["token_usage"]["total_tokens"])
books = process_raw_output(raw_output)
write_csv(books, args.output_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Scrape and process Hacker News comments")
parser.add_argument(
"-m",
"--model_name",
default="gpt-3.5-turbo",
help="Model name (default: gpt-3.5-turbo) see: https://platform.openai.com/docs/models",
)
parser.add_argument(
"-l",
"--model_token_limit",
type=int,
default=4096,
help="Model token limit (default: 4096) see: https://platform.openai.com/docs/models",
)
parser.add_argument("-t", "--timeout", type=int, default=120, help="Request timeout (default: 120)")
parser.add_argument("-p", "--post_id", type=int, required=True, help="Hacker News post ID")
parser.add_argument(
"-o",
"--output_file",
help="Output file name (default: <post_id>.csv)",
)
args = parser.parse_args()
if args.output_file is None:
args.output_file = f"{args.post_id}.csv"
asyncio.run(main(args))