diff --git a/README-sycl.md b/README-sycl.md
index 93b623daf6a1a..bd1984706225f 100644
--- a/README-sycl.md
+++ b/README-sycl.md
@@ -1,6 +1,7 @@
# llama.cpp for SYCL
- [Background](#background)
+- [Recommended Release](#recommended-release)
- [News](#news)
- [OS](#os)
- [Hardware](#hardware)
@@ -31,8 +32,23 @@ When targeting **Intel CPU**, it is recommended to use llama.cpp for [Intel oneM
It has the similar design of other llama.cpp BLAS-based paths such as *OpenBLAS, cuBLAS, etc..*. In beginning work, the oneAPI's [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) open-source migration tool (Commercial release [IntelĀ® DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) was used for this purpose.
+## Recommended Release
+
+The SYCL backend would be broken by some PRs due to no online CI.
+
+The following release is verified with good quality:
+
+|Commit ID|Tag|Release|Verified Platform|
+|-|-|-|-|
+|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggerganov/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1|
+
+
## News
+- 2024.5
+ - Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770.
+ - Arch Linux is verified successfully.
+
- 2024.4
- Support data types: GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_XS, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M.
diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp
index 8b95cec374c23..36eadaac26a12 100644
--- a/examples/cvector-generator/pca.hpp
+++ b/examples/cvector-generator/pca.hpp
@@ -64,15 +64,15 @@ struct pca_model {
struct ggml_tensor * dev_eigenvector;
pca_model(struct ggml_tensor * t_input) {
-// TODO: enable GPU support when support for GGML_OP_SQRT is added
-// #ifdef GGML_USE_CUDA
-// fprintf(stderr, "%s: using CUDA backend\n", __func__);
-// backend = ggml_backend_cuda_init(0); // init device 0
-// if (!backend) {
-// fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
-// }
-// #endif
+#ifdef GGML_USE_CUDA
+ fprintf(stderr, "%s: using CUDA backend\n", __func__);
+ backend = ggml_backend_cuda_init(0); // init device 0
+ if (!backend) {
+ fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
+ }
+#endif
+// TODO: enable Metal support when support for GGML_OP_SQRT is added
// #ifdef GGML_USE_METAL
// fprintf(stderr, "%s: using Metal backend\n", __func__);
// backend = ggml_backend_metal_init();
diff --git a/ggml-backend.c b/ggml-backend.c
index 2bec7bea38a85..26dce7f724213 100644
--- a/ggml-backend.c
+++ b/ggml-backend.c
@@ -1172,7 +1172,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
// check if a backend with higher prio wants to offload the op
if (src_backend_id == sched->n_backends - 1) {
for (int b = 0; b < src_backend_id; b++) {
- if (ggml_backend_offload_op(sched->backends[b], tensor)) {
+ if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
SET_CAUSE(tensor, "1.off");
return b;
}
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 593fa4cdaa514..b8298ab205e60 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -2267,6 +2267,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SQR:
ggml_cuda_op_sqr(ctx, dst);
break;
+ case GGML_OP_SQRT:
+ ggml_cuda_op_sqrt(ctx, dst);
+ break;
case GGML_OP_CLAMP:
ggml_cuda_op_clamp(ctx, dst);
break;
@@ -2830,6 +2833,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
case GGML_OP_SQR:
+ case GGML_OP_SQRT:
case GGML_OP_CLAMP:
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
diff --git a/ggml-cuda/unary.cu b/ggml-cuda/unary.cu
index a5ff96320f23f..f9e208011e2a8 100644
--- a/ggml-cuda/unary.cu
+++ b/ggml-cuda/unary.cu
@@ -92,6 +92,15 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] * x[i];
}
+static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = sqrtf(x[i]);
+}
+
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
gelu_f32<<>>(x, dst, k);
@@ -142,6 +151,11 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
sqr_f32<<>>(x, dst, k);
}
+static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE;
+ sqrt_f32<<>>(x, dst, k);
+}
+
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -284,3 +298,17 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
diff --git a/ggml-cuda/unary.cuh b/ggml-cuda/unary.cuh
index a1d07c04fcd43..4cfb0479e7169 100644
--- a/ggml-cuda/unary.cuh
+++ b/ggml-cuda/unary.cuh
@@ -8,6 +8,7 @@
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
+#define CUDA_SQRT_BLOCK_SIZE 256
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -28,3 +29,5 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp
index 22d9524b8d764..b01ad267446fb 100644
--- a/ggml-rpc.cpp
+++ b/ggml-rpc.cpp
@@ -73,9 +73,13 @@ struct rpc_tensor {
uint64_t view_offs;
uint64_t data;
char name[GGML_MAX_NAME];
+
+ char padding[4];
};
#pragma pack(pop)
+static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
+
// RPC commands
enum rpc_cmd {
ALLOC_BUFFER = 0,
@@ -599,9 +603,8 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector & o
int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
output.resize(output_size, 0);
memcpy(output.data(), &n_nodes, sizeof(n_nodes));
- uint64_t * out_nodes = (uint64_t *)(output.data() + sizeof(n_nodes));
for (uint32_t i = 0; i < n_nodes; i++) {
- out_nodes[i] = reinterpret_cast(cgraph->nodes[i]);
+ memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
}
uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
*out_ntensors = n_tensors;
@@ -1036,7 +1039,9 @@ bool rpc_server::graph_compute(const std::vector & input, std::vector tensor_map;
for (uint32_t i = 0; i < n_nodes; i++) {
- graph->nodes[i] = create_node(nodes[i], ctx, tensor_ptrs, tensor_map);
+ int64_t id;
+ memcpy(&id, &nodes[i], sizeof(id));
+ graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
}
ggml_status status = ggml_backend_graph_compute(backend, graph);
// output serialization format: | status (1 byte) |
diff --git a/gguf-py/scripts/gguf-dump.py b/gguf-py/scripts/gguf-dump.py
index 1a37a7b91409d..92d14d6cd0a69 100755
--- a/gguf-py/scripts/gguf-dump.py
+++ b/gguf-py/scripts/gguf-dump.py
@@ -14,7 +14,7 @@
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
sys.path.insert(0, str(Path(__file__).parent.parent))
-from gguf import GGUFReader, GGUFValueType # noqa: E402
+from gguf import GGUFReader, GGUFValueType, ReaderTensor # noqa: E402
logger = logging.getLogger("gguf-dump")
@@ -101,25 +101,285 @@ def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None:
json.dump(result, sys.stdout)
+def markdown_table_with_alignment_support(header_map: list[dict[str, str]], data: list[dict[str, Any]]):
+ # JSON to Markdown table formatting: https://stackoverflow.com/a/72983854/2850957
+
+ # Alignment Utility Function
+ def strAlign(padding: int, alignMode: str | None, strVal: str):
+ if alignMode == 'center':
+ return strVal.center(padding)
+ elif alignMode == 'right':
+ return strVal.rjust(padding - 1) + ' '
+ elif alignMode == 'left':
+ return ' ' + strVal.ljust(padding - 1)
+ else: # default left
+ return ' ' + strVal.ljust(padding - 1)
+
+ def dashAlign(padding: int, alignMode: str | None):
+ if alignMode == 'center':
+ return ':' + '-' * (padding - 2) + ':'
+ elif alignMode == 'right':
+ return '-' * (padding - 1) + ':'
+ elif alignMode == 'left':
+ return ':' + '-' * (padding - 1)
+ else: # default left
+ return '-' * (padding)
+
+ # Calculate Padding For Each Column Based On Header and Data Length
+ rowsPadding = {}
+ for index, columnEntry in enumerate(header_map):
+ padCount = max([len(str(v)) for d in data for k, v in d.items() if k == columnEntry['key_name']], default=0) + 2
+ headerPadCount = len(columnEntry['header_name']) + 2
+ rowsPadding[index] = headerPadCount if padCount <= headerPadCount else padCount
+
+ # Render Markdown Header
+ rows = []
+ rows.append('|'.join(strAlign(rowsPadding[index], columnEntry.get('align'), str(columnEntry['header_name'])) for index, columnEntry in enumerate(header_map)))
+ rows.append('|'.join(dashAlign(rowsPadding[index], columnEntry.get('align')) for index, columnEntry in enumerate(header_map)))
+
+ # Render Tabular Data
+ for item in data:
+ rows.append('|'.join(strAlign(rowsPadding[index], columnEntry.get('align'), str(item[columnEntry['key_name']])) for index, columnEntry in enumerate(header_map)))
+
+ # Convert Tabular String Rows Into String
+ tableString = ""
+ for row in rows:
+ tableString += f'|{row}|\n'
+
+ return tableString
+
+
+def element_count_rounded_notation(count: int) -> str:
+ if count > 1e15 :
+ # Quadrillion
+ scaled_amount = count * 1e-15
+ scale_suffix = "Q"
+ elif count > 1e12 :
+ # Trillions
+ scaled_amount = count * 1e-12
+ scale_suffix = "T"
+ elif count > 1e9 :
+ # Billions
+ scaled_amount = count * 1e-9
+ scale_suffix = "B"
+ elif count > 1e6 :
+ # Millions
+ scaled_amount = count * 1e-6
+ scale_suffix = "M"
+ elif count > 1e3 :
+ # Thousands
+ scaled_amount = count * 1e-3
+ scale_suffix = "K"
+ else:
+ # Under Thousands
+ scaled_amount = count
+ scale_suffix = ""
+ return f"{'~' if count > 1e3 else ''}{round(scaled_amount)}{scale_suffix}"
+
+
+def translate_tensor_name(name):
+ words = name.split(".")
+
+ # Source: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#standardized-tensor-names
+ abbreviation_dictionary = {
+ 'token_embd': 'Token embedding',
+ 'pos_embd': 'Position embedding',
+ 'output_norm': 'Output normalization',
+ 'output': 'Output',
+ 'attn_norm': 'Attention normalization',
+ 'attn_norm_2': 'Attention normalization',
+ 'attn_qkv': 'Attention query-key-value',
+ 'attn_q': 'Attention query',
+ 'attn_k': 'Attention key',
+ 'attn_v': 'Attention value',
+ 'attn_output': 'Attention output',
+ 'ffn_norm': 'Feed-forward network normalization',
+ 'ffn_up': 'Feed-forward network "up"',
+ 'ffn_gate': 'Feed-forward network "gate"',
+ 'ffn_down': 'Feed-forward network "down"',
+ 'ffn_gate_inp': 'Expert-routing layer for the Feed-forward network in Mixture of Expert models',
+ 'ffn_gate_exp': 'Feed-forward network "gate" layer per expert in Mixture of Expert models',
+ 'ffn_down_exp': 'Feed-forward network "down" layer per expert in Mixture of Expert models',
+ 'ffn_up_exp': 'Feed-forward network "up" layer per expert in Mixture of Expert models',
+ 'ssm_in': 'State space model input projections',
+ 'ssm_conv1d': 'State space model rolling/shift',
+ 'ssm_x': 'State space model selective parametrization',
+ 'ssm_a': 'State space model state compression',
+ 'ssm_d': 'State space model skip connection',
+ 'ssm_dt': 'State space model time step',
+ 'ssm_out': 'State space model output projection',
+ 'blk': 'Block'
+ }
+
+ expanded_words = []
+ for word in words:
+ word_norm = word.strip().lower()
+ if word_norm in abbreviation_dictionary:
+ expanded_words.append(abbreviation_dictionary[word_norm].title())
+ else:
+ expanded_words.append(word.title())
+
+ return ' '.join(expanded_words)
+
+
+def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
+ host_endian, file_endian = get_file_host_endian(reader)
+ markdown_content = ""
+ markdown_content += f'# {args.model} - GGUF Internal File Dump\n\n'
+ markdown_content += f'- Endian: {file_endian} endian\n'
+ markdown_content += '\n'
+ markdown_content += '## Key Value Metadata Store\n\n'
+ markdown_content += f'There are {len(reader.fields)} key-value pairs in this file\n'
+ markdown_content += '\n'
+
+ kv_dump_table: list[dict[str, str | int]] = []
+ for n, field in enumerate(reader.fields.values(), 1):
+ if not field.types:
+ pretty_type = 'N/A'
+ elif field.types[0] == GGUFValueType.ARRAY:
+ nest_count = len(field.types) - 1
+ pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count
+ else:
+ pretty_type = str(field.types[-1].name)
+
+ total_elements = len(field.data)
+ value = ""
+ if len(field.types) == 1:
+ curr_type = field.types[0]
+ if curr_type == GGUFValueType.STRING:
+ value = repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60])
+ elif curr_type in reader.gguf_scalar_to_np:
+ value = str(field.parts[-1][0])
+ else:
+ if field.types[0] == GGUFValueType.ARRAY:
+ curr_type = field.types[1]
+ if curr_type == GGUFValueType.STRING:
+ render_element = min(5, total_elements)
+ for element_pos in range(render_element):
+ value += repr(str(bytes(field.parts[-1 - element_pos]), encoding='utf-8')[:5]) + (", " if total_elements > 1 else "")
+ elif curr_type in reader.gguf_scalar_to_np:
+ render_element = min(7, total_elements)
+ for element_pos in range(render_element):
+ value += str(field.parts[-1 - element_pos][0]) + (", " if total_elements > 1 else "")
+ value = f'[ {value}{" ..." if total_elements > 1 else ""} ]'
+ kv_dump_table.append({"n":n, "pretty_type":pretty_type, "total_elements":total_elements, "field_name":field.name, "value":value})
+
+ kv_dump_table_header_map = [
+ {'key_name':'n', 'header_name':'POS', 'align':'right'},
+ {'key_name':'pretty_type', 'header_name':'TYPE', 'align':'left'},
+ {'key_name':'total_elements', 'header_name':'Count', 'align':'right'},
+ {'key_name':'field_name', 'header_name':'Key', 'align':'left'},
+ {'key_name':'value', 'header_name':'Value', 'align':'left'},
+ ]
+
+ markdown_content += markdown_table_with_alignment_support(kv_dump_table_header_map, kv_dump_table)
+
+ markdown_content += "\n"
+
+ if not args.no_tensors:
+ # Group tensors by their prefix and maintain order
+ tensor_prefix_order: list[str] = []
+ tensor_name_to_key: dict[str, int] = {}
+ tensor_groups: dict[str, list[ReaderTensor]] = {}
+ total_elements = sum(tensor.n_elements for tensor in reader.tensors)
+
+ # Parsing Tensors Record
+ for key, tensor in enumerate(reader.tensors):
+ tensor_components = tensor.name.split('.')
+
+ # Classify Tensor Group
+ tensor_group_name = "base"
+ if tensor_components[0] == 'blk':
+ tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}"
+
+ # Check if new Tensor Group
+ if tensor_group_name not in tensor_groups:
+ tensor_groups[tensor_group_name] = []
+ tensor_prefix_order.append(tensor_group_name)
+
+ # Record Tensor and Tensor Position
+ tensor_groups[tensor_group_name].append(tensor)
+ tensor_name_to_key[tensor.name] = key
+
+ # Tensors Mapping Dump
+ markdown_content += f'## Tensors Overview {element_count_rounded_notation(total_elements)} Elements\n\n'
+ markdown_content += f'Total number of elements in all tensors: {total_elements} Elements\n'
+ markdown_content += '\n'
+
+ for group in tensor_prefix_order:
+ tensors = tensor_groups[group]
+ group_elements = sum(tensor.n_elements for tensor in tensors)
+ markdown_content += f"- [{translate_tensor_name(group)} Tensor Group - {element_count_rounded_notation(group_elements)} Elements](#{group.replace('.', '_')})\n"
+
+ markdown_content += "\n"
+
+ for group in tensor_prefix_order:
+ tensors = tensor_groups[group]
+ group_elements = sum(tensor.n_elements for tensor in tensors)
+ group_percentage = group_elements / total_elements * 100
+ markdown_content += f"### {translate_tensor_name(group)} Tensor Group : {element_count_rounded_notation(group_elements)} Elements\n\n"
+
+ # Precalculate column sizing for visual consistency
+ prettify_element_est_count_size: int = 1
+ prettify_element_count_size: int = 1
+ prettify_dimension_max_widths: dict[int, int] = {}
+ for tensor in tensors:
+ prettify_element_est_count_size = max(prettify_element_est_count_size, len(str(element_count_rounded_notation(tensor.n_elements))))
+ prettify_element_count_size = max(prettify_element_count_size, len(str(tensor.n_elements)))
+ for i, dimension_size in enumerate(list(tensor.shape) + [1] * (4 - len(tensor.shape))):
+ prettify_dimension_max_widths[i] = max(prettify_dimension_max_widths.get(i,1), len(str(dimension_size)))
+
+ # Generate Tensor Layer Table Content
+ tensor_dump_table: list[dict[str, str | int]] = []
+ for tensor in tensors:
+ human_friendly_name = translate_tensor_name(tensor.name.replace(".weight", ".(W)").replace(".bias", ".(B)"))
+ pretty_dimension = ' x '.join(f'{str(d):>{prettify_dimension_max_widths[i]}}' for i, d in enumerate(list(tensor.shape) + [1] * (4 - len(tensor.shape))))
+ element_count_est = f"({element_count_rounded_notation(tensor.n_elements):>{prettify_element_est_count_size}})"
+ element_count_string = f"{element_count_est} {tensor.n_elements:>{prettify_element_count_size}}"
+ type_name_string = f"{tensor.tensor_type.name}"
+ tensor_dump_table.append({"t_id":tensor_name_to_key[tensor.name], "layer_name":tensor.name, "human_layer_name":human_friendly_name, "element_count":element_count_string, "pretty_dimension":pretty_dimension, "tensor_type":type_name_string})
+
+ tensor_dump_table_header_map = [
+ {'key_name':'t_id', 'header_name':'T_ID', 'align':'right'},
+ {'key_name':'layer_name', 'header_name':'Tensor Layer Name', 'align':'left'},
+ {'key_name':'human_layer_name', 'header_name':'Human Friendly Tensor Layer Name', 'align':'left'},
+ {'key_name':'element_count', 'header_name':'Elements', 'align':'left'},
+ {'key_name':'pretty_dimension', 'header_name':'Shape', 'align':'left'},
+ {'key_name':'tensor_type', 'header_name':'Type', 'align':'left'},
+ ]
+
+ markdown_content += markdown_table_with_alignment_support(tensor_dump_table_header_map, tensor_dump_table)
+
+ markdown_content += "\n"
+ markdown_content += f"- Total elements in {group}: ({element_count_rounded_notation(group_elements):>4}) {group_elements}\n"
+ markdown_content += f"- Percentage of total elements: {group_percentage:.2f}%\n"
+ markdown_content += "\n\n"
+
+ print(markdown_content) # noqa: NP100
+
+
def main() -> None:
parser = argparse.ArgumentParser(description="Dump GGUF file metadata")
parser.add_argument("model", type=str, help="GGUF format model filename")
parser.add_argument("--no-tensors", action="store_true", help="Don't dump tensor metadata")
parser.add_argument("--json", action="store_true", help="Produce JSON output")
parser.add_argument("--json-array", action="store_true", help="Include full array values in JSON output (long)")
+ parser.add_argument("--markdown", action="store_true", help="Produce markdown output")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
- if not args.json:
+ if not args.json and not args.markdown:
logger.info(f'* Loading: {args.model}')
reader = GGUFReader(args.model, 'r')
if args.json:
dump_metadata_json(reader, args)
+ elif args.markdown:
+ dump_markdown_metadata(reader, args)
else:
dump_metadata(reader, args)
diff --git a/llama.cpp b/llama.cpp
index bd4f8ec1865fb..dd7020dc0eeab 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1278,6 +1278,126 @@ struct no_init {
};
struct llama_file {
+
+#if defined(_WIN32)
+ // use FILE * so we don't have to re-open the file to mmap
+ FILE * fp;
+ HANDLE fp_win32;
+ size_t size;
+
+private:
+ std::string GetErrorMessageWin32(DWORD error_code) const {
+ std::string ret;
+ LPSTR lpMsgBuf = NULL;
+ DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
+ NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
+ if (!bufLen) {
+ ret = format("Win32 error code: %s", error_code);
+ } else {
+ ret = lpMsgBuf;
+ LocalFree(lpMsgBuf);
+ }
+
+ return ret;
+ }
+
+public:
+
+ llama_file(const char * fname, const char * mode) {
+ fp = ggml_fopen(fname, mode);
+ if (fp == NULL) {
+ throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
+ }
+ fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp));
+ seek(0, SEEK_END);
+ size = tell();
+ seek(0, SEEK_SET);
+ }
+
+ size_t tell() const {
+ // SetFilePointerEx returns the current position when seeking relative 0 bytes
+ LARGE_INTEGER li;
+ li.QuadPart = 0;
+ BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT);
+ if (!ret) {
+ throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
+ }
+
+ return li.QuadPart;
+ }
+
+ void seek(size_t offset, int whence) const {
+ // no need to convert SEEK_* to FILE_*. The enums are the same.
+ // Still, keep static asserts to avoid failures in the future.
+ static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN");
+ static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT");
+ static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END");
+
+ LARGE_INTEGER li;
+ li.QuadPart = offset;
+ BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence);
+ if (!ret) {
+ throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
+ }
+ }
+
+ void read_raw(void * ptr, size_t len) const {
+ // On Win32 ReadFile is significant faster than fread which is again significant faster than std::fstream. Thus
+ // use the Win32 API to do file io instead of the C/C++ library functions.
+
+ // There are conditions under which ReadFile cannot read chunks >64MB.
+ // Thus split the operation into smaller chunks if len exceeds this limit.
+ size_t bytes_read = 0;
+ while (bytes_read < len) {
+ size_t chunk_size = std::min(len - bytes_read, 64*1024*1024);
+ DWORD chunk_read = 0;
+ BOOL result = ReadFile(fp_win32, reinterpret_cast(ptr) + bytes_read, chunk_size, &chunk_read, NULL);
+ if (!result) {
+ throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
+ }
+ if (chunk_read < chunk_size || chunk_read == 0) {
+ throw std::runtime_error("unexpectedly reached end of file");
+ }
+
+ bytes_read += chunk_read;
+ } ;
+ }
+
+ uint32_t read_u32() const {
+ uint32_t val;
+ read_raw(&val, sizeof(val));
+ return val;
+ }
+
+ void write_raw(const void * ptr, size_t len) const {
+ // There are conditions under which WriteFile cannot write chunks >64MB.
+ // Thus split the operation into smaller chunks if len exceeds this limit.
+ size_t bytes_written = 0;
+ while (bytes_written < len) {
+ size_t chunk_size = std::min(len - bytes_written, 64*1024*1024);
+ DWORD chunk_written = 0;
+ BOOL result = WriteFile(fp_win32, reinterpret_cast(ptr) + bytes_written, chunk_size, &chunk_written, NULL);
+ if (!result) {
+ throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
+ }
+ if (chunk_written < chunk_size || chunk_written == 0) {
+ throw std::runtime_error("unexpectedly failed to write bytes");
+ }
+
+ bytes_written += chunk_written;
+ }
+ }
+
+ void write_u32(std::uint32_t val) const {
+ write_raw(&val, sizeof(val));
+ }
+
+ ~llama_file() {
+ if (fp) {
+ std::fclose(fp);
+ }
+ }
+#else
// use FILE * so we don't have to re-open the file to mmap
FILE * fp;
size_t size;
@@ -1298,7 +1418,10 @@ struct llama_file {
#else
long ret = std::ftell(fp);
#endif
- GGML_ASSERT(ret != -1); // this really shouldn't fail
+ if (ret == -1) {
+ throw std::runtime_error(format("ftell error: %s", strerror(errno)));
+ }
+
return (size_t) ret;
}
@@ -1308,7 +1431,9 @@ struct llama_file {
#else
int ret = std::fseek(fp, (long) offset, whence);
#endif
- GGML_ASSERT(ret == 0); // same
+ if (ret != 0) {
+ throw std::runtime_error(format("seek error: %s", strerror(errno)));
+ }
}
void read_raw(void * ptr, size_t len) const {
@@ -1351,6 +1476,7 @@ struct llama_file {
std::fclose(fp);
}
}
+#endif
};
using llama_files = std::vector>;
@@ -3721,6 +3847,44 @@ struct llama_model_loader {
std::vector> read_buf;
std::vector>> validation_result;
+#if defined(GGML_USE_CUDA)
+ // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives.
+ // NVMe raid configurations might require more / larger buffers.
+ constexpr size_t num_buffers = 4;
+ constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB
+
+ std::vector host_buffers;
+ std::vector host_ptrs;
+ std::vector events;
+ size_t buffer_idx = 0; // buffer to use for async loads
+
+ ggml_backend_t cuda_backend = nullptr;
+ if (!use_mmap && !check_tensors) {
+ // When not using mmaped io use async uploads from pinned memory to GPU memory.
+ // First determine if the CUDA backend is active, and if so, determine the device ID.
+ ggml_backend_buffer_t buf = bufs_mmap.count(0) ? bufs_mmap.at(0) : nullptr;
+ if (buf) {
+ ggml_backend_buffer_type_t buffer_type = ggml_backend_buffer_get_type(buf);
+ for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
+ auto * cuda_buffer_type = ggml_backend_cuda_buffer_type(i);
+ if (buffer_type == cuda_buffer_type) {
+ cuda_backend = ggml_backend_cuda_init(i);
+ break;
+ }
+ }
+ }
+
+ // If the cuda backend is active create pinned memory buffers and events for synchronisation.
+ if (cuda_backend) {
+ for (size_t idx = 0; idx < num_buffers; ++idx) {
+ host_buffers.emplace_back(ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size));
+ host_ptrs.emplace_back(ggml_backend_buffer_get_base(host_buffers[idx]));
+ events.emplace_back(ggml_backend_event_new(cuda_backend));
+ }
+ }
+ }
+#endif
+
for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) {
const auto * weight = get_weight(ggml_get_name(cur));
if (weight == nullptr) {
@@ -3776,12 +3940,36 @@ struct llama_model_loader {
}));
}
} else {
- read_buf.resize(n_size);
- file->seek(weight->offs, SEEK_SET);
- file->read_raw(read_buf.data(), n_size);
- ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
- if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
- throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
+#if defined(GGML_USE_CUDA)
+ // If cuda_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
+ if (cuda_backend) {
+ file->seek(weight->offs, SEEK_SET);
+
+ size_t bytes_read = 0;
+
+ while (bytes_read < n_size) {
+ size_t read_iteration = std::min(buffer_size, n_size - bytes_read);
+
+ ggml_backend_event_synchronize(events[buffer_idx]);
+ file->read_raw(host_ptrs[buffer_idx], read_iteration);
+ ggml_backend_tensor_set_async(cuda_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
+ ggml_backend_event_record(events[buffer_idx]);
+
+ bytes_read += read_iteration;
+ ++buffer_idx;
+ buffer_idx %= num_buffers;
+ }
+ }
+ else
+#endif
+ {
+ read_buf.resize(n_size);
+ file->seek(weight->offs, SEEK_SET);
+ file->read_raw(read_buf.data(), n_size);
+ ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
+ if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
+ throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
+ }
}
}
}
@@ -3789,6 +3977,18 @@ struct llama_model_loader {
size_done += n_size;
}
+#if defined(GGML_USE_CUDA)
+ // free temporary resources used for async cuda uploads
+ if (cuda_backend) {
+ for (size_t idx = 0; idx < num_buffers;++idx) {
+ ggml_backend_event_synchronize(events[idx]);
+ ggml_backend_event_free(events[idx]);
+ ggml_backend_buffer_free(host_buffers[idx]);
+ }
+ ggml_backend_free(cuda_backend);
+ }
+#endif
+
// check validation results
bool validation_failed = false;
for (auto & future : validation_result) {
@@ -5183,7 +5383,7 @@ static bool llm_load_tensors(
// create tensors for the weights
{
const int64_t n_embd = hparams.n_embd;
- const int64_t n_embd_head = n_embd / hparams.n_head;
+ const int64_t n_embd_head = (hparams.n_head == 0) ? 0 : n_embd / hparams.n_head;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const int64_t n_embd_gqa = n_embd_v_gqa;
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 2b48e623e3476..7c504e937a851 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1063,6 +1063,33 @@ struct test_sqr : public test_case {
}
};
+// GGML_OP_SQRT
+struct test_sqrt : public test_case {
+ const ggml_type type;
+ const std::array ne;
+
+ std::string vars() override {
+ return VARS_TO_STR2(type, ne);
+ }
+
+ test_sqrt(ggml_type type = GGML_TYPE_F32,
+ std::array ne = {10, 10, 10, 10})
+ : type(type), ne(ne) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_tensor * out = ggml_sqrt(ctx, a);
+ return out;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ // fill with positive values
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ init_tensor_uniform(t, 0.0f, 100.0f);
+ }
+ }
+};
+
// GGML_OP_CLAMP
struct test_clamp : public test_case {
const ggml_type type;
@@ -2200,6 +2227,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
test_cases.emplace_back(new test_sqr());
+ test_cases.emplace_back(new test_sqrt());
test_cases.emplace_back(new test_clamp());
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));