Skip to content

Commit

Permalink
from stream mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
santiagomed committed Nov 9, 2023
1 parent 07ae2b3 commit 51e1e25
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 12 deletions.
2 changes: 1 addition & 1 deletion models/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ edition = "2021"
candle = { git = "https://github.com/huggingface/candle", package = "candle-core" }
candle-transformers = { git = "https://github.com/huggingface/candle" }
candle-nn = { git = "https://github.com/huggingface/candle" }
tokenizers = { version = "0.13.4", features = ["unstable_wasm"] }
tokenizers = { version = "0.13.4", default-features = false, features = ["unstable_wasm"] }
serde_json = "1.0.99"
anyhow = "1"
serde = { version = "1.0.171", features = ["derive"] }
Expand Down
1 change: 1 addition & 0 deletions models/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod bert;
pub mod common;
pub mod mistral;
pub mod quantized;
pub(crate) mod utils;
25 changes: 17 additions & 8 deletions models/src/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,30 @@ impl Default for Config {
}

impl Mistral {
fn tokenizer<P>(tokenizer: P) -> anyhow::Result<tokenizers::Tokenizer>
pub fn from_path<P>(weights: P, tokenizer: P, config: Config) -> anyhow::Result<Self>
where
P: AsRef<std::path::Path>,
{
tokenizers::Tokenizer::from_file(tokenizer).map_err(|m| anyhow::anyhow!(m))
let cfg = mistral::Config::config_7b_v0_1(config.flash_attn);
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(weights)?;
let model = quantized_mistral::Model::new(&cfg, vb)?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(|m| anyhow::anyhow!(m))?;
Ok(Self {
model,
tokenizer,
temperature: config.temperature,
top_p: config.top_p,
seed: config.seed,
repeat_penalty: config.repeat_penalty,
repeat_last_n: config.repeat_last_n,
})
}

pub fn from_path<P>(weights: P, tokenizer: P, config: Config) -> anyhow::Result<Mistral>
where
P: AsRef<std::path::Path>,
{
pub fn from_stream(weights: Vec<u8>, tokenizer: Vec<u8>, config: Config) -> anyhow::Result<Self> {
let cfg = mistral::Config::config_7b_v0_1(config.flash_attn);
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(weights)?;
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?;
let model = quantized_mistral::Model::new(&cfg, vb)?;
let tokenizer = Mistral::tokenizer(tokenizer)?;
let tokenizer = tokenizers::Tokenizer::from_bytes(tokenizer).map_err(|m| anyhow::anyhow!(m))?;
Ok(Self {
model,
tokenizer,
Expand Down
95 changes: 95 additions & 0 deletions models/src/quantized.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//! Wip for quantized models.
#![allow(dead_code)]
#![allow(unused_variables)]
#![allow(unused_imports)]

use candle::quantized::{ggml_file, gguf_file};
use candle_transformers::models::quantized_llama::ModelWeights;

use crate::utils::text_generation::TextGeneration;

pub struct Config {
/// The temperature used to generate samples, use 0 for greedy sampling.
pub temperature: f64,

/// Nucleus sampling probability cutoff.
pub top_p: Option<f64>,

/// The seed to use when generating random samples.
pub seed: u64,

/// Penalty to be applied for repeating tokens, 1. means no penalty.
pub repeat_penalty: f32,

/// The context size to consider for the repeat penalty.
pub repeat_last_n: usize,
}

impl Default for Config {
fn default() -> Self {
Self {
temperature: 1.0,
top_p: None,
seed: 42,
repeat_penalty: 1.0,
repeat_last_n: 1,
}
}
}

pub struct Quantized {
/// The model weights.
model: ModelWeights,

/// The tokenizer config.
tokenizer: tokenizers::Tokenizer,

/// The temperature used to generate samples, use 0 for greedy sampling.
temperature: f64,

/// Nucleus sampling probability cutoff.
top_p: Option<f64>,

/// The seed to use when generating random samples.
seed: u64,

/// Penalty to be applied for repeating tokens, 1. means no penalty.
repeat_penalty: f32,

/// The context size to consider for the repeat penalty.
repeat_last_n: usize,
}

impl Quantized {
pub fn from_gguf_stream(model: Vec<u8>, tokenizer: Vec<u8>, config: Config) -> anyhow::Result<Self> {
let mut model_reader = std::io::Cursor::new(model);
let model_content = gguf_file::Content::read(&mut model_reader)?;
let model = ModelWeights::from_gguf(model_content, &mut model_reader)?;
let tokenizer = tokenizers::Tokenizer::from_bytes(tokenizer).map_err(|m| anyhow::anyhow!(m))?;
Ok(Self {
model,
tokenizer,
temperature: config.temperature,
top_p: config.top_p,
seed: config.seed,
repeat_penalty: config.repeat_penalty,
repeat_last_n: config.repeat_last_n,
})
}

pub fn from_ggml_stream(model: Vec<u8>, tokenizer: Vec<u8>, config: Config) -> anyhow::Result<Self> {
let mut model_reader = std::io::Cursor::new(model);
let model_content = ggml_file::Content::read(&mut model_reader)?;
let model = ModelWeights::from_ggml(model_content, 1)?;
let tokenizer = tokenizers::Tokenizer::from_bytes(tokenizer).map_err(|m| anyhow::anyhow!(m))?;
Ok(Self {
model,
tokenizer,
temperature: config.temperature,
top_p: config.top_p,
seed: config.seed,
repeat_penalty: config.repeat_penalty,
repeat_last_n: config.repeat_last_n,
})
}
}
15 changes: 12 additions & 3 deletions models/src/utils/text_generation.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
use super::token_stream::TokenOutputStream;
use candle::{DType, Device, Tensor};
use candle_transformers::{generation::LogitsProcessor, models::quantized_mistral::Model};
use candle_transformers::{
generation::LogitsProcessor,
models::{quantized_llama::ModelWeights, quantized_mistral::Model as MistralModel},
};
use std::io::Write;

#[allow(unused)] // We might repurpose this to generate for multiple models.
pub enum Model {
Llama(ModelWeights),
Mistral(MistralModel),
}

pub struct TextGeneration {
model: Model,
model: MistralModel,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
Expand All @@ -15,7 +24,7 @@ pub struct TextGeneration {
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
pub fn new(
model: Model,
model: MistralModel,
tokenizer: tokenizers::Tokenizer,
seed: u64,
temp: Option<f64>,
Expand Down

0 comments on commit 51e1e25

Please sign in to comment.