-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/jafioti/luminal
- Loading branch information
Showing
11 changed files
with
1,344 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Generated by Cargo | ||
# will have compiled files and executables | ||
debug/ | ||
target/ | ||
|
||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries | ||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html | ||
Cargo.lock | ||
|
||
# These are backup files generated by rustfmt | ||
**/*.rs.bk | ||
|
||
# MSVC Windows builds of rustc generate these, which store debugging information | ||
*.pdb | ||
setup/*.gguf | ||
setup/*.json | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
[package] | ||
name = "llama_server" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[features] | ||
metal = ["dep:luminal_metal", "dep:metal-rs"] | ||
cuda = ["dep:luminal_cuda", "dep:luminal_cudarc"] | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
luminal = { path = "../.." } | ||
luminal_nn = { path = "../../crates/luminal_nn" } | ||
luminal_cpu = { path = "../../crates/luminal_cpu" } | ||
luminal_metal = { path = "../../crates/luminal_metal", optional = true } | ||
luminal_cuda = { path = "../../crates/luminal_cuda", optional = true } | ||
clap = { version = "4.4.18", features = ["derive"] } | ||
byteorder = "1.5.0" | ||
memmap2 = "0.9.4" | ||
metal-rs = { version = "0.27.0", package = "metal", features = [ | ||
"mps", | ||
], optional = true } | ||
colored = "2.1.0" | ||
itertools = "0.12.1" | ||
luminal_cudarc = { version = "0.10.0", features = [ | ||
"cublas", | ||
"f16", | ||
], optional = true } | ||
tokenizers = "0.15.2" | ||
axum = "0.7.5" | ||
serde = { version = "1.0.199", features = ["derive"] } | ||
tokio = { version = "1.37.0", features = ["rt-multi-thread"] } | ||
tracing = "0.1.40" | ||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } | ||
chrono = "0.4.38" | ||
uuid = { version = "1.8.0", features = ["v4"] } | ||
async-trait = "0.1.80" | ||
serde_json = "1.0.116" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/usr/bin/env bash | ||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) | ||
|
||
echo "Downloading Model and Tokenizer..." | ||
curl --location https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json?download=true --output $SCRIPT_DIR/tokenizer.json | ||
curl --location https://huggingface.co/QuantFactory/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct.Q8_0.gguf?download=true --output $SCRIPT_DIR/llama3-8b.gguf | ||
echo "Done!" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
use chrono::Utc; | ||
use serde::{Deserialize, Serialize}; | ||
use uuid::Uuid; | ||
|
||
// src/chat.rs | ||
use crate::llama::setup::Model; // Import the Model struct | ||
|
||
#[derive(Deserialize)] | ||
pub struct ChatRequest { | ||
pub messages: Vec<Message>, | ||
} | ||
|
||
#[derive(Deserialize, Serialize)] | ||
pub struct Message { | ||
pub role: Role, | ||
pub content: String, | ||
} | ||
|
||
#[derive(Deserialize, Serialize, PartialEq, Eq, Debug)] | ||
pub enum Role { | ||
#[serde(rename = "system")] | ||
System, | ||
#[serde(rename = "assistant")] | ||
Assistant, | ||
#[serde(rename = "user")] | ||
User, | ||
} | ||
|
||
#[derive(Serialize)] | ||
pub struct ChatResponse { | ||
pub id: String, | ||
pub object: String, | ||
pub created: i64, | ||
pub model: String, | ||
pub choices: Vec<Choice>, | ||
pub usage: Usage, | ||
} | ||
|
||
#[derive(Serialize)] | ||
pub struct Choice { | ||
pub index: usize, | ||
pub message: Message, | ||
pub finish_reason: String, | ||
} | ||
|
||
#[derive(Serialize)] | ||
pub struct Usage { | ||
pub prompt_tokens: usize, | ||
pub completion_tokens: usize, | ||
pub total_tokens: usize, | ||
} | ||
|
||
pub fn apply_chat_template(messages: Vec<Message>) -> String { | ||
let mut output = "<|begin_of_text|>".to_string(); | ||
for message in messages { | ||
output += "<|start_header_id|>"; | ||
if message.role == Role::Assistant { | ||
output += "assistant" | ||
} else if message.role == Role::User { | ||
output += "user" | ||
} else if message.role == Role::System { | ||
output += "system" | ||
} | ||
output += "<|end_header_id|>"; | ||
output += "\n"; | ||
output += message.content.as_str(); | ||
output += "<|eot_id|>"; | ||
} | ||
output | ||
} | ||
|
||
/// Respond to chat request | ||
pub async fn respond_chat_request(model: &mut Model, request: ChatRequest) -> ChatResponse { | ||
let created = Utc::now().timestamp(); | ||
let raw_uuid = Uuid::new_v4(); | ||
let id = format!("chatcmpl-{}", raw_uuid); | ||
|
||
let mut prompt = apply_chat_template(request.messages); | ||
prompt += "<|start_header_id|>assistant<|end_header_id|>\n"; | ||
// let prompt = "<|begin_of_text|>Here is an implementation of merge sort: | ||
// | ||
// ```python" | ||
// .to_string(); | ||
let prompt_tokens = model.tokenizer.encode(prompt.clone(), false).unwrap(); | ||
let prompt_tokens = prompt_tokens.get_ids(); | ||
let prompt_tokens = prompt_tokens.len(); | ||
println!("Prompt: {:?}", prompt); | ||
|
||
// Generate | ||
let mut completion = vec![]; | ||
model.generate(&prompt, |token| { | ||
const EOS_TOKEN: u32 = 128009; | ||
if token != EOS_TOKEN { | ||
completion.push(token); | ||
} | ||
true | ||
}); | ||
// For now, just clear the cache each time | ||
model.clear_cache(); | ||
let completion_str = model.tokenizer.decode(&completion, false).unwrap(); | ||
let completion_tokens = completion.len(); | ||
|
||
let response = ChatResponse { | ||
id, | ||
created, | ||
object: "chat.completion".to_string(), | ||
model: "meta-llama/Meta-Llama-3-70B-Instruct".to_string(), | ||
choices: vec![Choice { | ||
index: 0, | ||
message: Message { | ||
role: Role::Assistant, | ||
content: completion_str, | ||
}, | ||
finish_reason: "stop".to_string(), | ||
}], | ||
usage: Usage { | ||
total_tokens: prompt_tokens + completion_tokens, | ||
prompt_tokens, | ||
completion_tokens, | ||
}, | ||
}; | ||
|
||
response | ||
} |
Oops, something went wrong.