Skip to content

Latest commit

 

History

History
 
 

bert

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 

BERT

An implementation of BERT (Devlin, et al., 2019) within MLX.

Downloading and Converting Weights

The convert.py script relies on transformers to download the weights, and exports them as a single .npz file.

python convert.py \
    --bert-model bert-base-uncased \
    --mlx-model weights/bert-base-uncased.npz

Usage

To use the Bert model in your own code, you can load it with:

from model import Bert, load_model

model, tokenizer = load_model(
    "bert-base-uncased",
    "weights/bert-base-uncased.npz")

batch = ["This is an example of BERT working on MLX."]
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}

output, pooled = model(**tokens)

The output contains a Batch x Tokens x Dims tensor, representing a vector for every input token. If you want to train anything at a token-level, you'll want to use this.

The pooled contains a Batch x Dims tensor, which is the pooled representation for each input. If you want to train a classification model, you'll want to use this.

Comparison with 🤗 transformers Implementation

In order to run the model, and have it forward inference on a batch of examples:

python model.py \
  --bert-model bert-base-uncased \
  --mlx-model weights/bert-base-uncased.npz

Which will show the following outputs:

MLX BERT:
[[[-0.52508914 -0.1993871  -0.28210318 ... -0.61125606  0.19114694
    0.8227601 ]
  [-0.8783862  -0.37107834 -0.52238125 ... -0.5067165   1.0847603
    0.31066895]
  [-0.70010054 -0.5424497  -0.26593682 ... -0.2688697   0.38338926
    0.6557663 ]
  ...

They can be compared against the 🤗 implementation with:

python hf_model.py \
  --bert-model bert-base-uncased

Which will show:

 HF BERT:
[[[-0.52508944 -0.1993877  -0.28210333 ... -0.6112575   0.19114678
    0.8227603 ]
  [-0.878387   -0.371079   -0.522381   ... -0.50671494  1.0847601
    0.31066933]
  [-0.7001008  -0.5424504  -0.26593733 ... -0.26887015  0.38339025
    0.65576553]
  ...