Skip to content

Commit

Permalink
feat: Support for bge-reranker-v2-m3 (#118)
Browse files Browse the repository at this point in the history
* adds support for bge-reranker-v2-m3

* impl Into<OnnxSource> to avoid breaking existing code

* adds bge-reranker-v2-m3 model to README
  • Loading branch information
rozgo authored Sep 29, 2024
1 parent 18cad72 commit bcd6304
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 11 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ authors = [
"Timon Vonk <[email protected]>",
"Luya Wang <[email protected]>",
"Tri <[email protected]>",
"Denny Wong <[email protected]>"
"Denny Wong <[email protected]>",
"Alex Rozgo <[email protected]>"
]
documentation = "https://docs.rs/fastembed"
repository = "https://github.com/Anush008/fastembed-rs"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ The default model is Flag Embedding, which is top of the [MTEB](https://huggingf
### Reranking

- [**BAAI/bge-reranker-base**](https://huggingface.co/BAAI/bge-reranker-base)
- [**BAAI/bge-reranker-v2-m3**](https://huggingface.co/BAAI/bge-reranker-v2-m3)
- [**jinaai/jina-reranker-v1-turbo-en**](https://huggingface.co/jinaai/jina-reranker-v1-turbo-en)
- [**jinaai/jina-reranker-v2-base-multiligual**](https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual)

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub use crate::models::{
pub use crate::output::{EmbeddingOutput, OutputKey, OutputPrecedence, SingleBatchOutput};
pub use crate::pooling::Pooling;
pub use crate::reranking::{
RerankInitOptions, RerankInitOptionsUserDefined, RerankResult, TextRerank,
OnnxSource, RerankInitOptions, RerankInitOptionsUserDefined, RerankResult, TextRerank,
UserDefinedRerankingModel,
};
pub use crate::sparse_text_embedding::{
Expand Down
13 changes: 13 additions & 0 deletions src/models/reranking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::fmt::Display;
pub enum RerankerModel {
/// BAAI/bge-reranker-base
BGERerankerBase,
/// rozgo/bge-reranker-v2-m3
BGERerankerV2M3,
/// jinaai/jina-reranker-v1-turbo-en
JINARerankerV1TurboEn,
/// jinaai/jina-reranker-v2-base-multilingual
Expand All @@ -17,18 +19,28 @@ pub fn reranker_model_list() -> Vec<RerankerModelInfo> {
description: String::from("reranker model for English and Chinese"),
model_code: String::from("BAAI/bge-reranker-base"),
model_file: String::from("onnx/model.onnx"),
additional_files: vec![],
},
RerankerModelInfo {
model: RerankerModel::BGERerankerV2M3,
description: String::from("reranker model for multilingual"),
model_code: String::from("rozgo/bge-reranker-v2-m3"),
model_file: String::from("model.onnx"),
additional_files: vec![String::from("model.onnx.data")],
},
RerankerModelInfo {
model: RerankerModel::JINARerankerV1TurboEn,
description: String::from("reranker model for English"),
model_code: String::from("jinaai/jina-reranker-v1-turbo-en"),
model_file: String::from("onnx/model.onnx"),
additional_files: vec![],
},
RerankerModelInfo {
model: RerankerModel::JINARerankerV2BaseMultiligual,
description: String::from("reranker model for multilingual"),
model_code: String::from("jinaai/jina-reranker-v2-base-multilingual"),
model_file: String::from("onnx/model.onnx"),
additional_files: vec![],
},
];
reranker_model_list
Expand All @@ -41,6 +53,7 @@ pub struct RerankerModelInfo {
pub description: String,
pub model_code: String,
pub model_file: String,
pub additional_files: Vec<String>,
}

impl Display for RerankerModel {
Expand Down
17 changes: 14 additions & 3 deletions src/reranking/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use tokenizers::Tokenizer;
#[cfg(feature = "online")]
use super::RerankInitOptions;
use super::{
RerankInitOptionsUserDefined, RerankResult, TextRerank, UserDefinedRerankingModel,
OnnxSource, RerankInitOptionsUserDefined, RerankResult, TextRerank, UserDefinedRerankingModel,
DEFAULT_BATCH_SIZE,
};

Expand Down Expand Up @@ -70,6 +70,13 @@ impl TextRerank {
let model_file_reference = model_repo
.get(&model_file_name)
.unwrap_or_else(|_| panic!("Failed to retrieve model file: {}", model_file_name));
let additional_files = TextRerank::get_model_info(&model_name).additional_files;
for additional_file in additional_files {
let _additional_file_reference =
model_repo.get(&additional_file).unwrap_or_else(|_| {
panic!("Failed to retrieve additional file: {}", additional_file)
});
}

let session = Session::builder()?
.with_execution_providers(execution_providers)?
Expand Down Expand Up @@ -98,8 +105,12 @@ impl TextRerank {
let session = Session::builder()?
.with_execution_providers(execution_providers)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(threads)?
.commit_from_memory(&model.onnx_file)?;
.with_intra_threads(threads)?;

let session = match &model.onnx_source {
OnnxSource::Memory(bytes) => session.commit_from_memory(bytes)?,
OnnxSource::File(path) => session.commit_from_file(path)?,
};

let tokenizer = load_tokenizer(model.tokenizer_files, max_length)?;
Ok(Self::new(tokenizer, session))
Expand Down
27 changes: 24 additions & 3 deletions src/reranking/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,41 @@ impl From<RerankInitOptions> for RerankInitOptionsUserDefined {
}
}

/// Enum for the source of the onnx file
///
/// User-defined models can either be in memory or on disk
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OnnxSource {
Memory(Vec<u8>),
File(PathBuf),
}

impl From<Vec<u8>> for OnnxSource {
fn from(bytes: Vec<u8>) -> Self {
OnnxSource::Memory(bytes)
}
}

impl From<PathBuf> for OnnxSource {
fn from(path: PathBuf) -> Self {
OnnxSource::File(path)
}
}

/// Struct for "bring your own" reranking models
///
/// The onnx_file and tokenizer_files are expecting the files' bytes
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct UserDefinedRerankingModel {
pub onnx_file: Vec<u8>,
pub onnx_source: OnnxSource,
pub tokenizer_files: TokenizerFiles,
}

impl UserDefinedRerankingModel {
pub fn new(onnx_file: Vec<u8>, tokenizer_files: TokenizerFiles) -> Self {
pub fn new(onnx_source: impl Into<OnnxSource>, tokenizer_files: TokenizerFiles) -> Self {
Self {
onnx_file,
onnx_source: onnx_source.into(),
tokenizer_files,
}
}
Expand Down
72 changes: 69 additions & 3 deletions tests/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use rayon::iter::{IntoParallelRefIterator, ParallelIterator};

use fastembed::{
read_file_to_bytes, Embedding, EmbeddingModel, ImageEmbedding, ImageInitOptions, InitOptions,
InitOptionsUserDefined, Pooling, QuantizationMode, RerankInitOptions,
InitOptionsUserDefined, OnnxSource, Pooling, QuantizationMode, RerankInitOptions,
RerankInitOptionsUserDefined, RerankerModel, SparseInitOptions, SparseTextEmbedding,
TextEmbedding, TextRerank, TokenizerFiles, UserDefinedEmbeddingModel,
UserDefinedRerankingModel, DEFAULT_CACHE_DIR,
Expand Down Expand Up @@ -284,6 +284,8 @@ fn test_rerank() {
.par_iter()
.for_each(|supported_model| {

println!("supported_model: {:?}", supported_model);

let result = TextRerank::try_new(RerankInitOptions::new(supported_model.model.clone()))
.unwrap();

Expand All @@ -300,14 +302,78 @@ fn test_rerank() {
.unwrap();

assert_eq!(results.len(), documents.len(), "rerank model {:?} failed", supported_model);
assert_eq!(results[0].document.as_ref().unwrap(), "panda is an animal");
assert_eq!(results[1].document.as_ref().unwrap(), "The giant panda, sometimes called a panda bear or simply panda, is a bear species endemic to China.");

let option_a = "panda is an animal";
let option_b = "The giant panda, sometimes called a panda bear or simply panda, is a bear species endemic to China.";

assert!(
results[0].document.as_ref().unwrap() == option_a ||
results[0].document.as_ref().unwrap() == option_b
);
assert!(
results[1].document.as_ref().unwrap() == option_a ||
results[1].document.as_ref().unwrap() == option_b
);
assert_ne!(results[0].document, results[1].document, "The top two results should be different");

// Clear the model cache to avoid running out of space on GitHub Actions.
clean_cache(supported_model.model_code.clone())
});
}

#[test]
fn test_user_defined_reranking_large_model() {
// Setup model to download from Hugging Face
let cache = hf_hub::Cache::new(std::path::PathBuf::from(fastembed::DEFAULT_CACHE_DIR));
let api = hf_hub::api::sync::ApiBuilder::from_cache(cache)
.with_progress(true)
.build()
.expect("Failed to build API from cache");
let model_repo = api.model("rozgo/bge-reranker-v2-m3".to_string());

// Download the onnx model file
let onnx_file = model_repo.download("model.onnx").unwrap();
// Onnx model exceeds the limit of 2GB for a file, so we need to download the data file separately
let _onnx_data_file = model_repo.get("model.onnx.data").unwrap();

// OnnxSource::File is used to load the onnx file using onnx session builder commit_from_file
let onnx_source = OnnxSource::File(onnx_file);

// Load the tokenizer files
let tokenizer_files: TokenizerFiles = TokenizerFiles {
tokenizer_file: read_file_to_bytes(&model_repo.get("tokenizer.json").unwrap()).unwrap(),
config_file: read_file_to_bytes(&model_repo.get("config.json").unwrap()).unwrap(),
special_tokens_map_file: read_file_to_bytes(
&model_repo.get("special_tokens_map.json").unwrap(),
)
.unwrap(),

tokenizer_config_file: read_file_to_bytes(
&model_repo.get("tokenizer_config.json").unwrap(),
)
.unwrap(),
};

let model = UserDefinedRerankingModel::new(onnx_source, tokenizer_files);

let user_defined_reranker =
TextRerank::try_new_from_user_defined(model, Default::default()).unwrap();

let documents = vec![
"Hello, World!",
"This is an example passage.",
"fastembed-rs is licensed under Apache-2.0",
"Some other short text here blah blah blah",
];

let results = user_defined_reranker
.rerank("Ciao, Earth!", documents.clone(), false, None)
.unwrap();

assert_eq!(results.len(), documents.len());
assert_eq!(results.first().unwrap().index, 0);
}

#[test]
fn test_user_defined_reranking_model() {
// Constitute the model in order to ensure it's downloaded and cached
Expand Down

0 comments on commit bcd6304

Please sign in to comment.