Skip to content

Commit

Permalink
Merge pull request #15 from tracel-ai/bert
Browse files Browse the repository at this point in the history
Add BERT family of models
  • Loading branch information
ashdtu authored Feb 13, 2024
2 parents bdd62f1 + 833cf80 commit b225ed6
Show file tree
Hide file tree
Showing 13 changed files with 989 additions and 5 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@ examples constructed using the [Burn](https://github.com/burn-rs/burn) deep lear

## Collection of Official Models

| Model | Description | Repository Link |
| ---------------------------------------------- | ------------------------------------------------- | -------------------------------------------- |
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| Model | Description | Repository Link |
|------------------------------------------------|-------------------------------------------------------|----------------------------------------------|
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/README.md) |

## Community Contributions

Explore the curated list of models developed by the community ♥.

| Model | Description | Repository Link |
| ------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------------------------------- |
|---------------------------------------------|-------------------------------------------------------------------|-----------------------------------------------------------------------------------|
| [Llama 2](https://arxiv.org/abs/2307.09288) | LLMs by Meta AI, ranging from 7 billion to 70 billion parameters. | [Gadersd/llama2-burn](https://github.com/Gadersd/llama2-burn) |
| [Whisper](https://arxiv.org/abs/2212.04356) | A general-purpose speech recognition model by OpenAI. | [Gadersd/whisper-burn](https://github.com/Gadersd/whisper-burn) |
| Stable Diffusion v1.4 | An image generation model developed by Stability AI. | [Gadersd/stable-diffusion-burn](https://github.com/Gadersd/stable-diffusion-burn) |
Expand Down
37 changes: 37 additions & 0 deletions bert-burn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
[package]
authors = ["Aasheesh Singh [email protected]"]
license = "MIT OR Apache-2.0"
name = "bert-burn"
version="0.1.0"
edition = "2021"

[features]
default = ["burn/dataset"]
f16 = []
ndarray = ["burn/ndarray"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
fusion = ["burn/fusion"]
# To be replaced by burn-safetensors once supported: https://github.com/tracel-ai/burn/issues/626
safetensors = ["candle-core/default"]


[dependencies]
# Burn
burn = {version = "0.12.1", default-features = false}
candle-core = { version = "0.3.2", optional = true}
# Tokenizer
tokenizers = { version = "0.15.0", default-features = false, features = [
"onig",
"http",
] }
burn-import = "0.12.1"
derive-new = "0.6.0"
hf-hub = { version = "0.3.2", features = ["tokio"] }

# Utils
serde = { version = "1.0.196", features = ["std", "derive"] }
libm = "0.2.8"
serde_json = "1.0.113"
tokio = "1.35.1"
1 change: 1 addition & 0 deletions bert-burn/LICENSE-APACHE
1 change: 1 addition & 0 deletions bert-burn/LICENSE-MIT
39 changes: 39 additions & 0 deletions bert-burn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Bert-Burn Model

This project provides an example implementation for inference on the BERT family of models. The following compatible
bert-variants: `roberta-base`(**default**)/`roberta-large`, `bert-base-uncased`/`bert-large-uncased`/`bert-base-cased`/`bert-large-cased`
can be loaded as following. The pre-trained weights and config files are automatically downloaded
from: [HuggingFace Model hub](https://huggingface.co/FacebookAI/roberta-base/tree/main)

### To include the model in your project

Add this to your `Cargo.toml`:

```toml
[dependencies]
bert-burn = { git = "https://github.com/burn-rs/models", package = "bert-burn", default-features = false }
```

## Example Usage

Example usage for getting sentence embedding from given input text. The model supports multiple backends from burn
(e.g. `ndarray`, `wgpu`, `tch-gpu`, `tch-cpu`) which can be selected using the `--features` flag. An example with `wgpu`
backend is shown below. The `fusion` flag is used to enable kernel fusion for the `wgpu` backend. It is not required
with other backends. The `safetensors` flag is used to support loading weights in `safetensors` format via `candle-core`
crate.

### WGPU backend

```bash
cd bert-burn/
# Get sentence embeddings from the RobBERTa encoder (default)
cargo run --example infer-embedding --release --features wgpu,fusion,safetensors

# Using bert-base-uncased model
cargo run --example infer-embedding --release --features wgpu,fusion,safetensors bert-base-uncased

# Using roberta-large model
cargo run --example infer-embedding --release --features wgpu,fusion,safetensors roberta-large
```


154 changes: 154 additions & 0 deletions bert-burn/examples/infer-embedding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
use bert_burn::data::{BertInputBatcher, BertTokenizer};
use bert_burn::loader::{download_hf_model, load_model_config};
use bert_burn::model::BertModel;
use burn::data::dataloader::batcher::Batcher;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use std::env;
use std::sync::Arc;

#[cfg(not(feature = "f16"))]
#[allow(dead_code)]
type ElemType = f32;
#[cfg(feature = "f16")]
type ElemType = burn::tensor::f16;

pub fn launch<B: Backend>(device: B::Device) {
let args: Vec<String> = env::args().collect();
let default_model = "roberta-base".to_string();
let model_variant = if args.len() > 1 {
// Use the argument provided by the user
// Possible values: "bert-base-uncased", "roberta-large" etc.
&args[1]
} else {
// Use the default value if no argument is provided
&default_model
};

println!("Model variant: {}", model_variant);

let text_samples = vec![
"Jays power up to take finale Contrary to popular belief, the power never really \
snapped back at SkyDome on Sunday. The lights came on after an hour delay, but it \
took some extra time for the batting orders to provide some extra wattage."
.to_string(),
"Yemen Sentences 15 Militants on Terror Charges A court in Yemen has sentenced one \
man to death and 14 others to prison terms for a series of attacks and terrorist \
plots in 2002, including the bombing of a French oil tanker."
.to_string(),
"IBM puts grids to work at U.S. Open IBM will put a collection of its On \
Demand-related products and technologies to this test next week at the U.S. Open \
tennis championships, implementing a grid-based infrastructure capable of running \
multiple workloads including two not associated with the tournament."
.to_string(),
];

let (config_file, model_file) = download_hf_model(model_variant);
let model_config = load_model_config(config_file);

let model: BertModel<B> =
BertModel::from_safetensors(model_file, &device, model_config.clone());

let tokenizer = Arc::new(BertTokenizer::new(
model_variant.to_string(),
model_config.pad_token_id.clone(),
));

// Batch the input samples to max sequence length with padding
let batcher = Arc::new(BertInputBatcher::<B>::new(
tokenizer.clone(),
device.clone(),
model_config.max_seq_len.unwrap().clone(),
));

// Batch input samples using the batcher Shape: [Batch size, Seq_len]
let input = batcher.batch(text_samples.clone());
let [batch_size, _seq_len] = input.tokens.dims();
println!("Input: {:?} // (Batch Size, Seq_len)", input.tokens.shape());

let output = model.forward(input);

// get sentence embedding from the first [CLS] token
let cls_token_idx = 0;

// Embedding size
let d_model = model_config.hidden_size.clone();
let sentence_embedding =
output
.clone()
.slice([0..batch_size, cls_token_idx..cls_token_idx + 1, 0..d_model]);

let sentence_embedding: Tensor<B, 2> = sentence_embedding.squeeze(1);
println!(
"Roberta Sentence embedding {:?} // (Batch Size, Embedding_dim)",
sentence_embedding.shape()
);
}

#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::ndarray::{NdArray, NdArrayDevice};

use crate::{launch, ElemType};

pub fn run() {
launch::<NdArray<ElemType>>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use crate::{launch, ElemType};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<LibTorch<ElemType>>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use crate::{launch, ElemType};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

pub fn run() {
launch::<LibTorch<ElemType>>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::{launch, ElemType};
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn::backend::Fusion;

pub fn run() {
launch::<Fusion<Wgpu<AutoGraphicsApi, ElemType, i32>>>(WgpuDevice::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "tch-cpu")]
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
}
51 changes: 51 additions & 0 deletions bert-burn/src/data/batcher.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use super::tokenizer::Tokenizer;
use burn::{
data::dataloader::batcher::Batcher,
nn::attention::generate_padding_mask,
tensor::{backend::Backend, Bool, Int, Tensor},
};
use std::sync::Arc;

#[derive(new)]
pub struct BertInputBatcher<B: Backend> {
/// Tokenizer for converting input text string to token IDs
tokenizer: Arc<dyn Tokenizer>,
/// Device on which to perform computation (e.g., CPU or CUDA device)
device: B::Device,
/// Maximum sequence length for tokenized text
max_seq_length: usize,
}

#[derive(Debug, Clone, new)]
pub struct BertInferenceBatch<B: Backend> {
/// Tokenized text as 2D tensor: [batch_size, max_seq_length]
pub tokens: Tensor<B, 2, Int>,
/// Padding mask for the tokenized text containing booleans for padding locations
pub mask_pad: Tensor<B, 2, Bool>,
}

impl<B: Backend> Batcher<String, BertInferenceBatch<B>> for BertInputBatcher<B> {
/// Batches a vector of strings into an inference batch
fn batch(&self, items: Vec<String>) -> BertInferenceBatch<B> {
let mut tokens_list = Vec::with_capacity(items.len());

// Tokenize each string
for item in items {
tokens_list.push(self.tokenizer.encode(&item));
}

// Generate padding mask for tokenized text
let mask = generate_padding_mask(
self.tokenizer.pad_token(),
tokens_list,
Some(self.max_seq_length),
&self.device,
);

// Create and return inference batch
BertInferenceBatch {
tokens: mask.tensor,
mask_pad: mask.mask,
}
}
}
5 changes: 5 additions & 0 deletions bert-burn/src/data/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod batcher;
mod tokenizer;

pub use batcher::*;
pub use tokenizer::*;
64 changes: 64 additions & 0 deletions bert-burn/src/data/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
pub trait Tokenizer: Send + Sync {
/// Converts a text string into a sequence of tokens.
fn encode(&self, value: &str) -> Vec<usize>;

/// Converts a sequence of tokens back into a text string.
fn decode(&self, tokens: &[usize]) -> String;

/// Gets the size of the tokenizer's vocabulary.
fn vocab_size(&self) -> usize;

/// Gets the token used for padding sequences to a consistent length.
fn pad_token(&self) -> usize;

/// Gets the string representation of the padding token.
/// The default implementation uses `decode` on the padding token.
fn pad_token_value(&self) -> String {
self.decode(&[self.pad_token()])
}
}

/// Struct represents a specific tokenizer using the Roberta BPE tokenization strategy.
pub struct BertTokenizer {
// The underlying tokenizer from the `tokenizers` library.
tokenizer: tokenizers::Tokenizer,
pad_token: usize,
}

// Default implementation for creating a new BertTokenizer.
// Downloads tokenizer from given model_name (eg: "roberta-base").
// Pad_token_id is the id of the padding token used to convert sequences to a consistent length.
// specified in the model's config.json.
impl BertTokenizer {
pub fn new(model_name: String, pad_token_id: usize) -> Self {
Self {
tokenizer: tokenizers::Tokenizer::from_pretrained(model_name, None).unwrap(),
pad_token: pad_token_id,
}
}
}

// Implementation of the Tokenizer trait for BertTokenizer.
impl Tokenizer for BertTokenizer {
/// Convert a text string into a sequence of tokens using the BERT model's tokenization strategy.
fn encode(&self, value: &str) -> Vec<usize> {
let tokens = self.tokenizer.encode(value, true).unwrap();
tokens.get_ids().iter().map(|t| *t as usize).collect()
}

/// Converts a sequence of tokens back into a text string.
fn decode(&self, tokens: &[usize]) -> String {
let tokens = tokens.iter().map(|t| *t as u32).collect::<Vec<u32>>();
self.tokenizer.decode(&tokens, false).unwrap()
}

/// Gets the size of the BERT tokenizer's vocabulary.
fn vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(true)
}

/// Gets the token used for padding sequences to a consistent length.
fn pad_token(&self) -> usize {
self.pad_token
}
}
Loading

0 comments on commit b225ed6

Please sign in to comment.