From 2bc3dcf05abeee4a47a489b6dde3f49c186e09c7 Mon Sep 17 00:00:00 2001 From: Santiago Medina Date: Wed, 13 Mar 2024 22:27:40 -0700 Subject: [PATCH] compiler error --- orca-core/src/llm/openai.rs | 21 +++++++++++++++++++-- orca-core/src/llm/quantized.rs | 8 ++++---- orca-models/src/mistral.rs | 5 ++--- orca-models/src/quantized.rs | 6 +++--- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/orca-core/src/llm/openai.rs b/orca-core/src/llm/openai.rs index 96ac396..5f2ecaf 100644 --- a/orca-core/src/llm/openai.rs +++ b/orca-core/src/llm/openai.rs @@ -46,6 +46,21 @@ pub struct Response { choices: Vec, } +#[derive(Serialize, Deserialize, Debug)] +pub struct QuotaError { + message: String, + #[serde(rename = "type")] + _type: String, + param: String, + code: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum OpenAIResponse { + Response(Response), + QuotaError(QuotaError), +} + #[derive(Serialize, Deserialize, Debug, Default, Clone)] pub struct OpenAIEmbeddingResponse { object: String, @@ -292,8 +307,10 @@ impl LLM for OpenAI { let messages = prompt.to_chat()?; let req = self.generate_request(messages.to_vec_ref())?; let res = self.client.execute(req).await?; - let res = res.json::().await?; - Ok(res.into()) + match res.json::().await? { + OpenAIResponse::Response(response) => Ok(response.into()), + OpenAIResponse::QuotaError(e) => Err(anyhow::anyhow!("Quota error: {}", e.message)), + } } } diff --git a/orca-core/src/llm/quantized.rs b/orca-core/src/llm/quantized.rs index 6e8c909..3f75a58 100644 --- a/orca-core/src/llm/quantized.rs +++ b/orca-core/src/llm/quantized.rs @@ -198,7 +198,7 @@ impl Quantized { let mut total_size_in_bytes = 0; for (_, tensor) in model.tensor_infos.iter() { let elem_count = tensor.shape.elem_count(); - total_size_in_bytes += elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); + total_size_in_bytes += elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size(); } log::info!( "loaded {:?} tensors ({}) in {:.2}s", @@ -206,14 +206,14 @@ impl Quantized { &format_size(total_size_in_bytes), start.elapsed().as_secs_f32(), ); - Some(ModelWeights::from_gguf(model, &mut file, &Device::Cpu)?) + Some(ModelWeights::from_gguf(model, &mut file)?) } Some("ggml" | "bin") | Some(_) | None => { - let model = ggml_file::Content::read(&mut file, &Device::Cpu)?; + let model = ggml_file::Content::read(&mut file)?; let mut total_size_in_bytes = 0; for (_, tensor) in model.tensors.iter() { let elem_count = tensor.shape().elem_count(); - total_size_in_bytes += elem_count * tensor.dtype().type_size() / tensor.dtype().block_size(); + total_size_in_bytes += elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size(); } log::info!( "loaded {:?} tensors ({}) in {:.2}s", diff --git a/orca-models/src/mistral.rs b/orca-models/src/mistral.rs index 14db6a7..ab5520b 100644 --- a/orca-models/src/mistral.rs +++ b/orca-models/src/mistral.rs @@ -1,4 +1,3 @@ -use candle::Device; use crate::utils::text_generation::{Model, TextGeneration}; use candle_transformers::models::mistral; use candle_transformers::models::quantized_mistral; @@ -73,7 +72,7 @@ impl Mistral { P: AsRef, { let cfg = mistral::Config::config_7b_v0_1(config.flash_attn); - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(weights, &Device::Cpu)?; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(weights)?; let model = quantized_mistral::Model::new(&cfg, vb)?; let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(|m| anyhow::anyhow!(m))?; Ok(Self { @@ -89,7 +88,7 @@ impl Mistral { pub fn from_stream(weights: Vec, tokenizer: Vec, config: Config) -> anyhow::Result { let cfg = mistral::Config::config_7b_v0_1(config.flash_attn); - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights, &Device::Cpu)?; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?; let model = quantized_mistral::Model::new(&cfg, vb)?; let tokenizer = tokenizers::Tokenizer::from_bytes(tokenizer).map_err(|m| anyhow::anyhow!(m))?; Ok(Self { diff --git a/orca-models/src/quantized.rs b/orca-models/src/quantized.rs index 6ce1552..e6ce59d 100644 --- a/orca-models/src/quantized.rs +++ b/orca-models/src/quantized.rs @@ -3,8 +3,8 @@ // #![allow(unused_variables)] // #![allow(unused_imports)] -use candle::Device; use candle::quantized::{ggml_file, gguf_file}; +use candle::Device; use candle_transformers::models::quantized_llama::ModelWeights; use crate::utils::text_generation::{Model, TextGeneration}; @@ -65,7 +65,7 @@ impl Quantized { pub fn from_gguf_stream(model: Vec, tokenizer: Vec, config: Config) -> anyhow::Result { let mut model_reader = std::io::Cursor::new(model); let model_content = gguf_file::Content::read(&mut model_reader)?; - let model = ModelWeights::from_gguf(model_content, &mut model_reader, &Device::Cpu)?; + let model = ModelWeights::from_gguf(model_content, &mut model_reader)?; let tokenizer = tokenizers::Tokenizer::from_bytes(tokenizer).map_err(|m| anyhow::anyhow!(m))?; Ok(Self { model, @@ -80,7 +80,7 @@ impl Quantized { pub fn from_ggml_stream(model: Vec, tokenizer: Vec, config: Config) -> anyhow::Result { let mut model_reader = std::io::Cursor::new(model); - let model_content = ggml_file::Content::read(&mut model_reader, &Device::Cpu)?; + let model_content = ggml_file::Content::read(&mut model_reader)?; let model = ModelWeights::from_ggml(model_content, 1)?; let tokenizer = tokenizers::Tokenizer::from_bytes(tokenizer).map_err(|m| anyhow::anyhow!(m))?; Ok(Self {