Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama_sampling_sample with default args is more naively usable #6519

Merged
merged 2 commits into from
Apr 8, 2024

Conversation

TheFlipbook
Copy link
Contributor

  • Hopefully reduces chances of error case cited in batch_add gives lower quality results than batch_get_one #6475
  • Batches populated by either llama_batch_get_one or llama_batch_add work with default args
    • Previously get_one could use the default argument
    • Previously add should usually have used the last index where batch->logits[idx] == true
  • This hopefully encourages the use of llama_batch_add
    • By giving expected results when using default arguments.
  • Believed to work with any currently well behaved program
    • Default arg now works for both cases (previously would give strange results for add case)
    • Any non-negative number is unaffected and behaves as previously
    • Negative arguments were previously invalid.

Copy link
Contributor

github-actions bot commented Apr 7, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 436 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=10778.95ms p(95)=28462.42ms fails=, finish reason: stop=385 truncated=51
  • Prompt processing (pp): avg=112.97tk/s p(95)=496.09tk/s
  • Token generation (tg): avg=25.98tk/s p(95)=35.95tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=unified-sampling-default commit=4d54281580c40ad689f0c43318cb78b642c81f54

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 436 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1712455148 --> 1712455770
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 515.02, 515.02, 515.02, 515.02, 515.02, 428.84, 428.84, 428.84, 428.84, 428.84, 416.97, 416.97, 416.97, 416.97, 416.97, 432.43, 432.43, 432.43, 432.43, 432.43, 459.03, 459.03, 459.03, 459.03, 459.03, 490.53, 490.53, 490.53, 490.53, 490.53, 506.26, 506.26, 506.26, 506.26, 506.26, 511.36, 511.36, 511.36, 511.36, 511.36, 531.6, 531.6, 531.6, 531.6, 531.6, 548.25, 548.25, 548.25, 548.25, 548.25, 550.62, 550.62, 550.62, 550.62, 550.62, 561.38, 561.38, 561.38, 561.38, 561.38, 564.43, 564.43, 564.43, 564.43, 564.43, 566.23, 566.23, 566.23, 566.23, 566.23, 579.46, 579.46, 579.46, 579.46, 579.46, 587.57, 587.57, 587.57, 587.57, 587.57, 594.62, 594.62, 594.62, 594.62, 594.62, 617.24, 617.24, 617.24, 617.24, 617.24, 587.36, 587.36, 587.36, 587.36, 587.36, 597.64, 597.64, 597.64, 597.64, 597.64, 600.29, 600.29, 600.29, 600.29, 600.29, 599.64, 599.64, 599.64, 599.64, 599.64, 604.49, 604.49, 604.49, 604.49, 604.49, 606.51, 606.51, 606.51, 606.51, 606.51, 606.8, 606.8, 606.8, 606.8, 606.8, 611.95, 611.95, 611.95, 611.95, 611.95, 613.06, 613.06, 613.06, 613.06, 613.06, 615.53, 615.53, 615.53, 615.53, 615.53, 631.85, 631.85, 631.85, 631.85, 631.85, 630.77, 630.77, 630.77, 630.77, 630.77, 633.33, 633.33, 633.33, 633.33, 633.33, 635.54, 635.54, 635.54, 635.54, 635.54, 635.39, 635.39, 635.39, 635.39, 635.39, 642.41, 642.41, 642.41, 642.41, 642.41, 641.83, 641.83, 641.83, 641.83, 641.83, 643.64, 643.64, 643.64, 643.64, 643.64, 644.7, 644.7, 644.7, 644.7, 644.7, 650.85, 650.85, 650.85, 650.85, 650.85, 651.75, 651.75, 651.75, 651.75, 651.75, 653.07, 653.07, 653.07, 653.07, 653.07, 660.53, 660.53, 660.53, 660.53, 660.53, 666.48, 666.48, 666.48, 666.48, 666.48, 670.94, 670.94, 670.94, 670.94, 670.94, 672.22, 672.22, 672.22, 672.22, 672.22, 676.53, 676.53, 676.53, 676.53, 676.53, 675.46, 675.46, 675.46, 675.46, 675.46, 676.03, 676.03, 676.03, 676.03, 676.03, 678.88, 678.88, 678.88, 678.88, 678.88, 681.62, 681.62, 681.62, 681.62, 681.62, 682.62, 682.62, 682.62, 682.62, 682.62, 669.75, 669.75, 669.75, 669.75, 669.75, 655.43, 655.43, 655.43, 655.43, 655.43, 654.1, 654.1, 654.1, 654.1, 654.1, 653.24, 653.24, 653.24, 653.24, 653.24, 652.4, 652.4, 652.4, 652.4, 652.4, 650.72, 650.72, 650.72, 650.72, 650.72, 651.69, 651.69, 651.69, 651.69, 651.69, 654.67, 654.67, 654.67, 654.67, 654.67, 656.22, 656.22, 656.22, 656.22, 656.22, 656.31, 656.31, 656.31, 656.31, 656.31, 656.62, 656.62]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 436 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1712455148 --> 1712455770
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.99, 33.99, 33.99, 33.99, 33.99, 34.1, 34.1, 34.1, 34.1, 34.1, 21.97, 21.97, 21.97, 21.97, 21.97, 22.64, 22.64, 22.64, 22.64, 22.64, 22.31, 22.31, 22.31, 22.31, 22.31, 22.83, 22.83, 22.83, 22.83, 22.83, 22.83, 22.83, 22.83, 22.83, 22.83, 23.56, 23.56, 23.56, 23.56, 23.56, 24.49, 24.49, 24.49, 24.49, 24.49, 24.58, 24.58, 24.58, 24.58, 24.58, 24.71, 24.71, 24.71, 24.71, 24.71, 24.62, 24.62, 24.62, 24.62, 24.62, 23.87, 23.87, 23.87, 23.87, 23.87, 23.65, 23.65, 23.65, 23.65, 23.65, 23.48, 23.48, 23.48, 23.48, 23.48, 23.38, 23.38, 23.38, 23.38, 23.38, 22.77, 22.77, 22.77, 22.77, 22.77, 22.76, 22.76, 22.76, 22.76, 22.76, 22.34, 22.34, 22.34, 22.34, 22.34, 21.75, 21.75, 21.75, 21.75, 21.75, 21.82, 21.82, 21.82, 21.82, 21.82, 22.02, 22.02, 22.02, 22.02, 22.02, 22.1, 22.1, 22.1, 22.1, 22.1, 21.76, 21.76, 21.76, 21.76, 21.76, 21.72, 21.72, 21.72, 21.72, 21.72, 21.6, 21.6, 21.6, 21.6, 21.6, 21.47, 21.47, 21.47, 21.47, 21.47, 21.63, 21.63, 21.63, 21.63, 21.63, 21.72, 21.72, 21.72, 21.72, 21.72, 21.6, 21.6, 21.6, 21.6, 21.6, 21.79, 21.79, 21.79, 21.79, 21.79, 21.81, 21.81, 21.81, 21.81, 21.81, 21.89, 21.89, 21.89, 21.89, 21.89, 21.82, 21.82, 21.82, 21.82, 21.82, 21.55, 21.55, 21.55, 21.55, 21.55, 21.64, 21.64, 21.64, 21.64, 21.64, 21.88, 21.88, 21.88, 21.88, 21.88, 21.97, 21.97, 21.97, 21.97, 21.97, 22.02, 22.02, 22.02, 22.02, 22.02, 22.27, 22.27, 22.27, 22.27, 22.27, 22.37, 22.37, 22.37, 22.37, 22.37, 22.33, 22.33, 22.33, 22.33, 22.33, 22.33, 22.33, 22.33, 22.33, 22.33, 22.28, 22.28, 22.28, 22.28, 22.28, 22.06, 22.06, 22.06, 22.06, 22.06, 22.08, 22.08, 22.08, 22.08, 22.08, 22.11, 22.11, 22.11, 22.11, 22.11, 22.22, 22.22, 22.22, 22.22, 22.22, 22.37, 22.37, 22.37, 22.37, 22.37, 22.51, 22.51, 22.51, 22.51, 22.51, 22.51, 22.51, 22.51, 22.51, 22.51, 22.41, 22.41, 22.41, 22.41, 22.41, 22.22, 22.22, 22.22, 22.22, 22.22, 22.19, 22.19, 22.19, 22.19, 22.19, 21.88, 21.88, 21.88, 21.88, 21.88, 21.56, 21.56, 21.56, 21.56, 21.56, 20.79, 20.79, 20.79, 20.79, 20.79, 20.72, 20.72, 20.72, 20.72, 20.72, 20.73, 20.73, 20.73, 20.73, 20.73, 20.79, 20.79, 20.79, 20.79, 20.79, 20.91, 20.91]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 436 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1712455148 --> 1712455770
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11, 0.11, 0.11, 0.11, 0.11, 0.37, 0.37, 0.37, 0.37, 0.37, 0.16, 0.16, 0.16, 0.16, 0.16, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.23, 0.23, 0.23, 0.23, 0.23, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.18, 0.18, 0.18, 0.18, 0.18, 0.21, 0.21, 0.21, 0.21, 0.21, 0.26, 0.26, 0.26, 0.26, 0.26, 0.22, 0.22, 0.22, 0.22, 0.22, 0.14, 0.14, 0.14, 0.14, 0.14, 0.3, 0.3, 0.3, 0.3, 0.3, 0.18, 0.18, 0.18, 0.18, 0.18, 0.29, 0.29, 0.29, 0.29, 0.29, 0.19, 0.19, 0.19, 0.19, 0.19, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.29, 0.29, 0.29, 0.29, 0.29, 0.27, 0.27, 0.27, 0.27, 0.27, 0.28, 0.28, 0.28, 0.28, 0.28, 0.2, 0.2, 0.2, 0.2, 0.2, 0.15, 0.15, 0.15, 0.15, 0.15, 0.12, 0.12, 0.12, 0.12, 0.12, 0.32, 0.32, 0.32, 0.32, 0.32, 0.12, 0.12, 0.12, 0.12, 0.12, 0.11, 0.11, 0.11, 0.11, 0.11, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.32, 0.32, 0.32, 0.32, 0.32, 0.11, 0.11, 0.11, 0.11, 0.11, 0.12, 0.12, 0.12, 0.12, 0.12, 0.08, 0.08, 0.08, 0.08, 0.08, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.14, 0.14, 0.14, 0.14, 0.14, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.21, 0.21, 0.21, 0.21, 0.21, 0.2, 0.2, 0.2, 0.2, 0.2, 0.19, 0.19, 0.19, 0.19, 0.19, 0.15, 0.15, 0.15, 0.15, 0.15, 0.19, 0.19, 0.19, 0.19, 0.19, 0.12, 0.12, 0.12, 0.12, 0.12, 0.11, 0.11, 0.11, 0.11, 0.11, 0.14, 0.14, 0.14, 0.14, 0.14, 0.33, 0.33, 0.33, 0.33, 0.33, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.58, 0.58, 0.58, 0.58, 0.58, 0.52, 0.52, 0.52, 0.52, 0.52, 0.42, 0.42, 0.42, 0.42, 0.42, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.14, 0.14]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 436 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1712455148 --> 1712455770
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 1.0, 1.0]
                    
