-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llava / Prismatic with LoRA #16
Comments
Hi @gorjanradevski, The weights and training logs are released by the authors in huggingface (https://huggingface.co/TRI-ML/prismatic-vlms/tree/main) As for LoRA, I tried a simple implementation with SigLIP + Phi-2 + LoRA here (https://github.com/NMS05/Prismatic-SigLIP-Phi2-LoRA-VLM). Thanks, |
Hi, @NMS05. Thanks for your great implementation, which helps me a lot. |
Awesome implementation @NMS05! Would love to add first-class LoRA support to this repo - would you be willing to contribute a PR? |
Hi @NMS05, I would like to know if you have evaluated the finetuned model. I download and evaluate
BTW, it seems the lora results are worse than full fine-tune results on phi-2.
|
Hmm so the evaluation code was able to fully reproduce the paper results for all existing models (as of last week); not sure if there's something different that needs to be changed for Phi-2 models (e.g., tokenization/generation parameters). I'll try to take a look when I can (or might just try training a Phi-2 model). But would be great if @NMS05 could weigh in on their results! |
Hi @zeyuanyin and @siddk , Resolved the previous issue. The problems were 1) LoRA had fewer params and 2) full-finetune is not optimal. Here is the latest model with improved design choices NMS05/DinoV2-SigLIP-Phi3-LoRA-VLM. Evaluation using prismatic-eval. Results (Acc) for SLIM splits are as,
Thanks, |
This is an (intentionally) massive PR that completely refactors the base Prismatic VLM codebase following TRI-ML#15. **Please do not try to review this entire PR for your sanity; instead, see "proof-of-correctness" below**. All Prismatic models are now instances of HuggingFace `transformers.AutoModelForVis2Seq` and have native compatibility with all external HF libraries (e.g., `peft`, `accelerate`, `bitsandbytes`), and can easily be integrated with existing training frameworks (HF Trainer, PyTorch Lightning). Because this PR represents what I feel is the most "stable" version of the Prismatic codebase, I've bumped the major version to `v1.0.0`. Additionally, this PR implements: - Support for batched generation (speeds up decoding) - Conversion scripts for "v0" model weights, with all configs/models/processors pushed to [huggingface.co/TRI-ML](https://huggingface.co/collections/TRI-ML/prismatic-vlms-66857a7c64b6a6b6fbc84ea4) - Most of the "popular" Prismatic checkpoints have already been converted + uploaded as **private** models; the remaining models will be converted iteratively. - Simplified interactive generation script at `scripts/generate.py` - Basic validation + tests. CC @blake-wulfe-tri @charlesrichter-tri @jmercat @paarthshah-tri @jensen-gao-tri for visibility. Resolves TRI-ML#15 --- **Proof-of-Correctness**: Rather than review all files, I structured this PR as a series of commits that uphold two invariants: - **Fully (Bitwise) Reproducible Training** - For two model configs, assert that running 64 gradient steps across 8 GPUs results in the *exact same loss curves and performance*. - **Deterministic Generation Output** - When loading (converted) checkpoints, assert that generation output is exactly identical (plus/minus some CUDA non-determinism). **Commits** (for parsing W&B loss curves below): - `base` -- Gray line, represents the original loss curve for training the original models (`siglip-224px+7b` and `prism-dinosiglip-224px-controlled`) from several months ago. - `#fd2a0e4` -- Purple line, represents the latest commit on `vlm-core` (upstream branch); sanity check to make sure nothing has changed in the time since the original models were trained. - [NEW] `#fc78732` -- Green line, implements necessary changes to Prismatic base class to prepare (unify API) for full refactor (adds `<image>` token, batched generation, restructures `forward()` pass to remove dependence on `multimodal_indices`). - [NEW] `#b322374` -- Red line, **full HF "parallel" implementation** (side-by-side with original Prismatic code). Adds new preprocessing objects following HF convention, defines `PrismaticForVision2Seq` core VLM model, conversion scripts, `hf_pretrain.py` and `hf_generate.py`. - [NEW] `#b63f704` -- Orange line, finalizes refactor. Purges all "old" Prismatic code, makes `hf_*` files first-class citizens. Refactors README and installation instructions. *Note*: The `siglip-224px` training runs are bitwise identical across all above commits; the `prism-dinosiglip-224px` is pretty much the same, modulo a slight difference in randomness that stems from the new "fused backbone" API (just affects weight initialization in a small way). <img width="2427" alt="reproducible-loss-curves" src="https://github.com/TRI-ML/prismatic-dev/assets/126100644/b35e877e-d7c0-4c24-b44d-5373720b4a67">
Great work! One quick question: in the paper you've reproduced the results from Llava. Additionally, for the Prismatic models experiments, you are fine-tuning the whole LM. I'm wondering did you try using LoRA in the LM for the Llava experiments as well as for the Prismatic models experiments? More specifically, why did you opt for the fine-tuning the entire LM instead of training LoRA modules?
Additionally, have you stored any training logs; e.g., weights & biases?
Thanks!
The text was updated successfully, but these errors were encountered: