Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenELM support #7359

Merged
merged 16 commits into from
Jul 4, 2024
Merged

OpenELM support #7359

merged 16 commits into from
Jul 4, 2024

Conversation

icecream95
Copy link
Contributor

Fixes: #6868.

Thanks to @joshcarp for an initial try at doing this (#6986), it was very helpful as a source to copy-paste from and check against.

Currently a bunch of the configuration is hardcoded into llama.cpp, so only the 270M model works at this point.

The ffn_up tensors in the converted model are actually concatenations of ffn_gate and ffn_up, perhaps the conversion script should separate them out?

The 270M model is impressively fast, and works fine for generation, but "Chat" mode in ./server doesn't really work well. Perhaps that's just because it hasn't been finetuned for that? I'm not really sure.

@icecream95 icecream95 marked this pull request as draft May 18, 2024 07:53
@icecream95 icecream95 changed the title Draft: OpenELM support OpenELM support May 18, 2024
@mofosyne mofosyne added model Model specific Review Complexity : High Generally require indepth knowledge of LLMs or GPUs labels May 18, 2024
@icecream95
Copy link
Contributor Author

It looks like context shift currently causes crashes, because build_k_shift uses the false number of heads in the .gguf.

A few other functions seem like they will be broken as well.

@github-actions github-actions bot added the python python script changes label May 18, 2024
Copy link
Contributor

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 512 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=9168.82ms p(95)=23246.02ms fails=, finish reason: stop=444 truncated=68
  • Prompt processing (pp): avg=111.45tk/s p(95)=515.19tk/s
  • Token generation (tg): avg=31.54tk/s p(95)=46.05tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=openelm commit=60b2e1b9c529f74f5bf881b05a6247ff6f58a71c

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 512 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1716104319 --> 1716104943
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 452.86, 452.86, 452.86, 452.86, 452.86, 896.59, 896.59, 896.59, 896.59, 896.59, 871.99, 871.99, 871.99, 871.99, 871.99, 877.49, 877.49, 877.49, 877.49, 877.49, 882.82, 882.82, 882.82, 882.82, 882.82, 887.52, 887.52, 887.52, 887.52, 887.52, 875.85, 875.85, 875.85, 875.85, 875.85, 890.18, 890.18, 890.18, 890.18, 890.18, 891.96, 891.96, 891.96, 891.96, 891.96, 903.52, 903.52, 903.52, 903.52, 903.52, 927.23, 927.23, 927.23, 927.23, 927.23, 913.62, 913.62, 913.62, 913.62, 913.62, 910.92, 910.92, 910.92, 910.92, 910.92, 926.37, 926.37, 926.37, 926.37, 926.37, 909.03, 909.03, 909.03, 909.03, 909.03, 912.02, 912.02, 912.02, 912.02, 912.02, 909.29, 909.29, 909.29, 909.29, 909.29, 892.43, 892.43, 892.43, 892.43, 892.43, 893.16, 893.16, 893.16, 893.16, 893.16, 887.09, 887.09, 887.09, 887.09, 887.09, 892.09, 892.09, 892.09, 892.09, 892.09, 891.43, 891.43, 891.43, 891.43, 891.43, 892.08, 892.08, 892.08, 892.08, 892.08, 886.73, 886.73, 886.73, 886.73, 886.73, 883.02, 883.02, 883.02, 883.02, 883.02, 882.49, 882.49, 882.49, 882.49, 882.49, 895.2, 895.2, 895.2, 895.2, 895.2, 891.89, 891.89, 891.89, 891.89, 891.89, 890.09, 890.09, 890.09, 890.09, 890.09, 889.3, 889.3, 889.3, 889.3, 889.3, 893.44, 893.44, 893.44, 893.44, 893.44, 892.14, 892.14, 892.14, 892.14, 892.14, 890.24, 890.24, 890.24, 890.24, 890.24, 893.23, 893.23, 893.23, 893.23, 893.23, 900.85, 900.85, 900.85, 900.85, 900.85, 901.77, 901.77, 901.77, 901.77, 901.77, 904.26, 904.26, 904.26, 904.26, 904.26, 907.47, 907.47, 907.47, 907.47, 907.47, 905.15, 905.15, 905.15, 905.15, 905.15, 901.81, 901.81, 901.81, 901.81, 901.81, 904.25, 904.25, 904.25, 904.25, 904.25, 905.24, 905.24, 905.24, 905.24, 905.24, 905.87, 905.87, 905.87, 905.87, 905.87, 909.23, 909.23, 909.23, 909.23, 909.23, 905.0, 905.0, 905.0, 905.0, 905.0, 905.48, 905.48, 905.48, 905.48, 905.48, 903.32, 903.32, 903.32, 903.32, 903.32, 901.1, 901.1, 901.1, 901.1, 901.1, 896.0, 896.0, 896.0, 896.0, 896.0, 899.86, 899.86, 899.86, 899.86, 899.86, 902.19, 902.19, 902.19, 902.19, 902.19, 901.36, 901.36, 901.36, 901.36, 901.36, 904.94, 904.94, 904.94, 904.94, 904.94, 903.99, 903.99, 903.99, 903.99, 903.99, 903.29, 903.29, 903.29, 903.29, 903.29, 906.62, 906.62, 906.62, 906.62, 906.62, 907.12, 907.12, 907.12, 907.12, 907.12, 910.26, 910.26, 910.26, 910.26, 910.26, 909.96, 909.96, 909.96, 909.96, 909.96, 909.5, 909.5, 909.5, 909.5, 909.5, 908.99, 908.99, 908.99]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 512 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1716104319 --> 1716104943
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 38.8, 38.8, 38.8, 38.8, 38.8, 31.58, 31.58, 31.58, 31.58, 31.58, 28.54, 28.54, 28.54, 28.54, 28.54, 28.83, 28.83, 28.83, 28.83, 28.83, 30.45, 30.45, 30.45, 30.45, 30.45, 30.37, 30.37, 30.37, 30.37, 30.37, 31.59, 31.59, 31.59, 31.59, 31.59, 32.59, 32.59, 32.59, 32.59, 32.59, 33.03, 33.03, 33.03, 33.03, 33.03, 33.78, 33.78, 33.78, 33.78, 33.78, 33.73, 33.73, 33.73, 33.73, 33.73, 33.98, 33.98, 33.98, 33.98, 33.98, 33.23, 33.23, 33.23, 33.23, 33.23, 33.23, 33.23, 33.23, 33.23, 33.23, 31.24, 31.24, 31.24, 31.24, 31.24, 30.47, 30.47, 30.47, 30.47, 30.47, 30.72, 30.72, 30.72, 30.72, 30.72, 30.84, 30.84, 30.84, 30.84, 30.84, 30.64, 30.64, 30.64, 30.64, 30.64, 30.15, 30.15, 30.15, 30.15, 30.15, 29.96, 29.96, 29.96, 29.96, 29.96, 29.7, 29.7, 29.7, 29.7, 29.7, 29.91, 29.91, 29.91, 29.91, 29.91, 29.94, 29.94, 29.94, 29.94, 29.94, 29.73, 29.73, 29.73, 29.73, 29.73, 30.07, 30.07, 30.07, 30.07, 30.07, 30.03, 30.03, 30.03, 30.03, 30.03, 30.0, 30.0, 30.0, 30.0, 30.0, 29.92, 29.92, 29.92, 29.92, 29.92, 30.08, 30.08, 30.08, 30.08, 30.08, 30.26, 30.26, 30.26, 30.26, 30.26, 30.4, 30.4, 30.4, 30.4, 30.4, 30.5, 30.5, 30.5, 30.5, 30.5, 30.57, 30.57, 30.57, 30.57, 30.57, 30.55, 30.55, 30.55, 30.55, 30.55, 30.48, 30.48, 30.48, 30.48, 30.48, 29.94, 29.94, 29.94, 29.94, 29.94, 29.86, 29.86, 29.86, 29.86, 29.86, 29.53, 29.53, 29.53, 29.53, 29.53, 29.36, 29.36, 29.36, 29.36, 29.36, 29.49, 29.49, 29.49, 29.49, 29.49, 29.59, 29.59, 29.59, 29.59, 29.59, 29.71, 29.71, 29.71, 29.71, 29.71, 29.84, 29.84, 29.84, 29.84, 29.84, 29.84, 29.84, 29.84, 29.84, 29.84, 29.74, 29.74, 29.74, 29.74, 29.74, 29.42, 29.42, 29.42, 29.42, 29.42, 29.16, 29.16, 29.16, 29.16, 29.16, 27.96, 27.96, 27.96, 27.96, 27.96, 28.02, 28.02, 28.02, 28.02, 28.02, 28.07, 28.07, 28.07, 28.07, 28.07, 28.26, 28.26, 28.26, 28.26, 28.26, 28.27, 28.27, 28.27, 28.27, 28.27, 28.3, 28.3, 28.3, 28.3, 28.3, 28.34, 28.34, 28.34, 28.34, 28.34, 28.31, 28.31, 28.31, 28.31, 28.31, 28.3, 28.3, 28.3, 28.3, 28.3, 28.22, 28.22, 28.22, 28.22, 28.22, 28.27, 28.27, 28.27, 28.27, 28.27, 28.34, 28.34, 28.34, 28.34, 28.34, 28.41, 28.41, 28.41]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 512 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1716104319 --> 1716104943
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.17, 0.17, 0.17, 0.17, 0.17, 0.43, 0.43, 0.43, 0.43, 0.43, 0.19, 0.19, 0.19, 0.19, 0.19, 0.11, 0.11, 0.11, 0.11, 0.11, 0.19, 0.19, 0.19, 0.19, 0.19, 0.23, 0.23, 0.23, 0.23, 0.23, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.09, 0.09, 0.09, 0.09, 0.09, 0.15, 0.15, 0.15, 0.15, 0.15, 0.27, 0.27, 0.27, 0.27, 0.27, 0.24, 0.24, 0.24, 0.24, 0.24, 0.37, 0.37, 0.37, 0.37, 0.37, 0.3, 0.3, 0.3, 0.3, 0.3, 0.17, 0.17, 0.17, 0.17, 0.17, 0.11, 0.11, 0.11, 0.11, 0.11, 0.29, 0.29, 0.29, 0.29, 0.29, 0.32, 0.32, 0.32, 0.32, 0.32, 0.18, 0.18, 0.18, 0.18, 0.18, 0.23, 0.23, 0.23, 0.23, 0.23, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.33, 0.33, 0.33, 0.33, 0.33, 0.09, 0.09, 0.09, 0.09, 0.09, 0.12, 0.12, 0.12, 0.12, 0.12, 0.23, 0.23, 0.23, 0.23, 0.23, 0.22, 0.22, 0.22, 0.22, 0.22, 0.11, 0.11, 0.11, 0.11, 0.11, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.26, 0.26, 0.26, 0.26, 0.26, 0.35, 0.35, 0.35, 0.35, 0.35, 0.29, 0.29, 0.29, 0.29, 0.29, 0.45, 0.45, 0.45, 0.45, 0.45, 0.24, 0.24, 0.24, 0.24, 0.24, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.26, 0.26, 0.26, 0.26, 0.26, 0.44, 0.44, 0.44, 0.44, 0.44, 0.54, 0.54, 0.54, 0.54, 0.54, 0.66, 0.66, 0.66, 0.66, 0.66, 0.58, 0.58, 0.58, 0.58, 0.58, 0.18, 0.18, 0.18, 0.18, 0.18, 0.19, 0.19, 0.19, 0.19, 0.19, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.19, 0.19, 0.19, 0.19, 0.19, 0.26, 0.26, 0.26, 0.26, 0.26, 0.17, 0.17, 0.17, 0.17, 0.17, 0.23, 0.23, 0.23, 0.23, 0.23, 0.15, 0.15, 0.15, 0.15, 0.15, 0.25, 0.25, 0.25, 0.25, 0.25, 0.22, 0.22, 0.22, 0.22, 0.22, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 512 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1716104319 --> 1716104943
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0]
                    
Loading

@ggerganov
Copy link
Owner

ggerganov commented May 19, 2024

The ffn_up tensors in the converted model are actually concatenations of ffn_gate and ffn_up, perhaps the conversion script should separate them out?

We already have this logic for the Refact models:

elif name == f"transformer.h.{bid}.mlp.gate_up_proj.weight":
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), data_torch[:ff_dim]))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), data_torch[ff_dim:]))

You can try to reuse it in a similar way for OpenELM

The 270M model is impressively fast, and works fine for generation, but "Chat" mode in ./server doesn't really work well. Perhaps that's just because it hasn't been finetuned for that? I'm not really sure.

Have you ran perplexity runs with this model?

It looks like context shift currently causes crashes, because build_k_shift uses the false number of heads in the .gguf.

A few other functions seem like they will be broken as well.

We'll probably need to generalize the head number to be determined per layer. Do you need to some assistance with that?

@icecream95
Copy link
Contributor Author

I've been quite tired recently, so it might be a while before I'm able to come back to this.

I see that @jart's #7445 has already been merged with similar modifications to llama_model_type_name, but I think git merge will do the right thing here without needing to change that commit.

@joshcarp
Copy link

@icecream95 might jump back on this cause im curious to where i got stuck

@compilade
Copy link
Collaborator

