Skip to content

Commit

Permalink
lda modeling
Browse files Browse the repository at this point in the history
  • Loading branch information
jgibson517 committed May 26, 2024
1 parent bbed9af commit 46d1187
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 19 deletions.
6 changes: 3 additions & 3 deletions civiclens/nlp/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from civiclens.nlp.comments import get_doc_comments, rep_comment_analysis
from civiclens.nlp.models import sentence_transformer, sentiment_pipeline
from civiclens.nlp.tools import RepComments, sentiment_analysis
from civiclens.nlp.topics import FlanLabeler, HDAModel, topic_comment_analysis
from civiclens.nlp.topics import FlanLabeler, TopicModel, topic_comment_analysis
from civiclens.utils.database_access import Database, pull_data, upload_comments


Expand Down Expand Up @@ -142,7 +142,7 @@ def docs_have_titles():
)

# topic modeling
topic_model = HDAModel()
topic_model = TopicModel()
comment_data = topic_comment_analysis(
comment_data,
model=topic_model,
Expand All @@ -151,4 +151,4 @@ def docs_have_titles():
)

logger.info(f"Proccessed document: {doc_id}")
# upload_comments(Database(), comment_data)
upload_comments(Database(), comment_data)
57 changes: 41 additions & 16 deletions civiclens/nlp/topics.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import pickle
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Callable

import gensim.corpora as corpora
from gensim.corpora import Dictionary
from gensim.models import HdpModel, Phrases
from gensim.models import LdaModel, Phrases
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from textblob import TextBlob
from transformers import pipeline

from civiclens.nlp.models import title_model, title_tokenizer
from civiclens.nlp.tools import Comment, RepComments
from civiclens.utils.text import clean_text, regex_tokenize
from civiclens.utils.text import clean_text


def stopwords(model_path: Path) -> set[str]:
Expand All @@ -33,18 +33,20 @@ def stopwords(model_path: Path) -> set[str]:
return stop_words


class HDAModel:
class TopicModel:
"""
Peforms LDA topic modeling
"""

def __init__(self):
self.model = None
self.tokenizer = partial(regex_tokenize, pattern=r"\W+")
# self.tokenizer = partial(regex_tokenize, pattern=r"\W+")
self.stop_words = stopwords(
Path(__file__).resolve().parent / "saved_models/stop_words.pickle"
Path(__file__).resolve().parent.parent
/ "utils/objects/custom_stopwords.pkl"
)
self.terms = None
self.pos_tags = {"NN", "NNS", "NNP", "NNPS"}

def _process_text(
self, comments: list[Comment]
Expand All @@ -55,16 +57,17 @@ def _process_text(
docs = []
document_ids = {}
for idx, comment in enumerate(comments):
docs.append(self.tokenizer(clean_text(comment.text).lower()))
docs.append(clean_text(comment.text).lower())
document_ids[idx] = comment.id

# remove numbers, 2 character tokens, and stop words
docs = [
[
token
for token in doc
for token, tag in TextBlob(doc).tags
if not token.isnumeric()
and len(token) > 2
and tag in self.pos_tags
and token not in self.stop_words
]
for doc in docs
Expand Down Expand Up @@ -97,7 +100,7 @@ def run_model(self, comments: list[Comment]):
docs, document_id = self._process_text(comments)
token_dict, corpus = self._create_corpus(docs)

hdp_model = HdpModel(corpus, token_dict)
hdp_model = LdaModel(corpus, id2word=token_dict, num_topics=15)
numeric_topics = self._find_best_topic(hdp_model, corpus)

comment_topics = {}
Expand All @@ -115,7 +118,7 @@ def run_model(self, comments: list[Comment]):
return comment_topics

def _find_best_topic(
self, model: HdpModel, corpus: list[tuple]
self, model: LdaModel, corpus: list[tuple]
) -> dict[int, int]:
"""
Computes most probable topic per document
Expand Down Expand Up @@ -187,9 +190,15 @@ def __init__(self) -> None:
federal policy. Ensure the label accurately encompasses the main theme
represented by all the input words.
Example:
Input words: ["healthcare", "insurance", "coverage", "affordable"]
Output label: "Affordable Healthcare Access"
Examples:
Input words: ["climate", "emissions", "renewable", "energy", "policy"]
Output label: "Climate Change and Renewable Energy Policy"
Input words: ["tax", "reform", "income", "brackets", "reduction"]
Output label: "Income Tax Reform and Reduction"
Input words: ["immigration", "policy", "border", "security", "visas"]
Output label: "Immigration Policy and Border Security"
Now, generate a topic label for the following list of words:
Expand All @@ -206,15 +215,30 @@ def __init__(self) -> None:
self.hf_pipeline = HuggingFacePipeline(pipeline=self.pipe)
self.parse = StrOutputParser()

def _clean_ouput(self, text: str) -> tuple[str]:
"""
Converts LLM output formatted as "Output label: red, blue" into list
of labels, ["Red", "Blue"]
"""
label_text = text.split(": ")[-1]
label_set = set(label_text.split(", "))

return tuple(label.title() for label in label_set)

def generate_label(self, summary, terms) -> str:
"""
Creates label for list of topic terms using FLAN
"""
if summary:
prompt = PromptTemplate.from_template(self.summary_template)
chain = prompt | self.hf_pipeline | self.parse
return chain.invoke({"summary": summary, "words": terms})
return self._clean_ouput(
chain.invoke({"summary": summary, "words": terms})
)

prompt = PromptTemplate.from_template(self.no_summary_template)
chain = prompt | self.hf_pipeline | self.parse
return chain.invoke({"words": terms})
return self._clean_ouput(chain.invoke({"words": terms}))


def label_topics(
Expand All @@ -239,7 +263,7 @@ def label_topics(

def topic_comment_analysis(
comment_data: RepComments,
model: HDAModel = None,
model: TopicModel = None,
labeler: FlanLabeler = None,
sentiment_analyzer: Callable = None,
) -> RepComments:
Expand Down Expand Up @@ -269,6 +293,7 @@ def topic_comment_analysis(

comment_topics = model.run_model(comments)
topic_terms = model.get_terms()
# pprint(topic_terms)
topic_labels = label_topics(topic_terms, comment_data.summary, labeler)

# filter out non_rep comments
Expand Down
Binary file added civiclens/utils/objects/custom_stopwords.pkl
Binary file not shown.
File renamed without changes.

0 comments on commit 46d1187

Please sign in to comment.