-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from tracel-ai/bert
Add BERT family of models
- Loading branch information
Showing
13 changed files
with
989 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
[package] | ||
authors = ["Aasheesh Singh [email protected]"] | ||
license = "MIT OR Apache-2.0" | ||
name = "bert-burn" | ||
version="0.1.0" | ||
edition = "2021" | ||
|
||
[features] | ||
default = ["burn/dataset"] | ||
f16 = [] | ||
ndarray = ["burn/ndarray"] | ||
tch-cpu = ["burn/tch"] | ||
tch-gpu = ["burn/tch"] | ||
wgpu = ["burn/wgpu"] | ||
fusion = ["burn/fusion"] | ||
# To be replaced by burn-safetensors once supported: https://github.com/tracel-ai/burn/issues/626 | ||
safetensors = ["candle-core/default"] | ||
|
||
|
||
[dependencies] | ||
# Burn | ||
burn = {version = "0.12.1", default-features = false} | ||
candle-core = { version = "0.3.2", optional = true} | ||
# Tokenizer | ||
tokenizers = { version = "0.15.0", default-features = false, features = [ | ||
"onig", | ||
"http", | ||
] } | ||
burn-import = "0.12.1" | ||
derive-new = "0.6.0" | ||
hf-hub = { version = "0.3.2", features = ["tokio"] } | ||
|
||
# Utils | ||
serde = { version = "1.0.196", features = ["std", "derive"] } | ||
libm = "0.2.8" | ||
serde_json = "1.0.113" | ||
tokio = "1.35.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../LICENSE-APACHE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../LICENSE-MIT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Bert-Burn Model | ||
|
||
This project provides an example implementation for inference on the BERT family of models. The following compatible | ||
bert-variants: `roberta-base`(**default**)/`roberta-large`, `bert-base-uncased`/`bert-large-uncased`/`bert-base-cased`/`bert-large-cased` | ||
can be loaded as following. The pre-trained weights and config files are automatically downloaded | ||
from: [HuggingFace Model hub](https://huggingface.co/FacebookAI/roberta-base/tree/main) | ||
|
||
### To include the model in your project | ||
|
||
Add this to your `Cargo.toml`: | ||
|
||
```toml | ||
[dependencies] | ||
bert-burn = { git = "https://github.com/burn-rs/models", package = "bert-burn", default-features = false } | ||
``` | ||
|
||
## Example Usage | ||
|
||
Example usage for getting sentence embedding from given input text. The model supports multiple backends from burn | ||
(e.g. `ndarray`, `wgpu`, `tch-gpu`, `tch-cpu`) which can be selected using the `--features` flag. An example with `wgpu` | ||
backend is shown below. The `fusion` flag is used to enable kernel fusion for the `wgpu` backend. It is not required | ||
with other backends. The `safetensors` flag is used to support loading weights in `safetensors` format via `candle-core` | ||
crate. | ||
|
||
### WGPU backend | ||
|
||
```bash | ||
cd bert-burn/ | ||
# Get sentence embeddings from the RobBERTa encoder (default) | ||
cargo run --example infer-embedding --release --features wgpu,fusion,safetensors | ||
|
||
# Using bert-base-uncased model | ||
cargo run --example infer-embedding --release --features wgpu,fusion,safetensors bert-base-uncased | ||
|
||
# Using roberta-large model | ||
cargo run --example infer-embedding --release --features wgpu,fusion,safetensors roberta-large | ||
``` | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
use bert_burn::data::{BertInputBatcher, BertTokenizer}; | ||
use bert_burn::loader::{download_hf_model, load_model_config}; | ||
use bert_burn::model::BertModel; | ||
use burn::data::dataloader::batcher::Batcher; | ||
use burn::tensor::backend::Backend; | ||
use burn::tensor::Tensor; | ||
use std::env; | ||
use std::sync::Arc; | ||
|
||
#[cfg(not(feature = "f16"))] | ||
#[allow(dead_code)] | ||
type ElemType = f32; | ||
#[cfg(feature = "f16")] | ||
type ElemType = burn::tensor::f16; | ||
|
||
pub fn launch<B: Backend>(device: B::Device) { | ||
let args: Vec<String> = env::args().collect(); | ||
let default_model = "roberta-base".to_string(); | ||
let model_variant = if args.len() > 1 { | ||
// Use the argument provided by the user | ||
// Possible values: "bert-base-uncased", "roberta-large" etc. | ||
&args[1] | ||
} else { | ||
// Use the default value if no argument is provided | ||
&default_model | ||
}; | ||
|
||
println!("Model variant: {}", model_variant); | ||
|
||
let text_samples = vec![ | ||
"Jays power up to take finale Contrary to popular belief, the power never really \ | ||
snapped back at SkyDome on Sunday. The lights came on after an hour delay, but it \ | ||
took some extra time for the batting orders to provide some extra wattage." | ||
.to_string(), | ||
"Yemen Sentences 15 Militants on Terror Charges A court in Yemen has sentenced one \ | ||
man to death and 14 others to prison terms for a series of attacks and terrorist \ | ||
plots in 2002, including the bombing of a French oil tanker." | ||
.to_string(), | ||
"IBM puts grids to work at U.S. Open IBM will put a collection of its On \ | ||
Demand-related products and technologies to this test next week at the U.S. Open \ | ||
tennis championships, implementing a grid-based infrastructure capable of running \ | ||
multiple workloads including two not associated with the tournament." | ||
.to_string(), | ||
]; | ||
|
||
let (config_file, model_file) = download_hf_model(model_variant); | ||
let model_config = load_model_config(config_file); | ||
|
||
let model: BertModel<B> = | ||
BertModel::from_safetensors(model_file, &device, model_config.clone()); | ||
|
||
let tokenizer = Arc::new(BertTokenizer::new( | ||
model_variant.to_string(), | ||
model_config.pad_token_id.clone(), | ||
)); | ||
|
||
// Batch the input samples to max sequence length with padding | ||
let batcher = Arc::new(BertInputBatcher::<B>::new( | ||
tokenizer.clone(), | ||
device.clone(), | ||
model_config.max_seq_len.unwrap().clone(), | ||
)); | ||
|
||
// Batch input samples using the batcher Shape: [Batch size, Seq_len] | ||
let input = batcher.batch(text_samples.clone()); | ||
let [batch_size, _seq_len] = input.tokens.dims(); | ||
println!("Input: {:?} // (Batch Size, Seq_len)", input.tokens.shape()); | ||
|
||
let output = model.forward(input); | ||
|
||
// get sentence embedding from the first [CLS] token | ||
let cls_token_idx = 0; | ||
|
||
// Embedding size | ||
let d_model = model_config.hidden_size.clone(); | ||
let sentence_embedding = | ||
output | ||
.clone() | ||
.slice([0..batch_size, cls_token_idx..cls_token_idx + 1, 0..d_model]); | ||
|
||
let sentence_embedding: Tensor<B, 2> = sentence_embedding.squeeze(1); | ||
println!( | ||
"Roberta Sentence embedding {:?} // (Batch Size, Embedding_dim)", | ||
sentence_embedding.shape() | ||
); | ||
} | ||
|
||
#[cfg(any( | ||
feature = "ndarray", | ||
feature = "ndarray-blas-netlib", | ||
feature = "ndarray-blas-openblas", | ||
feature = "ndarray-blas-accelerate", | ||
))] | ||
mod ndarray { | ||
use burn::backend::ndarray::{NdArray, NdArrayDevice}; | ||
|
||
use crate::{launch, ElemType}; | ||
|
||
pub fn run() { | ||
launch::<NdArray<ElemType>>(NdArrayDevice::Cpu); | ||
} | ||
} | ||
|
||
#[cfg(feature = "tch-gpu")] | ||
mod tch_gpu { | ||
use crate::{launch, ElemType}; | ||
use burn::backend::libtorch::{LibTorch, LibTorchDevice}; | ||
|
||
pub fn run() { | ||
#[cfg(not(target_os = "macos"))] | ||
let device = LibTorchDevice::Cuda(0); | ||
#[cfg(target_os = "macos")] | ||
let device = LibTorchDevice::Mps; | ||
|
||
launch::<LibTorch<ElemType>>(device); | ||
} | ||
} | ||
|
||
#[cfg(feature = "tch-cpu")] | ||
mod tch_cpu { | ||
use crate::{launch, ElemType}; | ||
use burn::backend::libtorch::{LibTorch, LibTorchDevice}; | ||
|
||
pub fn run() { | ||
launch::<LibTorch<ElemType>>(LibTorchDevice::Cpu); | ||
} | ||
} | ||
|
||
#[cfg(feature = "wgpu")] | ||
mod wgpu { | ||
use crate::{launch, ElemType}; | ||
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; | ||
use burn::backend::Fusion; | ||
|
||
pub fn run() { | ||
launch::<Fusion<Wgpu<AutoGraphicsApi, ElemType, i32>>>(WgpuDevice::default()); | ||
} | ||
} | ||
|
||
fn main() { | ||
#[cfg(any( | ||
feature = "ndarray", | ||
feature = "ndarray-blas-netlib", | ||
feature = "ndarray-blas-openblas", | ||
feature = "ndarray-blas-accelerate", | ||
))] | ||
ndarray::run(); | ||
#[cfg(feature = "tch-gpu")] | ||
tch_gpu::run(); | ||
#[cfg(feature = "tch-cpu")] | ||
tch_cpu::run(); | ||
#[cfg(feature = "wgpu")] | ||
wgpu::run(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
use super::tokenizer::Tokenizer; | ||
use burn::{ | ||
data::dataloader::batcher::Batcher, | ||
nn::attention::generate_padding_mask, | ||
tensor::{backend::Backend, Bool, Int, Tensor}, | ||
}; | ||
use std::sync::Arc; | ||
|
||
#[derive(new)] | ||
pub struct BertInputBatcher<B: Backend> { | ||
/// Tokenizer for converting input text string to token IDs | ||
tokenizer: Arc<dyn Tokenizer>, | ||
/// Device on which to perform computation (e.g., CPU or CUDA device) | ||
device: B::Device, | ||
/// Maximum sequence length for tokenized text | ||
max_seq_length: usize, | ||
} | ||
|
||
#[derive(Debug, Clone, new)] | ||
pub struct BertInferenceBatch<B: Backend> { | ||
/// Tokenized text as 2D tensor: [batch_size, max_seq_length] | ||
pub tokens: Tensor<B, 2, Int>, | ||
/// Padding mask for the tokenized text containing booleans for padding locations | ||
pub mask_pad: Tensor<B, 2, Bool>, | ||
} | ||
|
||
impl<B: Backend> Batcher<String, BertInferenceBatch<B>> for BertInputBatcher<B> { | ||
/// Batches a vector of strings into an inference batch | ||
fn batch(&self, items: Vec<String>) -> BertInferenceBatch<B> { | ||
let mut tokens_list = Vec::with_capacity(items.len()); | ||
|
||
// Tokenize each string | ||
for item in items { | ||
tokens_list.push(self.tokenizer.encode(&item)); | ||
} | ||
|
||
// Generate padding mask for tokenized text | ||
let mask = generate_padding_mask( | ||
self.tokenizer.pad_token(), | ||
tokens_list, | ||
Some(self.max_seq_length), | ||
&self.device, | ||
); | ||
|
||
// Create and return inference batch | ||
BertInferenceBatch { | ||
tokens: mask.tensor, | ||
mask_pad: mask.mask, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
mod batcher; | ||
mod tokenizer; | ||
|
||
pub use batcher::*; | ||
pub use tokenizer::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
pub trait Tokenizer: Send + Sync { | ||
/// Converts a text string into a sequence of tokens. | ||
fn encode(&self, value: &str) -> Vec<usize>; | ||
|
||
/// Converts a sequence of tokens back into a text string. | ||
fn decode(&self, tokens: &[usize]) -> String; | ||
|
||
/// Gets the size of the tokenizer's vocabulary. | ||
fn vocab_size(&self) -> usize; | ||
|
||
/// Gets the token used for padding sequences to a consistent length. | ||
fn pad_token(&self) -> usize; | ||
|
||
/// Gets the string representation of the padding token. | ||
/// The default implementation uses `decode` on the padding token. | ||
fn pad_token_value(&self) -> String { | ||
self.decode(&[self.pad_token()]) | ||
} | ||
} | ||
|
||
/// Struct represents a specific tokenizer using the Roberta BPE tokenization strategy. | ||
pub struct BertTokenizer { | ||
// The underlying tokenizer from the `tokenizers` library. | ||
tokenizer: tokenizers::Tokenizer, | ||
pad_token: usize, | ||
} | ||
|
||
// Default implementation for creating a new BertTokenizer. | ||
// Downloads tokenizer from given model_name (eg: "roberta-base"). | ||
// Pad_token_id is the id of the padding token used to convert sequences to a consistent length. | ||
// specified in the model's config.json. | ||
impl BertTokenizer { | ||
pub fn new(model_name: String, pad_token_id: usize) -> Self { | ||
Self { | ||
tokenizer: tokenizers::Tokenizer::from_pretrained(model_name, None).unwrap(), | ||
pad_token: pad_token_id, | ||
} | ||
} | ||
} | ||
|
||
// Implementation of the Tokenizer trait for BertTokenizer. | ||
impl Tokenizer for BertTokenizer { | ||
/// Convert a text string into a sequence of tokens using the BERT model's tokenization strategy. | ||
fn encode(&self, value: &str) -> Vec<usize> { | ||
let tokens = self.tokenizer.encode(value, true).unwrap(); | ||
tokens.get_ids().iter().map(|t| *t as usize).collect() | ||
} | ||
|
||
/// Converts a sequence of tokens back into a text string. | ||
fn decode(&self, tokens: &[usize]) -> String { | ||
let tokens = tokens.iter().map(|t| *t as u32).collect::<Vec<u32>>(); | ||
self.tokenizer.decode(&tokens, false).unwrap() | ||
} | ||
|
||
/// Gets the size of the BERT tokenizer's vocabulary. | ||
fn vocab_size(&self) -> usize { | ||
self.tokenizer.get_vocab_size(true) | ||
} | ||
|
||
/// Gets the token used for padding sequences to a consistent length. | ||
fn pad_token(&self) -> usize { | ||
self.pad_token | ||
} | ||
} |
Oops, something went wrong.