We'll probably need to generalize the head number to be determined per layer.

I'll be working on porting the variable GQA stuff written for Jamba (in #7531) to OpenELM in the next days, which will both help with making OpenELM work and with reducing the size of the Jamba PR (ref: #7531 (comment)).

I think the ffn_multipliers could also be simplified similarly by taking care of them in the convert script while the resulting FFN sizes could be directly part of {arch}.feed_forward_length as an array of integers.

@icecream95 do you mind if I push to this branch, or would you prefer that I use another one? In the meantime I'll work on a local branch.

* llama : add variable GQA and variable FFN sizes

Some metadata keys can now also be arrays to support setting
their value per-layer for models like OpenELM.
@compilade
Copy link
Collaborator

compilade commented Jul 1, 2024

I've got this working for the 270M model, the 1.1B model, and the 450M model. Getting reasonable perplexity. I did not yet test the 3B model, but will try once it finishes downloading. EDIT: the 3B model works too!

According to https://github.com/apple/corenet/blob/2261885b6696950aaf481a862e8926921ef1a067/projects/openelm/instruction_tuning/openelm-instruct.yaml#L7, the prompt template is zephyr, but it seems to perform extremely poorly in practice, while an alpaca-like template (with ### Instruction:) seems to work better for some reason. Even no template at all works pretty well from what I've seen.

