Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HF Integration #89

Open
sedrick-keh-tri opened this issue Nov 14, 2023 · 3 comments
Open

HF Integration #89

sedrick-keh-tri opened this issue Nov 14, 2023 · 3 comments
Assignees
Labels
enhancement New feature or request

Comments

@sedrick-keh-tri
Copy link
Collaborator

Hi OpenLM team! Is there interest in making OpenLM models loadable using just HF?

I see some OpenLM models up on HF, but they are not readily loadable using HF. The proposed changes would involve adding an OpenLM class on HF, similar to how other models are hosted on HF (e.g. Mistral).

For comparison, both #54 and #20 allow saved OpenLM models to be loaded using HF functions, but under the hood it still calls OpenLM functions and requires the OpenLM library downloaded locally. What I'm thinking is basically porting OpenLM's model.py into the transformers library itself, so that OpenLM trained models can be shared and loaded more easily. I can work on this if you think it's a good idea.

@mitchellnw @sagadre @achalddave

@achalddave
Copy link
Collaborator

Discussed with @ruixin31 and @sedrick-keh-tri offline, summarizing here: this is generally something we'd like to have. The only question is one of timing and priority: we're improving openlm rapidly (e.g. #74), so we may want to put off integrating into HF to reduce maintenance effort. @sedrick-keh-tri will look into this and add it if it's easy, otherwise we'll punt until later this year.

@achalddave achalddave added the enhancement New feature or request label Nov 15, 2023
@sedrick-keh-tri
Copy link
Collaborator Author

sedrick-keh-tri commented Dec 15, 2023

Implemented here: https://github.com/sedrick-keh-tri/transformers

Steps:

  1. Install HF transformers from the repo above instead of the usual HF transformers.
  2. You can now use AutoModelForCausalLM to load the model. (Note: Requires CUDA. Does not work for CPU.)
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("TRI-ML/openlm-1b")
model = AutoModelForCausalLM.from_pretrained("TRI-ML/openlm-1b").to("cuda")
a = tokenizer("hi", return_tensors="pt")
out = model.generate(a['input_ids'].to("cuda"), max_length=60, do_sample=False)
print(tokenizer.decode(out[0]))
  1. Another example: 7B code model
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("TRI-ML/openlm-7b-code")
model = AutoModelForCausalLM.from_pretrained("TRI-ML/openlm-7b-code").to("cuda")
a = tokenizer("def find_most_common(arr)", return_tensors="pt")
out = model.generate(a['input_ids'].to("cuda"), max_length=60, do_sample=False)
print(tokenizer.decode(out[0]))

Note: This is an unofficial implementation, so we aren't merging it with the HF transformers repo right now. If OpenLM wants to eventually release models, I would be in favor of integrating with HF then.

@sedrick-keh-tri
Copy link
Collaborator Author

Note to self (and to future OpenLM folks who want to work on this):

  • OpenLM uses xformer attention, which means when x is fed into each layer (open_lm Block), the shape of x is (bsz, seq_len, hidden_dim).
    def forward(self, x, past_key_value=None, use_cache=False):
    Meanwhile, in HF the shape of x is (bsz, 1, hidden_dim). As a result, I had to also cache Q rather than just caching K and V (see this commit). Caching Q allows us to reconstruct the Q when we pass it to xformers
  • There are instances where the x.shape instead of (bsz, 1, hidden_dim) suddenly becomes (bsz, >1, hidden_dim). Not 100% sure why this happens, but to stop it, I just do x = x[:, -1, :].

Testing:

  • I've tested the above implementation and compared with the scripts/generate.py file in OpenLM. For greedy decoding, the results are identical. I also compared the logits to make sure they are the same. For sampling decoding, I didn't do thorough tests, but I did a quantitative look at a few results, and they looked fine.

Some other things we want to consider/fix in for future release:

  • Model formatting -- The OpenLM saved model format isn't directly compatible with HF. If you notice the TRI-LM model, there is a .pt file, which is the OpenLM file, and there is a separate pytorch_model.bin, which is the OpenLM formatted model. The main difference is that the pytorch_model.bin is only the state_dict and some of the state dict keys are renamed from {name} to {model.name}. Eventually we might want to create a converter script or just save two copies of the model.
  • Proper documentation and code cleanliness -- I just copied the modeling.py files from LLaMA/Mistral, and edited accordingly, so there are a bunch of unused variables, outdated docstrings, etc.
  • Completeness -- OpenLM is continuously being updated so I just took a minimal approach and just copied what was needed. e.g. for the norms, I only used gain_only_layer_norm since that was what we were mostly using. Eventually we definitely want to copy everything over.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants