Skip to content

Commit

Permalink
Better library usage
Browse files Browse the repository at this point in the history
  • Loading branch information
neubig committed Mar 5, 2024
1 parent 9937be3 commit a7736f8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
4 changes: 2 additions & 2 deletions examples/community_lm/community_lm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"from community_lm_constants import politician_feelings, groups_feelings, anes_df\n",
"from community_lm_utils import generate_community_opinion, compute_group_stance\n",
"\n",
"device = 'mps' # change to 'mps' if you have a mac, or 'cuda:0' if you have an NVIDIA GPU "
"device = 'cpu' # change to 'mps' if you have a mac, or 'cuda:0' if you have an NVIDIA GPU "
]
},
{
Expand Down Expand Up @@ -384,7 +384,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down
7 changes: 6 additions & 1 deletion llments/lm/base/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from llments.lm.lm import LanguageModel
from transformers import pipeline, set_seed, TextGenerationPipeline


class HuggingFaceLM(LanguageModel):
Expand All @@ -14,6 +13,12 @@ def __init__(
model: The name of the model.
device: The device to run the model on.
"""
try:
from transformers import pipeline, set_seed, TextGenerationPipeline
except ImportError:
raise ImportError(
"You need to install the `transformers` package to use this class."
)
self.text_generator: TextGenerationPipeline = pipeline(
"text-generation", model=model, device=device
)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ classifiers = [
]
dependencies = [
"pandas",
"tqdm",
]
dynamic = ["version"]

Expand Down

0 comments on commit a7736f8

Please sign in to comment.