diff --git a/README-sycl.md b/README-sycl.md index 93b623daf6a1a..bd1984706225f 100644 --- a/README-sycl.md +++ b/README-sycl.md @@ -1,6 +1,7 @@ # llama.cpp for SYCL - [Background](#background) +- [Recommended Release](#recommended-release) - [News](#news) - [OS](#os) - [Hardware](#hardware) @@ -31,8 +32,23 @@ When targeting **Intel CPU**, it is recommended to use llama.cpp for [Intel oneM It has the similar design of other llama.cpp BLAS-based paths such as *OpenBLAS, cuBLAS, etc..*. In beginning work, the oneAPI's [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) open-source migration tool (Commercial release [IntelĀ® DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) was used for this purpose. +## Recommended Release + +The SYCL backend would be broken by some PRs due to no online CI. + +The following release is verified with good quality: + +|Commit ID|Tag|Release|Verified Platform| +|-|-|-|-| +|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggerganov/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1| + + ## News +- 2024.5 + - Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770. + - Arch Linux is verified successfully. + - 2024.4 - Support data types: GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_XS, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M. diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp index 8b95cec374c23..36eadaac26a12 100644 --- a/examples/cvector-generator/pca.hpp +++ b/examples/cvector-generator/pca.hpp @@ -64,15 +64,15 @@ struct pca_model { struct ggml_tensor * dev_eigenvector; pca_model(struct ggml_tensor * t_input) { -// TODO: enable GPU support when support for GGML_OP_SQRT is added -// #ifdef GGML_USE_CUDA -// fprintf(stderr, "%s: using CUDA backend\n", __func__); -// backend = ggml_backend_cuda_init(0); // init device 0 -// if (!backend) { -// fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); -// } -// #endif +#ifdef GGML_USE_CUDA + fprintf(stderr, "%s: using CUDA backend\n", __func__); + backend = ggml_backend_cuda_init(0); // init device 0 + if (!backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } +#endif +// TODO: enable Metal support when support for GGML_OP_SQRT is added // #ifdef GGML_USE_METAL // fprintf(stderr, "%s: using Metal backend\n", __func__); // backend = ggml_backend_metal_init(); diff --git a/ggml-backend.c b/ggml-backend.c index 2bec7bea38a85..26dce7f724213 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -1172,7 +1172,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st // check if a backend with higher prio wants to offload the op if (src_backend_id == sched->n_backends - 1) { for (int b = 0; b < src_backend_id; b++) { - if (ggml_backend_offload_op(sched->backends[b], tensor)) { + if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) { SET_CAUSE(tensor, "1.off"); return b; } diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 593fa4cdaa514..b8298ab205e60 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2267,6 +2267,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SQR: ggml_cuda_op_sqr(ctx, dst); break; + case GGML_OP_SQRT: + ggml_cuda_op_sqrt(ctx, dst); + break; case GGML_OP_CLAMP: ggml_cuda_op_clamp(ctx, dst); break; @@ -2830,6 +2833,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_RMS_NORM: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SQRT: case GGML_OP_CLAMP: case GGML_OP_CONT: case GGML_OP_DIAG_MASK_INF: diff --git a/ggml-cuda/unary.cu b/ggml-cuda/unary.cu index a5ff96320f23f..f9e208011e2a8 100644 --- a/ggml-cuda/unary.cu +++ b/ggml-cuda/unary.cu @@ -92,6 +92,15 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) { dst[i] = x[i] * x[i]; } +static __global__ void sqrt_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = sqrtf(x[i]); +} + static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; gelu_f32<<>>(x, dst, k); @@ -142,6 +151,11 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t sqr_f32<<>>(x, dst, k); } +static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE; + sqrt_f32<<>>(x, dst, k); +} + void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -284,3 +298,17 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } + +void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +} diff --git a/ggml-cuda/unary.cuh b/ggml-cuda/unary.cuh index a1d07c04fcd43..4cfb0479e7169 100644 --- a/ggml-cuda/unary.cuh +++ b/ggml-cuda/unary.cuh @@ -8,6 +8,7 @@ #define CUDA_HARDSIGMOID_BLOCK_SIZE 256 #define CUDA_HARDSWISH_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256 +#define CUDA_SQRT_BLOCK_SIZE 256 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -28,3 +29,5 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp index 22d9524b8d764..b01ad267446fb 100644 --- a/ggml-rpc.cpp +++ b/ggml-rpc.cpp @@ -73,9 +73,13 @@ struct rpc_tensor { uint64_t view_offs; uint64_t data; char name[GGML_MAX_NAME]; + + char padding[4]; }; #pragma pack(pop) +static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8"); + // RPC commands enum rpc_cmd { ALLOC_BUFFER = 0, @@ -599,9 +603,8 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector & o int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); output.resize(output_size, 0); memcpy(output.data(), &n_nodes, sizeof(n_nodes)); - uint64_t * out_nodes = (uint64_t *)(output.data() + sizeof(n_nodes)); for (uint32_t i = 0; i < n_nodes; i++) { - out_nodes[i] = reinterpret_cast(cgraph->nodes[i]); + memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); } uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t)); *out_ntensors = n_tensors; @@ -1036,7 +1039,9 @@ bool rpc_server::graph_compute(const std::vector & input, std::vector tensor_map; for (uint32_t i = 0; i < n_nodes; i++) { - graph->nodes[i] = create_node(nodes[i], ctx, tensor_ptrs, tensor_map); + int64_t id; + memcpy(&id, &nodes[i], sizeof(id)); + graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map); } ggml_status status = ggml_backend_graph_compute(backend, graph); // output serialization format: | status (1 byte) | diff --git a/gguf-py/scripts/gguf-dump.py b/gguf-py/scripts/gguf-dump.py index 1a37a7b91409d..92d14d6cd0a69 100755 --- a/gguf-py/scripts/gguf-dump.py +++ b/gguf-py/scripts/gguf-dump.py @@ -14,7 +14,7 @@ if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): sys.path.insert(0, str(Path(__file__).parent.parent)) -from gguf import GGUFReader, GGUFValueType # noqa: E402 +from gguf import GGUFReader, GGUFValueType, ReaderTensor # noqa: E402 logger = logging.getLogger("gguf-dump") @@ -101,25 +101,285 @@ def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None: json.dump(result, sys.stdout) +def markdown_table_with_alignment_support(header_map: list[dict[str, str]], data: list[dict[str, Any]]): + # JSON to Markdown table formatting: https://stackoverflow.com/a/72983854/2850957 + + # Alignment Utility Function + def strAlign(padding: int, alignMode: str | None, strVal: str): + if alignMode == 'center': + return strVal.center(padding) + elif alignMode == 'right': + return strVal.rjust(padding - 1) + ' ' + elif alignMode == 'left': + return ' ' + strVal.ljust(padding - 1) + else: # default left + return ' ' + strVal.ljust(padding - 1) + + def dashAlign(padding: int, alignMode: str | None): + if alignMode == 'center': + return ':' + '-' * (padding - 2) + ':' + elif alignMode == 'right': + return '-' * (padding - 1) + ':' + elif alignMode == 'left': + return ':' + '-' * (padding - 1) + else: # default left + return '-' * (padding) + + # Calculate Padding For Each Column Based On Header and Data Length + rowsPadding = {} + for index, columnEntry in enumerate(header_map): + padCount = max([len(str(v)) for d in data for k, v in d.items() if k == columnEntry['key_name']], default=0) + 2 + headerPadCount = len(columnEntry['header_name']) + 2 + rowsPadding[index] = headerPadCount if padCount <= headerPadCount else padCount + + # Render Markdown Header + rows = [] + rows.append('|'.join(strAlign(rowsPadding[index], columnEntry.get('align'), str(columnEntry['header_name'])) for index, columnEntry in enumerate(header_map))) + rows.append('|'.join(dashAlign(rowsPadding[index], columnEntry.get('align')) for index, columnEntry in enumerate(header_map))) + + # Render Tabular Data + for item in data: + rows.append('|'.join(strAlign(rowsPadding[index], columnEntry.get('align'), str(item[columnEntry['key_name']])) for index, columnEntry in enumerate(header_map))) + + # Convert Tabular String Rows Into String + tableString = "" + for row in rows: + tableString += f'|{row}|\n' + + return tableString + + +def element_count_rounded_notation(count: int) -> str: + if count > 1e15 : + # Quadrillion + scaled_amount = count * 1e-15 + scale_suffix = "Q" + elif count > 1e12 : + # Trillions + scaled_amount = count * 1e-12 + scale_suffix = "T" + elif count > 1e9 : + # Billions + scaled_amount = count * 1e-9 + scale_suffix = "B" + elif count > 1e6 : + # Millions + scaled_amount = count * 1e-6 + scale_suffix = "M" + elif count > 1e3 : + # Thousands + scaled_amount = count * 1e-3 + scale_suffix = "K" + else: + # Under Thousands + scaled_amount = count + scale_suffix = "" + return f"{'~' if count > 1e3 else ''}{round(scaled_amount)}{scale_suffix}" + + +def translate_tensor_name(name): + words = name.split(".") + + # Source: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#standardized-tensor-names + abbreviation_dictionary = { + 'token_embd': 'Token embedding', + 'pos_embd': 'Position embedding', + 'output_norm': 'Output normalization', + 'output': 'Output', + 'attn_norm': 'Attention normalization', + 'attn_norm_2': 'Attention normalization', + 'attn_qkv': 'Attention query-key-value', + 'attn_q': 'Attention query', + 'attn_k': 'Attention key', + 'attn_v': 'Attention value', + 'attn_output': 'Attention output', + 'ffn_norm': 'Feed-forward network normalization', + 'ffn_up': 'Feed-forward network "up"', + 'ffn_gate': 'Feed-forward network "gate"', + 'ffn_down': 'Feed-forward network "down"', + 'ffn_gate_inp': 'Expert-routing layer for the Feed-forward network in Mixture of Expert models', + 'ffn_gate_exp': 'Feed-forward network "gate" layer per expert in Mixture of Expert models', + 'ffn_down_exp': 'Feed-forward network "down" layer per expert in Mixture of Expert models', + 'ffn_up_exp': 'Feed-forward network "up" layer per expert in Mixture of Expert models', + 'ssm_in': 'State space model input projections', + 'ssm_conv1d': 'State space model rolling/shift', + 'ssm_x': 'State space model selective parametrization', + 'ssm_a': 'State space model state compression', + 'ssm_d': 'State space model skip connection', + 'ssm_dt': 'State space model time step', + 'ssm_out': 'State space model output projection', + 'blk': 'Block' + } + + expanded_words = [] + for word in words: + word_norm = word.strip().lower() + if word_norm in abbreviation_dictionary: + expanded_words.append(abbreviation_dictionary[word_norm].title()) + else: + expanded_words.append(word.title()) + + return ' '.join(expanded_words) + + +def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None: + host_endian, file_endian = get_file_host_endian(reader) + markdown_content = "" + markdown_content += f'# {args.model} - GGUF Internal File Dump\n\n' + markdown_content += f'- Endian: {file_endian} endian\n' + markdown_content += '\n' + markdown_content += '## Key Value Metadata Store\n\n' + markdown_content += f'There are {len(reader.fields)} key-value pairs in this file\n' + markdown_content += '\n' + + kv_dump_table: list[dict[str, str | int]] = [] + for n, field in enumerate(reader.fields.values(), 1): + if not field.types: + pretty_type = 'N/A' + elif field.types[0] == GGUFValueType.ARRAY: + nest_count = len(field.types) - 1 + pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count + else: + pretty_type = str(field.types[-1].name) + + total_elements = len(field.data) + value = "" + if len(field.types) == 1: + curr_type = field.types[0] + if curr_type == GGUFValueType.STRING: + value = repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60]) + elif curr_type in reader.gguf_scalar_to_np: + value = str(field.parts[-1][0]) + else: + if field.types[0] == GGUFValueType.ARRAY: + curr_type = field.types[1] + if curr_type == GGUFValueType.STRING: + render_element = min(5, total_elements) + for element_pos in range(render_element): + value += repr(str(bytes(field.parts[-1 - element_pos]), encoding='utf-8')[:5]) + (", " if total_elements > 1 else "") + elif curr_type in reader.gguf_scalar_to_np: + render_element = min(7, total_elements) + for element_pos in range(render_element): + value += str(field.parts[-1 - element_pos][0]) + (", " if total_elements > 1 else "") + value = f'[ {value}{" ..." if total_elements > 1 else ""} ]' + kv_dump_table.append({"n":n, "pretty_type":pretty_type, "total_elements":total_elements, "field_name":field.name, "value":value}) + + kv_dump_table_header_map = [ + {'key_name':'n', 'header_name':'POS', 'align':'right'}, + {'key_name':'pretty_type', 'header_name':'TYPE', 'align':'left'}, + {'key_name':'total_elements', 'header_name':'Count', 'align':'right'}, + {'key_name':'field_name', 'header_name':'Key', 'align':'left'}, + {'key_name':'value', 'header_name':'Value', 'align':'left'}, + ] + + markdown_content += markdown_table_with_alignment_support(kv_dump_table_header_map, kv_dump_table) + + markdown_content += "\n" + + if not args.no_tensors: + # Group tensors by their prefix and maintain order + tensor_prefix_order: list[str] = [] + tensor_name_to_key: dict[str, int] = {} + tensor_groups: dict[str, list[ReaderTensor]] = {} + total_elements = sum(tensor.n_elements for tensor in reader.tensors) + + # Parsing Tensors Record + for key, tensor in enumerate(reader.tensors): + tensor_components = tensor.name.split('.') + + # Classify Tensor Group + tensor_group_name = "base" + if tensor_components[0] == 'blk': + tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}" + + # Check if new Tensor Group + if tensor_group_name not in tensor_groups: + tensor_groups[tensor_group_name] = [] + tensor_prefix_order.append(tensor_group_name) + + # Record Tensor and Tensor Position + tensor_groups[tensor_group_name].append(tensor) + tensor_name_to_key[tensor.name] = key + + # Tensors Mapping Dump + markdown_content += f'## Tensors Overview {element_count_rounded_notation(total_elements)} Elements\n\n' + markdown_content += f'Total number of elements in all tensors: {total_elements} Elements\n' + markdown_content += '\n' + + for group in tensor_prefix_order: + tensors = tensor_groups[group] + group_elements = sum(tensor.n_elements for tensor in tensors) + markdown_content += f"- [{translate_tensor_name(group)} Tensor Group - {element_count_rounded_notation(group_elements)} Elements](#{group.replace('.', '_')})\n" + + markdown_content += "\n" + + for group in tensor_prefix_order: + tensors = tensor_groups[group] + group_elements = sum(tensor.n_elements for tensor in tensors) + group_percentage = group_elements / total_elements * 100 + markdown_content += f"### {translate_tensor_name(group)} Tensor Group : {element_count_rounded_notation(group_elements)} Elements\n\n" + + # Precalculate column sizing for visual consistency + prettify_element_est_count_size: int = 1 + prettify_element_count_size: int = 1 + prettify_dimension_max_widths: dict[int, int] = {} + for tensor in tensors: + prettify_element_est_count_size = max(prettify_element_est_count_size, len(str(element_count_rounded_notation(tensor.n_elements)))) + prettify_element_count_size = max(prettify_element_count_size, len(str(tensor.n_elements))) + for i, dimension_size in enumerate(list(tensor.shape) + [1] * (4 - len(tensor.shape))): + prettify_dimension_max_widths[i] = max(prettify_dimension_max_widths.get(i,1), len(str(dimension_size))) + + # Generate Tensor Layer Table Content + tensor_dump_table: list[dict[str, str | int]] = [] + for tensor in tensors: + human_friendly_name = translate_tensor_name(tensor.name.replace(".weight", ".(W)").replace(".bias", ".(B)")) + pretty_dimension = ' x '.join(f'{str(d):>{prettify_dimension_max_widths[i]}}' for i, d in enumerate(list(tensor.shape) + [1] * (4 - len(tensor.shape)))) + element_count_est = f"({element_count_rounded_notation(tensor.n_elements):>{prettify_element_est_count_size}})" + element_count_string = f"{element_count_est} {tensor.n_elements:>{prettify_element_count_size}}" + type_name_string = f"{tensor.tensor_type.name}" + tensor_dump_table.append({"t_id":tensor_name_to_key[tensor.name], "layer_name":tensor.name, "human_layer_name":human_friendly_name, "element_count":element_count_string, "pretty_dimension":pretty_dimension, "tensor_type":type_name_string}) + + tensor_dump_table_header_map = [ + {'key_name':'t_id', 'header_name':'T_ID', 'align':'right'}, + {'key_name':'layer_name', 'header_name':'Tensor Layer Name', 'align':'left'}, + {'key_name':'human_layer_name', 'header_name':'Human Friendly Tensor Layer Name', 'align':'left'}, + {'key_name':'element_count', 'header_name':'Elements', 'align':'left'}, + {'key_name':'pretty_dimension', 'header_name':'Shape', 'align':'left'}, + {'key_name':'tensor_type', 'header_name':'Type', 'align':'left'}, + ] + + markdown_content += markdown_table_with_alignment_support(tensor_dump_table_header_map, tensor_dump_table) + + markdown_content += "\n" + markdown_content += f"- Total elements in {group}: ({element_count_rounded_notation(group_elements):>4}) {group_elements}\n" + markdown_content += f"- Percentage of total elements: {group_percentage:.2f}%\n" + markdown_content += "\n\n" + + print(markdown_content) # noqa: NP100 + + def main() -> None: parser = argparse.ArgumentParser(description="Dump GGUF file metadata") parser.add_argument("model", type=str, help="GGUF format model filename") parser.add_argument("--no-tensors", action="store_true", help="Don't dump tensor metadata") parser.add_argument("--json", action="store_true", help="Produce JSON output") parser.add_argument("--json-array", action="store_true", help="Include full array values in JSON output (long)") + parser.add_argument("--markdown", action="store_true", help="Produce markdown output") parser.add_argument("--verbose", action="store_true", help="increase output verbosity") args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) - if not args.json: + if not args.json and not args.markdown: logger.info(f'* Loading: {args.model}') reader = GGUFReader(args.model, 'r') if args.json: dump_metadata_json(reader, args) + elif args.markdown: + dump_markdown_metadata(reader, args) else: dump_metadata(reader, args) diff --git a/llama.cpp b/llama.cpp index bd4f8ec1865fb..dd7020dc0eeab 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1278,6 +1278,126 @@ struct no_init { }; struct llama_file { + +#if defined(_WIN32) + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + HANDLE fp_win32; + size_t size; + +private: + std::string GetErrorMessageWin32(DWORD error_code) const { + std::string ret; + LPSTR lpMsgBuf = NULL; + DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL); + if (!bufLen) { + ret = format("Win32 error code: %s", error_code); + } else { + ret = lpMsgBuf; + LocalFree(lpMsgBuf); + } + + return ret; + } + +public: + + llama_file(const char * fname, const char * mode) { + fp = ggml_fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp)); + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { + // SetFilePointerEx returns the current position when seeking relative 0 bytes + LARGE_INTEGER li; + li.QuadPart = 0; + BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + + return li.QuadPart; + } + + void seek(size_t offset, int whence) const { + // no need to convert SEEK_* to FILE_*. The enums are the same. + // Still, keep static asserts to avoid failures in the future. + static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN"); + static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT"); + static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END"); + + LARGE_INTEGER li; + li.QuadPart = offset; + BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + } + + void read_raw(void * ptr, size_t len) const { + // On Win32 ReadFile is significant faster than fread which is again significant faster than std::fstream. Thus + // use the Win32 API to do file io instead of the C/C++ library functions. + + // There are conditions under which ReadFile cannot read chunks >64MB. + // Thus split the operation into smaller chunks if len exceeds this limit. + size_t bytes_read = 0; + while (bytes_read < len) { + size_t chunk_size = std::min(len - bytes_read, 64*1024*1024); + DWORD chunk_read = 0; + BOOL result = ReadFile(fp_win32, reinterpret_cast(ptr) + bytes_read, chunk_size, &chunk_read, NULL); + if (!result) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_read < chunk_size || chunk_read == 0) { + throw std::runtime_error("unexpectedly reached end of file"); + } + + bytes_read += chunk_read; + } ; + } + + uint32_t read_u32() const { + uint32_t val; + read_raw(&val, sizeof(val)); + return val; + } + + void write_raw(const void * ptr, size_t len) const { + // There are conditions under which WriteFile cannot write chunks >64MB. + // Thus split the operation into smaller chunks if len exceeds this limit. + size_t bytes_written = 0; + while (bytes_written < len) { + size_t chunk_size = std::min(len - bytes_written, 64*1024*1024); + DWORD chunk_written = 0; + BOOL result = WriteFile(fp_win32, reinterpret_cast(ptr) + bytes_written, chunk_size, &chunk_written, NULL); + if (!result) { + throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_written < chunk_size || chunk_written == 0) { + throw std::runtime_error("unexpectedly failed to write bytes"); + } + + bytes_written += chunk_written; + } + } + + void write_u32(std::uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~llama_file() { + if (fp) { + std::fclose(fp); + } + } +#else // use FILE * so we don't have to re-open the file to mmap FILE * fp; size_t size; @@ -1298,7 +1418,10 @@ struct llama_file { #else long ret = std::ftell(fp); #endif - GGML_ASSERT(ret != -1); // this really shouldn't fail + if (ret == -1) { + throw std::runtime_error(format("ftell error: %s", strerror(errno))); + } + return (size_t) ret; } @@ -1308,7 +1431,9 @@ struct llama_file { #else int ret = std::fseek(fp, (long) offset, whence); #endif - GGML_ASSERT(ret == 0); // same + if (ret != 0) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } } void read_raw(void * ptr, size_t len) const { @@ -1351,6 +1476,7 @@ struct llama_file { std::fclose(fp); } } +#endif }; using llama_files = std::vector>; @@ -3721,6 +3847,44 @@ struct llama_model_loader { std::vector> read_buf; std::vector>> validation_result; +#if defined(GGML_USE_CUDA) + // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. + // NVMe raid configurations might require more / larger buffers. + constexpr size_t num_buffers = 4; + constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB + + std::vector host_buffers; + std::vector host_ptrs; + std::vector events; + size_t buffer_idx = 0; // buffer to use for async loads + + ggml_backend_t cuda_backend = nullptr; + if (!use_mmap && !check_tensors) { + // When not using mmaped io use async uploads from pinned memory to GPU memory. + // First determine if the CUDA backend is active, and if so, determine the device ID. + ggml_backend_buffer_t buf = bufs_mmap.count(0) ? bufs_mmap.at(0) : nullptr; + if (buf) { + ggml_backend_buffer_type_t buffer_type = ggml_backend_buffer_get_type(buf); + for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) { + auto * cuda_buffer_type = ggml_backend_cuda_buffer_type(i); + if (buffer_type == cuda_buffer_type) { + cuda_backend = ggml_backend_cuda_init(i); + break; + } + } + } + + // If the cuda backend is active create pinned memory buffers and events for synchronisation. + if (cuda_backend) { + for (size_t idx = 0; idx < num_buffers; ++idx) { + host_buffers.emplace_back(ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size)); + host_ptrs.emplace_back(ggml_backend_buffer_get_base(host_buffers[idx])); + events.emplace_back(ggml_backend_event_new(cuda_backend)); + } + } + } +#endif + for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) { const auto * weight = get_weight(ggml_get_name(cur)); if (weight == nullptr) { @@ -3776,12 +3940,36 @@ struct llama_model_loader { })); } } else { - read_buf.resize(n_size); - file->seek(weight->offs, SEEK_SET); - file->read_raw(read_buf.data(), n_size); - ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); - if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { - throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); +#if defined(GGML_USE_CUDA) + // If cuda_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. + if (cuda_backend) { + file->seek(weight->offs, SEEK_SET); + + size_t bytes_read = 0; + + while (bytes_read < n_size) { + size_t read_iteration = std::min(buffer_size, n_size - bytes_read); + + ggml_backend_event_synchronize(events[buffer_idx]); + file->read_raw(host_ptrs[buffer_idx], read_iteration); + ggml_backend_tensor_set_async(cuda_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + ggml_backend_event_record(events[buffer_idx]); + + bytes_read += read_iteration; + ++buffer_idx; + buffer_idx %= num_buffers; + } + } + else +#endif + { + read_buf.resize(n_size); + file->seek(weight->offs, SEEK_SET); + file->read_raw(read_buf.data(), n_size); + ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); + if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } } } } @@ -3789,6 +3977,18 @@ struct llama_model_loader { size_done += n_size; } +#if defined(GGML_USE_CUDA) + // free temporary resources used for async cuda uploads + if (cuda_backend) { + for (size_t idx = 0; idx < num_buffers;++idx) { + ggml_backend_event_synchronize(events[idx]); + ggml_backend_event_free(events[idx]); + ggml_backend_buffer_free(host_buffers[idx]); + } + ggml_backend_free(cuda_backend); + } +#endif + // check validation results bool validation_failed = false; for (auto & future : validation_result) { @@ -5183,7 +5383,7 @@ static bool llm_load_tensors( // create tensors for the weights { const int64_t n_embd = hparams.n_embd; - const int64_t n_embd_head = n_embd / hparams.n_head; + const int64_t n_embd_head = (hparams.n_head == 0) ? 0 : n_embd / hparams.n_head; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = n_embd_v_gqa; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2b48e623e3476..7c504e937a851 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1063,6 +1063,33 @@ struct test_sqr : public test_case { } }; +// GGML_OP_SQRT +struct test_sqrt : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_sqrt(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 10, 10}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_sqrt(ctx, a); + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + // fill with positive values + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, 0.0f, 100.0f); + } + } +}; + // GGML_OP_CLAMP struct test_clamp : public test_case { const ggml_type type; @@ -2200,6 +2227,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } test_cases.emplace_back(new test_sqr()); + test_cases.emplace_back(new test_sqrt()); test_cases.emplace_back(new test_clamp()); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));