forked from openvla/openvla
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_openvla_weights_to_hf.py
272 lines (223 loc) · 13 KB
/
convert_openvla_weights_to_hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
"""
convert_openvla_weights_to_hf.py
Utility script for converting full OpenVLA VLA weights (from this repository, in the default "Prismatic" format) to
the HuggingFace "AutoClasses" (e.g., those defined in `prismatic.extern.hf_*`) for "native" use in `transformers``
via `trust_remote_code = True`.
Theoretically, these changes should be fully compatible with directly merging the models into `transformers` down the
line, with first-class support.
Usage:
python vla-scripts/extern/convert_openvla_weights_to_hf.py \
--openvla_model_path_or_id <PATH TO PRISMATIC TRAINING RUN DIR> \
--output_hf_model_local_path <OUTPUT DIR FOR CONVERTED CHECKPOINT>
"""
import json
import os
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Union
import draccus
import timm
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from timm.models.vision_transformer import LayerScale
from transformers import AutoTokenizer
from prismatic.conf import ModelConfig
from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
@dataclass
class HFConvertConfig:
# fmt: off
openvla_model_path_or_id: Union[str, Path] = ( # Path to Pretrained VLA (on disk or HF Hub)
"runs/prism-dinosiglip-224px+mx-oxe-magic-soup-plus+n8+b32+x7"
)
output_hf_model_local_path: Path = Path( # Path to Local Path to save HF model
"hf-convert/openvla-7b"
)
output_hf_model_hub_path: str = "openvla/openvla-7b" # (Optional) Path to HF Hub Path to push
# model to
# HF Hub Credentials (required for Gated Models like LLaMa-2)
hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
def __post_init__(self) -> None:
self.hf_token = self.hf_token.read_text().strip() if isinstance(self.hf_token, Path) else self.hf_token
# fmt: on
# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
def ls_apply_patch(ls_module: LayerScale):
ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
del ls_module.gamma
# === Conversion Constants ===
PROJECTOR_KEY_MAPPING = {
"projector.0.weight": "projector.fc1.weight",
"projector.0.bias": "projector.fc1.bias",
"projector.2.weight": "projector.fc2.weight",
"projector.2.bias": "projector.fc2.bias",
"projector.4.weight": "projector.fc3.weight",
"projector.4.bias": "projector.fc3.bias",
}
def remap_state_dicts_for_hf(
prismatic_vision_backbone_state_dict: Dict[str, torch.Tensor],
projector_state_dict: Dict[str, torch.Tensor],
llm_backbone_state_dict: Dict[str, torch.Tensor],
use_fused_vision_backbone: bool = False,
) -> Dict[str, torch.Tensor]:
"""Iterate through Prismatic component state dictionaries and unify / fix key mapping for HF conversion."""
hf_state_dict = {}
# Iterate through Projector =>> use `PROJECTOR_KEY_MAPPING`
for key, value in projector_state_dict.items():
hf_state_dict[PROJECTOR_KEY_MAPPING[key]] = value
# Iterate through LLM Backbone =>> replace `llm.` with `language_model.`
for key, value in llm_backbone_state_dict.items():
hf_state_dict[key.replace("llm.", "language_model.")] = value
# Iterate through Vision Backbone =>> add "vision_backbone." prefix
if not use_fused_vision_backbone:
for key, value in prismatic_vision_backbone_state_dict.items():
hf_state_dict[key.replace("featurizer.", "vision_backbone.featurizer.")] = value
else:
# Note =>> Assumes that backbones are always DINO + SigLIP...
for key, value in prismatic_vision_backbone_state_dict.items():
if key.startswith("dino_featurizer"):
if key.endswith(".gamma"):
# Handle `LayerScale gamma` =>> DINOv2 only!
key = key.replace(".gamma", ".scale_factor")
hf_state_dict[key.replace("dino_featurizer.", "vision_backbone.featurizer.")] = value
elif key.startswith("siglip_featurizer"):
hf_state_dict[key.replace("siglip_featurizer.", "vision_backbone.fused_featurizer.")] = value
return hf_state_dict
@draccus.wrap()
def convert_openvla_weights_to_hf(cfg: HFConvertConfig) -> None:
print(f"[*] Converting OpenVLA Model `{cfg.openvla_model_path_or_id}` to HF Transformers Format")
torch.set_default_dtype(torch.bfloat16)
# Get `config.json`, 'dataset_statistics.json' and `checkpoint_pt` -- mirrors logic in `prismatic.models.load.py`
if os.path.isdir(cfg.openvla_model_path_or_id):
print(f"[*] Loading from Local Path `{(run_dir := Path(cfg.openvla_model_path_or_id))}`")
config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt"
dataset_statistics_json = run_dir / "dataset_statistics.json"
assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`"
assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"
else:
print(f"[*] Downloading Prismatic Checkpoint from HF Hub :: `TRI-ML/{cfg.openvla_model_path_or_id}`")
config_json = hf_hub_download("openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/config.json")
checkpoint_pt = hf_hub_download(
"openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/checkpoints/latest-checkpoint.pt"
)
dataset_statistics_json = hf_hub_download(
"openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/dataset_statistics.json"
)
# Load "Native" Config JSON =>> Create LLM Config & Instantiate Tokenizer
with open(config_json, "r") as f:
vla_cfg = json.load(f)["vla"]
prismatic_config = ModelConfig.get_choice_class(vla_cfg["base_vlm"])().__dict__
# Load Normalization Statistics
with open(dataset_statistics_json, "r") as f:
norm_stats = json.load(f)
# Create HF OpenVLAConfig (`transformers.PretrainedConfig`)
hf_config = OpenVLAConfig(
vision_backbone_id=prismatic_config["vision_backbone_id"],
llm_backbone_id=prismatic_config["llm_backbone_id"],
arch_specifier=prismatic_config["arch_specifier"],
image_resize_strategy=prismatic_config["image_resize_strategy"],
llm_max_length=prismatic_config["llm_max_length"],
torch_dtype=torch.bfloat16,
norm_stats=norm_stats,
)
# Instantiate & Add Pad to Tokenizer =>> following `prismatic.models.materialize.get_llm_backbone_and_tokenizer`
# TODO (siddk) :: Implement batched generation -- in which case this should set `padding_side = "left"`!
print("[*] Instantiating and Patching Tokenizer, LLM Config")
tokenizer = AutoTokenizer.from_pretrained(
hf_config.hf_llm_id, model_max_length=hf_config.llm_max_length, token=cfg.hf_token, padding_side="right"
)
tokenizer.add_special_tokens({"pad_token": "<PAD>"})
tokenizer.init_kwargs.pop("add_prefix_space", None) # Pop to prevent unnecessary warning on reload...
assert tokenizer.pad_token_id == hf_config.pad_token_id, "Incorrect Pad Token ID!"
assert len(tokenizer) > hf_config.text_config.vocab_size, "Tokenizer vocabulary must be larger than LLM vocabulary!"
# Patch LLM Config in `hf_config` with vocab_size (+ `hf_config.pad_to_multiple_of`), pad_token_id + validate
hf_config.text_config.vocab_size += hf_config.pad_to_multiple_of
hf_config.text_config.pad_token_id = hf_config.pad_token_id
hf_config.text_config.torch_dtype = torch.bfloat16
assert hf_config.text_config.use_cache, "LLM config `use_cache` should be True for inference (set default)!"
# Create Vision Backbone & Transform =>> following `prismatic.models.materialize.get_vision_backbone_and_transform`
# =>> Deviates a bit from existing code; as such, explicitly tested in `tests/test_image_transforms.py`
print("[*] Loading TIMM Vision Backbone(s) and Image Transform(s) =>> Initializing PrismaticImageProcessor")
input_sizes, interpolations, means, stds = [], [], [], []
for idx, timm_model_id in enumerate(hf_config.timm_model_ids):
timm_vision_backbone = timm.create_model(
timm_model_id,
pretrained=True,
num_classes=0,
img_size=hf_config.image_sizes[idx],
act_layer=hf_config.timm_override_act_layers[idx],
)
# Get Per-Backbone Image Processing
data_cfg = timm.data.resolve_model_data_config(timm_vision_backbone)
input_sizes.append((3, hf_config.image_sizes[idx], hf_config.image_sizes[idx]))
interpolations.append(data_cfg["interpolation"])
means.append(data_cfg["mean"])
stds.append(data_cfg["std"])
# Patch `LayerScale` because of HF annoying `fix_key` overwrite...
for module in timm_vision_backbone.modules():
if isinstance(module, LayerScale):
ls_apply_patch(module)
# Create PrismaticImageProcessor (`transformers.ImageProcessingMixin`)
hf_image_processor = PrismaticImageProcessor(
use_fused_vision_backbone=hf_config.use_fused_vision_backbone,
image_resize_strategy=hf_config.image_resize_strategy,
input_sizes=input_sizes,
interpolations=interpolations,
means=means,
stds=stds,
)
# Create top-level PrismaticProcessor (`transformers.ProcessorMixin` =>> enables registry w/ AutoProcessor)
print("[*] Creating PrismaticProcessor Instance from Tokenizer and PrismaticImageProcessor")
hf_processor = PrismaticProcessor(image_processor=hf_image_processor, tokenizer=tokenizer)
# Load Prismatic Model State Dictionary (in preparation for conversion)
print("[*] Loading Prismatic VLM State Dictionary from Checkpoint")
model_state_dict = torch.load(checkpoint_pt, map_location="cpu")["model"]
assert ("downsampler" not in model_state_dict) or (len(model_state_dict["downsampler"]) == 0), "Downsampler?"
assert all([k in model_state_dict for k in ["vision_backbone", "projector", "llm_backbone"]]), "Missing keys!"
# Convert
print("[*] Running Conversion")
converted_state_dict = remap_state_dicts_for_hf(
model_state_dict["vision_backbone"],
model_state_dict["projector"],
model_state_dict["llm_backbone"],
use_fused_vision_backbone=hf_config.use_fused_vision_backbone,
)
# Create PrismaticForConditionalGeneration =>> Note that we can't initialize on `meta` device because TIMM
print("[*] Building (Randomly Initialized) Model =>> OpenVLAForActionPrediction")
hf_model = OpenVLAForActionPrediction(hf_config)
hf_model.load_state_dict(converted_state_dict, strict=True, assign=True)
# Cast Model to BF16 before Saving
hf_model.to(torch.bfloat16)
# Save Pretrained Versions to Local Path
print("[*] Saving Model & Processor to Local Path")
hf_model.save_pretrained(cfg.output_hf_model_local_path, max_shard_size="7GB")
hf_image_processor.save_pretrained(cfg.output_hf_model_local_path)
hf_processor.save_pretrained(cfg.output_hf_model_local_path)
# Copy `dataset_statistics.json` File to Converted Checkpoint Directory
output_dataset_statistics_json = cfg.output_hf_model_local_path / "dataset_statistics.json"
shutil.copyfile(dataset_statistics_json, output_dataset_statistics_json)
print(f"[*] Saving Complete! Saved converted checkpoint to: {cfg.output_hf_model_local_path}")
#####################################################################################
# Optional: Push Model to Hugging Face Hub
#####################################################################################
# # Register AutoClasses
# OpenVLAConfig.register_for_auto_class()
# PrismaticImageProcessor.register_for_auto_class("AutoImageProcessor")
# PrismaticProcessor.register_for_auto_class("AutoProcessor")
# OpenVLAForActionPrediction.register_for_auto_class("AutoModelForVision2Seq")
# # Push to HF Hub
# print("[*] Pushing Model & Processor to HF Hub")
# hf_config.push_to_hub(cfg.output_hf_model_hub_path)
# hf_model.push_to_hub(cfg.output_hf_model_hub_path, max_shard_size="7GB")
# hf_image_processor.push_to_hub(cfg.output_hf_model_hub_path)
# hf_processor.push_to_hub(cfg.output_hf_model_hub_path)
if __name__ == "__main__":
convert_openvla_weights_to_hf()