This might be because the tokenizer doesn't have tokens for <|user|> and <|assistant|>, since it's using Llama-2's tokenizer.

Session file saving support isn't yet done, but everything else should pretty much work.

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be because the tokenizer doesn't have tokens for <|user|> and <|assistant|>, since it's using Llama-2's tokenizer.

Do we know what is the reason for the tokenizer confusion? It seems the model is unusable without the correct tokenizer

src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated
Comment on lines 2115 to 2118
// TODO: find a more compact way to add more per-layer hyper-parameters
std::vector<int32_t> n_head_vec;
std::vector<int32_t> n_head_kv_vec;
std::vector<int32_t> n_ff_vec;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect negative values here? If not, we should change to uint32_t

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've used int32_t here due to a limitation of GGUFWriter, which doesn't allow choosing the specific type of integer arrays and defaults to INT32. And also because get_arr in src/llama.cpp can only load arrays into vectors which have exactly the same type as the loaded array.

I'll see if this can be fixed, but it will likely require using Numpy types (e.g. np.uint32) in GGUFValueType.get_type() in gguf-py/gguf/constants.py.

src/llama.cpp Outdated
@@ -2173,18 +2202,53 @@ struct llama_hparams {
return false;
}

uint32_t n_gqa() const {
// TODO: deduplicate per-layer getters
uint32_t n_head_l(uint32_t layer) const {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove hparams.n_head and rename this to:

uint32_t n_head(uint32_t layer) const {

This way everywhere we reference hparams.n_head, we will use hparams.n_head().
We can do this for the rest of the parameters that make sense to be per-layer, the main goal being to avoid having duplicated information in hparams.n_head and hparams.n_head_vec.

We can do this change in a follow-up refactoring PR, together with finding a more compact way to have per-layer hparams (maybe via new struct llama_hparams_layer)

Copy link
Collaborator

@compilade compilade Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the main goal being to avoid having duplicated information in hparams.n_head and hparams.n_head_vec.

The advantage of duplicating the information is to have a fallback value, and it also avoids having to allocate the vectors when there's only one value for all the layers. The disadvantage is the possible confusion of which value to use.

We can do this change in a follow-up refactoring PR, together with finding a more compact way to have per-layer hparams (maybe via new struct llama_hparams_layer)

Yes a follow-up refactor seems appropriate. The goal with this initial implementation was to minimize having to change the accesses to hparams.n_head, hparams.n_head_kv, and hparams.n_ff all over the place, and also to re-use the nice abstraction which llama_model_loader.get_arr seemed to be (it was introduced in #7225, but isn't used by anything on master).

Although a possible huge problem with using std::vector in llama_hparams is that it's no longer trivially copyable as assumed by llama_state_save_file_internal and llama_state_load_file_internal.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although a possible huge problem with using std::vector in llama_hparams is that it's no longer trivially copyable as assumed by llama_state_save_file_internal and llama_state_load_file_internal.

Hm, correct. We could switch to uint32 n_head_vec[LLAMA_MAX_LAYERS]; instead and use n_layer as the container size?

@ggerganov ggerganov marked this pull request as ready for review July 1, 2024 11:19
@compilade
Copy link
Collaborator

Do we know what is the reason for the tokenizer confusion? It seems the model is unusable without the correct tokenizer

I'm not sure. They clearly say they used the Llama 2 tokenizer, while they also suggest to use it verbatim. This doesn't feel right, but this is apparently what they have done.

From playing around with the OpenELM Instruct models with only the BOS token as a prompt and different seeds at --temp 1, they show no sign of having been trained on a chat template.

When asking questions or instructions, at least they seem to follow them, but the answers are never short (which can be useful in some cases).

It's a bit unfortunate they did not add special tokens (or even re-use the BOS and EOS token in a chatml-style way, like the bagel models do), which makes their instruct models worse than they could have been. Although this could likely be fixed by better fine-tuning.

Co-authored-by: Georgi Gerganov <[email protected]>
@sqzhang-jeremy
Copy link

I've got this working for the 270M model, the 1.1B model, and the 450M model. Getting reasonable perplexity. I did not yet test the 3B model, but will try once it finishes downloading. EDIT: the 3B model works too!

According to https://github.com/apple/corenet/blob/2261885b6696950aaf481a862e8926921ef1a067/projects/openelm/instruction_tuning/openelm-instruct.yaml#L7, the prompt template is zephyr, but it seems to perform extremely poorly in practice, while an alpaca-like template (with ### Instruction:) seems to work better for some reason. Even no template at all works pretty well from what I've seen.

This might be because the tokenizer doesn't have tokens for <|user|> and <|assistant|>, since it's using Llama-2's tokenizer.

Session file saving support isn't yet done, but everything else should pretty much work.

Question: How did you get the .gguf format OpenELM models? Could you please share the experience? Thanks!

src/llama.cpp Outdated Show resolved Hide resolved
@ggerganov
Copy link
Owner

@compilade Shall we merge after green CI?

Copy link
Collaborator

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we merge after green CI?

Yes, this feels pretty much ready.

src/llama.cpp Outdated
Comment on lines 108 to 109
#define LLAMA_MAX_LAYERS 256
#define LLAMA_MAX_EXPERTS 160
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to note from which models these values come (or maybe not; it's also fine as-is). The 160 experts come from DeepSeekV2, but I don't know about models with 256 layers.

There are some models with more than 128 layers, like Goliath-120B (137 layers), and TheProfessor-155B (180 layers), so I assume 256 is there for some future-proofing, being the next power of 2.

(PaLM 540B (even though not open-weights) has 118 layers, so Llama-3-405B might still be fine with this limit. The models with more layers are usually merges.)

@ggerganov ggerganov merged commit d7fd29f into ggerganov:master Jul 4, 2024
12 checks passed
@ngxson
Copy link
Collaborator

ngxson commented Jul 4, 2024

I've been trying some chat templates, but there is no sign that it is trained for chat.

Apparently, apple did not make it clear if "Instruction" means single turn Q&A or multi-turn conversation. Hopefully someone will release a fine tune with proper chat template support.

Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jul 4, 2024
* Initial OpenELM support (270M only so far)

* Fill out missing entries in llama_model_type_name

* fixup! Initial OpenELM support (270M only so far)

Fix formatting

* llama : support all OpenELM models

* llama : add variable GQA and variable FFN sizes

Some metadata keys can now also be arrays to support setting
their value per-layer for models like OpenELM.

* llama : minor spacing changes

Co-authored-by: Georgi Gerganov <[email protected]>

* llama : use std::array for per-layer hparams

* llama : fix save/load state

* llama : do not print hparams for vocab-only models

* llama : handle n_head == 0

* llama : use const ref for print_f and fix division by zero

* llama : fix t5 uses of n_head and n_ff

* llama : minor comment

---------

Co-authored-by: Francis Couture-Harpin <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jul 5, 2024
* Initial OpenELM support (270M only so far)

* Fill out missing entries in llama_model_type_name

* fixup! Initial OpenELM support (270M only so far)

Fix formatting

* llama : support all OpenELM models

* llama : add variable GQA and variable FFN sizes

Some metadata keys can now also be arrays to support setting
their value per-layer for models like OpenELM.

* llama : minor spacing changes

Co-authored-by: Georgi Gerganov <[email protected]>

* llama : use std::array for per-layer hparams

* llama : fix save/load state

* llama : do not print hparams for vocab-only models

* llama : handle n_head == 0

* llama : use const ref for print_f and fix division by zero

* llama : fix t5 uses of n_head and n_ff

* llama : minor comment

---------

Co-authored-by: Francis Couture-Harpin <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jul 6, 2024
* Initial OpenELM support (270M only so far)

* Fill out missing entries in llama_model_type_name

* fixup! Initial OpenELM support (270M only so far)

Fix formatting

* llama : support all OpenELM models

* llama : add variable GQA and variable FFN sizes

Some metadata keys can now also be arrays to support setting
their value per-layer for models like OpenELM.

* llama : minor spacing changes

Co-authored-by: Georgi Gerganov <[email protected]>

* llama : use std::array for per-layer hparams

* llama : fix save/load state

* llama : do not print hparams for vocab-only models

* llama : handle n_head == 0

* llama : use const ref for print_f and fix division by zero

* llama : fix t5 uses of n_head and n_ff

* llama : minor comment

---------

Co-authored-by: Francis Couture-Harpin <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
@sqzhang-jeremy
Copy link

sqzhang-jeremy commented Jul 6, 2024

Hi there!

I tried to deploy the openELM series models like openELM-270M-IT. After transforming to .gguf format , I met a bug which showed that unknown architecture openelm

Platform: Android Pixel 8 Pro(ARM)
Git Hash: a38b884

Q: How to tackle load model issue?
~/llama.cpp/bin $ ./llama-cli -m ../../model/apple/openelm-270M-it-model-q8_0.gguf -n 128 --color -if Log start main: build = 3291 (f6190247) main: built with clang version 18.1.8 for aarch64-unknown-linux-android24 main: seed = 1720242044 llama_model_loader: loaded meta data with 25 key-value pairs and 146 tensors from ../../model/apple/openelm-270M-it-model-q8_0.gguf (version GGUF V3 (latest)) llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output. llama_model_loader: - kv 0: general.architecture str = openelm llama_model_loader: - kv 1: general.name str = openelm-270M llama_model_loader: - kv 2: openelm.block_count u32 = 16 llama_model_loader: - kv 3: openelm.context_length u32 = 2048 llama_model_loader: - kv 4: openelm.embedding_length u32 = 1280 llama_model_loader: - kv 5: openelm.feed_forward_length arr[i32,16] = [768, 1024, 1280, 1536, 1792, 2048, 2... llama_model_loader: - kv 6: openelm.attention.head_count arr[i32,16] = [12, 12, 12, 12, 12, 16, 16, 16, 16, ... llama_model_loader: - kv 7: openelm.attention.head_count_kv arr[i32,16] = [3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, ... llama_model_loader: - kv 8: openelm.rope.freq_base f32 = 10000.000000 llama_model_loader: - kv 9: openelm.attention.layer_norm_rms_epsilon f32 = 0.000001 llama_model_loader: - kv 10: openelm.rope.dimension_count u32 = 64 llama_model_loader: - kv 11: openelm.attention.key_length u32 = 64 llama_model_loader: - kv 12: openelm.attention.value_length u32 = 64 llama_model_loader: - kv 13: general.file_type u32 = 7 llama_model_loader: - kv 14: tokenizer.ggml.model str = llama llama_model_loader: - kv 15: tokenizer.ggml.pre str = default llama_model_loader: - kv 16: tokenizer.ggml.tokens arr[str,32000] = ["<unk>", "<s>", "</s>", "<0x00>", "<... llama_model_loader: - kv 17: tokenizer.ggml.scores arr[f32,32000] = [0.000000, 0.000000, 0.000000, 0.0000... llama_model_loader: - kv 18: tokenizer.ggml.token_type arr[i32,32000] = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ... llama_model_loader: - kv 19: tokenizer.ggml.bos_token_id u32 = 1 llama_model_loader: - kv 20: tokenizer.ggml.eos_token_id u32 = 2 llama_model_loader: - kv 21: tokenizer.ggml.unknown_token_id u32 = 0 llama_model_loader: - kv 22: tokenizer.ggml.add_bos_token bool = true llama_model_loader: - kv 23: tokenizer.ggml.add_eos_token bool = false llama_model_loader: - kv 24: general.quantization_version u32 = 2 llama_model_loader: - type f32: 65 tensors llama_model_loader: - type q8_0: 81 tensors llama_model_load: error loading model: error loading model architecture: unknown model architecture: 'openelm' llama_load_model_from_file: failed to load model llama_init_from_gpt_params: error: failed to load model '../../model/apple/openelm-270M-it-model-q8_0.gguf' main: error: unable to load model

image

@compilade
Copy link
Collaborator

compilade commented Jul 6, 2024

@sqzhang-jeremy Thanks for trying this.

How to tackle load model issue?

Try to rebuild your binaries.

Log start main: build = 3291 (f6190247)

This build is from a commit from before OpenELM support was merged.

@sqzhang-jeremy
Copy link

@sqzhang-jeremy Thanks for trying this.

How to tackle load model issue?

Try to rebuild your binaries.

Log start main: build = 3291 (f6190247)

This build is from a commit from before OpenELM support was merged.

After rebuilding, It worked! Thank you!

arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jul 7, 2024
* Initial OpenELM support (270M only so far)

* Fill out missing entries in llama_model_type_name

* fixup! Initial OpenELM support (270M only so far)

Fix formatting

* llama : support all OpenELM models

* llama : add variable GQA and variable FFN sizes

Some metadata keys can now also be arrays to support setting
their value per-layer for models like OpenELM.

* llama : minor spacing changes

Co-authored-by: Georgi Gerganov <[email protected]>

* llama : use std::array for per-layer hparams

* llama : fix save/load state

* llama : do not print hparams for vocab-only models

* llama : handle n_head == 0

* llama : use const ref for print_f and fix division by zero

* llama : fix t5 uses of n_head and n_ff

* llama : minor comment

---------

Co-authored-by: Francis Couture-Harpin <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jul 11, 2024
* Initial OpenELM support (270M only so far)

* Fill out missing entries in llama_model_type_name

* fixup! Initial OpenELM support (270M only so far)

Fix formatting

* llama : support all OpenELM models

* llama : add variable GQA and variable FFN sizes

Some metadata keys can now also be arrays to support setting
their value per-layer for models like OpenELM.

* llama : minor spacing changes

Co-authored-by: Georgi Gerganov <[email protected]>

* llama : use std::array for per-layer hparams

* llama : fix save/load state

* llama : do not print hparams for vocab-only models

* llama : handle n_head == 0

* llama : use const ref for print_f and fix division by zero

* llama : fix t5 uses of n_head and n_ff

* llama : minor comment

---------

Co-authored-by: Francis Couture-Harpin <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model Model specific python python script changes Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for OpenELM of Apple
7 participants