-
Notifications
You must be signed in to change notification settings - Fork 896
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into feat/batch_generate
- Loading branch information
Showing
43 changed files
with
1,150 additions
and
690 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,9 @@ __pycache__/ | |
# C extensions | ||
*.so | ||
|
||
# Vim | ||
*.swp | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
# Copyright © 2023-2024 Apple Inc. | ||
|
||
__version__ = "0.18.2" | ||
__version__ = "0.19.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright © 2023-2024 Apple Inc. | ||
|
||
import argparse | ||
import json | ||
|
||
import mlx.core as mx | ||
|
||
from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache | ||
from .utils import load, stream_generate | ||
|
||
DEFAULT_TEMP = 0.0 | ||
DEFAULT_TOP_P = 1.0 | ||
DEFAULT_SEED = 0 | ||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" | ||
|
||
|
||
def setup_arg_parser(): | ||
"""Set up and return the argument parser.""" | ||
parser = argparse.ArgumentParser(description="Chat with an LLM") | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
help="The path to the local model directory or Hugging Face repo.", | ||
default=DEFAULT_MODEL, | ||
) | ||
parser.add_argument( | ||
"--adapter-path", | ||
type=str, | ||
help="Optional path for the trained adapter weights and config.", | ||
) | ||
parser.add_argument( | ||
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" | ||
) | ||
parser.add_argument( | ||
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" | ||
) | ||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") | ||
parser.add_argument( | ||
"--max-kv-size", | ||
type=int, | ||
help="Set the maximum key-value cache size", | ||
default=None, | ||
) | ||
return parser | ||
|
||
|
||
def main(): | ||
parser = setup_arg_parser() | ||
args = parser.parse_args() | ||
|
||
mx.random.seed(args.seed) | ||
|
||
model, tokenizer = load( | ||
args.model, | ||
adapter_path=args.adapter_path, | ||
tokenizer_config={"trust_remote_code": True}, | ||
) | ||
|
||
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.") | ||
prompt_cache = make_prompt_cache(model, args.max_kv_size) | ||
while True: | ||
query = input(">> ") | ||
if query == "q": | ||
break | ||
messages = [{"role": "user", "content": query}] | ||
prompt = tokenizer.apply_chat_template( | ||
messages, tokenize=False, add_generation_prompt=True | ||
) | ||
for response in stream_generate( | ||
model, | ||
tokenizer, | ||
prompt, | ||
temp=args.temp, | ||
top_p=args.top_p, | ||
prompt_cache=prompt_cache, | ||
): | ||
print(response, flush=True, end="") | ||
print() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright © 2024 Apple Inc. | ||
|
||
""" | ||
An example of a multi-turn chat with prompt caching. | ||
""" | ||
|
||
from mlx_lm import generate, load | ||
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache | ||
|
||
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit") | ||
|
||
# Make the initial prompt cache for the model | ||
prompt_cache = make_prompt_cache(model) | ||
|
||
# User turn | ||
prompt = "Hi my name is <Name>." | ||
messages = [{"role": "user", "content": prompt}] | ||
prompt = tokenizer.apply_chat_template( | ||
messages, tokenize=False, add_generation_prompt=True | ||
) | ||
|
||
# Assistant response | ||
response = generate( | ||
model, | ||
tokenizer, | ||
prompt=prompt, | ||
verbose=True, | ||
temp=0.0, | ||
prompt_cache=prompt_cache, | ||
) | ||
|
||
# User turn | ||
prompt = "What's my name?" | ||
messages = [{"role": "user", "content": prompt}] | ||
prompt = tokenizer.apply_chat_template( | ||
messages, tokenize=False, add_generation_prompt=True | ||
) | ||
|
||
# Assistant response | ||
response = generate( | ||
model, | ||
tokenizer, | ||
prompt=prompt, | ||
verbose=True, | ||
temp=0.0, | ||
prompt_cache=prompt_cache, | ||
) | ||
|
||
# Save the prompt cache to disk to reuse it at a later time | ||
save_prompt_cache("mistral_prompt.safetensors", prompt_cache) | ||
|
||
# Load the prompt cache from disk | ||
prompt_cache = load_prompt_cache("mistral_prompt.safetensors") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
# Copyright © 2024 Apple Inc. | ||
|
||
from mlx_lm import generate, load | ||
|
||
# Specify the checkpoint | ||
|
Oops, something went wrong.