Loading

llama.cpp Outdated
Comment on lines 15551 to 15557
// non-trivial case scan for the last output.
for (int32_t i = ctx->output_ids.size() - 1; i >= 0; --i) {
const int32_t candidate = ctx->output_ids[i];
if (candidate >= 0) {
return i;
}
}
Copy link
Collaborator

@compilade compilade Apr 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure there's a faster and simpler way to do this. If ctx->n_outputs is set to the total number of outputs after a batch is processed, the last logits would be available at ctx->logits + (ctx->n_outputs - 1)*ctx->model.hparams.n_vocab.

(this could be a special case of llama_get_logits_ith instead, to avoid trying to find the last logits index, which is still non-trivial)

In llama_decode_internal, the number of outputs is already calculated as n_outputs, so a simple lctx.n_outputs = n_outputs near the end of the function (way after the compute graph and its inputs are built, outside of the ubatch loop) would do the trick.

(BTW, I worked on adding ctx->output_ids, ctx->n_outputs, and other related stuff in #6122)

(click to expand) A possible way to implement what I'm trying to explain (diffed from `master` (d4f220a at the time of writing))
diff --git a/common/sampling.h b/common/sampling.h
index 56ed991b..639b819a 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -129,7 +129,7 @@ llama_token llama_sampling_sample(
         struct llama_sampling_context * ctx_sampling,
         struct llama_context * ctx_main,
         struct llama_context * ctx_cfg,
-        int idx = 0);
+        int idx = -1);
 
 // Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
 llama_token_data_array llama_sampling_prepare(
diff --git a/llama.cpp b/llama.cpp
index 21772618..adfb55fa 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -2177,7 +2177,7 @@ struct llama_context {
 
     std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
     size_t  output_size = 0; // capacity (of tokens positions) for the output buffers
-    int32_t n_outputs   = 0; // number of actually-used outputs in the current ubatch
+    int32_t n_outputs   = 0; // number of actually-used outputs in the current ubatch or last logical batch
 
     bool logits_all = false;
 
@@ -10411,6 +10411,9 @@ static int llama_decode_internal(
         n_outputs_prev += lctx.n_outputs;
     }
 
+    // set to total number of outputs in the batch, for use in llama_get_logits_ith
+    lctx.n_outputs = n_outputs;
+
     // wait for the computation to finish (automatically done when obtaining the model output)
     //llama_synchronize(&lctx);
 
@@ -15511,23 +15514,29 @@ float * llama_get_logits(struct llama_context * ctx) {
 }
 
 float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
+    int32_t j;
+
     llama_synchronize(ctx);
 
     try {
         if (ctx->logits == nullptr) {
             throw std::runtime_error("no logits");
         }
-        if ((size_t) i >= ctx->output_ids.size()) {
+
+        if (i < 0) {
+            j = ctx->n_outputs + i;
+        } else if ((size_t) i >= ctx->output_ids.size()) {
             throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
+        } else {
+            j = ctx->output_ids[i];
         }
-        const int32_t j = ctx->output_ids[i];
 
         if (j < 0) {
             throw std::runtime_error(format("batch.logits[%d] != true", i));
         }
-        if ((size_t) j >= ctx->output_size) {
+        if (j >= ctx->n_outputs) {
             // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
+            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
         }
 
         return ctx->logits + j*ctx->model.hparams.n_vocab;
@@ -15547,23 +15556,29 @@ float * llama_get_embeddings(struct llama_context * ctx) {
 }
 
 float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
+    int32_t j;
+
     llama_synchronize(ctx);
 
     try {
         if (ctx->embd == nullptr) {
             throw std::runtime_error("no embeddings");
         }
-        if ((size_t) i >= ctx->output_ids.size()) {
+
+        if (i < 0) {
+            j = ctx->n_outputs + i;
+        } else if ((size_t) i >= ctx->output_ids.size()) {
             throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
+        } else {
+            j = ctx->output_ids[i];
         }
-        const int32_t j = ctx->output_ids[i];
 
         if (j < 0) {
             throw std::runtime_error(format("batch.logits[%d] != true", i));
         }
-        if ((size_t) j >= ctx->output_size) {
+        if (j >= ctx->n_outputs) {
             // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
+            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
         }
 
         return ctx->embd + j*ctx->model.hparams.n_embd;

Note that this also makes llama_get_embeddings_ith support negative indices, which may or may not be useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems reasonable to me!

I had been worried about changing the public api of llama_get_logits_ith, but if that is viable, then I concur, it seems guaranteed optimal to do it there, rather than on the index.

I had been hopeful that the worst case was uncommon, as it seemed likely that the last index would be set for wanting logits, and therefore the for loop would, in practice, usually exit on the first loop.

Would you like me to update this PR to conform to the suggested code? I think the suggestion looks great, and even adds a cool piece of functionality of python-style negative indexing.

Copy link
Collaborator

@compilade compilade Apr 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had been worried about changing the public api of llama_get_logits_ith

Since this seems like a backward-compatible change, there is no need to worry. API changes are quite common in llama.cpp, there's even a "Recent API changes" section in the README.md for this purpose.

I had been hopeful that the worst case was uncommon, as it seemed likely that the last index would be set for wanting logits, and therefore the for loop would, in practice, usually exit on the first loop.

You are right, the worst case isn't common (at least when using batch.logits). Another way would have been to always return ctx->cparams.n_batch - 1 or ctx->output_ids.size() - 1, and let llama_get_logits_ith do its asserts to fail in unexpected cases, but this would not work with llama_batch_get_one.

Thinking a bit more about this, the worst case is actually quite common in the sense that it happens all the time when using llama_batch_get_one, because the "last valid index" in this case is 0.

Would you like me to update this PR to conform to the suggested code?

Yes. You had the idea of making this, so feel free to make commits in your name.

I hereby allow @TheFlipbook to use my code review suggestion from this comment thread and do absolutely anything with it, including but not limited to committing it (even with modifications) without having to give me attribution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the PR with these suggestions. Apologies for the double-push, the first one had a formatting mistake.

Let me know if you'd suggest any further updates! Thanks for all the help!

TheFlipbook added a commit to TheFlipbook/llama.cpp that referenced this pull request Apr 7, 2024
* Batches populated by either llama_batch_get_one or llama_batch_add work with default args
  * Previously get_one could use the default argument
  * Previously add should usually have used the last index where logits[idx] == true
* This hopefully encourages the use of llama_batch_add
  * By giving expected results when using default arguments.
* Adds "negative indexing" feature to llama_get_logits_ith and llama_get_embeddings_ith
* Believed to work with any currently well behaved program
  * Default arg now works for both cases (previously would give strange results for add case)
  * Any non-negative number is unaffected and behaves as previously
  * Negative arguments were previously invalid.
* Implemented as a special case of indexing as suggested by @compilade in ggerganov#6519
@TheFlipbook TheFlipbook force-pushed the unified-sampling-default branch from 4d54281 to 0d574fc Compare April 7, 2024 20:58
* Batches populated by either llama_batch_get_one or llama_batch_add work with default args
  * Previously get_one could use the default argument
  * Previously add should usually have used the last index where logits[idx] == true
* This hopefully encourages the use of llama_batch_add
  * By giving expected results when using default arguments.
* Adds "negative indexing" feature to llama_get_logits_ith and llama_get_embeddings_ith
* Believed to work with any currently well behaved program
  * Default arg now works for both cases (previously would give strange results for add case)
  * Any non-negative number is unaffected and behaves as previously
  * Negative arguments were previously invalid.
* Implemented as a special case of indexing as suggested by @compilade in ggerganov#6519
@TheFlipbook TheFlipbook force-pushed the unified-sampling-default branch from 0d574fc to 95bf5f7 Compare April 7, 2024 21:00
* cited in macOS CI tests
* Missed in original updates based on PR feedback in ggerganov#6519
@TheFlipbook
Copy link
Contributor Author

It looks like CI / macOS-latest-cmake-x64 might be failing but I'm not quite sure how to verify if the PR is causal.

I'm trying to read the logs, and I see:
Test #21: test-backend-ops ....................***Failed 0.33 sec
and above that lines like this:

21:   SQR(type=f32,ne=[10,10,10,10]): ggml_backend_alloc_ctx_tensors_from_buft: tensor sent_0 is too large to fit in a Metal buffer (tensor size: 4096, max buffer size: 0)
21: failed to allocate tensors

If I'm following the chain for that error correctly it's because ggml_backend_metal_buffer_type_get_max_size is returning 0. Based on https://developer.apple.com/documentation/metal/mtldevice/2966563-maxbufferlength, I'm not quite sure how this PR could be causal for that error.

Perhaps the CI error is a transient error?

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The macOS workflow has been failing occasionally recently - last time I looked into this, it seemed the runner can end up randomly without a GPU device for some reason. Anyway, it's not relevant for this change

@ggerganov ggerganov merged commit e3c337d into ggerganov:master Apr 8, 2024
55 of 59 checks passed
github-actions bot pushed a commit to KerfuffleV2/ggml-sys-bleedingedge that referenced this pull request Apr 8, 2024
== Relevant log messages from source repo:

commit b73e564b16086845a8b4fffd26e22685d3e0c3db
Author: Georgi Gerganov <[email protected]>
Date:   Mon Apr 8 16:23:01 2024 +0300

    quantize : fix precedence of cli args (#6541)

commit e3c337d87ca650972105a51c6ce302dd236c07ad
Author: Rick G <[email protected]>
Date:   Mon Apr 8 06:02:30 2024 -0700

    llama : support negative ith in llama_get_ API (#6519)

    * llama_sampling_sample with default args is more naively usable

    * Batches populated by either llama_batch_get_one or llama_batch_add work with default args
      * Previously get_one could use the default argument
      * Previously add should usually have used the last index where logits[idx] == true
    * This hopefully encourages the use of llama_batch_add
      * By giving expected results when using default arguments.
    * Adds "negative indexing" feature to llama_get_logits_ith and llama_get_embeddings_ith
    * Believed to work with any currently well behaved program
      * Default arg now works for both cases (previously would give strange results for add case)
      * Any non-negative number is unaffected and behaves as previously
      * Negative arguments were previously invalid.
    * Implemented as a special case of indexing as suggested by @compilade in ggerganov/llama.cpp#6519

    * Fixed mismatch type errors

    * cited in macOS CI tests
    * Missed in original updates based on PR feedback in ggerganov/llama.cpp#6519

commit beea6e1b16e783a0886e78dec01002a8c00db24d
Author: Jan Boon <[email protected]>
Date:   Mon Apr 8 20:43:30 2024 +0800

    llama : save and restore kv cache for single seq id (#6341)

    * llama : save and restore kv cache for single seq id

    * remove trailing whitespace

    * respond error in case there's no space in the kv cache

    * add kv seq save restore to test case

    * add --slot-save-path arg to enable save restore and restrict save location

    * Returning 0 for some cases, instead of asserting.

    * cleanup error cases

    * rename sequence state functions

    * rename state get set functions

    * add previous function names back in with DEPRECATED notice

    * update doc

    * adjust endpoints to preferred style

    * fix restoring zero cell count

    * handle seq rm return value

    * unused param

    * keep in the size check

    * fix return types

    * add server test case for slot save restore

    * cleanup

    * add cake

    * cleanup style

    * add special

    * removing a whole sequence never fails

    * move sequence state file functionality from server to llama to match session api and add version tags

    * catch exceptions on save as well

    * error log messages

    * check types for stricter restore

    * update server doc

    * readme : update API changes date

    * strict filename validation

    * move include, reject bom as well

    * also reject empty filename

    * reject whitespace and trailing dot

    ---------

    Co-authored-by: Martin Evans <[email protected]>
    Co-authored-by: Georgi Gerganov <[email protected]>
@TheFlipbook TheFlipbook deleted the unified-sampling-default branch April 9, 2024 08:05
tybalex pushed a commit to rubra-ai/tools.cpp that referenced this pull request Apr 17, 2024
* llama_sampling_sample with default args is more naively usable

* Batches populated by either llama_batch_get_one or llama_batch_add work with default args
  * Previously get_one could use the default argument
  * Previously add should usually have used the last index where logits[idx] == true
* This hopefully encourages the use of llama_batch_add
  * By giving expected results when using default arguments.
* Adds "negative indexing" feature to llama_get_logits_ith and llama_get_embeddings_ith
* Believed to work with any currently well behaved program
  * Default arg now works for both cases (previously would give strange results for add case)
  * Any non-negative number is unaffected and behaves as previously
  * Negative arguments were previously invalid.
* Implemented as a special case of indexing as suggested by @compilade in ggerganov#6519

* Fixed mismatch type errors

* cited in macOS CI tests
* Missed in original updates based on PR feedback in ggerganov#6519
shg8 pushed a commit to shg8/llama-exp that referenced this pull request Apr 27, 2024
* llama_sampling_sample with default args is more naively usable

* Batches populated by either llama_batch_get_one or llama_batch_add work with default args
  * Previously get_one could use the default argument
  * Previously add should usually have used the last index where logits[idx] == true
* This hopefully encourages the use of llama_batch_add
  * By giving expected results when using default arguments.
* Adds "negative indexing" feature to llama_get_logits_ith and llama_get_embeddings_ith
* Believed to work with any currently well behaved program
  * Default arg now works for both cases (previously would give strange results for add case)
  * Any non-negative number is unaffected and behaves as previously
  * Negative arguments were previously invalid.
* Implemented as a special case of indexing as suggested by @compilade in ggerganov/llama.cpp#6519

* Fixed mismatch type errors

* cited in macOS CI tests
* Missed in original updates based on PR feedback in ggerganov/llama.cpp#6519
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants