Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

threadpool: skip polling for unused threads #9461

Merged
merged 6 commits into from
Sep 17, 2024

Conversation

max-krasnyansky
Copy link
Collaborator

Currently all threads do N polling rounds even if only 1 thread is active (n_threads_cur == 1).
For smaller graphs/models/prompts the unused threads may end up always polling and never sleeping because we keep getting new graphs to work on.

This PR adds support for skipping the polling for unused threads (ith >= n_threads_cur). They simply go to sleep and we wake them up when we get a new graph to work on.

n_threads_cur is now an atomic_int to explicitly tell the compiler and thread sanitizer that it is written from one thread and read from other threads (free from race conditions). All loads and stores use relaxed memory order so there is no additional overhead.

Here are some scenarios with the default build on M2 Max, with debug prints for n_thread updates, and for threads going to sleep.

Full offload (Metal)

8 threads are started. Only 1 is active, so the other 7 skip the polling and go to sleep.

./llama-cli -m ../gguf/llama-v2-115m.q4_0.gguf --seed 42 -p 'what is the most popular cookie in the world?' -n 4

threadpool: n_threads_cur 1 n_threads 1
 what is the most popular cookie in the world?threadpool: n_threads_cur 8 n_threads 1
thread #1 waiting for work (sleeping)
thread #2 waiting for work (sleeping)
thread #3 waiting for work (sleeping)
thread #5 waiting for work (sleeping)
thread #7 waiting for work (sleeping)
thread #6 waiting for work (sleeping)
thread #4 waiting for work (sleeping)

threadpool: n_threads_cur 1 n_threads 1
thread #1 waiting for work (sleeping)
thread #2 waiting for work (sleeping)
thread #5 waiting for work (sleeping)
thread #3 waiting for work (sleeping)
thread #7 waiting for work (sleeping)
thread #4 waiting for work (sleeping)
thread #6 waiting for work (sleeping)

CPU only

8 threads are started, and they are all active. hybrid-polling enabled by default prevents them from going to sleep.

./llama-cli -m ../gguf/llama-v2-115m.q4_0.gguf --seed 42 -p 'what is the most popular cookie in the world?' -n 4 -ngl 0

threadpool: n_threads_cur 8 n_threads 8
 what is the most popular cookie in the world?threadpool: n_threads_cur 8 n_threads 8
threadpool: n_threads_cur 8 n_threads 8
threadpool: n_threads_cur 8 n_threads 8
threadpool: n_threads_cur 8 n_threads 8

No KV offload

8 threads are started, and we alternate between using all 8 and just one for different parts of the graph.

./llama-cli -m ../gguf/llama-v2-115m.q4_0.gguf --seed 42 -p 'what is the most popular cookie in the world?' -n 4 -nkvo

threadpool: n_threads_cur 1 n_threads 1
threadpool: n_threads_cur 8 n_threads 8
threadpool: n_threads_cur 8 n_threads 8
threadpool: n_threads_cur 8 n_threads 8
threadpool: n_threads_cur 8 n_threads 8

 what is the most popular cookie in the world?threadpool: n_threads_cur 8 n_threads 1
threadpool: n_threads_cur 1 n_threads 8
thread #5 waiting for work (sleeping)
thread #7 waiting for work (sleeping)
thread #6 waiting for work (sleeping)
thread #2 waiting for work (sleeping)
thread #4 waiting for work (sleeping)
thread #3 waiting for work (sleeping)
thread #1 waiting for work (sleeping)
threadpool: n_threads_cur 8 n_threads 8
threadpool: n_threads_cur 8 n_threads 8
threadpool: n_threads_cur 8 n_threads 8
threadpool: n_threads_cur 8 n_threads 1
threadpool: n_threads_cur 1 n_threads 8

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Sep 13, 2024
@ggerganov
Copy link
Member

Haven't looked at the changes yet, but on Mac have you tried running the thread sanitizer? It detects data races when running CPU-only mode, even with this PR. Started happening after #8672. Here are the steps to run it:

cmake \
	-DCMAKE_BUILD_TYPE=Debug \
	-DLLAMA_SANITIZE_THREAD=ON \
	-DGGML_METAL=OFF \
	-DGGML_LLAMAFILE=OFF ..

make -j && ./bin/llama-simple -m ${model} -p "hello" -ngl 0

You should see output like this:

llama_new_context_with_model: graph splits = 514

main: n_predict = 32, n_ctx = 32768, n_kv_req = 32

<s> hello==================
WARNING: ThreadSanitizer: data race (pid=40272)
  Read of size 6 at 0x00010b80000a by thread T3:
    #0 ggml_vec_dot_q4_0_q8_0 ggml-quants.c:4170 (libggml.dylib:arm64+0x9e31c)
    #1 ggml_compute_forward_mul_mat_one_chunk ggml.c:12771 (libggml.dylib:arm64+0x5d398)
    #2 ggml_compute_forward_mul_mat ggml.c:12987 (libggml.dylib:arm64+0x40374)
    #3 ggml_compute_forward ggml.c:17574 (libggml.dylib:arm64+0x3d29c)
    #4 ggml_graph_compute_thread ggml.c:19953 (libggml.dylib:arm64+0x2a1c0)
    #5 ggml_graph_compute_secondary_thread ggml.c:20061 (libggml.dylib:arm64+0x3c890)

  Previous write of size 1 at 0x00010b80000d by main thread:
    #0 quantize_row_q8_0 ggml-quants.c:900 (libggml.dylib:arm64+0x8cd44)
    #1 ggml_compute_forward_mul_mat ggml.c:12872 (libggml.dylib:arm64+0x3fba8)
    #2 ggml_compute_forward ggml.c:17574 (libggml.dylib:arm64+0x3d29c)
    #3 ggml_graph_compute_thread ggml.c:19953 (libggml.dylib:arm64+0x2a1c0)
    #4 ggml_graph_compute ggml.c:20244 (libggml.dylib:arm64+0x29e3c)
    #5 ggml_backend_cpu_graph_compute ggml-backend.c:817 (libggml.dylib:arm64+0x898bc)
    #6 ggml_backend_graph_compute_async ggml-backend.c:282 (libggml.dylib:arm64+0x7f5bc)
    #7 ggml_backend_sched_compute_splits ggml-backend.c:1818 (libggml.dylib:arm64+0x86b94)
    #8 ggml_backend_sched_graph_compute_async ggml-backend.c:2006 (libggml.dylib:arm64+0x86594)
    #9 llama_graph_compute(llama_context&, ggml_cgraph*, int, ggml_threadpool*) llama.cpp:16057 (libllama.dylib:arm64+0xfe7d0)

  Location is heap block of size 46208 at 0x00010b800000 allocated by main thread:
    #0 malloc <null>:92589700 (libclang_rt.tsan_osx_dynamic.dylib:arm64e+0x5609c)
    #1 ggml_backend_cpu_graph_compute ggml-backend.c:805 (libggml.dylib:arm64+0x897c0)
    #2 ggml_backend_graph_compute_async ggml-backend.c:282 (libggml.dylib:arm64+0x7f5bc)
    #3 ggml_backend_sched_compute_splits ggml-backend.c:1818 (libggml.dylib:arm64+0x86b94)
    #4 ggml_backend_sched_graph_compute_async ggml-backend.c:2006 (libggml.dylib:arm64+0x86594)
    #5 llama_graph_compute(llama_context&, ggml_cgraph*, int, ggml_threadpool*) llama.cpp:16057 (libllama.dylib:arm64+0xfe7d0)

  Thread T3 (tid=15403991, running) created by main thread at:
    #0 pthread_create <null>:92589700 (libclang_rt.tsan_osx_dynamic.dylib:arm64e+0x3062c)
    #1 ggml_threadpool_new_impl ggml.c:20171 (libggml.dylib:arm64+0x299f8)
    #2 ggml_graph_compute ggml.c:20208 (libggml.dylib:arm64+0x29cfc)
    #3 ggml_backend_cpu_graph_compute ggml-backend.c:817 (libggml.dylib:arm64+0x898bc)
    #4 ggml_backend_graph_compute_async ggml-backend.c:282 (libggml.dylib:arm64+0x7f5bc)
    #5 ggml_backend_sched_compute_splits ggml-backend.c:1818 (libggml.dylib:arm64+0x86b94)
    #6 ggml_backend_sched_graph_compute_async ggml-backend.c:2006 (libggml.dylib:arm64+0x86594)
    #7 llama_graph_compute(llama_context&, ggml_cgraph*, int, ggml_threadpool*) llama.cpp:16057 (libllama.dylib:arm64+0xfe7d0)

SUMMARY: ThreadSanitizer: data race ggml-quants.c:4170 in ggml_vec_dot_q4_0_q8_0
==================
==================
WARNING: ThreadSanitizer: data race (pid=40272)
  Read of size 8 at 0x00010b800068 by thread T6:
    #0 ggml_vec_dot_q4_0_q8_0 ggml-quants.c:4172 (libggml.dylib:arm64+0x9e37c)
...

We have to find a way to resolve these warnings.

@max-krasnyansky
Copy link
Collaborator Author

Haven't looked at the changes yet, but on Mac have you tried running the thread sanitizer? It detects data races when running CPU-only mode, even with this PR. Started happening after #8672.
...
We have to find a way to resolve these warnings.

Oh. I assumed those are benign. We did have a bunch of similar thread sanitizer warnings on x86-64 with openmp before the threadpool changes. So I figured it's some minor overlap in matmul kernels.
Will double check in a bit and report back.

@max-krasnyansky
Copy link
Collaborator Author

@ggerganov @slaren Quick update on the thread sanitizer warnings.
I reverted to the commit 9f7d4bc , just before the treadpool merge.

