Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
jncraton committed Apr 18, 2024
1 parent 0757cc0 commit 042f81d
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 25 deletions.
13 changes: 9 additions & 4 deletions languagemodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ def complete(prompt: str) -> str:
"""

result = generate(
["Write a sentence"], prefix=prompt,
max_tokens=config["max_tokens"], temperature=0.7, topk=40
["Write a sentence"],
prefix=prompt,
max_tokens=config["max_tokens"],
temperature=0.7,
topk=40,
)[0]

if result.startswith(prompt):
Expand Down Expand Up @@ -267,8 +270,10 @@ def classify(doc: str, label1: str, label2: str) -> str:
'negative'
"""

return do(f"Classify as {label1} or {label2}: {doc}\n\nClassification:",
choices=[label1, label2])
return do(
f"Classify as {label1} or {label2}: {doc}\n\nClassification:",
choices=[label1, label2],
)


def store_doc(doc: str, name: str = "") -> None:
Expand Down
27 changes: 11 additions & 16 deletions languagemodels/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class ModelFilterException(Exception):
"architecture": "decoder-only-transformer",
"license": "apache-2.0",
"prompt_fmt": (
"GPT4 Correct User: {instruction}<|end_of_turn|>"
"GPT4 Correct Assistant:"
"GPT4 Correct User: {instruction}<|end_of_turn|>" "GPT4 Correct Assistant:"
),
},
{
Expand All @@ -48,9 +47,10 @@ class ModelFilterException(Exception):
"backend": "ct2",
"architecture": "decoder-only-transformer",
"license": "llama3",
"prompt_fmt": ("<|start_header_id|>user<|end_header_id|>\n\n"
"{instruction}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
"prompt_fmt": (
"<|start_header_id|>user<|end_header_id|>\n\n"
"{instruction}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
),
},
{
Expand All @@ -63,8 +63,7 @@ class ModelFilterException(Exception):
"architecture": "decoder-only-transformer",
"license": "apache-2.0",
"prompt_fmt": (
"GPT4 Correct User: {instruction}<|end_of_turn|>"
"GPT4 Correct Assistant:"
"GPT4 Correct User: {instruction}<|end_of_turn|>" "GPT4 Correct Assistant:"
),
},
{
Expand Down Expand Up @@ -204,9 +203,7 @@ class ModelFilterException(Exception):
"backend": "ct2",
"architecture": "encoder-decoder-transformer",
"license": "apache-2.0",
"prompt_fmt": (
"Instruction: Be helpful. <USER> {instruction}"
),
"prompt_fmt": ("Instruction: Be helpful. <USER> {instruction}"),
},
{
"name": "LaMini-Flan-T5-77M",
Expand Down Expand Up @@ -249,8 +246,8 @@ class ModelFilterException(Exception):
"architecture": "decoder-only-transformer",
"license": "gemma-terms-of-use",
"prompt_fmt": "<bos><start_of_turn>user\n"
"{instruction}<end_of_turn>\n"
"<start_of_turn>model",
"{instruction}<end_of_turn>\n"
"<start_of_turn>model",
},
{
"name": "h2o-danube2-1.8b-chat",
Expand Down Expand Up @@ -324,9 +321,7 @@ class ModelFilterException(Exception):
"backend": "ct2",
"architecture": "decoder-only-transformer",
"license": "mit",
"prompt_fmt": (
"<|user|>{instruction}<|assistant|>"
),
"prompt_fmt": ("<|user|>{instruction}<|assistant|>"),
},
{
"name": "codet5p-770m-py",
Expand Down Expand Up @@ -553,7 +548,7 @@ def convert_to_gb(space):

multipliers = {
"g": 1.0,
"m": 2 ** -10,
"m": 2**-10,
}

space = space.lower()
Expand Down
4 changes: 2 additions & 2 deletions languagemodels/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,13 @@ def rank_instruct(inputs, targets):
in_tok += toks * len(targets)

if "Generator" in str(type(model)):
scores = model.score_batch([i+t for i, t in zip(in_tok, targ_tok)])
scores = model.score_batch([i + t for i, t in zip(in_tok, targ_tok)])
else:
scores = model.score_batch(in_tok, target=targ_tok)

ret = []
for i in range(0, len(inputs) * len(targets), len(targets)):
logprobs = [sum(r.log_probs) for r in scores[i:i+len(targets)]]
logprobs = [sum(r.log_probs) for r in scores[i : i + len(targets)]]
results = sorted(zip(targets, logprobs), key=lambda r: -r[1])
ret.append([r[0] for r in results])

Expand Down
11 changes: 8 additions & 3 deletions languagemodels/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,16 @@ def initialize_tokenizer(model_type, model_name):
def initialize_model(model_type, model_name):
model_info = get_model_info(model_type)

path = snapshot_download(model_info["path"], max_workers=1,
allow_patterns=["*.bin", "*.txt", "*.json"])
path = snapshot_download(
model_info["path"], max_workers=1, allow_patterns=["*.bin", "*.txt", "*.json"]
)

if model_info["architecture"] == "encoder-only-transformer":
return ctranslate2.Encoder(path, "cpu", compute_type="int8", )
return ctranslate2.Encoder(
path,
"cpu",
compute_type="int8",
)
elif model_info["architecture"] == "decoder-only-transformer":
return ctranslate2.Generator(path, config["device"], compute_type="int8")
else:
Expand Down
1 change: 1 addition & 0 deletions test/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import psutil


def mem_used_gb():
process = psutil.Process(os.getpid())
bytes = process.memory_info().rss
Expand Down

0 comments on commit 042f81d

Please sign in to comment.