This is the official repo for "SpatialBot: Precise Spatial Understanding with Vision Language Models".
SJTU, Stanford, BAAI, PKU, Oxford, SEU
Model: 🤗3B model in HF | 🤗3B ckpt in HF | 🤖3B model in wisemodel | 🤖3B ckpt in wisemodel
Training set SpatialQA: Please drop an email to [email protected]
Embodiment training set SpatialQA-E: 🤗SpatialQA-E in HF
Benchmark: 🤗SpatialBench in HF | 🤖SpatialBench in wisemodel
Paper: 📃General VQA + embodiment arXiv
Embodiment Videos Preview: ⚙ SpatialBot in embodiment
- Install dependencies first:
pip install torch transformers accelerate pillow numpy
-
Download SpatialBot-3B. Users in mainland China may want to download HF model from HF mirror site and change
model_name
to local path of SpatialBot-3B folder. -
Run the model:
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import warnings
import numpy as np
# disable some warnings
transformers.logging.set_verbosity_error()
transformers.logging.disable_progress_bar()
warnings.filterwarnings('ignore')
# set device
device = 'cuda' # or cpu
model_name = 'RussRobin/SpatialBot-3B'
offset_bos = 0
# create model
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # float32 for cpu
device_map='auto',
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True)
# text prompt
prompt = 'What is the depth value of point <0.5,0.2>? Answer directly from depth map.'
text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image 1>\n<image 2>\n{prompt} ASSISTANT:"
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image 1>\n<image 2>\n')]
input_ids = torch.tensor(text_chunks[0] + [-201] + [-202] + text_chunks[1][offset_bos:], dtype=torch.long).unsqueeze(0).to(device)
image1 = Image.open('rgb.jpg')
image2 = Image.open('depth.png')
channels = len(image2.getbands())
if channels == 1:
img = np.array(image2)
height, width = img.shape
three_channel_array = np.zeros((height, width, 3), dtype=np.uint8)
three_channel_array[:, :, 0] = (img // 1024) * 4
three_channel_array[:, :, 1] = (img // 32) * 8
three_channel_array[:, :, 2] = (img % 32) * 8
image2 = Image.fromarray(three_channel_array, 'RGB')
image_tensor = model.process_images([image1,image2], model.config).to(dtype=model.dtype, device=device)
# If 'Expected all tensors to be on the same device' error is thrown, uncomment the following line
# model.get_vision_tower().to('cuda')
# generate
output_ids = model.generate(
input_ids,
images=image_tensor,
max_new_tokens=100,
use_cache=True,
repetition_penalty=1.0 # increase this to avoid chattering
)[0]
print(tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip())
SpatialBot is a multi-image version of Bunny.
If you've installed Bunny, just replace the code with ours are reinstall bunny
package.
You can start from a docker or configure local environments.
We provide a ready to run docker container. Please update it with our codes:
# 1. download docker image
docker pull russellrobin/bunny:latest
# 2. run container
# docker run -itd ...
# docker exec -it ...
# 3. upgrade transformers and bunny package
cd SpatialBot && pip install --upgrade transformers && pip uninstall bunny && pip install -e .
Please follow the instructions, but use codes in this repo.
Please download the base LLM and vision tower weights first. To pretrain the model:
sh script/train/pretrain.sh
To finetune SpatialBot with LoRA:
sh script/train/finetune_lora.sh
Parameters:
MODEL_TYPE
: base LLM type, we support phi-2, phi-3,qwen1.5-0.5b, qwen1.5-1.8b (4B), and llama3-8b
.
PRETRAIN_DIR
: path to a pretrained model.
OUTPUT_DIR
: path to save model.
--model_name_or_path
: path to base LLM.
--vision_tower
: path to vision encoder. We support CLIP, SigLIP, and EVA-CLIP.
--version
: for Phi-2 and QWen, use bunny
. For Phi-3/Llama3
, please use phi3/llama
Please find finetuned ckpts of SpatialBot-3B-LoRA in HF, which is based on Phi-2 and SigLIP. You will need to modify paths in config.json
to run on your device.
Merged and ready to run model is available at SpatialBot-3B.
Pretrained models can be found in Model Zoo.
expand to see the instructions for continuously finetuing SpatialBot on your own data.
- Prepare data: convert your data to a
JSON
file of a list of all samples with the format like:
[
{
'id': 'continuous_1',
'image': ['/path/to/image_1','path_to_image_2'], # images are optional. We support 0-8 images, 0-2 recommended.
"conversations": [
{
"from": "human",
"value": "<image 1>\n<image 2>\nHello SpatialBot."
},
{
"from": "gpt",
"value": "Hi."
}
},
{...}, # other QAs
]
-
Prepare model:
-
download merged LoRA SpatialBot
-
add
"continuous_training": true
in/path/to/merged_model/config.json
to ensure loading the vision tower from merged weights
-
-
Edit script: both
finetune_full.sh
andfinetune_lora.sh
can be used, before:-
change
--model_name_or_path
to/path/to/merged_model
-
delete
--pretrain_mm_mlp_adapter
because we load the cross-modality projector from merged weights -
customize the hyperparameters, e.g. the learning rate, to fit your dataset
-
Please note that if you continuously fine-tune SpatialBot using LoRA, --model-base
should be SpatialBot models rather than the original LLMs when loading.
Please download SpatialBench and
put them under ./eval/spatial_bench
.
Use our SpatialBench script to evaluate on it.
Parameters:
--depth
: use this parameter is evaluating model with RGB-Depth input. Otherwise, RGB only.
Please follow our general instructions LoRA, or general instructions Full-parameter to prepare data and evaluate SpatialBot on SpatialBench and general VLM benchmarks.
Please refer to embodiment instructions to evaluate model on embodiment tasks.
To merge LoRA tuning models, see merge instructions
RGBD inference:
python -m bunny.serve.cli_depth \
--model-path /path/to/bunny_lora_weights \
--model-base /path/to/base_llm_model \
--model-type phi-2 \ # NOTE: or phi-3/llama3-8b/qwen1.5-1.8b
--conv-mode bunny \ # NOTE: or phi3/llama. bunny is for Phi-2 and QWen1.5
--image-file /path/to/the/test/rgb/image \
--depth-file /path/to/the/test/depth/image
RGB inference:
python -m bunny.serve.cli \
--model-path /path/to/bunny_lora_weights \
--model-base /path/to/base_llm_model \
--model-type phi-2 \ # NOTE: or phi-3/llama3-8b/qwen1.5-1.8b
--conv-mode bunny \ # NOTE: or phi3/llama. bunny is for Phi-2 and QWen1.5
--image-file /path/to/the/test/rgb/image \
Please reach out to us if you are interested in SpatialQA: [email protected]
.
Feel free to try SpatialBot-3B model, which is trained on SpatialQA.
We recommend using depth information from sensors if possible. Follow depthmap instructions to prepare estimated depth information on your own RGB images.
We collect SpatialQA-E, a robot manipulation dataset focusing on spatial understanding and reasoning.
Download: 🤗SpatialQA-E in HF
SpatialBot is finetuned on SpatialQA-E for pick-and-place abilities. It is a Vision-Language-Action (VLA) model, supporting multi-frame RGB or RGB-D inputs. One version of SpatialBot works by predicting delta position (or velocity) per each frame. Another version works by moving among predicted key points. CKPTs and the dataset will be available soon.
A preview of SpatialBot in embodiment is available.
If you find this repository helpful, please cite our paper.
@article{cai2024spatialbot,
title={SpatialBot: Precise Spatial Understanding with Vision Language Models},
author={Cai, Wenxiao and Ponomarenko, Yaroslav and Yuan, Jianhao and Li, Xiaoqi and Yang, Wankou and Dong, Hao and Zhao, Bo},
journal={arXiv preprint arXiv:2406.13642},
year={2024}
}
The project employs specific datasets and checkpoints that are governed by their original licenses. Users must adhere to all terms and conditions outlined in these licenses. The checkpoints are restricted to uses that comply with the license agreements of Bunny, LLaMA 3, Phi-2, Phi-3, QWen-1.5, and GPT-4. The dataset is provided under the CC-BY-4.0 license.
- The training of this work is built upon the Bunny: A family of lightweight multimodal models.
- This work utilizes LLMs from Phi-2, Phi-3, QWen-1.5-0.5B ,QWen-1.5-4B , and Llama-3-8B.