Skip to content

Commit

Permalink
Merge pull request #609 from GraphScope/remove_glove
Browse files Browse the repository at this point in the history
fix: remove GloveEmbedding implementation in clustering
  • Loading branch information
longbinlai authored Nov 29, 2024
2 parents 81e6ae1 + f937084 commit 94a7d65
Show file tree
Hide file tree
Showing 8 changed files with 0 additions and 54 deletions.
4 changes: 0 additions & 4 deletions python/graphy/apps/demo_app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import torchtext

torchtext.disable_torchtext_deprecation_warning()

from workflow import SurveyPaperReading, ThreadPoolWorkflowExecutor
from graph.nodes.paper_reading_nodes import ProgressInfo
from config import WF_UPLOADS_DIR, WF_OUTPUT_DIR, WF_DATA_DIR, WF_VECTDB_DIR
Expand Down
2 changes: 0 additions & 2 deletions python/graphy/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
TextEmbedding,
DefaultEmbedding,
TfidfEmbedding,
GloveEmbedding,
SentenceTransformerEmbedding,
)

Expand All @@ -21,7 +20,6 @@
"TextEmbedding",
"DefaultEmbedding",
"TfidfEmbedding",
"GloveEmbedding",
"SentenceTransformerEmbedding",
]

Expand Down
31 changes: 0 additions & 31 deletions python/graphy/models/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List
from chromadb.utils import embedding_functions
from sentence_transformers import SentenceTransformer
from torchtext.vocab import GloVe
from sklearn.feature_extraction.text import TfidfVectorizer

import numpy as np
Expand Down Expand Up @@ -57,36 +56,6 @@ def get_name(self):
return "TF-IDF"


class GloveEmbedding(TextEmbedding):
def __init__(self):
# TODO: the parameters can be configurable if necessary
self.embeddings = GloVe(name="6B", dim=100)
self.max_length = 100
self.embedding_dim = 100

def embed(self, text_data: List[str]):
def sentence_embedding(sentence):
words = sentence.split()
num_words = min(len(words), self.max_length)
embedding_sentence = np.zeros((self.max_length, self.embedding_dim))
for i in range(num_words):
word = words[i]
if word in self.embeddings.stoi:
embedding_sentence[i] = self.embeddings.vectors[
self.embeddings.stoi[word]
]
return embedding_sentence.flatten()

return np.vstack([sentence_embedding(data) for data in text_data])

def chroma_embedding_model(self):
# TODO
return None

def get_name(self):
return "GloVe"


class SentenceTransformerEmbedding(TextEmbedding):
def __init__(self, embedding_model_name: str = ""):
if not embedding_model_name:
Expand Down
4 changes: 0 additions & 4 deletions python/graphy/paper_scrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import torchtext

torchtext.disable_torchtext_deprecation_warning()

import ray

from workflow import ThreadPoolWorkflowExecutor, SurveyPaperReading
Expand Down
1 change: 0 additions & 1 deletion python/graphy/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ spacy==3.7.4
tiktoken>=0.8.0
tools>=0.1.9
torch==2.3.0
torchtext==0.18.0
transformers==4.41
webdriver-manager>=4.0
Werkzeug>=3.0.3
Expand Down
4 changes: 0 additions & 4 deletions python/graphy/tests/workflow/inspector_navigator_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import torchtext

torchtext.disable_torchtext_deprecation_warning()

import pytest
from models import DEFAULT_LLM_MODEL_CONFIG
from workflow import ThreadPoolWorkflowExecutor
Expand Down
4 changes: 0 additions & 4 deletions python/graphy/tests/workflow/paper_inspector_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import torchtext

torchtext.disable_torchtext_deprecation_warning()

import pytest
from unittest.mock import MagicMock, create_autospec
from graph import BaseGraph
Expand Down
4 changes: 0 additions & 4 deletions python/graphy/utils/text_clustering.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import torchtext

torchtext.disable_torchtext_deprecation_warning()

import json
import os
import numpy as np
Expand Down

0 comments on commit 94a7d65

Please sign in to comment.