Skip to content

Commit

Permalink
Merge pull request rustformers#416 from rustformers/rust-1.72-fixes
Browse files Browse the repository at this point in the history
Rust 1.72 fixes
  • Loading branch information
philpax authored Aug 27, 2023
2 parents 2f6ffd4 + 1c9efac commit 18b2a7d
Show file tree
Hide file tree
Showing 15 changed files with 71 additions and 62 deletions.
3 changes: 2 additions & 1 deletion crates/ggml/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ description = "Semi-idiomatic Rust bindings for the ggml library (from `ggml-sys
license = "MIT"

[dependencies]
thiserror = { workspace = true }
ggml-sys = { path = "sys", version = "0.2.0-dev" }

thiserror = { workspace = true }
memmap2 = { workspace = true }

[dev-dependencies]
Expand Down
13 changes: 11 additions & 2 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ impl PartialEq for ContextInner {
impl Eq for ContextInner {}
impl ContextInner {
pub(crate) fn new(ptr: *mut ggml_sys::ggml_context) -> Arc<Self> {
// This context can only be used from one thread at a time - hence why
// it doesn't implement `Send/Sync` - but higher-level abstractions may
// choose to layer their own abstractions that implement higher-level
// synchronization that can offer thread-safety guarantees. To ensure
// that we don't break those, we still use an `Arc` here.
// TODO: check if this is correct?
#[allow(clippy::arc_with_non_send_sync)]
Arc::new(Self {
ptr: NonNull::new(ptr).expect("Should not be null"),
offloaded_tensors: Default::default(),
Expand Down Expand Up @@ -118,7 +125,9 @@ impl PartialEq for ContextStorage {
impl Eq for ContextStorage {}

impl Context {
/// Creates a new [Context] with the given storage..
// See explanation in [`ContextInner::new`].
#[allow(clippy::arc_with_non_send_sync)]
/// Creates a new [Context] with the given storage.
pub fn new(storage: ContextStorage) -> Self {
let init_params = match &storage {
ContextStorage::Buffer(buffer) => sys::ggml_init_params {
Expand Down Expand Up @@ -296,7 +305,7 @@ impl Context {
self.new_tensor_raw(tensor)
}

/// Repeats the `a` tensor along the first dimension of the `b` tensor.
/// Repeats the `a` tensor along the first dimension of the `b` tensor.
pub fn op_repeat(&self, a: &Tensor, b: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_repeat(self.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
Expand Down
2 changes: 1 addition & 1 deletion crates/ggml/src/format/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ pub fn load<E: Error, R: BufRead + Seek>(
match container_type {
ContainerType::Ggml
| ContainerType::Ggmf(1)
| ContainerType::Ggjt(1 | 2 | 3)
| ContainerType::Ggjt(1..=3)
| ContainerType::Ggla(1) => {}
_ => return Err(LoadError::InvalidFormatVersion(container_type)),
}
Expand Down
12 changes: 8 additions & 4 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use tracing::{instrument, log};
use ggml::accelerator::metal::MetalContext;

use crate::{
mulf, util, InferenceParameters, Model, ModelParameters, OutputRequest, Prompt, TokenId,
TokenUtf8Buffer, TokenizationError,
mulf, util, InferenceParameters, Model, ModelContext, ModelParameters, OutputRequest, Prompt,
TokenId, TokenUtf8Buffer, TokenizationError,
};

// The size of a scratch buffer used for inference. This is used for temporary
Expand Down Expand Up @@ -148,6 +148,10 @@ impl InferenceSession {
ggml::accelerator::set_scratch_size(config.n_batch * 1024 * 1024);
}

// TODO: revisit this with `Rc`, maybe? We should be able to prove that the session
// context is only accessed from one thread at a time, but I've already spent enough
// time on this as-is.
#[allow(clippy::arc_with_non_send_sync)]
let session_ctx = Arc::new(ggml::Context::new_with_allocate(context_byte_size));

// Initialize key + value memory tensors
Expand Down Expand Up @@ -215,7 +219,7 @@ impl InferenceSession {
/// Compute a model (possibly building a graph in the provided closure when called for the first time and/or when parameters have)
pub fn compute<F>(
&mut self,
#[allow(unused_variables)] model_context: Arc<Context>,
#[allow(unused_variables)] model_context: ModelContext,
input_tokens: &[TokenId],
builder: F,
) -> GraphOutputs
Expand All @@ -242,7 +246,7 @@ impl InferenceSession {
#[cfg(feature = "metal")]
{
if let Some(ref mut metal_context) = self.metal_context {
metal_context.add_context(model_context);
metal_context.add_context(model_context.0);
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub use loader::{
};
pub use lora::{LoraAdapter, LoraParameters};
pub use memmap2::Mmap;
pub use model::{Hyperparameters, KnownModel, Model, ModelParameters, OutputRequest};
pub use model::{Hyperparameters, KnownModel, Model, ModelContext, ModelParameters, OutputRequest};
pub use quantize::{quantize, QuantizeError, QuantizeProgress};
pub use regex::Regex;
pub use tokenizer::{
Expand Down
32 changes: 12 additions & 20 deletions crates/llm-base/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use std::{
fs::File,
io::{BufRead, BufReader, Read, Seek, SeekFrom},
path::{Path, PathBuf},
sync::Arc,
};

use crate::{
util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelParameters, TokenId,
Tokenizer, TokenizerLoadError, TokenizerSource,
util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelContext, ModelParameters,
TokenId, Tokenizer, TokenizerLoadError, TokenizerSource,
};
pub use ggml::{format::FormatMagic, ContainerType};
use ggml::{
Expand Down Expand Up @@ -398,7 +399,7 @@ pub trait TensorLoader<E: std::error::Error> {
/// Gets a tensor from the loader.
fn load(&mut self, name: &str) -> Result<ggml::Tensor, E>;
/// Finish loading the model, returning the context.
fn finish(self) -> Context;
fn finish(self) -> ModelContext;
}

/// Load a GGML model from the `path` and configure it per the `params`. The status
Expand Down Expand Up @@ -653,12 +654,7 @@ impl TensorLoader<LoadError> for MmapCompatibleLoader<'_> {
path: Default::default(),
})?;

let mut main_context = FileContext::new(
&self.context,
&mut self.file,
&self.path,
self.context.storage().as_mmap(),
);
let mut main_context = FileContext::new(&self.context, &mut self.file, &self.path);

let mut tensor = main_context.get_tensor(info)?;

Expand All @@ -681,29 +677,25 @@ impl TensorLoader<LoadError> for MmapCompatibleLoader<'_> {
Ok(tensor)
}

fn finish(self) -> Context {
self.context
fn finish(self) -> ModelContext {
// We can ignore this warning as it's OK to share this particular
// context around, being that it is immutable.
#[allow(clippy::arc_with_non_send_sync)]
ModelContext(Arc::new(self.context))
}
}

pub(crate) struct FileContext<'a> {
context: &'a Context,
file: &'a mut File,
path: &'a Path,
mmap: Option<&'a Mmap>,
}
impl<'a> FileContext<'a> {
pub(crate) fn new(
context: &'a Context,
file: &'a mut File,
path: &'a Path,
mmap: Option<&'a Mmap>,
) -> Self {
pub(crate) fn new(context: &'a Context, file: &'a mut File, path: &'a Path) -> Self {
Self {
context,
file,
path,
mmap,
}
}

Expand Down Expand Up @@ -738,7 +730,7 @@ impl<'a> FileContext<'a> {
}
};

match self.mmap {
match self.context.storage().as_mmap() {
Some(mmap) => unsafe {
let ptr = mmap.as_ptr().offset(info.start_offset as isize);
tensor.set_data(ptr as *mut std::ffi::c_void);
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/lora.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl LoraAdapter {
// Create a temporary context for the patching operations
// TODO: test if GPU can be enabled (make it configurable)
let patch_context = ggml::Context::new_with_allocate(patch_context_size);
let mut patch_file = FileContext::new(&patch_context, &mut self.file, &self.path, None);
let mut patch_file = FileContext::new(&patch_context, &mut self.file, &self.path);

// Load the A and B tensors
let a = patch_file.get_tensor(&a_info)?;
Expand Down
11 changes: 11 additions & 0 deletions crates/llm-base/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
fmt::Debug,
io::{BufRead, Write},
path::{Path, PathBuf},
sync::Arc,
};

use ggml::accelerator::Backend;
Expand Down Expand Up @@ -263,3 +264,13 @@ pub struct OutputRequest {
/// `n_batch * n_embd`.
pub embeddings: Option<Vec<f32>>,
}

/// Contains the GGML context for a [`Model`]. Implements `Send` and `Sync`
/// to allow for the free transfer of models; this is made possible by this
/// context being effectively inert after creation, so that it cannot be
/// modified across threads.
#[derive(Clone)]
#[allow(clippy::arc_with_non_send_sync)]
pub struct ModelContext(pub(crate) Arc<ggml::Context>);
unsafe impl Send for ModelContext {}
unsafe impl Sync for ModelContext {}
8 changes: 3 additions & 5 deletions crates/models/bloom/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
//! for the `llm` ecosystem.
#![deny(missing_docs)]

use std::sync::Arc;

use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel,
ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
};

/// The BLOOM model. Ref: [Introducing BLOOM](https://bigscience.huggingface.co/blog/bloom)
Expand Down Expand Up @@ -37,7 +35,7 @@ pub struct Bloom {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for Bloom {}
Expand Down Expand Up @@ -101,7 +99,7 @@ impl KnownModel for Bloom {
output_norm_bias,
output,
layers,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 3 additions & 5 deletions crates/models/falcon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
//! supported. It is currently only available as a preview.
#![deny(missing_docs)]

use std::sync::Arc;

use ggml::Tensor;
use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
};

/// The Falcon model. Ref: [Technology Innovation Institute](https://huggingface.co/tiiuae)
Expand All @@ -39,7 +37,7 @@ pub struct Falcon {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for Falcon {}
Expand Down Expand Up @@ -138,7 +136,7 @@ impl KnownModel for Falcon {
output_norm_b,
lm_head,
layers,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 3 additions & 5 deletions crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
//! An implementation of [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2) for the `llm` ecosystem.
#![deny(missing_docs)]

use std::sync::Arc;

use ggml::Tensor;
use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
};

/// The GPT-2 model. Ref: [The Illustrated GPT-2](https://jalammar.github.io/illustrated-gpt2/)
Expand Down Expand Up @@ -38,7 +36,7 @@ pub struct Gpt2 {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for Gpt2 {}
Expand Down Expand Up @@ -123,7 +121,7 @@ impl KnownModel for Gpt2 {
wte,
wpe,
lm_head,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 4 additions & 4 deletions crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
//! An implementation of [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) for the `llm` ecosystem.
#![deny(missing_docs)]

use std::{error::Error, sync::Arc};
use std::error::Error;

use ggml::Tensor;
use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
};

/// The GPT-J model. Ref: [GitHub](https://github.com/kingoflolz/mesh-transformer-jax/#gpt-j-6b)
Expand All @@ -35,7 +35,7 @@ pub struct GptJ {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for GptJ {}
Expand Down Expand Up @@ -117,7 +117,7 @@ impl KnownModel for GptJ {
lmh_g,
lmh_b,
layers,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 4 additions & 4 deletions crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
//! This crate also supports the [RedPajama](https://www.together.xyz/blog/redpajama) GPT-NeoX model.
#![deny(missing_docs)]

use std::{error::Error, sync::Arc};
use std::error::Error;

use ggml::Tensor;
use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
};

/// The GPT-NeoX model. Ref: [GitHub](https://github.com/EleutherAI/gpt-neox)
Expand All @@ -35,7 +35,7 @@ pub struct GptNeoX {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for GptNeoX {}
Expand Down Expand Up @@ -137,7 +137,7 @@ impl KnownModel for GptNeoX {
wte,
lmh_g,
layers,
context: Arc::new(context),
context,
})
}

Expand Down
Loading

0 comments on commit 18b2a7d

Please sign in to comment.