Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/jafioti/luminal
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed May 2, 2024
2 parents 07b7bb6 + 9e048ff commit dc01280
Show file tree
Hide file tree
Showing 11 changed files with 1,344 additions and 4 deletions.
5 changes: 1 addition & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,4 @@ members = [
"crates/luminal_nn",
"crates/luminal_training",
]
exclude = [
"crates/luminal_metal",
"crates/luminal_cuda",
]
exclude = ["crates/luminal_metal", "crates/luminal_cuda"]
17 changes: 17 additions & 0 deletions examples/llama_server/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Generated by Cargo
# will have compiled files and executables
debug/
target/

# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock

# These are backup files generated by rustfmt
**/*.rs.bk

# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
setup/*.gguf
setup/*.json
.vscode
39 changes: 39 additions & 0 deletions examples/llama_server/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
[package]
name = "llama_server"
version = "0.1.0"
edition = "2021"

[features]
metal = ["dep:luminal_metal", "dep:metal-rs"]
cuda = ["dep:luminal_cuda", "dep:luminal_cudarc"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
luminal = { path = "../.." }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cpu = { path = "../../crates/luminal_cpu" }
luminal_metal = { path = "../../crates/luminal_metal", optional = true }
luminal_cuda = { path = "../../crates/luminal_cuda", optional = true }
clap = { version = "4.4.18", features = ["derive"] }
byteorder = "1.5.0"
memmap2 = "0.9.4"
metal-rs = { version = "0.27.0", package = "metal", features = [
"mps",
], optional = true }
colored = "2.1.0"
itertools = "0.12.1"
luminal_cudarc = { version = "0.10.0", features = [
"cublas",
"f16",
], optional = true }
tokenizers = "0.15.2"
axum = "0.7.5"
serde = { version = "1.0.199", features = ["derive"] }
tokio = { version = "1.37.0", features = ["rt-multi-thread"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
chrono = "0.4.38"
uuid = { version = "1.8.0", features = ["v4"] }
async-trait = "0.1.80"
serde_json = "1.0.116"
7 changes: 7 additions & 0 deletions examples/llama_server/setup/setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/usr/bin/env bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )

echo "Downloading Model and Tokenizer..."
curl --location https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json?download=true --output $SCRIPT_DIR/tokenizer.json
curl --location https://huggingface.co/QuantFactory/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct.Q8_0.gguf?download=true --output $SCRIPT_DIR/llama3-8b.gguf
echo "Done!"
124 changes: 124 additions & 0 deletions examples/llama_server/src/chat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
use chrono::Utc;
use serde::{Deserialize, Serialize};
use uuid::Uuid;

// src/chat.rs
use crate::llama::setup::Model; // Import the Model struct

#[derive(Deserialize)]
pub struct ChatRequest {
pub messages: Vec<Message>,
}

#[derive(Deserialize, Serialize)]
pub struct Message {
pub role: Role,
pub content: String,
}

#[derive(Deserialize, Serialize, PartialEq, Eq, Debug)]
pub enum Role {
#[serde(rename = "system")]
System,
#[serde(rename = "assistant")]
Assistant,
#[serde(rename = "user")]
User,
}

#[derive(Serialize)]
pub struct ChatResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
}

#[derive(Serialize)]
pub struct Choice {
pub index: usize,
pub message: Message,
pub finish_reason: String,
}

#[derive(Serialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}

pub fn apply_chat_template(messages: Vec<Message>) -> String {
let mut output = "<|begin_of_text|>".to_string();
for message in messages {
output += "<|start_header_id|>";
if message.role == Role::Assistant {
output += "assistant"
} else if message.role == Role::User {
output += "user"
} else if message.role == Role::System {
output += "system"
}
output += "<|end_header_id|>";
output += "\n";
output += message.content.as_str();
output += "<|eot_id|>";
}
output
}

/// Respond to chat request
pub async fn respond_chat_request(model: &mut Model, request: ChatRequest) -> ChatResponse {
let created = Utc::now().timestamp();
let raw_uuid = Uuid::new_v4();
let id = format!("chatcmpl-{}", raw_uuid);

let mut prompt = apply_chat_template(request.messages);
prompt += "<|start_header_id|>assistant<|end_header_id|>\n";
// let prompt = "<|begin_of_text|>Here is an implementation of merge sort:
//
// ```python"
// .to_string();
let prompt_tokens = model.tokenizer.encode(prompt.clone(), false).unwrap();
let prompt_tokens = prompt_tokens.get_ids();
let prompt_tokens = prompt_tokens.len();
println!("Prompt: {:?}", prompt);

// Generate
let mut completion = vec![];
model.generate(&prompt, |token| {
const EOS_TOKEN: u32 = 128009;
if token != EOS_TOKEN {
completion.push(token);
}
true
});
// For now, just clear the cache each time
model.clear_cache();
let completion_str = model.tokenizer.decode(&completion, false).unwrap();
let completion_tokens = completion.len();

let response = ChatResponse {
id,
created,
object: "chat.completion".to_string(),
model: "meta-llama/Meta-Llama-3-70B-Instruct".to_string(),
choices: vec![Choice {
index: 0,
message: Message {
role: Role::Assistant,
content: completion_str,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
total_tokens: prompt_tokens + completion_tokens,
prompt_tokens,
completion_tokens,
},
};

response
}
Loading

0 comments on commit dc01280

Please sign in to comment.