git log --oneline
...
42c76d13 Threadpool: take 2 (#8672)
9f7d4bcf server : fix crash when error handler dumps invalid utf-8 json (#9195). <<<<<<<<<<<
1d1ccce6 flake.lock: Update (#9162)
...

Built for Mac with OpenMP and Thread Sanitizer using LLVM 18 installed via homebrew.

CXX=clang++ CC=clang CFLAGS="-march=armv8.7a" CCFLAGS="-march=armv8.7a" cmake -DCMAKE_BUILD_TYPE=Debug -DLLAMA_SANITIZE_THREAD=ON -DGGML_METAL=OFF -DGGML_LLAMAFILE=OFF ..
-- The C compiler identification is Clang 18.1.8
-- Check for working C compiler: /opt/homebrew/opt/llvm/bin/clang - skipped
...
-- Found OpenMP_C: -fopenmp=libomp (found version "5.1")
-- Found OpenMP_CXX: -fopenmp=libomp (found version "5.1")
-- Found OpenMP: TRUE (found version "5.1")

And we're getting a bunch of those warnings.

./bin/llama-simple -m ../../gguf/llama-v2-115m.q4_0.gguf -p Hello -ngl 0
...
WARNING: ThreadSanitizer: data race (pid=20652)
  Write of size 4 at 0x000107500180 by thread T2:
    #0 ggml_vec_cpy_f32 ggml.c:1926 (libggml.dylib:arm64+0x44938)
    #1 ggml_compute_forward_soft_max_f32 ggml.c:13989 (libggml.dylib:arm64+0x4e184)
    #2 ggml_compute_forward_soft_max ggml.c:14037 (libggml.dylib:arm64+0x2de10)
    #3 ggml_compute_forward ggml.c:17279 (libggml.dylib:arm64+0x29524)
    #4 ggml_graph_compute_thread ggml.c:19339 (libggml.dylib:arm64+0x28ff0)
    #5 ggml_graph_compute.omp_outlined_debug__ ggml.c:19390 (libggml.dylib:arm64+0x28e1c)
    #6 ggml_graph_compute.omp_outlined ggml.c:19376 (libggml.dylib:arm64+0x2918c)
    #7 __kmp_invoke_microtask <null>:77598336 (libomp.dylib:arm64+0x861b8)

  Previous read of size 8 at 0x000107500180 by thread T5:
    #0 ggml_vec_dot_f16 ggml.c:2058 (libggml.dylib:arm64+0x597c)
    #1 ggml_compute_forward_mul_mat_one_chunk ggml.c:12514 (libggml.dylib:arm64+0x483a4)
    #2 ggml_compute_forward_mul_mat ggml.c:12730 (libggml.dylib:arm64+0x2c0c4)
    #3 ggml_compute_forward ggml.c:17215 (libggml.dylib:arm64+0x29424)
    #4 ggml_graph_compute_thread ggml.c:19339 (libggml.dylib:arm64+0x28ff0)
    #5 ggml_graph_compute.omp_outlined_debug__ ggml.c:19390 (libggml.dylib:arm64+0x28e1c)
    #6 ggml_graph_compute.omp_outlined ggml.c:19376 (libggml.dylib:arm64+0x2918c)
    #7 __kmp_invoke_microtask <null>:77598336 (libomp.dylib:arm64+0x861b8)
...    
  Read of size 6 at 0x000107500bb2 by thread T3:
    #0 ggml_vec_dot_q4_0_q8_0 ggml-quants.c:3783 (libggml.dylib:arm64+0x96f74)
    #1 ggml_compute_forward_mul_mat_one_chunk ggml.c:12514 (libggml.dylib:arm64+0x483a4)
    #2 ggml_compute_forward_mul_mat ggml.c:12730 (libggml.dylib:arm64+0x2c0c4)
    #3 ggml_compute_forward ggml.c:17215 (libggml.dylib:arm64+0x29424)
    #4 ggml_graph_compute_thread ggml.c:19339 (libggml.dylib:arm64+0x28ff0)
    #5 ggml_graph_compute.omp_outlined_debug__ ggml.c:19390 (libggml.dylib:arm64+0x28e1c)
    #6 ggml_graph_compute.omp_outlined ggml.c:19376 (libggml.dylib:arm64+0x2918c)
    #7 __kmp_invoke_microtask <null>:77598336 (libomp.dylib:arm64+0x861b8)

  Previous write of size 1 at 0x000107500bb5 by thread T1:
    #0 quantize_row_q8_0 ggml-quants.c:900 (libggml.dylib:arm64+0x85d98)
    #1 ggml_compute_forward_mul_mat ggml.c:12615 (libggml.dylib:arm64+0x2b970)
    #2 ggml_compute_forward ggml.c:17215 (libggml.dylib:arm64+0x29424)
    #3 ggml_graph_compute_thread ggml.c:19339 (libggml.dylib:arm64+0x28ff0)
    #4 ggml_graph_compute.omp_outlined_debug__ ggml.c:19390 (libggml.dylib:arm64+0x28e1c)
    #5 ggml_graph_compute.omp_outlined ggml.c:19376 (libggml.dylib:arm64+0x2918c)
    #6 __kmp_invoke_microtask <null>:77598336 (libomp.dylib:arm64+0x861b8)
...
Read of size 4 at 0x000116005e80 by thread T1:
    #0 ggml_fp32_to_fp16_row ggml.c:457 (libggml.dylib:arm64+0x4d30)
    #1 ggml_compute_forward_mul_mat ggml.c:12615 (libggml.dylib:arm64+0x2b970)
    #2 ggml_compute_forward ggml.c:17215 (libggml.dylib:arm64+0x29424)
    #3 ggml_graph_compute_thread ggml.c:19339 (libggml.dylib:arm64+0x28ff0)
    #4 ggml_graph_compute.omp_outlined_debug__ ggml.c:19390 (libggml.dylib:arm64+0x28e1c)
    #5 ggml_graph_compute.omp_outlined ggml.c:19376 (libggml.dylib:arm64+0x2918c)
    #6 __kmp_invoke_microtask <null>:115347072 (libomp.dylib:arm64+0x861b8)

  Previous write of size 8 at 0x000116005e80 by thread T6:
    #0 ggml_vec_soft_max_f32 ggml.c:2662 (libggml.dylib:arm64+0x4e70c)
    #1 ggml_compute_forward_soft_max_f32 ggml.c:14013 (libggml.dylib:arm64+0x4e3d4)
    #2 ggml_compute_forward_soft_max ggml.c:14037 (libggml.dylib:arm64+0x2de10)
    #3 ggml_compute_forward ggml.c:17279 (libggml.dylib:arm64+0x29524)
    #4 ggml_graph_compute_thread ggml.c:19339 (libggml.dylib:arm64+0x28ff0)
    #5 ggml_graph_compute.omp_outlined_debug__ ggml.c:19390 (libggml.dylib:arm64+0x28e1c)
    #6 ggml_graph_compute.omp_outlined ggml.c:19376 (libggml.dylib:arm64+0x2918c)
    #7 __kmp_invoke_microtask <null>:115347072 (libomp.dylib:arm64+0x861b8)

That's why I assumed those are sort of known / benign. Perhaps, they are not?
Interestingly enough, we don't get any warnings without openmp on that same commit.

The threadpool support makes timing/behavior very similar to openmp that's why those warnings are now showing up in the default builds (ie threadpool is enabled by default on the Mac with Apple toolchains).

As I mentioned earlier we do see a bunch of those sanitizer warnings on x86-64 with openmp/threadpool as well.

How do you guys want to proceed?
I suggest we merge this, since it addresses real issues with unused threads spinning, and makes n_threads_cur accesses explicit. We can then follow up on those sanitizer warnigns

@slaren
Copy link
Member

slaren commented Sep 13, 2024

I don't think that the warnings reported by address sanitizer here are benign. OpenMP has known compatibility issues with Address Sanitizer since it is not aware of the synchronization mechanism used by OpenMP, but this should not happen when using plain pthreads and atomics. I believe that this is due to using relaxed memory order in ggml_barrier and ggml_compute_forward_mul_mat. There must be a memory fence between every operation, which previously this was done via the implicit seq_cst atomic load/stores in ggml_barrier, and in ggml_compute_forward_mul_mat there must also be a memory fence between the quantization of src1 and the the matrix multiplication, since all the threads need that data.

It could probably be done with more relaxed memory order, but these changes (on top of this PR) seem to fix the tsan warnings:

diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index c3b462b3..a49d3992 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -3188,7 +3188,7 @@ static void ggml_barrier(struct ggml_threadpool * threadpool) {
 }
 #else
 static void ggml_barrier(struct ggml_threadpool * threadpool) {
-    int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed);
+    int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_seq_cst);
     if (n_threads == 1) {
         return;
     }
@@ -3196,16 +3196,16 @@ static void ggml_barrier(struct ggml_threadpool * threadpool) {
     atomic_int * n_barrier = &threadpool->n_barrier;
     atomic_int * n_barrier_passed = &threadpool->n_barrier_passed;

-    int passed_old = atomic_load_explicit(n_barrier_passed, memory_order_relaxed);
+    int passed_old = atomic_load_explicit(n_barrier_passed, memory_order_seq_cst);

     if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) {
         // last thread
         atomic_store(n_barrier, 0);
-        atomic_fetch_add_explicit(n_barrier_passed, 1, memory_order_relaxed);
+        atomic_fetch_add_explicit(n_barrier_passed, 1, memory_order_seq_cst);
     } else {
         // wait for other threads
         while (true) {
-            if (atomic_load_explicit(n_barrier_passed, memory_order_relaxed) != passed_old) {
+            if (atomic_load_explicit(n_barrier_passed, memory_order_seq_cst) != passed_old) {
                 return;
             }
             ggml_thread_cpu_relax();
@@ -12879,7 +12879,8 @@ UseGgmlGemm1:;

     if (ith == 0) {
         // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
-        atomic_store_explicit(&params->threadpool->current_chunk, nth, memory_order_relaxed);
+        //atomic_store_explicit(&params->threadpool->current_chunk, nth, memory_order_relaxed);
+        atomic_store(&params->threadpool->current_chunk, nth);
     }

     ggml_barrier(params->threadpool);
@@ -12990,7 +12991,8 @@ UseGgmlGemm2:;
             break;
         }

-        current_chunk = atomic_fetch_add_explicit(&params->threadpool->current_chunk, 1, memory_order_relaxed);
+        //current_chunk = atomic_fetch_add_explicit(&params->threadpool->current_chunk, 1, memory_order_relaxed);
+        current_chunk = atomic_fetch_add(&params->threadpool->current_chunk, 1);
     }
 }

@ggerganov
Copy link
Member

@slaren This patch fixes the sanitizer warnings on my end.

I don't think that the warnings reported by address sanitizer here are benign. OpenMP has known compatibility issues with Address Sanitizer since it is not aware of the synchronization mechanism used by OpenMP, but this should not happen when using plain pthreads and atomics.

Yes, I agree. Reading on the internet about this, it appears that when OpenMP is enabled, the sanitizers can report issues, which is to be expected and there is not much we can do about it. We just have to make sure there are no issues when OpenMP is off.

@max-krasnyansky
Copy link
Collaborator Author

@slaren This patch fixes the sanitizer warnings on my end.

I don't think that the warnings reported by address sanitizer here are benign. OpenMP has known compatibility issues with Address Sanitizer since it is not aware of the synchronization mechanism used by OpenMP, but this should not happen when using plain pthreads and atomics.

Yes, I agree. Reading on the internet about this, it appears that when OpenMP is enabled, the sanitizers can report issues, which is to be expected and there is not much we can do about it. We just have to make sure there are no issues when OpenMP is off.

I did a bunch of digging and actually convinced myself that the warnings are benign :).
There are several discussions on how thread sanitizer has issues with fences and things.
google/sanitizers#1415
Different projects end up doing different workarounds for that issue.

In our case, we just need to make sure that the thread sanitizer understands that ggml_barrier() enforces ordering.
We already use full memory barrier order for updating n_barrier. This guaranties ordering of the ops.
ie once we pass that atomic_inc all previous ops will have completed.

Take a look at the latest updates. I made everything explicit and documented in the ggml_barrier.

Thread sanitizer is happy now and performance looks good (same as before).

@max-krasnyansky
Copy link
Collaborator Author

Quick update.

I realized llama-cli produces additional thread sanitizer warnings. I'm hunting those down and will post updates here. Those are also benign so I'm trying to see what's the best way to silence them without ugly hacks.

In terms of the overall correctness I further convinced myself that we should be good on that. As I mentioned above, the main thing we need to make sure is that all CPUs finish processing an Op (matmul, copy, etc) and that all of the memory writes complete before we exit ggml_barrier(). i.e. A true race condition here would mean that some of the memory writes got re-ordered and are still inflight but a thread (or several threads) already exited the barrier. This would cause us to start processing next Op prematurely. If we had something like that happening it should be easy to trigger on the machines with many cores/thread (like the EPYC) where you'd get different output for the same seed (or even garbage output).
This scenario should be impossible since, as I mentioned, every thread goes through the strictly ordered seq_cst atomic_add of the n_barrier.

The updates to the threadpool state itself are done only from the main thread, under the mutex (which is also a full barrier) and when the worker threads are either spinning on n_graph or sleeping on cond_var. So the warnings in that area are because the thread sanitizer not seeing expected acquire / release patterns not because we actually have race conditions there.

If you can think of a scenario where we do have a true race condition do let me know. Maybe I'm missing something.

llama-simple runs with the current version are clean on my M2 Max and Ryzen9 setups (with AppleClang-15 and LLVM/Clang-18).

llama-cli and 'llama-bench` generate some warnings.

@slaren
Copy link
Member

slaren commented Sep 14, 2024

I am not convinced that the seq_cst in the increase of n_barrier is enough to ensure that all the threads have the full result of the previous operations, because at the time this operation is performed not every thread may have finished the computation, so an additional synchronization is necessary between the waiting threads and the other threads as they complete their jobs. The issue with thread sanitizer that you linked seems to be specifically about atomic_thread_fence, but I don't thinks that it applies to this scenario. I also ran some tests and didn't see any performance drop from always performing the additional acquire/release (not just with thread sanitizer), so it doesn't seem to be any downside from keeping it.

@max-krasnyansky
Copy link
Collaborator Author

I am not convinced that the seq_cst in the increase of n_barrier is enough to ensure that all the threads have the full result of the previous operations, because at the time this operation is performed not every thread may have finished the computation, so an additional synchronization is necessary between the waiting threads and the other threads as they complete their jobs. The issue with thread sanitizer that you linked seems to be specifically about atomic_thread_fence, but I don't thinks that it applies to this scenario. I also ran some tests and didn't see any performance drop from always performing the additional acquire/release (not just with thread sanitizer), so it doesn't seem to be any downside from keeping it.

I'd going to try to convince you :)
Here is a sample timeline with three threads and how I think about this.

thread-0                    |  thread-1                   |  thread-2
                            |                             |
- completes Op-X            | - working on Op-X           | - working on Op-X
- mem updates complete      | ...                         | ...
- atomic_inc(n_barrier)     | ...                         | ...
- spin_on(n_barrier_passed) | ...                         | ...
...                         | ...                         | ...
...                         | - completes Op-X            | ...
...                         | - mem updates complete      | ...
...                         | - atomic_inc(n_barrier)     | ...
...                         | - spin_on(n_barrier_passed) | ...
...                         | ...                         | - completes Op-X
...                         | ...                         | - mem updates complete
...                         | ...                         | - atomic_inc(n_barrier) : last-thread
...                         | ...                         | - atomic_inc(n_barrier_passed)
- exit ggml_barrier         | ...                         | - exit ggml_barrier
- kickoff Op-Y              | ...                         | - spin_on(n_graph)
- update state              | ...                         | ...
- mutex + cond_broadcast    | - exit ggml_barrier         | ...
- atomic_inc(n_graph)       | - spin_on(n_graph)          | ...
- working on Op-Y           | ...                         | - working on Op-Y
- using Op-X output         | - working on Op-Y           | ...
...                         | - using Op-X output         | ...

There is no need for the threads to complete Op processing at the same time.
What's important is that their updates are not reordered with respect to ggml_barrier which is exactly what that
strict atomic_add(n_barrier, seq_cst) does.
I don't see were we could have issues with incomplete outputs. We won't start processing new Op until all threads
go through that.

Re: just doing strict ordering everywhere. It's hard to measure the overhead with high-level tests.
I definitely would like to avoid spinning loops that do full barriers. That's expensive and more power
hungry even if it doesn't affect token rates in obvious ways. I'll see if we can get some easy to use
metrics (from perf or maybe snapdragon profiler) that shows the issue, but just speaking from the
experience it'd be better to avoid doing lots of memory barriers.

@slaren
Copy link
Member

slaren commented Sep 14, 2024

I have no doubt that what you are saying is true in practice for the specific hardware. It certainly is for x86 where all atomic load/stores have rel/acq semantics and, chances are, both versions of the code generate the exact same asm. I defer to your knowledge about the way this works in ARM.

But ultimately we are not programming for any specific hardware, we are programming for the C virtual machine and the semantics specified thereof. Quoting cppreference.com:

If an atomic store in thread A is tagged memory_order_release, an atomic load in thread B from the same variable is tagged memory_order_acquire, and the load in thread B reads a value written by the store in thread A, then the store in thread A synchronizes-with the load in thread B.

The important part here is and the load in thread B reads a value written by the store in thread A. Thread 0 in your example does not load a value written by thread 1 or thread 2, so there is no guarantee that it will see the writes that happened before n_barrier is increased by these threads.

@github-actions github-actions bot added the testing Everything test related label Sep 16, 2024
@max-krasnyansky
Copy link
Collaborator Author

max-krasnyansky commented Sep 16, 2024

@slaren Sorry for the delayed response.
I did a bunch more experiments and added a micro benchmark / test (test-barrier) to further tune/debug/profile things.

The threading/sequence example I provided above is actually generic and assumes the C/C++ memory order semantics (not a specific arch). Perhaps, I shortened the ops a bit too much. The atomic_inc(n_barrier) in the above example was referring to atomic_fetch_add_explicit(n_barrier, memory_order_seq_cst) that we have in the code. It's a read-modify-write operation that enforces full memory ordering.

Here is the reference from the same source (https://en.cppreference.com/w/c/atomic/memory_order)

memory_order_seq_cst | A load operation with this memory order performs an acquire operation, a store performs
a release operation, and read-modify-write performs both an acquire operation and a release operation, plus a single total
order exists in which all threads observe all modifications in the same order (see Sequentially-consistent ordering below).

On the arm64 (armv8.2-a and up) that translates to LDADDAL instruction.
LDADDAL Xs, Xt, [Xn|SP] ; 64-bit, acquire and release general registers

In other words, once all the threads go through that atomic_fetch_add it's guaranteed that all their outstanding writes and reads (i.e the output of the previous Op-X) have been flushed. And our threads are not issuing any new write or reads till all of the threads go through the barrier (ie till they see n_barrier_passed change);

btw The Thread Sanitizer issue I linked to earlier (about the fences) is similar in the sense that this 'atomic_fetch_add_explicit(n_barrier, memory_order_seq_cst)' is acting as a full fence. And OpenMP causes the exact same confusion for the Thread Sanitizer. #pragma omp barrier is a full blown fence/sync-point but the Thread Sanitize doesn't understand that.

Now, the new test that I added (it does tons of ggml_barriers) did highlight the need for making n_barrier_passed update use memory_order_seq_cst as well. So, I made the ggml_barrier simpler and more robust. Thread Sanitizer is now happy (with llama-simple and test-barrier) without any hacks, and I'm happy that we don't need a heavy barrier while spinning :)

M2 Max and Snapdragon Gen 3 are looking good. But I didn't yet get a chance to do more testing on the Ryzen, EPYC and X-Elite yet. Will try to do that later today and provide an update.

Currently all threads do N polling rounds even if only 1 thread is active (n_threads_cur == 1).
This commit adds a check to skip the polling for unused threads (ith >= n_threads_cur).

n_threads_cur is now an atomic_int to explicitly tell thread sanitizer that it is written
from one thread and read from other threads (not a race conditions).
Avoid using strict memory order while polling, yet make sure that all threads go through
full memory barrier (memory fence) on ggml_barrier entrace and exit.
This test does lots of small, parallel matmul ops where the barriers in between dominate the overhead.
Using the same tricks as ggml_barrier. All the polling is done with relaxed memory order
to keep it efficient, once the new graph is detected we do full fence using read-modify-write
with strict memory order.
Do not use threadpool->ec (exit code) to decide whether to exit the compute loop.
threadpool->ec is not atomic which makes thread-sanitizer rightfully unhappy about it.

Instead introduce atomic threadpool->abort flag used for this. This is consistent with
how we handle threadpool->stop or pause.

While at it add an explicit atomic_load for n_threads_cur for consistency.
fixes use-after-free detected by gcc thread-sanitizer on x86-64
for some reason llvm sanitizer is not detecting this issue.
@max-krasnyansky
Copy link
Collaborator Author

max-krasnyansky commented Sep 17, 2024

@slaren @ggerganov

Take a look at the latest. Seems like this should be good to go:

  • All TSAN (thread-sanitizer) warnings have been resolved with GCC and LLVM on x86-64 and arm64
    • This applies to all tools that use different scenarios llama-simple, llama-cli, llama-bench, test-barrier (new test)
  • All graph_compute paths should be clean and optimal now
    • We use relaxed / low-overhead ops while polling and enforce full memory fences on ggml_barrier entry & exit and for new graph signaling
    • Some of the changes (ie the extra strictness) is a bit of an overkill but TSAN is happy and overall things should be robust and easier to follow (I added comments in the code about sync assumptions).
  • Reviewed generated assembly for ggml_barrier on arm64 and x86-64. Looks really clean.
  • Added new test test-barrier that specifically stresses ggml_barrier sync

Tested on M2 Max, Snapdragon Gen3 (S24 Ultra), Ryzen 3950X, EPYC 7543.
Tested the following build flavors:

  • Clang 18 with and without TSAN, with and without OpenMP
  • GCC 11 with and without TSAN, with and without OpenMP
  • AppleClang 15 with and without TSAN

Performance-wise things are looking good. llama-bench numbers are about the same. I didn't get a chance to do detailed sweeps but just eyeballing same benchmark runs we did during threadpool development (small models, medium models, etc) all look about same.

On the arm64 our threadpool does quite a bit better than OpenMP.
On x86-64 (Ryzen and EPYC) we do better than OpenMP on 4-8 threads, about same on up to 10 threads, and quite a bit worse on 16+ threads. Looks like OpenMP has some more tricks for larger number of threads that we can add later
(they are probably grouping threads and are using separate atomic counters per group to reduce contention or something like that).

Here are some microbenchmark numbers using that new test.

S24 Ultra

Our default Android NDK armv8.7 build with and without OpenMP.

$ adb shell 'cd /data/local/tmp/lmcp/omp; LD_LIBRARY_PATH=$(pwd) simpleperf stat ./test-barrier 4 1000'
graph-compute with
 n_threads: 4
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 1951499 usec 
 1951.5 usec per-iter
 975.75 nsec per-node         <<<<<<<<<<<<<<
Performance counter statistics:

#           count  event_name                # count / runtime
   24,997,955,221  cpu-cycles                # 3.174760 GHz      
      420,417,199  stalled-cycles-frontend   # 53.393 M/sec      
   11,869,259,205  stalled-cycles-backend    # 1.507 G/sec       
   62,218,492,542  instructions              # 7.902 G/sec       
       32,574,249  branch-misses             # 4.137 M/sec       
  7873.983098(ms)  task-clock                # 3.979660 cpus used
               55  context-switches          # 6.985 /sec        
            1,194  page-faults               # 151.639 /sec      

Total test time: 1.978557 seconds.

$ adb shell 'cd /data/local/tmp/lmcp/threadpool; LD_LIBRARY_PATH=$(pwd) simpleperf stat ./test-barrier 4 1000'
graph-compute with
 n_threads: 4
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 1512980 usec 
 1512.98 usec per-iter
 756.49 nsec per-node      <<<<<<<<<<<<<<<
Performance counter statistics:

#           count  event_name                # count / runtime
   19,398,999,287  cpu-cycles                # 3.177219 GHz      
      335,061,364  stalled-cycles-frontend   # 54.877 M/sec      
    9,729,533,849  stalled-cycles-backend    # 1.594 G/sec       
   43,824,011,735  instructions              # 7.178 G/sec       
       30,679,054  branch-misses             # 5.025 M/sec       
  6105.568248(ms)  task-clock                # 3.974757 cpus used
               69  context-switches          # 11.301 /sec       
            1,191  page-faults               # 195.068 /sec      

Total test time: 1.536086 seconds.

AMD Ryzen 9 3950X

LLVM 18 build.

$ for i in omp tp; do echo $i ---; perf stat ./build-$i/bin/test-barrier 8 1000; done
omp ---
graph-compute with
 n_threads: 8
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 4140192 usec 
 4140.19 usec per-iter
 2070.1 nsec per-node     <<<<<<<<<<<<<<<

 Performance counter stats for './build-omp/bin/test-barrier 8 1000':

         33,194.24 msec task-clock                       #    7.978 CPUs utilized             
               107      context-switches                 #    3.223 /sec                      
                39      cpu-migrations                   #    1.175 /sec                      
             5,215      page-faults                      #  157.106 /sec                      
   140,454,317,989      cycles                           #    4.231 GHz                         (83.31%)
     6,908,434,135      stalled-cycles-frontend          #    4.92% frontend cycles idle        (83.40%)
   101,658,825,977      stalled-cycles-backend           #   72.38% backend cycles idle         (83.33%)
   105,428,191,577      instructions                     #    0.75  insn per cycle            
                                                  #    0.96  stalled cycles per insn     (83.32%)
    23,001,083,248      branches                         #  692.924 M/sec                       (83.32%)
        53,624,810      branch-misses                    #    0.23% of all branches             (83.32%)

       4.160851819 seconds time elapsed

      33.167168000 seconds user
       0.027999000 seconds sys


tp ---
graph-compute with
 n_threads: 8
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 3997023 usec 
 3997.02 usec per-iter
 1998.51 nsec per-node     <<<<<<<<<<<<<<<

 Performance counter stats for './build-tp/bin/test-barrier 8 1000':

         32,026.64 msec task-clock                       #    7.981 CPUs utilized             
                60      context-switches                 #    1.873 /sec                      
                 4      cpu-migrations                   #    0.125 /sec                      
             5,117      page-faults                      #  159.773 /sec                      
   135,415,400,254      cycles                           #    4.228 GHz                         (83.38%)
     3,043,182,557      stalled-cycles-frontend          #    2.25% frontend cycles idle        (83.35%)
   123,767,255,179      stalled-cycles-backend           #   91.40% backend cycles idle         (83.31%)
    35,564,865,094      instructions                     #    0.26  insn per cycle            
                                                  #    3.48  stalled cycles per insn     (83.31%)
     3,756,973,809      branches                         #  117.308 M/sec                       (83.31%)
        36,963,824      branch-misses                    #    0.98% of all branches             (83.35%)

       4.012665163 seconds time elapsed

      32.011009000 seconds user
       0.007998000 seconds sys

$ for i in omp tp; do echo $i ---; perf stat ./build-$i/bin/test-barrier 16 1000; done
omp ---
graph-compute with
 n_threads: 16
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 5465623 usec 
 5465.62 usec per-iter
 2732.81 nsec per-node   <<<<<<<<<<<<

 Performance counter stats for './build-omp/bin/test-barrier 16 1000':

         87,613.02 msec task-clock                       #   15.959 CPUs utilized             
               188      context-switches                 #    2.146 /sec                      
                76      cpu-migrations                   #    0.867 /sec                      
             7,199      page-faults                      #   82.168 /sec                      
   364,924,013,957      cycles                           #    4.165 GHz                         (83.34%)
    14,137,988,220      stalled-cycles-frontend          #    3.87% frontend cycles idle        (83.34%)
   268,616,864,028      stalled-cycles-backend           #   73.61% backend cycles idle         (83.31%)
   266,896,527,467      instructions                     #    0.73  insn per cycle            
                                                  #    1.01  stalled cycles per insn     (83.32%)
    69,209,501,698      branches                         #  789.945 M/sec                       (83.35%)
       109,803,340      branch-misses                    #    0.16% of all branches             (83.34%)

       5.489813742 seconds time elapsed

      87.605444000 seconds user
       0.007998000 seconds sys


tp ---
graph-compute with
 n_threads: 16
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 8510498 usec 
 8510.5 usec per-iter
 4255.25 nsec per-node        <<<<<<<<<<<<<<<<<<

 Performance counter stats for './build-tp/bin/test-barrier 16 1000':

        136,323.95 msec task-clock                       #   15.981 CPUs utilized             
               228      context-switches                 #    1.672 /sec                      
                54      cpu-migrations                   #    0.396 /sec                      
             6,964      page-faults                      #   51.084 /sec                      
   569,904,400,152      cycles                           #    4.181 GHz                         (83.33%)
    13,056,232,384      stalled-cycles-frontend          #    2.29% frontend cycles idle        (83.36%)
   536,932,593,729      stalled-cycles-backend           #   94.21% backend cycles idle         (83.33%)
    55,025,127,922      instructions                     #    0.10  insn per cycle            
                                                  #    9.76  stalled cycles per insn     (83.33%)
     6,798,291,363      branches                         #   49.869 M/sec                       (83.33%)
        87,812,236      branch-misses                    #    1.29% of all branches             (83.33%)

       8.530302492 seconds time elapsed

     136.276698000 seconds user
       0.007997000 seconds sys



@slaren
Copy link
Member

slaren commented Sep 17, 2024

Looks good to me. I get these results with test-barrier:

13900k
omp - 1 threads
graph-compute with
 n_threads: 1
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 947999 usec
 947.999 usec per-iter
 474 nsec per-node

tp - 1 threads
graph-compute with
 n_threads: 1
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 976557 usec
 976.557 usec per-iter
 488.279 nsec per-node

==
omp - 2 threads
graph-compute with
 n_threads: 2
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 1243895 usec
 1243.9 usec per-iter
 621.948 nsec per-node

tp - 2 threads
graph-compute with
 n_threads: 2
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 1112920 usec
 1112.92 usec per-iter
 556.46 nsec per-node

==
omp - 4 threads
graph-compute with
 n_threads: 4
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 1490911 usec
 1490.91 usec per-iter
 745.456 nsec per-node

tp - 4 threads
graph-compute with
 n_threads: 4
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 1440426 usec
 1440.43 usec per-iter
 720.213 nsec per-node

==
omp - 8 threads
graph-compute with
 n_threads: 8
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 2078942 usec
 2078.94 usec per-iter
 1039.47 nsec per-node

tp - 8 threads
graph-compute with
 n_threads: 8
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 3161917 usec
 3161.92 usec per-iter
 1580.96 nsec per-node

==
omp - 16 threads
graph-compute with
 n_threads: 16
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 3584546 usec
 3584.55 usec per-iter
 1792.27 nsec per-node

tp - 16 threads
graph-compute with
 n_threads: 16
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 9161812 usec
 9161.81 usec per-iter
 4580.91 nsec per-node

==
omp - 32 threads
graph-compute with
 n_threads: 32
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 6308977 usec
 6308.98 usec per-iter
 3154.49 nsec per-node

tp - 32 threads
graph-compute with
 n_threads: 32
   n_nodes: 2000
  n_rounds: 1000
graph-compute took 20437237 usec
 20437.2 usec per-iter
 10218.6 nsec per-node

@max-krasnyansky
Copy link
Collaborator Author

Looks good to me. I get these results with test-barrier:

13900k

Definitely makes sense to followup on improving n_threads > 8 cases on x86-64 in the next iteration (i.e as part of threadpool v3 that we discussed).
Thanks for sharing the numbers.

@ggerganov
Copy link
Member

I've done a few tests on M1 Pro, M2 Ultra and Ryzen 9 5950X and all seems good. Thank you.

@ggerganov ggerganov merged commit 0226613 into ggml-org:master Sep 17, 2024
52 checks passed
@jxy
Copy link
Contributor

jxy commented Sep 18, 2024

OpenMP with metal is broken after this commit on a M2 with Sequoia 15.0, using clang 18.1.8.

  • works: -t 4 -ngl 0 or -t 1 -ngl 99
  • broken (wrong output): -t 4 -ngl 99

@max-krasnyansky
Copy link
Collaborator Author

max-krasnyansky commented Sep 18, 2024

OpenMP with metal is broken after this commit on a M2 with Sequoia 15.0, using clang 18.1.8.

  • works: -t 4 -ngl 0 or -t 1 -ngl 99
  • broken (wrong output): -t 4 -ngl 99

Oh. Interesting. Can you please share how exactly you build it?
(cmake command, etc)
And which model you're using.

@max-krasnyansky
Copy link
Collaborator Author

OpenMP with metal is broken after this commit on a M2 with Sequoia 15.0, using clang 18.1.8.

  • works: -t 4 -ngl 0 or -t 1 -ngl 99
  • broken (wrong output): -t 4 -ngl 99

Ok. I was able to reproduce it on M2 Max with Sequoia 15 and llvm 18.

CC=clang CXX=clang++ CFLAGS="-march=armv8.7-a" CXXFLAGS="-march=armv8.7-a" cmake -D GGML_OPENMP=ON -B build-metal-omp .
...
CC=clang CXX=clang++ CFLAGS="-march=armv8.7-a" CXXFLAGS="-march=armv8.7-a" cmake -D GGML_OPENMP=OFF -B build-metal-tp .
~/src/llama.cpp-master$ ./build-metal-tp/bin/llama-cli -m ../gguf/llama-v2-7b.q4_0.gguf -n 40 -ngl 99 --seed 42 -t 8 -f ../hawaii-v2
...
Please summarize previous passage.
[/INST]  Sure! Here's a summary of the passage:
Hawaii is an island state of the United States located in the Pacific Ocean, about 2,000 miles from

llama_perf_sampler_print:    sampling time =       0.90 ms /  1248 runs   (    0.00 ms per token, 1391304.35 tokens per second)
llama_perf_context_print:        load time =     363.42 ms
llama_perf_context_print: prompt eval time =    2261.68 ms /  1208 tokens (    1.87 ms per token,   534.12 tokens per second)
llama_perf_context_print:        eval time =     748.75 ms /    39 runs   (   19.20 ms per token,    52.09 tokens per second)
llama_perf_context_print:       total time =    3015.21 ms /  1247 tokens

~/src/llama.cpp-master$ ./build-metal-omp/bin/llama-cli -m ../gguf/llama-v2-7b.q4_0.gguf -n 40 -ngl 99 --seed 42 -t 8 -f ../hawaii-v2
...
Please summarize previous passage.
[/INST]MSMSMSMSMSMSMSMSMSMSMSMSPAMSPAMSMSMSMSMSMSMSPAMSMSMSMS MSMSPAMSMSMS MSMSMSMSMSMSMS

llama_perf_sampler_print:    sampling time =       0.89 ms /  1248 runs   (    0.00 ms per token, 1406989.85 tokens per second)
llama_perf_context_print:        load time =     376.54 ms
llama_perf_context_print: prompt eval time =    2259.75 ms /  1208 tokens (    1.87 ms per token,   534.57 tokens per second)
llama_perf_context_print:        eval time =     747.02 ms /    39 runs   (   19.15 ms per token,    52.21 tokens per second)
llama_perf_context_print:       total time =    3011.25 ms /  1247 tokens
ggml_metal_free: deallocating

Interestingly enough. If just a single layer runs on on the CPU then it works fine

~/src/llama.cpp-master$ ./build-metal-omp/bin/llama-cli -m ../gguf/llama-v2-7b.q4_0.gguf -n 40 -ngl 31 --seed 42 -t 8 -f ../hawaii-v2
...
Please summarize previous passage.
[/INST]  Sure! Here's a summary of the passage:
Hawaii is an island state of the United States located in the Pacific Ocean, about 2,000 miles from

llama_perf_sampler_print:    sampling time =       1.16 ms /  1248 runs   (    0.00 ms per token, 1076790.34 tokens per second)
llama_perf_context_print:        load time =     372.13 ms
llama_perf_context_print: prompt eval time =    2625.05 ms /  1208 tokens (    2.17 ms per token,   460.18 tokens per second)
llama_perf_context_print:        eval time =     875.85 ms /    39 runs   (   22.46 ms per token,    44.53 tokens per second)
llama_perf_context_print:       total time =    3505.91 ms /  1247 tokens
ggml_metal_free: deallocating

I'll try to figure out what exactly broke with OpenMP in this case. It's not immediately obvious.

@max-krasnyansky
Copy link
Collaborator Author

OpenMP with metal is broken after this commit on a M2 with Sequoia 15.0, using clang 18.1.8.

  • works: -t 4 -ngl 0 or -t 1 -ngl 99
  • broken (wrong output): -t 4 -ngl 99

Fixed in
#9538

dsx1986 pushed a commit to dsx1986/llama.cpp that referenced this pull request Oct 29, 2024
* threadpool: skip polling for unused threads

Currently all threads do N polling rounds even if only 1 thread is active (n_threads_cur == 1).
This commit adds a check to skip the polling for unused threads (ith >= n_threads_cur).

n_threads_cur is now an atomic_int to explicitly tell thread sanitizer that it is written
from one thread and read from other threads (not a race conditions).

* threadpool: further simplify and improve ggml_barrier

Avoid using strict memory order while polling, yet make sure that all threads go through
full memory barrier (memory fence) on ggml_barrier entrace and exit.

* threads: add simple barrier test

This test does lots of small, parallel matmul ops where the barriers in between dominate the overhead.

* threadpool: improve thread sync for new-graphs

Using the same tricks as ggml_barrier. All the polling is done with relaxed memory order
to keep it efficient, once the new graph is detected we do full fence using read-modify-write
with strict memory order.

* threadpool: improve abort handling

Do not use threadpool->ec (exit code) to decide whether to exit the compute loop.
threadpool->ec is not atomic which makes thread-sanitizer rightfully unhappy about it.

Instead introduce atomic threadpool->abort flag used for this. This is consistent with
how we handle threadpool->stop or pause.

While at it add an explicit atomic_load for n_threads_cur for consistency.

* test-barrier: release threadpool before releasing the context

fixes use-after-free detected by gcc thread-sanitizer on x86-64
for some reason llvm sanitizer is not detecting this issue.
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 15, 2024
* threadpool: skip polling for unused threads

Currently all threads do N polling rounds even if only 1 thread is active (n_threads_cur == 1).
This commit adds a check to skip the polling for unused threads (ith >= n_threads_cur).

n_threads_cur is now an atomic_int to explicitly tell thread sanitizer that it is written
from one thread and read from other threads (not a race conditions).

* threadpool: further simplify and improve ggml_barrier

Avoid using strict memory order while polling, yet make sure that all threads go through
full memory barrier (memory fence) on ggml_barrier entrace and exit.

* threads: add simple barrier test

This test does lots of small, parallel matmul ops where the barriers in between dominate the overhead.

* threadpool: improve thread sync for new-graphs

Using the same tricks as ggml_barrier. All the polling is done with relaxed memory order
to keep it efficient, once the new graph is detected we do full fence using read-modify-write
with strict memory order.

* threadpool: improve abort handling

Do not use threadpool->ec (exit code) to decide whether to exit the compute loop.
threadpool->ec is not atomic which makes thread-sanitizer rightfully unhappy about it.

Instead introduce atomic threadpool->abort flag used for this. This is consistent with
how we handle threadpool->stop or pause.

While at it add an explicit atomic_load for n_threads_cur for consistency.

* test-barrier: release threadpool before releasing the context

fixes use-after-free detected by gcc thread-sanitizer on x86-64
for some reason llvm sanitizer is not detecting this issue.
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 18, 2024
* threadpool: skip polling for unused threads

Currently all threads do N polling rounds even if only 1 thread is active (n_threads_cur == 1).
This commit adds a check to skip the polling for unused threads (ith >= n_threads_cur).

n_threads_cur is now an atomic_int to explicitly tell thread sanitizer that it is written
from one thread and read from other threads (not a race conditions).

* threadpool: further simplify and improve ggml_barrier

Avoid using strict memory order while polling, yet make sure that all threads go through
full memory barrier (memory fence) on ggml_barrier entrace and exit.

* threads: add simple barrier test

This test does lots of small, parallel matmul ops where the barriers in between dominate the overhead.

* threadpool: improve thread sync for new-graphs

Using the same tricks as ggml_barrier. All the polling is done with relaxed memory order
to keep it efficient, once the new graph is detected we do full fence using read-modify-write
with strict memory order.

* threadpool: improve abort handling

Do not use threadpool->ec (exit code) to decide whether to exit the compute loop.
threadpool->ec is not atomic which makes thread-sanitizer rightfully unhappy about it.

Instead introduce atomic threadpool->abort flag used for this. This is consistent with
how we handle threadpool->stop or pause.

While at it add an explicit atomic_load for n_threads_cur for consistency.

* test-barrier: release threadpool before releasing the context

fixes use-after-free detected by gcc thread-sanitizer on x86-64
for some reason llvm sanitizer is not detecting this issue.
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Feb 25, 2025
commit bb2a4d125e33148d2b4e9363bf8ace14f722a610
Author: Nexesenex <[email protected]>
Date:   Mon Nov 11 08:59:32 2024 +0100

    8x22b

commit 9d4926ff9559ecae25f19fadcb55586677575b61
Merge: 9c65f44 b0cefea
Author: Nexesenex <[email protected]>
Date:   Mon Nov 11 08:59:07 2024 +0100

    Merge branch 'master' into Nexes_CQ30

commit 9c65f44
Author: Nexesenex <[email protected]>
Date:   Sun Nov 3 04:30:14 2024 +0100

    Test base 2048

commit 8ccafe8
Merge: d0d276f 9830b69
Author: Nexesenex <[email protected]>
Date:   Sun Nov 3 04:28:33 2024 +0100

    Merge branch 'master' into Nexes_CQ30

commit d0d276f
Merge: 7cecefd 418f5ee
Author: Nexesenex <[email protected]>
Date:   Fri Nov 1 20:18:27 2024 +0100

    Merge branch 'master' into Nexes_CQ30

commit 7cecefd
Merge: a5303b7 8841ce3
Author: Nexesenex <[email protected]>
Date:   Mon Oct 28 06:45:16 2024 +0100

    Merge branch 'master' into Nexes_CQ30

commit a5303b7
Merge: f21ab1e 167a515
Author: Nexesenex <[email protected]>
Date:   Thu Oct 24 19:55:36 2024 +0200

    Merge branch 'master' into Nexes_CQ30

commit f21ab1e
Merge: c72289e 20011f1
Author: Nexesenex <[email protected]>
Date:   Wed Oct 23 20:26:42 2024 +0200

    Merge branch 'gg/default-kq-f32-prec' into Nexes_CQ20

commit c72289e
Merge: eaee12e 190a37d
Author: Nexesenex <[email protected]>
Date:   Wed Oct 23 20:26:34 2024 +0200

    Merge branch 'master' into Nexes_CQ20

commit 20011f1
Author: Georgi Gerganov <[email protected]>
Date:   Wed Oct 23 14:32:27 2024 +0300

    llama : switch KQ multiplication to use F32 precision by default

    ggml-ci

commit eaee12e
Author: Nexesenex <[email protected]>
Date:   Mon Oct 21 15:41:24 2024 +0200

    EXL SXL and UXL types to test the new bits formula

commit 6abef2a
Merge: aa73a4e d5ebd79
Author: Nexesenex <[email protected]>
Date:   Mon Oct 21 15:40:22 2024 +0200

    Merge branch 'master' into Nexes_CQ20

commit aa73a4e
Author: Nexesenex <[email protected]>
Date:   Sat Oct 19 19:04:33 2024 +0200

    use_some_bits and use_most_bits

commit 7794c8f
Merge: 1cf274d cda0e4b
Author: Nexesenex <[email protected]>
Date:   Sat Oct 19 19:04:05 2024 +0200

    Merge branch 'master' into Nexes_CQ20

commit 1cf274d
Author: Nexesenex <[email protected]>
Date:   Fri Oct 18 21:00:56 2024 +0200

    ML UXL and EXL boost

commit f105e0f
Author: Nexesenex <[email protected]>
Date:   Fri Oct 18 21:05:49 2024 +0200

    Revert compile for Ampere

commit 1b25cbb
Author: Nexesenex <[email protected]>
Date:   Fri Oct 18 21:05:04 2024 +0200

    Delete CMakePresets.json

commit 1c440a8
Merge: 366e0c8 afd9909
Author: Nexesenex <[email protected]>
Date:   Fri Oct 18 20:42:34 2024 +0200

    Merge branch 'master' into Nexes_CQ20

commit 366e0c8
Author: Nexesenex <[email protected]>
Date:   Wed Oct 16 16:57:30 2024 +0200

    Fix indent model sizes

commit cf8375c
Author: Nexesenex <[email protected]>
Date:   Wed Oct 16 16:41:57 2024 +0200

    continue Q5_K mixes

commit 2d052f7
Author: Nexesenex <[email protected]>
Date:   Tue Oct 15 17:42:48 2024 +0200

    difquants three/four eights alt for Mistral Large

commit 29cecae
Author: Nexesenex <[email protected]>
Date:   Tue Oct 15 16:03:12 2024 +0200

    Q5_K_XSR, SR, ML, and XL revamp

commit 412b56f
Author: Nexesenex <[email protected]>
Date:   Mon Oct 14 17:08:23 2024 +0200

    IQ3_X5L and IQ3_X7L fix for Mistral Large

commit ca86ce8
Author: Nexesenex <[email protected]>
Date:   Mon Oct 14 15:24:37 2024 +0200

    Pursue IQ3 revamp

commit 6c51f39
Author: Nexesenex <[email protected]>
Date:   Sun Oct 13 22:22:40 2024 +0200

    IQ3_XXXXL, EXL and renaming >=IQ3_ML scheme

    Test for Mistral Large

    IQ3_XL = IQ3_X5L and so on.

commit 64bfe69
Author: Nexesenex <[email protected]>
Date:   Sun Oct 13 22:33:05 2024 +0200

    Activate F16

commit 575ebc2
Merge: 38229d3 d4c19c0
Author: Nexesenex <[email protected]>
Date:   Sun Oct 13 22:22:30 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 38229d3
Author: Nexesenex <[email protected]>
Date:   Sat Oct 12 20:57:18 2024 +0200

    Fix specify tensors in quantize

commit b947b6e
Author: Nexesenex <[email protected]>
Date:   Sat Oct 12 13:38:22 2024 +0200

    New FTYPE Q5_K_XL

commit ba1b854
Author: Nexesenex <[email protected]>
Date:   Sat Oct 12 13:36:25 2024 +0200

    New FTYPE IQ4_XXSR

    and beef up attn_k IQ4_XSR

commit 79fa98c
Author: Nexesenex <[email protected]>
Date:   Sun Oct 13 02:00:38 2024 +0200

    GGML_MAX_COPIES_1 in CML

commit f95ed01
Merge: accd71d edc2656
Author: Nexesenex <[email protected]>
Date:   Sun Oct 13 02:02:06 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit accd71d
Merge: b5103f4 11ac980
Author: Nexesenex <[email protected]>
Date:   Sat Oct 12 13:23:11 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit b5103f4
Author: Nexesenex <[email protected]>
Date:   Fri Oct 11 13:43:48 2024 +0200

    Better model info (ikawrakow#84)

    Co-Authored-By: Kawrakow <[email protected]>

commit b302561
Author: Nexesenex <[email protected]>
Date:   Fri Oct 11 13:17:39 2024 +0200

    IQ3_UXL for test

commit 8c6e408
Merge: 66a9b05 7eee341
Author: Nexesenex <[email protected]>
Date:   Fri Oct 11 13:17:30 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 66a9b05
Author: Nexesenex <[email protected]>
Date:   Wed Oct 9 04:30:45 2024 +0200

    correct iQ4_LR

commit 298990a
Merge: f1814f1 dca1d4b
Author: Nexesenex <[email protected]>
Date:   Tue Oct 8 22:11:53 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit f1814f1
Author: Nexesenex <[email protected]>
Date:   Mon Oct 7 23:21:56 2024 +0200

    Rebump attn_v

commit b94a9b0
Merge: 18677c8 6374743
Author: Nexesenex <[email protected]>
Date:   Mon Oct 7 23:21:38 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 18677c8
Author: Nexesenex <[email protected]>
Date:   Sun Oct 6 02:12:09 2024 +0200

    IQ4_LR

commit a2500c1
Author: Nexesenex <[email protected]>
Date:   Sun Oct 6 02:12:55 2024 +0200

    Crack down fallback GGML_types

commit 75b8800
Author: Nexesenex <[email protected]>
Date:   Sat Oct 5 23:18:02 2024 +0200

    More overhaul for IQ4_XSR and new IQ4_MR

commit 167a3c5
Author: Nexesenex <[email protected]>
Date:   Sat Oct 5 17:17:50 2024 +0200

    GGML SCHED MAX COPIES 1

commit 8433050
Author: Nexesenex <[email protected]>
Date:   Sat Oct 5 17:14:39 2024 +0200

    Adapt CML

commit 1e0f64e
Author: Nexesenex <[email protected]>
Date:   Sat Oct 5 17:07:07 2024 +0200

    Compile for Ampere

commit 35ce3f6
Merge: 6480054 8c475b9
Author: Nexesenex <[email protected]>
Date:   Sat Oct 5 17:03:34 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 6480054
Author: Nexesenex <[email protected]>
Date:   Fri Oct 4 18:21:54 2024 +0200

    IQ4_XSR revamp

commit 1ec8328
Author: Nexesenex <[email protected]>
Date:   Mon Aug 19 17:00:34 2024 +0200

    Clarify PPL result

commit de50e13
Merge: ed67589 d5ed2b9
Author: Nexesenex <[email protected]>
Date:   Thu Oct 3 22:23:08 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit ed67589
Merge: 06ab3a2 70392f1
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 10:22:50 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 06ab3a2
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 10:22:46 2024 +0200

    More size logging

commit 9d97928
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 10:21:25 2024 +0200

    Update llama.cpp

commit 700d205
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 03:51:26 2024 +0200

    IQ3_XS more

commit da840a3
Merge: 056c47d 116efee
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 03:30:18 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 056c47d
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 03:30:15 2024 +0200

    Reapply "threadpool : skip polling for unused threads (ggml-org#9461)"

    This reverts commit 2a8dbf8.

commit 8d789ac
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 03:20:58 2024 +0200

    IQ3_XS

commit 413fc43
Author: Nexesenex <[email protected]>
Date:   Mon Sep 23 19:34:45 2024 +0200

    Fix IQ3 <=M

commit 9ed3522
Merge: 2a8dbf8 1d48e98
Author: Nexesenex <[email protected]>
Date:   Mon Sep 23 18:50:43 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 2a8dbf8
Author: Nexesenex <[email protected]>
Date:   Sun Sep 22 02:48:50 2024 +0200

    Revert "threadpool : skip polling for unused threads (ggml-org#9461)"

    This reverts commit 0226613.

commit 6faac9f
Author: Nexesenex <[email protected]>
Date:   Sun Sep 22 02:46:37 2024 +0200

    Revert "Update CUDA graph on scale change plus clear nodes/params  (ggml-org#9550)"

    This reverts commit 41f4778.

commit f377f88
Merge: e3ec684 d09770c
Author: Nexesenex <[email protected]>
Date:   Sat Sep 21 17:25:04 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit e3ec684
Author: Nexesenex <[email protected]>
Date:   Fri Sep 20 06:36:47 2024 +0200

    reinsert cqs

commit d48aad3
Author: Nexesenex <[email protected]>
Date:   Mon Sep 2 05:50:08 2024 +0200

    Play with IQ3 quants

commit 5af6481
Author: Nexesenex <[email protected]>
Date:   Mon Sep 2 01:41:19 2024 +0200

    IQ4_XSR_rework

commit dd770d2
Author: Nexesenex <[email protected]>
Date:   Sat Aug 31 17:05:00 2024 +0200

    refine IQ3 quants

commit 32ce04a
Author: Nexesenex <[email protected]>
Date:   Sat Aug 31 14:22:00 2024 +0200

    Use of vocab as difquant criteria

    The pre-vocab>128k models are more sensitive to ffn_down quant than to ffn_gate and up.

commit 86a7e4a
Author: Nexesenex <[email protected]>
Date:   Fri Aug 30 12:15:54 2024 +0200

    IQ3_UXL

commit 97fbd74
Author: Nexesenex <[email protected]>
Date:   Thu Aug 29 22:40:32 2024 +0200

    New difquant seven_eights

commit c6732bf
Author: Nexesenex <[email protected]>
Date:   Wed Aug 28 16:06:38 2024 +0200

    Bump a bit output for big models in IQ2 and IQ3

commit cce61d3
Author: Nexesenex <[email protected]>
Date:   Wed Aug 28 13:00:53 2024 +0200

    Difquant attn_q and attn_o for IQ3_XXS, XS, and S

    And also establishing a bump to difquant_first_last_tensors for attn_k and attn_v

commit 1e7e816
Author: Nexesenex <[email protected]>
Date:   Wed Aug 28 02:24:55 2024 +0200

    Add IQ3_ML, reinstate IQ3_XXXL

commit 7b0dc30
Author: Nexesenex <[email protected]>
Date:   Wed Aug 28 00:52:45 2024 +0200

    Bump IQ3_XS

commit 6263649
Author: Nexesenex <[email protected]>
Date:   Tue Aug 27 16:19:10 2024 +0200

    Revert variable V below Q5_K

commit eb4a69e
Author: Nexesenex <[email protected]>
Date:   Tue Aug 27 13:26:15 2024 +0200

    Difquant for IQ2_XL & IQ3 for attn_k and attn_v

    And prepare difquant for these quants for attn_o and attn_q

commit c84d981
Author: Nexesenex <[email protected]>
Date:   Tue Aug 27 06:13:39 2024 +0200

    correct settings

commit c667f2e
Author: Nexesenex <[email protected]>
Date:   Mon Aug 26 23:05:04 2024 +0200

    Temporary settings for IQ3 attn_k and attn_v

commit 294aeec
Author: Nexesenex <[email protected]>
Date:   Mon Aug 26 18:18:05 2024 +0200

    Corrections and clean-up

    Back to Q8_0 for attn_k and attn_v if 8 experts or more.

    for attn_v and attn_k if experts>=4
    GQA>=12 brought back to expert>=4 quant level instead of 8
    GQA8 brought to GQA7, and GQA7 brought to GQA4.

commit e7c5163
Author: Nexesenex <[email protected]>
Date:   Mon Aug 26 14:33:34 2024 +0200

    Shrink a bit Q2_K when GQA<2

    and optimize difquants_first_last and fl_more

commit ff48606
Author: Nexesenex <[email protected]>
Date:   Mon Aug 26 14:02:09 2024 +0200

    IQI_XL, IQ2_S, IQ2_XS enhanced

commit 8a1ab24
Author: Nexesenex <[email protected]>
Date:   Mon Aug 26 12:57:21 2024 +0200

    IQ1_XS,  IQ1_S, IQ1_M, IQ2_XXS, Q2_M, Q2_K enhanced

    testing templates for other quants.

commit 26aac8e
Author: Nexesenex <[email protected]>
Date:   Sun Aug 25 14:42:33 2024 +0200

    Soften the token embeddings bump for experts >= 4

commit 5644d4c
Merge: 16aee45 6026da5
Author: Nexesenex <[email protected]>
Date:   Fri Sep 20 01:38:20 2024 +0200

    Merge branch 'master' into pr/8836

commit 16aee45
Author: Nexesenex <[email protected]>
Date:   Sun Aug 25 14:25:46 2024 +0200

    correction

commit dd3df75
Author: Nexesenex <[email protected]>
Date:   Sun Aug 25 03:30:36 2024 +0200

    Bad indents and trailing whitespaces

commit f63860e
Author: Nexesenex <[email protected]>
Date:   Sun Aug 25 03:17:21 2024 +0200

    Put back ffn_down tree where it was before.

commit 8fc46df
Author: Nexesenex <[email protected]>
Date:   Sat Aug 24 22:30:45 2024 +0200

    Bump a bit ffn_gate and down for some GQA<2 models

commit 53b8eaa
Author: Nexesenex <[email protected]>
Date:   Sat Aug 24 21:57:07 2024 +0200

    Remove deprecated rules for token embeddings

commit 844d11b
Author: Nexesenex <[email protected]>
Date:   Sat Aug 24 21:02:51 2024 +0200

    bad indent

commit 5ae5971
Author: Nexesenex <[email protected]>
Date:   Sat Aug 24 20:50:07 2024 +0200

    Revamp Q2_K and Q3_K quants

    Q3_K_XL takes the place of Q3_K_L.
    Q3_K_L becomes intermediary between Q3_K_M and XL.

commit 1bde168
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 23:27:26 2024 +0200

    Usage of n_head to discriminate very small models

    Of which the size is more sensitive to the non repeating tensors

commit 16e9c37
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 23:18:59 2024 +0200

    various corrections on IQ2_S+ and IQ3 quants

commit 380b53d
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 21:59:34 2024 +0200

    Fix IQ4_XSR

commit 6081085
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 17:48:31 2024 +0200

    Ravamp attn_output

commit 6b5cebf
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 16:40:40 2024 +0200

    Revamp a bit output weight

    for more granularity in low quants.

commit f796954
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 14:17:19 2024 +0200

    Revamp FFN down and attn_k

    And complete FFN up
    Shrink a bit more non GQA models

commit 596a4ae
Author: Nexesenex <[email protected]>
Date:   Thu Aug 22 19:12:25 2024 +0200

    Readd variable attn_k, attn_q, attn_o after merge

commit fb2b9ea
Merge: 3a027b8 e11bd85
Author: Nexesenex <[email protected]>
Date:   Sun Aug 25 02:59:57 2024 +0200

    Merge branch 'master' into pr/8836

commit 3a027b8
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 00:08:42 2024 +0200

    Revamp IQ4_XSR, remove IQ3_XXXL

commit e05da54
Author: Nexesenex <[email protected]>
Date:   Thu Aug 22 19:12:13 2024 +0200

    Overhaul of FFN, if GQA and if not

commit 1607a02
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 12:38:45 2024 +0200

    Further adjustments difquant formulas

commit 179ad0f
Author: Nexesenex <[email protected]>
Date:   Wed Aug 21 13:10:54 2024 +0200

    Little rework of the difquant formulas

commit 644aa9f
Author: Nexesenex <[email protected]>
Date:   Wed Aug 21 13:07:32 2024 +0200

    Correction too small tensor embeddings to quantize

    IQ2_XS doesn't seem to work as such, back to IQ2_S

commit 32f6ead
Author: Nexesenex <[email protected]>
Date:   Mon Aug 19 17:58:12 2024 +0200

    Improve IQ1 and IQ2 quants

    And fix mistakes for the attn.output of IQ2_XL and the ffn gate and up of IQ2_XS

    Reformat attn_ouput mess and split GQA4/GQA2

commit d7b9d21
Author: Nexesenex <[email protected]>
Date:   Tue Aug 20 12:45:30 2024 +0200

    Shrink a bit IQ3_XXS, bump a bit IQ3_M

commit dbadcdd
Author: Nexesenex <[email protected]>
Date:   Tue Aug 20 11:59:41 2024 +0200

    harmonize formatting of tensor type conditions

commit ce86019
Author: Nexesenex <[email protected]>
Date:   Wed Aug 21 12:25:38 2024 +0200

    change function use_*_bits into difquant_*_tensors

    this to clarify what it does, especially with the 5 additional levels of difquant

commit cfe866e
Merge: fddff02 fc54ef0
Author: Nexesenex <[email protected]>
Date:   Wed Aug 21 12:23:41 2024 +0200

    Merge branch 'master' into pr/8836

commit fddff02
Author: Nexesenex <[email protected]>
Date:   Mon Aug 19 01:43:31 2024 +0200

    Rework IQ3_XXS and IQ3_XS

    and fix parenthesis mistake on IQ3_S

commit 207ffe6
Author: Nexesenex <[email protected]>
Date:   Sun Aug 18 23:28:13 2024 +0200

    Reorder, corrections, settling lower IQ3 quants

commit 8c1a3c5
Merge: a7f9164 cfac111
Author: Nexesenex <[email protected]>
Date:   Tue Aug 20 00:48:05 2024 +0200

    Merge branch 'master' into pr/8836

commit a7f9164
Author: Nexesenex <[email protected]>
Date:   Mon Aug 19 16:02:00 2024 +0200

    Fix mistake

commit caeb839
Author: Nexesenex <[email protected]>
Date:   Sun Aug 18 17:58:17 2024 +0200

    Boost embeddings and output weights for MOEs.

    They are single and non-repeating, the boost is thus reasonable compared to the 4 or more experts size.

commit 503048a
Author: Nexesenex <[email protected]>
Date:   Sun Aug 18 17:44:11 2024 +0200

    Correct IQ3_M

commit ddb1373
Author: Nexesenex <[email protected]>
Date:   Sun Aug 18 16:56:55 2024 +0200

    IQ3_XXL and IQ3_XXXL

    We now have a full range of quants between IQ3_M and IQ4_XS

commit a79633b
Merge: b02eaf6 554b049
Author: Nexesenex <[email protected]>
Date:   Sun Aug 18 22:12:39 2024 +0200

    Merge branch 'master' into pr/8836

commit b02eaf6
Author: Nexesenex <[email protected]>
Date:   Sat Aug 17 14:58:25 2024 +0200

    Mass use of the few/some/more/many bits bump logic

    Add few bits logic and rework the 4 settings for 25/37.5/50/75% quant bump when used.

commit 4ba5618
Author: Nexesenex <[email protected]>
Date:   Sat Aug 17 12:31:36 2024 +0200

    Adapt token embeddings and output.weight to vocab size

    due to the huge increase of the embeddings and output weight size for models with huge vocab, they seem to quantize with less loss.

commit 17b7151
Author: Nexesenex <[email protected]>
Date:   Sat Aug 17 00:17:41 2024 +0200

    Update IQ3_M attn_k and IQ3_XL token_embd

commit e4c506d
Merge: eeccd31 2fb9267
Author: Nexesenex <[email protected]>
Date:   Sun Aug 18 04:09:22 2024 +0200

    Merge branch 'master' into pr/8836

commit eeccd31
Merge: 8c9017b 5fd89a7
Author: Nexesenex <[email protected]>
Date:   Thu Aug 15 02:30:10 2024 +0200

    Merge branch 'master' into pr/8836

commit 8c9017b
Author: Nexesenex <[email protected]>
Date:   Mon Aug 12 22:20:02 2024 +0200

    Simplify IQ4_XSR

    But leave in place as a "demo" the more complex template set by Ikawrakow to customize the layers quants, with the added attn_q, attn_k, and attn_output tensors.

commit 8c10533
Merge: cd92ba6 fc4ca27
Author: Nexesenex <[email protected]>
Date:   Mon Aug 12 20:28:38 2024 +0200

    Merge branch 'master' into pr/8836

commit cd92ba6
Author: Nexesenex <[email protected]>
Date:   Mon Aug 12 19:45:46 2024 +0200

    IQ4_XSR (test FTYPE) and attention_wv logic for all attn_*.weights

    Also, Advise iMatrix for IQ2_M and Q2_K FTypes

commit 3e2eb6d
Merge: df9e6fd df5478f
Author: Nexesenex <[email protected]>
Date:   Mon Aug 12 14:25:23 2024 +0200

    Merge branch 'master' into pr/8836

commit df9e6fd
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 21:49:23 2024 +0200

    Adjustments on output and embeddings

commit 1ad18f8
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 21:44:29 2024 +0200

    Adjustments on attn_k

commit 8c2c03f
Merge: 91db53b 8cd1bcf
Author: Nexes the Old <[email protected]>
Date:   Sun Aug 11 16:46:15 2024 +0200

    Merge b3569

    b3569

commit 91db53b
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 16:41:23 2024 +0200

    IQ1_XL and some corrections

    notably on attn_q and parenthesis

commit 1268d58
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 02:13:08 2024 +0200

    More adjustments

commit ef83a87
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 01:30:18 2024 +0200

    Revert of ffn gate and up on IQ3_M

    and indent

commit e2e2d77
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 01:13:12 2024 +0200

    misplaced file lol

commit 8ad71f4
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 01:11:24 2024 +0200

    IQ1_XS

    and small adjustments.

commit 14f4f40
Merge: 8bc7a98 6e02327
Author: Nexes the Old <[email protected]>
Date:   Sat Aug 10 20:45:26 2024 +0200

    Merge b3565

    Merge b3565

commit 8bc7a98
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 20:40:27 2024 +0200

    2 forgotten files

commit f0806ac
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 20:34:17 2024 +0200

    IQ2_XL , IQ3_XL , Q2_K_L

    Plus some adjustments on the FFNs

commit 49617b1
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 18:37:29 2024 +0200

    Advancing on several tensors

    - Progressivity for token embeddings and attn_qkv
    - FFN down for IQ1 and IQ2 quants
    - FFN gate and up for IQ2_S and IQ2_M, for progressivity in the IQ2 range.

commit 415d5e4
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 17:32:29 2024 +0200

    Refactor furthermore attn.v

    And also lower attn_q for IQ2_XS, in order to separate it more for the quite misnamed IQ2_S

commit 8c8e43c
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 16:38:11 2024 +0200

    Settings for MOE >= 8 experts applied to >= 4 experts

commit aa4eb59
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 16:33:55 2024 +0200

    Further refactor attn_k

    With attn_k set for all quants bellow 3bpw except Q2_K_S.

commit 8f1b99f
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 13:09:11 2024 +0200

    Shortening formatting

commit 7212098
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 12:52:57 2024 +0200

    IQ1 and IQ2 refactor

    Attn_q in Q3_K for experts >= 8
    Attn_k in Q5_K for experts >= 8
    Attn_v in Q6_K for experts >= 8, in IQ3_XXS for IQ2_XXS and IQ2_XS
    Attn_output in Q4_K for experts >= 8

commit 1bc4dc5
Author: Nexesenex <[email protected]>
Date:   Fri Aug 9 22:49:42 2024 +0200

    Bump IQ3_M

    attn.v in Q5_K
    attn.k in IQ4_XS

commit 1118c04
Author: Nexes the Old <[email protected]>
Date:   Thu Aug 8 18:56:20 2024 +0200

    correct mistake in conditionality for attn.k

commit 8006b15
Author: Nexes the Old <[email protected]>
Date:   Thu Aug 8 18:50:48 2024 +0200

    Avoid to shrink attn.k.weight for IQ3_XS and XXS when GQA or MOE

commit 59c5d47
Author: Nexes the Old <[email protected]>
Date:   Sun Aug 4 12:06:06 2024 +0200

    attn_qkv.weight in IQ4_XS for FTYPE IQ3_M

    If FTYPE IQ4_XS has attn_qkv.weight in IQ4_XS, then FTYPE IQ3_M should not have it in Q4_K (4.5BPW), but in IQ4_XS (4.25BPW) also.

commit 93c35f8
Author: Nexes the Old <[email protected]>
Date:   Sun Aug 4 11:59:52 2024 +0200

    attn.output.tensor of  FYPE IQ3_M in IQ4_XS

    If FTYPE IQ4_XS has attn.output.tensor in IQ4_XS (4.5BPW), there's no reason to have FTYPE IQ3_M to have attn.output.tensor in Q4_K (4.5BPW).
    In terms of perplexity, on a Llama 3.1 70b model, the proposed change reduces the size by 1%, and increases the preplexity by 0.25%.

commit d5779c2
Author: Nexes the Old <[email protected]>
Date:   Sat Aug 3 03:04:25 2024 +0200

    More occurences of n_experts == 8 changed to >= in quant strategies

commit 7d337d0
Author: Nexes the Old <[email protected]>
Date:   Sat Aug 3 01:35:08 2024 +0200

    Slight reorder of the attn.weight tree

    And application of the attn.v.weight logic I used for IQ2 and IQ3, but only when such logic is already implied by the existing quant strategies, as a compromise to not disturb too much Ikawrakow's quant strategies.

commit 6398663
Author: Nexes the Old <[email protected]>
Date:   Fri Aug 2 23:49:03 2024 +0200

    Apply the GQA2/Expert2 conditionality to the IQ3 quants

    In coherence with the proposed modifications to the IQ2 quant strategies, which make even more sense for the IQ3 quant strategies.

commit b77cdd8
Author: Nexes the Old <[email protected]>
Date:   Fri Aug 2 20:40:04 2024 +0200

    Small changes for IQ2 quant strategies (notably IQ2_S and IQ2_M)

    Here's a few edits I consider useful to improve a bit the IQ2 model quant strategies for some models:

    - The tensor attn.v.weight passed in Q4_K for models like Gemma (GQA 2), and the various franken MOEs having 2 experts, this to not sabotage them with a too small value head quant (Q2_K is meh for such important head) while the size of that head is low relatively to the total size of the affected models.

    - The tensor attn.k.weight passed in Q4_K for models with 8 experts or more, rather than simply 8 experts.

    - The tensor attn.output.weight passed in IQ3_XXS (instead of IQ3_S) for the quant strategies IQ2_S and IQ2_M, this to have a progressiveness between the IQ2_XS quant strategies (which use IQ2_XS for the attn.output.weight) and the IQ3_XXS quant strategies (which use.. IQ3_S quant for attn.output.weight). The benefit of an IQ3_S quant instead of an IQ3_XXS for that tensor is quasi-inexistant on IQ2_S and IQ2_M quant strategies, especially compared to the size bump it provokes.

    More broadly, I think that the whole IQ2 quant strategies bunch should be harmonized/refactored like the rest of the quant strategies are established (tensor by tensor), rather than under an different kind of tree mixing these 5 quant strategies.

    I'm using these settings (and many more edits) for a long time, with benefit, and I think they could be standard.
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Feb 25, 2025
commit bb2a4d125e33148d2b4e9363bf8ace14f722a610
Author: Nexesenex <[email protected]>
Date:   Mon Nov 11 08:59:32 2024 +0100

    8x22b

commit 9d4926ff9559ecae25f19fadcb55586677575b61
Merge: 9c65f44 b0cefea
Author: Nexesenex <[email protected]>
Date:   Mon Nov 11 08:59:07 2024 +0100

    Merge branch 'master' into Nexes_CQ30

commit 9c65f44
Author: Nexesenex <[email protected]>
Date:   Sun Nov 3 04:30:14 2024 +0100

    Test base 2048

commit 8ccafe8
Merge: d0d276f 9830b69
Author: Nexesenex <[email protected]>
Date:   Sun Nov 3 04:28:33 2024 +0100

    Merge branch 'master' into Nexes_CQ30

commit d0d276f
Merge: 7cecefd 418f5ee
Author: Nexesenex <[email protected]>
Date:   Fri Nov 1 20:18:27 2024 +0100

    Merge branch 'master' into Nexes_CQ30

commit 7cecefd
Merge: a5303b7 8841ce3
Author: Nexesenex <[email protected]>
Date:   Mon Oct 28 06:45:16 2024 +0100

    Merge branch 'master' into Nexes_CQ30

commit a5303b7
Merge: f21ab1e 167a515
Author: Nexesenex <[email protected]>
Date:   Thu Oct 24 19:55:36 2024 +0200

    Merge branch 'master' into Nexes_CQ30

commit f21ab1e
Merge: c72289e 20011f1
Author: Nexesenex <[email protected]>
Date:   Wed Oct 23 20:26:42 2024 +0200

    Merge branch 'gg/default-kq-f32-prec' into Nexes_CQ20

commit c72289e
Merge: eaee12e 190a37d
Author: Nexesenex <[email protected]>
Date:   Wed Oct 23 20:26:34 2024 +0200

    Merge branch 'master' into Nexes_CQ20

commit 20011f1
Author: Georgi Gerganov <[email protected]>
Date:   Wed Oct 23 14:32:27 2024 +0300

    llama : switch KQ multiplication to use F32 precision by default

    ggml-ci

commit eaee12e
Author: Nexesenex <[email protected]>
Date:   Mon Oct 21 15:41:24 2024 +0200

    EXL SXL and UXL types to test the new bits formula

commit 6abef2a
Merge: aa73a4e d5ebd79
Author: Nexesenex <[email protected]>
Date:   Mon Oct 21 15:40:22 2024 +0200

    Merge branch 'master' into Nexes_CQ20

commit aa73a4e
Author: Nexesenex <[email protected]>
Date:   Sat Oct 19 19:04:33 2024 +0200

    use_some_bits and use_most_bits

commit 7794c8f
Merge: 1cf274d cda0e4b
Author: Nexesenex <[email protected]>
Date:   Sat Oct 19 19:04:05 2024 +0200

    Merge branch 'master' into Nexes_CQ20

commit 1cf274d
Author: Nexesenex <[email protected]>
Date:   Fri Oct 18 21:00:56 2024 +0200

    ML UXL and EXL boost

commit f105e0f
Author: Nexesenex <[email protected]>
Date:   Fri Oct 18 21:05:49 2024 +0200

    Revert compile for Ampere

commit 1b25cbb
Author: Nexesenex <[email protected]>
Date:   Fri Oct 18 21:05:04 2024 +0200

    Delete CMakePresets.json

commit 1c440a8
Merge: 366e0c8 afd9909
Author: Nexesenex <[email protected]>
Date:   Fri Oct 18 20:42:34 2024 +0200

    Merge branch 'master' into Nexes_CQ20

commit 366e0c8
Author: Nexesenex <[email protected]>
Date:   Wed Oct 16 16:57:30 2024 +0200

    Fix indent model sizes

commit cf8375c
Author: Nexesenex <[email protected]>
Date:   Wed Oct 16 16:41:57 2024 +0200

    continue Q5_K mixes

commit 2d052f7
Author: Nexesenex <[email protected]>
Date:   Tue Oct 15 17:42:48 2024 +0200

    difquants three/four eights alt for Mistral Large

commit 29cecae
Author: Nexesenex <[email protected]>
Date:   Tue Oct 15 16:03:12 2024 +0200

    Q5_K_XSR, SR, ML, and XL revamp

commit 412b56f
Author: Nexesenex <[email protected]>
Date:   Mon Oct 14 17:08:23 2024 +0200

    IQ3_X5L and IQ3_X7L fix for Mistral Large

commit ca86ce8
Author: Nexesenex <[email protected]>
Date:   Mon Oct 14 15:24:37 2024 +0200

    Pursue IQ3 revamp

commit 6c51f39
Author: Nexesenex <[email protected]>
Date:   Sun Oct 13 22:22:40 2024 +0200

    IQ3_XXXXL, EXL and renaming >=IQ3_ML scheme

    Test for Mistral Large

    IQ3_XL = IQ3_X5L and so on.

commit 64bfe69
Author: Nexesenex <[email protected]>
Date:   Sun Oct 13 22:33:05 2024 +0200

    Activate F16

commit 575ebc2
Merge: 38229d3 d4c19c0
Author: Nexesenex <[email protected]>
Date:   Sun Oct 13 22:22:30 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 38229d3
Author: Nexesenex <[email protected]>
Date:   Sat Oct 12 20:57:18 2024 +0200

    Fix specify tensors in quantize

commit b947b6e
Author: Nexesenex <[email protected]>
Date:   Sat Oct 12 13:38:22 2024 +0200

    New FTYPE Q5_K_XL

commit ba1b854
Author: Nexesenex <[email protected]>
Date:   Sat Oct 12 13:36:25 2024 +0200

    New FTYPE IQ4_XXSR

    and beef up attn_k IQ4_XSR

commit 79fa98c
Author: Nexesenex <[email protected]>
Date:   Sun Oct 13 02:00:38 2024 +0200

    GGML_MAX_COPIES_1 in CML

commit f95ed01
Merge: accd71d edc2656
Author: Nexesenex <[email protected]>
Date:   Sun Oct 13 02:02:06 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit accd71d
Merge: b5103f4 11ac980
Author: Nexesenex <[email protected]>
Date:   Sat Oct 12 13:23:11 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit b5103f4
Author: Nexesenex <[email protected]>
Date:   Fri Oct 11 13:43:48 2024 +0200

    Better model info (ikawrakow#84)

    Co-Authored-By: Kawrakow <[email protected]>

commit b302561
Author: Nexesenex <[email protected]>
Date:   Fri Oct 11 13:17:39 2024 +0200

    IQ3_UXL for test

commit 8c6e408
Merge: 66a9b05 7eee341
Author: Nexesenex <[email protected]>
Date:   Fri Oct 11 13:17:30 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 66a9b05
Author: Nexesenex <[email protected]>
Date:   Wed Oct 9 04:30:45 2024 +0200

    correct iQ4_LR

commit 298990a
Merge: f1814f1 dca1d4b
Author: Nexesenex <[email protected]>
Date:   Tue Oct 8 22:11:53 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit f1814f1
Author: Nexesenex <[email protected]>
Date:   Mon Oct 7 23:21:56 2024 +0200

    Rebump attn_v

commit b94a9b0
Merge: 18677c8 6374743
Author: Nexesenex <[email protected]>
Date:   Mon Oct 7 23:21:38 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 18677c8
Author: Nexesenex <[email protected]>
Date:   Sun Oct 6 02:12:09 2024 +0200

    IQ4_LR

commit a2500c1
Author: Nexesenex <[email protected]>
Date:   Sun Oct 6 02:12:55 2024 +0200

    Crack down fallback GGML_types

commit 75b8800
Author: Nexesenex <[email protected]>
Date:   Sat Oct 5 23:18:02 2024 +0200

    More overhaul for IQ4_XSR and new IQ4_MR

commit 167a3c5
Author: Nexesenex <[email protected]>
Date:   Sat Oct 5 17:17:50 2024 +0200

    GGML SCHED MAX COPIES 1

commit 8433050
Author: Nexesenex <[email protected]>
Date:   Sat Oct 5 17:14:39 2024 +0200

    Adapt CML

commit 1e0f64e
Author: Nexesenex <[email protected]>
Date:   Sat Oct 5 17:07:07 2024 +0200

    Compile for Ampere

commit 35ce3f6
Merge: 6480054 8c475b9
Author: Nexesenex <[email protected]>
Date:   Sat Oct 5 17:03:34 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 6480054
Author: Nexesenex <[email protected]>
Date:   Fri Oct 4 18:21:54 2024 +0200

    IQ4_XSR revamp

commit 1ec8328
Author: Nexesenex <[email protected]>
Date:   Mon Aug 19 17:00:34 2024 +0200

    Clarify PPL result

commit de50e13
Merge: ed67589 d5ed2b9
Author: Nexesenex <[email protected]>
Date:   Thu Oct 3 22:23:08 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit ed67589
Merge: 06ab3a2 70392f1
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 10:22:50 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 06ab3a2
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 10:22:46 2024 +0200

    More size logging

commit 9d97928
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 10:21:25 2024 +0200

    Update llama.cpp

commit 700d205
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 03:51:26 2024 +0200

    IQ3_XS more

commit da840a3
Merge: 056c47d 116efee
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 03:30:18 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 056c47d
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 03:30:15 2024 +0200

    Reapply "threadpool : skip polling for unused threads (ggml-org#9461)"

    This reverts commit 2a8dbf8.

commit 8d789ac
Author: Nexesenex <[email protected]>
Date:   Tue Sep 24 03:20:58 2024 +0200

    IQ3_XS

commit 413fc43
Author: Nexesenex <[email protected]>
Date:   Mon Sep 23 19:34:45 2024 +0200

    Fix IQ3 <=M

commit 9ed3522
Merge: 2a8dbf8 1d48e98
Author: Nexesenex <[email protected]>
Date:   Mon Sep 23 18:50:43 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit 2a8dbf8
Author: Nexesenex <[email protected]>
Date:   Sun Sep 22 02:48:50 2024 +0200

    Revert "threadpool : skip polling for unused threads (ggml-org#9461)"

    This reverts commit 0226613.

commit 6faac9f
Author: Nexesenex <[email protected]>
Date:   Sun Sep 22 02:46:37 2024 +0200

    Revert "Update CUDA graph on scale change plus clear nodes/params  (ggml-org#9550)"

    This reverts commit 41f4778.

commit f377f88
Merge: e3ec684 d09770c
Author: Nexesenex <[email protected]>
Date:   Sat Sep 21 17:25:04 2024 +0200

    Merge branch 'master' into Nexes_CQ_10

commit e3ec684
Author: Nexesenex <[email protected]>
Date:   Fri Sep 20 06:36:47 2024 +0200

    reinsert cqs

commit d48aad3
Author: Nexesenex <[email protected]>
Date:   Mon Sep 2 05:50:08 2024 +0200

    Play with IQ3 quants

commit 5af6481
Author: Nexesenex <[email protected]>
Date:   Mon Sep 2 01:41:19 2024 +0200

    IQ4_XSR_rework

commit dd770d2
Author: Nexesenex <[email protected]>
Date:   Sat Aug 31 17:05:00 2024 +0200

    refine IQ3 quants

commit 32ce04a
Author: Nexesenex <[email protected]>
Date:   Sat Aug 31 14:22:00 2024 +0200

    Use of vocab as difquant criteria

    The pre-vocab>128k models are more sensitive to ffn_down quant than to ffn_gate and up.

commit 86a7e4a
Author: Nexesenex <[email protected]>
Date:   Fri Aug 30 12:15:54 2024 +0200

    IQ3_UXL

commit 97fbd74
Author: Nexesenex <[email protected]>
Date:   Thu Aug 29 22:40:32 2024 +0200

    New difquant seven_eights

commit c6732bf
Author: Nexesenex <[email protected]>
Date:   Wed Aug 28 16:06:38 2024 +0200

    Bump a bit output for big models in IQ2 and IQ3

commit cce61d3
Author: Nexesenex <[email protected]>
Date:   Wed Aug 28 13:00:53 2024 +0200

    Difquant attn_q and attn_o for IQ3_XXS, XS, and S

    And also establishing a bump to difquant_first_last_tensors for attn_k and attn_v

commit 1e7e816
Author: Nexesenex <[email protected]>
Date:   Wed Aug 28 02:24:55 2024 +0200

    Add IQ3_ML, reinstate IQ3_XXXL

commit 7b0dc30
Author: Nexesenex <[email protected]>
Date:   Wed Aug 28 00:52:45 2024 +0200

    Bump IQ3_XS

commit 6263649
Author: Nexesenex <[email protected]>
Date:   Tue Aug 27 16:19:10 2024 +0200

    Revert variable V below Q5_K

commit eb4a69e
Author: Nexesenex <[email protected]>
Date:   Tue Aug 27 13:26:15 2024 +0200

    Difquant for IQ2_XL & IQ3 for attn_k and attn_v

    And prepare difquant for these quants for attn_o and attn_q

commit c84d981
Author: Nexesenex <[email protected]>
Date:   Tue Aug 27 06:13:39 2024 +0200

    correct settings

commit c667f2e
Author: Nexesenex <[email protected]>
Date:   Mon Aug 26 23:05:04 2024 +0200

    Temporary settings for IQ3 attn_k and attn_v

commit 294aeec
Author: Nexesenex <[email protected]>
Date:   Mon Aug 26 18:18:05 2024 +0200

    Corrections and clean-up

    Back to Q8_0 for attn_k and attn_v if 8 experts or more.

    for attn_v and attn_k if experts>=4
    GQA>=12 brought back to expert>=4 quant level instead of 8
    GQA8 brought to GQA7, and GQA7 brought to GQA4.

commit e7c5163
Author: Nexesenex <[email protected]>
Date:   Mon Aug 26 14:33:34 2024 +0200

    Shrink a bit Q2_K when GQA<2

    and optimize difquants_first_last and fl_more

commit ff48606
Author: Nexesenex <[email protected]>
Date:   Mon Aug 26 14:02:09 2024 +0200

    IQI_XL, IQ2_S, IQ2_XS enhanced

commit 8a1ab24
Author: Nexesenex <[email protected]>
Date:   Mon Aug 26 12:57:21 2024 +0200

    IQ1_XS,  IQ1_S, IQ1_M, IQ2_XXS, Q2_M, Q2_K enhanced

    testing templates for other quants.

commit 26aac8e
Author: Nexesenex <[email protected]>
Date:   Sun Aug 25 14:42:33 2024 +0200

    Soften the token embeddings bump for experts >= 4

commit 5644d4c
Merge: 16aee45 6026da5
Author: Nexesenex <[email protected]>
Date:   Fri Sep 20 01:38:20 2024 +0200

    Merge branch 'master' into pr/8836

commit 16aee45
Author: Nexesenex <[email protected]>
Date:   Sun Aug 25 14:25:46 2024 +0200

    correction

commit dd3df75
Author: Nexesenex <[email protected]>
Date:   Sun Aug 25 03:30:36 2024 +0200

    Bad indents and trailing whitespaces

commit f63860e
Author: Nexesenex <[email protected]>
Date:   Sun Aug 25 03:17:21 2024 +0200

    Put back ffn_down tree where it was before.

commit 8fc46df
Author: Nexesenex <[email protected]>
Date:   Sat Aug 24 22:30:45 2024 +0200

    Bump a bit ffn_gate and down for some GQA<2 models

commit 53b8eaa
Author: Nexesenex <[email protected]>
Date:   Sat Aug 24 21:57:07 2024 +0200

    Remove deprecated rules for token embeddings

commit 844d11b
Author: Nexesenex <[email protected]>
Date:   Sat Aug 24 21:02:51 2024 +0200

    bad indent

commit 5ae5971
Author: Nexesenex <[email protected]>
Date:   Sat Aug 24 20:50:07 2024 +0200

    Revamp Q2_K and Q3_K quants

    Q3_K_XL takes the place of Q3_K_L.
    Q3_K_L becomes intermediary between Q3_K_M and XL.

commit 1bde168
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 23:27:26 2024 +0200

    Usage of n_head to discriminate very small models

    Of which the size is more sensitive to the non repeating tensors

commit 16e9c37
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 23:18:59 2024 +0200

    various corrections on IQ2_S+ and IQ3 quants

commit 380b53d
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 21:59:34 2024 +0200

    Fix IQ4_XSR

commit 6081085
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 17:48:31 2024 +0200

    Ravamp attn_output

commit 6b5cebf
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 16:40:40 2024 +0200

    Revamp a bit output weight

    for more granularity in low quants.

commit f796954
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 14:17:19 2024 +0200

    Revamp FFN down and attn_k

    And complete FFN up
    Shrink a bit more non GQA models

commit 596a4ae
Author: Nexesenex <[email protected]>
Date:   Thu Aug 22 19:12:25 2024 +0200

    Readd variable attn_k, attn_q, attn_o after merge

commit fb2b9ea
Merge: 3a027b8 e11bd85
Author: Nexesenex <[email protected]>
Date:   Sun Aug 25 02:59:57 2024 +0200

    Merge branch 'master' into pr/8836

commit 3a027b8
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 00:08:42 2024 +0200

    Revamp IQ4_XSR, remove IQ3_XXXL

commit e05da54
Author: Nexesenex <[email protected]>
Date:   Thu Aug 22 19:12:13 2024 +0200

    Overhaul of FFN, if GQA and if not

commit 1607a02
Author: Nexesenex <[email protected]>
Date:   Fri Aug 23 12:38:45 2024 +0200

    Further adjustments difquant formulas

commit 179ad0f
Author: Nexesenex <[email protected]>
Date:   Wed Aug 21 13:10:54 2024 +0200

    Little rework of the difquant formulas

commit 644aa9f
Author: Nexesenex <[email protected]>
Date:   Wed Aug 21 13:07:32 2024 +0200

    Correction too small tensor embeddings to quantize

    IQ2_XS doesn't seem to work as such, back to IQ2_S

commit 32f6ead
Author: Nexesenex <[email protected]>
Date:   Mon Aug 19 17:58:12 2024 +0200

    Improve IQ1 and IQ2 quants

    And fix mistakes for the attn.output of IQ2_XL and the ffn gate and up of IQ2_XS

    Reformat attn_ouput mess and split GQA4/GQA2

commit d7b9d21
Author: Nexesenex <[email protected]>
Date:   Tue Aug 20 12:45:30 2024 +0200

    Shrink a bit IQ3_XXS, bump a bit IQ3_M

commit dbadcdd
Author: Nexesenex <[email protected]>
Date:   Tue Aug 20 11:59:41 2024 +0200

    harmonize formatting of tensor type conditions

commit ce86019
Author: Nexesenex <[email protected]>
Date:   Wed Aug 21 12:25:38 2024 +0200

    change function use_*_bits into difquant_*_tensors

    this to clarify what it does, especially with the 5 additional levels of difquant

commit cfe866e
Merge: fddff02 fc54ef0
Author: Nexesenex <[email protected]>
Date:   Wed Aug 21 12:23:41 2024 +0200

    Merge branch 'master' into pr/8836

commit fddff02
Author: Nexesenex <[email protected]>
Date:   Mon Aug 19 01:43:31 2024 +0200

    Rework IQ3_XXS and IQ3_XS

    and fix parenthesis mistake on IQ3_S

commit 207ffe6
Author: Nexesenex <[email protected]>
Date:   Sun Aug 18 23:28:13 2024 +0200

    Reorder, corrections, settling lower IQ3 quants

commit 8c1a3c5
Merge: a7f9164 cfac111
Author: Nexesenex <[email protected]>
Date:   Tue Aug 20 00:48:05 2024 +0200

    Merge branch 'master' into pr/8836

commit a7f9164
Author: Nexesenex <[email protected]>
Date:   Mon Aug 19 16:02:00 2024 +0200

    Fix mistake

commit caeb839
Author: Nexesenex <[email protected]>
Date:   Sun Aug 18 17:58:17 2024 +0200

    Boost embeddings and output weights for MOEs.

    They are single and non-repeating, the boost is thus reasonable compared to the 4 or more experts size.

commit 503048a
Author: Nexesenex <[email protected]>
Date:   Sun Aug 18 17:44:11 2024 +0200

    Correct IQ3_M

commit ddb1373
Author: Nexesenex <[email protected]>
Date:   Sun Aug 18 16:56:55 2024 +0200

    IQ3_XXL and IQ3_XXXL

    We now have a full range of quants between IQ3_M and IQ4_XS

commit a79633b
Merge: b02eaf6 554b049
Author: Nexesenex <[email protected]>
Date:   Sun Aug 18 22:12:39 2024 +0200

    Merge branch 'master' into pr/8836

commit b02eaf6
Author: Nexesenex <[email protected]>
Date:   Sat Aug 17 14:58:25 2024 +0200

    Mass use of the few/some/more/many bits bump logic

    Add few bits logic and rework the 4 settings for 25/37.5/50/75% quant bump when used.

commit 4ba5618
Author: Nexesenex <[email protected]>
Date:   Sat Aug 17 12:31:36 2024 +0200

    Adapt token embeddings and output.weight to vocab size

    due to the huge increase of the embeddings and output weight size for models with huge vocab, they seem to quantize with less loss.

commit 17b7151
Author: Nexesenex <[email protected]>
Date:   Sat Aug 17 00:17:41 2024 +0200

    Update IQ3_M attn_k and IQ3_XL token_embd

commit e4c506d
Merge: eeccd31 2fb9267
Author: Nexesenex <[email protected]>
Date:   Sun Aug 18 04:09:22 2024 +0200

    Merge branch 'master' into pr/8836

commit eeccd31
Merge: 8c9017b 5fd89a7
Author: Nexesenex <[email protected]>
Date:   Thu Aug 15 02:30:10 2024 +0200

    Merge branch 'master' into pr/8836

commit 8c9017b
Author: Nexesenex <[email protected]>
Date:   Mon Aug 12 22:20:02 2024 +0200

    Simplify IQ4_XSR

    But leave in place as a "demo" the more complex template set by Ikawrakow to customize the layers quants, with the added attn_q, attn_k, and attn_output tensors.

commit 8c10533
Merge: cd92ba6 fc4ca27
Author: Nexesenex <[email protected]>
Date:   Mon Aug 12 20:28:38 2024 +0200

    Merge branch 'master' into pr/8836

commit cd92ba6
Author: Nexesenex <[email protected]>
Date:   Mon Aug 12 19:45:46 2024 +0200

    IQ4_XSR (test FTYPE) and attention_wv logic for all attn_*.weights

    Also, Advise iMatrix for IQ2_M and Q2_K FTypes

commit 3e2eb6d
Merge: df9e6fd df5478f
Author: Nexesenex <[email protected]>
Date:   Mon Aug 12 14:25:23 2024 +0200

    Merge branch 'master' into pr/8836

commit df9e6fd
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 21:49:23 2024 +0200

    Adjustments on output and embeddings

commit 1ad18f8
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 21:44:29 2024 +0200

    Adjustments on attn_k

commit 8c2c03f
Merge: 91db53b 8cd1bcf
Author: Nexes the Old <[email protected]>
Date:   Sun Aug 11 16:46:15 2024 +0200

    Merge b3569

    b3569

commit 91db53b
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 16:41:23 2024 +0200

    IQ1_XL and some corrections

    notably on attn_q and parenthesis

commit 1268d58
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 02:13:08 2024 +0200

    More adjustments

commit ef83a87
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 01:30:18 2024 +0200

    Revert of ffn gate and up on IQ3_M

    and indent

commit e2e2d77
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 01:13:12 2024 +0200

    misplaced file lol

commit 8ad71f4
Author: Nexesenex <[email protected]>
Date:   Sun Aug 11 01:11:24 2024 +0200

    IQ1_XS

    and small adjustments.

commit 14f4f40
Merge: 8bc7a98 6e02327
Author: Nexes the Old <[email protected]>
Date:   Sat Aug 10 20:45:26 2024 +0200

    Merge b3565

    Merge b3565

commit 8bc7a98
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 20:40:27 2024 +0200

    2 forgotten files

commit f0806ac
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 20:34:17 2024 +0200

    IQ2_XL , IQ3_XL , Q2_K_L

    Plus some adjustments on the FFNs

commit 49617b1
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 18:37:29 2024 +0200

    Advancing on several tensors

    - Progressivity for token embeddings and attn_qkv
    - FFN down for IQ1 and IQ2 quants
    - FFN gate and up for IQ2_S and IQ2_M, for progressivity in the IQ2 range.

commit 415d5e4
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 17:32:29 2024 +0200

    Refactor furthermore attn.v

    And also lower attn_q for IQ2_XS, in order to separate it more for the quite misnamed IQ2_S

commit 8c8e43c
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 16:38:11 2024 +0200

    Settings for MOE >= 8 experts applied to >= 4 experts

commit aa4eb59
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 16:33:55 2024 +0200

    Further refactor attn_k

    With attn_k set for all quants bellow 3bpw except Q2_K_S.

commit 8f1b99f
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 13:09:11 2024 +0200

    Shortening formatting

commit 7212098
Author: Nexesenex <[email protected]>
Date:   Sat Aug 10 12:52:57 2024 +0200

    IQ1 and IQ2 refactor

    Attn_q in Q3_K for experts >= 8
    Attn_k in Q5_K for experts >= 8
    Attn_v in Q6_K for experts >= 8, in IQ3_XXS for IQ2_XXS and IQ2_XS
    Attn_output in Q4_K for experts >= 8

commit 1bc4dc5
Author: Nexesenex <[email protected]>
Date:   Fri Aug 9 22:49:42 2024 +0200

    Bump IQ3_M

    attn.v in Q5_K
    attn.k in IQ4_XS

commit 1118c04
Author: Nexes the Old <[email protected]>
Date:   Thu Aug 8 18:56:20 2024 +0200

    correct mistake in conditionality for attn.k

commit 8006b15
Author: Nexes the Old <[email protected]>
Date:   Thu Aug 8 18:50:48 2024 +0200

    Avoid to shrink attn.k.weight for IQ3_XS and XXS when GQA or MOE

commit 59c5d47
Author: Nexes the Old <[email protected]>
Date:   Sun Aug 4 12:06:06 2024 +0200

    attn_qkv.weight in IQ4_XS for FTYPE IQ3_M

    If FTYPE IQ4_XS has attn_qkv.weight in IQ4_XS, then FTYPE IQ3_M should not have it in Q4_K (4.5BPW), but in IQ4_XS (4.25BPW) also.

commit 93c35f8
Author: Nexes the Old <[email protected]>
Date:   Sun Aug 4 11:59:52 2024 +0200

    attn.output.tensor of  FYPE IQ3_M in IQ4_XS

    If FTYPE IQ4_XS has attn.output.tensor in IQ4_XS (4.5BPW), there's no reason to have FTYPE IQ3_M to have attn.output.tensor in Q4_K (4.5BPW).
    In terms of perplexity, on a Llama 3.1 70b model, the proposed change reduces the size by 1%, and increases the preplexity by 0.25%.

commit d5779c2
Author: Nexes the Old <[email protected]>
Date:   Sat Aug 3 03:04:25 2024 +0200

    More occurences of n_experts == 8 changed to >= in quant strategies

commit 7d337d0
Author: Nexes the Old <[email protected]>
Date:   Sat Aug 3 01:35:08 2024 +0200

    Slight reorder of the attn.weight tree

    And application of the attn.v.weight logic I used for IQ2 and IQ3, but only when such logic is already implied by the existing quant strategies, as a compromise to not disturb too much Ikawrakow's quant strategies.

commit 6398663
Author: Nexes the Old <[email protected]>
Date:   Fri Aug 2 23:49:03 2024 +0200

    Apply the GQA2/Expert2 conditionality to the IQ3 quants

    In coherence with the proposed modifications to the IQ2 quant strategies, which make even more sense for the IQ3 quant strategies.

commit b77cdd8
Author: Nexes the Old <[email protected]>
Date:   Fri Aug 2 20:40:04 2024 +0200

    Small changes for IQ2 quant strategies (notably IQ2_S and IQ2_M)

    Here's a few edits I consider useful to improve a bit the IQ2 model quant strategies for some models:

    - The tensor attn.v.weight passed in Q4_K for models like Gemma (GQA 2), and the various franken MOEs having 2 experts, this to not sabotage them with a too small value head quant (Q2_K is meh for such important head) while the size of that head is low relatively to the total size of the affected models.

    - The tensor attn.k.weight passed in Q4_K for models with 8 experts or more, rather than simply 8 experts.

    - The tensor attn.output.weight passed in IQ3_XXS (instead of IQ3_S) for the quant strategies IQ2_S and IQ2_M, this to have a progressiveness between the IQ2_XS quant strategies (which use IQ2_XS for the attn.output.weight) and the IQ3_XXS quant strategies (which use.. IQ3_S quant for attn.output.weight). The benefit of an IQ3_S quant instead of an IQ3_XXS for that tensor is quasi-inexistant on IQ2_S and IQ2_M quant strategies, especially compared to the size bump it provokes.

    More broadly, I think that the whole IQ2 quant strategies bunch should be harmonized/refactored like the rest of the quant strategies are established (tensor by tensor), rather than under an different kind of tree mixing these 5 quant strategies.

    I'm using these settings (and many more edits) for a long time, with benefit, and I think they could be standard.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants