Skip to content

Commit

Permalink
Merge branch 'master' into xsn/argparser_v3
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Sep 6, 2024
2 parents e1281d0 + 815b1fb commit 5ae09fd
Show file tree
Hide file tree
Showing 26 changed files with 992 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ jobs:
run: |
mkdir build
cd build
cmake .. -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DBUILD_SHARED_LIBS=ON
cmake .. -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DBUILD_SHARED_LIBS=ON -DGGML_RPC=ON
cmake --build . --config Release -j $((${env:NUMBER_OF_PROCESSORS} - 1)) -t ggml
cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS}
Expand Down
8 changes: 4 additions & 4 deletions CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@

{
"name": "arm64-windows-msvc", "hidden": true,
"architecture": { "value": "arm64", "strategy": "external" },
"toolset": { "value": "host=x86_64", "strategy": "external" },
"architecture": { "value": "arm64", "strategy": "external" },
"toolset": { "value": "host=x64", "strategy": "external" },
"cacheVariables": {
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-msvc.cmake"
}
},

{
"name": "arm64-windows-llvm", "hidden": true,
"architecture": { "value": "arm64", "strategy": "external" },
"toolset": { "value": "host=x86_64", "strategy": "external" },
"architecture": { "value": "arm64", "strategy": "external" },
"toolset": { "value": "host=x64", "strategy": "external" },
"cacheVariables": {
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-llvm.cmake"
}
Expand Down
9 changes: 9 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2235,6 +2235,15 @@ std::vector<llama_arg> gpt_params_parser_init(gpt_params & params, llama_example
else { throw std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR}));
add_opt(llama_arg(
{"--output-format"}, "{md,jsonl}",
"output format for batched-bench results (default: md)",
[&params](std::string value) {
/**/ if (value == "jsonl") { params.batched_bench_output_jsonl = true; }
else if (value == "md") { params.batched_bench_output_jsonl = false; }
else { std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_BENCH}));
#ifndef LOG_DISABLE_LOGS
// TODO: make this looks less weird
add_opt(llama_arg(
Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ struct gpt_params {
bool spm_infill = false; // suffix/prefix/middle pattern for infill

std::string lora_outfile = "ggml-lora-merged-f16.gguf";

// batched-bench params
bool batched_bench_output_jsonl = false;
};

struct llama_arg {
Expand Down
47 changes: 33 additions & 14 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,20 @@ def prepare_tensors(self):
):
data_qtype = gguf.GGMLQuantizationType.F32

if data_qtype is False and any(
self.match_model_tensor_name(new_name, key, bid)
for key in (
gguf.MODEL_TENSOR.TOKEN_EMBD,
gguf.MODEL_TENSOR.OUTPUT,
)
):
if self.ftype in (
gguf.LlamaFileType.MOSTLY_TQ1_0,
gguf.LlamaFileType.MOSTLY_TQ2_0,
):
# TODO: use Q4_K and Q6_K
data_qtype = gguf.GGMLQuantizationType.F16

# No override (data_qtype is False), or wants to be quantized (data_qtype is True)
if isinstance(data_qtype, bool):
if self.ftype == gguf.LlamaFileType.ALL_F32:
Expand All @@ -318,6 +332,10 @@ def prepare_tensors(self):
data_qtype = gguf.GGMLQuantizationType.BF16
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
data_qtype = gguf.GGMLQuantizationType.Q8_0
elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ1_0:
data_qtype = gguf.GGMLQuantizationType.TQ1_0
elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0:
data_qtype = gguf.GGMLQuantizationType.TQ2_0
else:
raise ValueError(f"Unknown file type: {self.ftype.name}")

Expand Down Expand Up @@ -1623,15 +1641,16 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(1.0)

def weight_quant(self, weight):
def weight_quant(self, weight: Tensor) -> Tensor:
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
weight = (weight * s).round().clamp(-1, 1) / s
scale = weight.abs().max().unsqueeze(0)
weight = torch.where(weight.abs().less(1e-6), 0, weight).type(dtype)
weight = torch.sign(weight).type(dtype)
return weight.type(dtype), scale.type(torch.float32)
scale = weight.abs().mean().clamp(min=1e-5)
iscale = 1 / scale
# TODO: multiply by the scale directly instead of inverting it twice
# (this is also unnecessarily doubly inverted upstream)
# ref: https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/af89e318d78a70802061246bf037199d2fb97020/utils_quant.py#L10
result = (weight * iscale).round().clamp(-1, 1) / iscale
return result.type(dtype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
new_name = self.map_tensor_name(name)
Expand All @@ -1646,11 +1665,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
gguf.MODEL_TENSOR.FFN_GATE,
]):
# transform weight into 1/0/-1 (in fp32)
weight_torch, scale_torch = self.weight_quant(data_torch)
yield (new_name, weight_torch)
yield (new_name.removesuffix(".weight") + ".scale", scale_torch)
else:
yield (new_name, data_torch)
data_torch = self.weight_quant(data_torch)

yield (new_name, data_torch)


@Model.register("GrokForCausalLM")
Expand Down Expand Up @@ -4011,8 +4028,8 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
"--bigendian", action="store_true",
Expand Down Expand Up @@ -4099,6 +4116,8 @@ def main() -> None:
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
"tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0,
"tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0,
"auto": gguf.LlamaFileType.GUESSED,
}

Expand Down
9 changes: 9 additions & 0 deletions examples/batched-bench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,12 @@ There are 2 modes of operation:
| 128 | 256 | 8 | 3072 | 0.751 | 1363.92 | 15.110 | 135.54 | 15.861 | 193.69 |
| 128 | 256 | 16 | 6144 | 1.569 | 1304.93 | 18.073 | 226.64 | 19.642 | 312.80 |
| 128 | 256 | 32 | 12288 | 3.409 | 1201.35 | 19.223 | 426.15 | 22.633 | 542.93 |

### JSONL output

Pass `--output-format jsonl` to output JSONL instead of Markdown, á la

```json lines
{"n_kv_max": 2048, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "is_pp_shared": 0, "n_gpu_layers": 99, "n_threads": 8, "n_threads_batch": 8, "pp": 128, "tg": 128, "pl": 1, "n_kv": 256, "t_pp": 0.233810, "speed_pp": 547.453064, "t_tg": 3.503684, "speed_tg": 36.532974, "t": 3.737494, "speed": 68.495094}
{"n_kv_max": 2048, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "is_pp_shared": 0, "n_gpu_layers": 99, "n_threads": 8, "n_threads_batch": 8, "pp": 128, "tg": 128, "pl": 2, "n_kv": 512, "t_pp": 0.422602, "speed_pp": 605.770935, "t_tg": 11.106112, "speed_tg": 23.050371, "t": 11.528713, "speed": 44.410854}
```
24 changes: 17 additions & 7 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,13 @@ int main(int argc, char ** argv) {
}
}

LOG_TEE("\n");
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("\n");

LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
LOG_TEE("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------");
if (!params.batched_bench_output_jsonl) {
LOG_TEE("\n");
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("\n");
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
LOG_TEE("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------");
}

for ( int i_pp = 0; i_pp < (int) n_pp.size(); ++i_pp) {
for ( int i_tg = 0; i_tg < (int) n_tg.size(); ++i_tg) {
Expand Down Expand Up @@ -193,7 +194,16 @@ int main(int argc, char ** argv) {
const float speed_tg = pl*tg / t_tg;
const float speed = n_kv / t;

LOG_TEE("|%6d | %6d | %4d | %6d | %8.3f | %8.2f | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed);
if(params.batched_bench_output_jsonl) {
LOG_TEE(
"{\"n_kv_max\": %d, \"n_batch\": %d, \"n_ubatch\": %d, \"flash_attn\": %d, \"is_pp_shared\": %d, \"n_gpu_layers\": %d, \"n_threads\": %u, \"n_threads_batch\": %u, "
"\"pp\": %d, \"tg\": %d, \"pl\": %d, \"n_kv\": %d, \"t_pp\": %f, \"speed_pp\": %f, \"t_tg\": %f, \"speed_tg\": %f, \"t\": %f, \"speed\": %f}\n",
n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch,
pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed
);
} else {
LOG_TEE("|%6d | %6d | %4d | %6d | %8.3f | %8.2f | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed);
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "TQ1_0", LLAMA_FTYPE_MOSTLY_TQ1_0, " 1.69 bpw ternarization", },
{ "TQ2_0", LLAMA_FTYPE_MOSTLY_TQ2_0, " 2.06 bpw ternarization", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.96G, +3.5199 ppl @ Llama-3-8B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.96G, +3.1836 ppl @ Llama-3-8B", },
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", },
Expand Down
1 change: 1 addition & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ struct server_queue {

// multi-task version of post()
int post(std::vector<server_task> & tasks, bool front = false) {
std::unique_lock<std::mutex> lock(mutex_tasks);
for (auto & task : tasks) {
if (task.id == -1) {
task.id = id++;
Expand Down
1 change: 1 addition & 0 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ option(GGML_VULKAN "ggml: use Vulkan"
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)
option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug output" OFF)
option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF)
option(GGML_VULKAN_PERF "ggml: enable Vulkan perf output" OFF)
option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
Expand Down
2 changes: 2 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ extern "C" {
GGML_TYPE_Q4_0_4_4 = 31,
GGML_TYPE_Q4_0_4_8 = 32,
GGML_TYPE_Q4_0_8_8 = 33,
GGML_TYPE_TQ1_0 = 34,
GGML_TYPE_TQ2_0 = 35,
GGML_TYPE_COUNT,
};

Expand Down
4 changes: 4 additions & 0 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,10 @@ if (GGML_VULKAN)
add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
endif()

if (GGML_VULKAN_SHADER_DEBUG_INFO)
add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
endif()

if (GGML_VULKAN_PERF)
add_compile_definitions(GGML_VULKAN_PERF)
endif()
Expand Down
20 changes: 20 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,25 @@ typedef struct {
} block_q8_0x8;
static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding");

//
// Ternary quantization
//

// 1.6875 bpw
typedef struct {
uint8_t qs[(QK_K - 4 * QK_K / 64) / 5]; // 5 elements per byte (3^5 = 243 < 256)
uint8_t qh[QK_K/64]; // 4 elements per byte
ggml_half d;
} block_tq1_0;
static_assert(sizeof(block_tq1_0) == sizeof(ggml_half) + QK_K / 64 + (QK_K - 4 * QK_K / 64) / 5, "wrong tq1_0 block size/padding");

// 2.0625 bpw
typedef struct {
uint8_t qs[QK_K/4]; // 2 bits per element
ggml_half d;
} block_tq2_0;
static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding");

//
// Super-block quantization structures
//
Expand Down Expand Up @@ -361,6 +380,7 @@ typedef struct {
} block_iq3_s;
static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");

// 1.5625 bpw
typedef struct {
ggml_half d;
uint8_t qs[QK_K/8];
Expand Down
11 changes: 4 additions & 7 deletions ggml/src/ggml-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ typedef __fp16 ggml_fp16_internal_t;

// 32-bit ARM compatibility

// vaddvq_s16
// vaddlvq_s16
// vpaddq_s16
// vpaddq_s32
// vaddvq_s32
Expand All @@ -185,12 +185,9 @@ typedef __fp16 ggml_fp16_internal_t;
// vzip1_u8
// vzip2_u8

inline static int32_t vaddvq_s16(int16x8_t v) {
return
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
inline static int32_t vaddlvq_s16(int16x8_t v) {
int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
}

inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
Expand Down
Loading

0 comments on commit 5ae09fd

Please sign in to comment.