diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..24c50c1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.pt +*.egg-info +__pycache__ +wip \ No newline at end of file diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..d41edc2 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,9 @@ +MIT LICENSE + +Copyright 2024 Kyle Finn + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..ea2e468 --- /dev/null +++ b/README.md @@ -0,0 +1,16 @@ +# torch2cpp + +Status: WIP + +Some features are supported, but most are not yet. + +Models are traced with torch.fx to extract the flattened AST of low level function calls. + +The AST graph is then traversed, generating C++ code in order. + +Temporary buffer shapes and ref-counts are tracked to enable compile-time scheduled memory re-use (~10x buffer reduction in common cases). + +Weights are stored in the binary as bfloat16, and unpacked to float32 at runtime. (Would like to investigate more options here for both storage and inference) + +The bundled tensor math lib uses compile-time shapes and in-place storage, so there is no dynamic memory allocation at all. + diff --git a/example/Makefile b/example/Makefile new file mode 100644 index 0000000..19ef5f5 --- /dev/null +++ b/example/Makefile @@ -0,0 +1,16 @@ +INCLUDES := $(shell python -m torch2cpp.includes) + +build/model.js : build/model.cpp + em++ -Os build/model.cpp -I$(INCLUDES) \ + -o build/model.js -s MODULARIZE=1 -s EXPORT_NAME=load_model \ + -s EXPORTED_FUNCTIONS=_model_step,_model_reset,_model_encode,_model_decode + +build/chat_cli : chat_cli.cpp build/model.cpp + c++ -std=c++17 -Os -march=native -ffast-math \ + build/model.cpp chat_cli.cpp -I$(INCLUDES) -o build/chat_cli + +.PHONY: model.js +model.js: build/model.js + +.PHONY: chat_cli +chat_cli: build/chat_cli \ No newline at end of file diff --git a/example/build/.gitignore b/example/build/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/example/build/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/example/chat_cli.cpp b/example/chat_cli.cpp new file mode 100644 index 0000000..a86eebc --- /dev/null +++ b/example/chat_cli.cpp @@ -0,0 +1,42 @@ +#include +#include + +extern "C" { +void model_reset(); +int model_step(int prevtok, float temperature); +int model_encode(char const* str, int str_len, int * out, int out_len); +int model_decode(int const* toks, int toks_len, char * out, int out_len); +} // extern C + +int main(int argc, char ** argv) +{ + std::string prompt; + constexpr int max_tokens = 256; + int toks[max_tokens]; + char decode[max_tokens]; + float temperature = 1; + while(true) + { + std::getline(std::cin, prompt); + prompt += "\n"; + + int n_tok = model_encode(prompt.c_str(), prompt.size(), toks, max_tokens); + for(int i=0 ; i +#include +#include +#include + +namespace ml { + +struct bfloat16 { uint16_t data; }; + + +// only a truly insane person would write a tensor math library +// using c++ template metaprogramming + + +template +struct tensor; + + +template +struct tensor +{ + using SubT = tensor; + + template + using StackT = tensor; + + SubT data[dim0]; + + tensor() {} + tensor(bfloat16 const* src) + { + // little endian only!!! + uint16_t * bf = reinterpret_cast(ptr()); + for(int i=0 ; i(&data); } + float const* ptr() const { return reinterpret_cast(&data); } + + tensor & zero_() + { + for(int i=0 ; i +struct tensor<> +{ + float f = 0; + + tensor() {} + static constexpr int numel() { return 1; } + float item() const { return f; } + operator float() const { return f; } + tensor<> & operator=(float x) { f=x; return *this; } + tensor<> & operator+=(float x) { f+=x; return *this; } + + tensor & zero_() + { + f = 0; + return *this; + } +}; + + + + +template +struct slice {}; + +template +auto getitem(tensor & x) +{ + return x; +} + +template +auto getitem(tensor & x, slice s, Slices... sss) +{ + static_assert(sizeof...(Sx) == 0, "not implemented yet"); + + using OutT = typename decltype(getitem(x[0], sss...))::template StackT; + OutT out; + for(int i=0 ; i +auto getitem(tensor & x, int i, Slices... sss) +{ + if(i < 0) i += dim0; + return getitem(x[i], sss...); +} + + + + +template +tensor embedding(tensor<> const& x, tensor const& m) +{ + return m[int(x.item())]; +} + +template +auto embedding(tensor const& x, tensor const& m) +{ + using OutT = typename decltype(embedding(x[0], m))::template StackT; + OutT out; + for(int i=0 ; i +tensor add(tensor const& x, tensor const& y) +{ + tensor out; + for(int i=0 ; i +tensor mul(tensor const& x, tensor const& y) +{ + tensor out; + for(int i=0 ; i +tensor sigmoid(tensor const& x) +{ + tensor out; + for(int i=0 ; i +tensor softsign(tensor const& x) +{ + tensor out; + for(int i=0 ; i +tensor rms_norm(tensor const& x, tensor const& m, float eps) +{ + float norm = 0; + for(int i=0 ; i out; + for(int i=0 ; i +tensor rms_norm(tensor const& x, tensor const& m, float eps) +{ + tensor out; + for(int i=0 ; i +tensor linear(tensor const& x, tensor const& m, nullptr_t const&) +{ + // static constexpr int T = 16; + // static_assert(W % T == 0, ""); + tensor out; + // TODO optimize + for(int i = 0 ; i < H ; i++) + { + out[i] = 0; + for(int j = 0 ; j < W ; j++) + { + out[i] += x[j].f * m[i][j].f; + } + } + return out; +} +template +tensor linear(tensor const& x, tensor const& m, tensor const& b) +{ + tensor out; + // TODO optimize + for(int i = 0 ; i < H ; i++) + { + out[i] = b[i]; + for(int j = 0 ; j < W ; j++) + { + out[i] += x[j] * m[i][j]; + } + } + return out; +} + +template +auto linear(tensor const& x, tensor const& m, Bias const& b) +{ + using OutT = typename decltype(linear(x[0], m, b))::template StackT; + OutT out; + for(int i=0 ; i +tensor sqrll_kernel( + tensor const& x, + tensor const& r, + tensor const& p) +{ + tensor out; + + for(int j=0 ; j +tensor sqrll_kernel( + tensor const& x, + tensor const& r, + tensor const& p) +{ + tensor out; + for(int i=0 ; i> 7; + x ^= x << 17; + return (x % int(1e6)) / 1e6; + } +}; + +int sample_(float * x, int n, ml::rng64 & rng, float temperature=1) +{ + if(temperature < 0.01) + { + int out = 0; + for(int i=0 ; i x[out]) out=i; + return out; + } + + float sum_exp = 0; + for(int i=0 ; i -40) ? ml::fast_exp(x[i] / temperature) : 0; + sum_exp += x[i]; + } + float thresh = rng() * sum_exp; + float cumprob = 0; + for(int i=0 ; i thresh) { return i; } + } + return n-1; +} + +} // namespace ml diff --git a/src/torch2cpp/include/torch2cpp/tokenizer.h b/src/torch2cpp/include/torch2cpp/tokenizer.h new file mode 100644 index 0000000..cc621cd --- /dev/null +++ b/src/torch2cpp/include/torch2cpp/tokenizer.h @@ -0,0 +1,116 @@ +#include + +#include + +template +struct Tokenizer +{ + struct Token + { + uint8_t length; + uint8_t str[]; + }; + struct Node + { + // next[0] is exit token id + int token_id = -1; + int next[256]; + + Node() { for(int & i : next) { i = -1; } } + }; + enum Error + { + ERR_OVERFLOW = -1, + ERR_BUG = -999 + }; + + + Token const* tokens[NTOK]; + Node tree[NTREE]; + + Tokenizer(uint8_t const* token_pack) + { + for(Token const* & tok : tokens) + { + tok = reinterpret_cast(token_pack); + token_pack += tok->length + 1; + } + + // note that null chars are dropped from tokenizations + int ntree = 1; // tree[0] is root node + for(int tok_id=0 ; tok_idlength ; i++) + { + uint8_t c = tokens[tok_id]->str[i]; + if(tree[node].next[c] < 0) + tree[node].next[c] = ntree++; + node = tree[node].next[c]; + } + tree[node].token_id = tok_id; + } + } + + // return number of output tokens filled + // return negative if failed + int encode(char const* str, int str_len, int * out, int out_len) const + { + int out_fill = 0; + int node = 0; + int token = -1; + int token_end = -1; + for(int i=0 ; i= 0) + { + token = match; + token_end = i; + } + + uint8_t c = str[i]; + int next = tree[node].next[c]; + if(next >= 0) + { + node = next; + i ++; + } + else if(token >= 0) + { + if(out_fill >= out_len) { return ERR_OVERFLOW; } + out[out_fill++] = token; + i = token_end; + node = 0; + token = -1; + token_end = -1; + } + else + { + return ERR_BUG; + } + } + int match = tree[node].token_id; + if(match >= 0) + { + if(out_fill >= out_len) { return ERR_OVERFLOW; } + out[out_fill++] = match; + } + return out_fill; + } + + // return number of output chars filled + // return negative if failed + int decode(int const* toks, int toks_len, char * out, int out_len) const + { + int out_fill = 0; + for(int i=0 ; i out_len) { return ERR_OVERFLOW; } + std::copy(tok.str + 0, tok.str + tok.length, out + out_fill); + out_fill += tok.length; + } + return out_fill; + } +}; diff --git a/src/torch2cpp/includes.py b/src/torch2cpp/includes.py new file mode 100644 index 0000000..794a340 --- /dev/null +++ b/src/torch2cpp/includes.py @@ -0,0 +1,2 @@ +from pathlib import Path +print(Path(__file__).parent / 'include') \ No newline at end of file diff --git a/src/torch2cpp/torch2cpp.py b/src/torch2cpp/torch2cpp.py new file mode 100644 index 0000000..44cb524 --- /dev/null +++ b/src/torch2cpp/torch2cpp.py @@ -0,0 +1,305 @@ +import torch +import operator +import struct + + +def shape_dtype(shape, ref=False, const=False): + shape = ','.join(str(d) for d in shape) + dtype = f'ml::tensor<{shape}>' + if ref: + dtype += '&' + if const: + dtype = 'const '+dtype + return dtype + +def val_dtype(val, ref=False, const=False): + if isinstance(val, (list, tuple)): + subtypes = [val_dtype(a, ref, const) for a in val] + subtypes = ','.join(subtypes) + return f'std::tuple<{subtypes}>' + return shape_dtype(tuple(val.shape), ref, const) + +def flatten(val): + if isinstance(val, (list, tuple)): + return [a for vx in val for a in flatten(vx)] + return [val] + + +class CustomTracer(torch.fx.Tracer): + def is_leaf_module(self, mod, name): + # Tracing into Modules like torch.nn.Linear + # makes translation easier + return False + + + +class Interpreter(torch.fx.Interpreter): + inputs = {} + output_type = None + output_ref = None + weights = {} + + tmp_vars = {} + node_vars = {} + + fwds = [] + + def get_tmp(self, node, shape): + refcount = len(node.users) + for name, info in self.tmp_vars.items(): + tshape, tref = info + if shape == tshape and tref == 0: + self.node_vars[node] = name + info[1] = refcount + return name + name = f'tmp{len(self.tmp_vars)}' + self.tmp_vars[name] = [shape, refcount] + self.node_vars[node] = name + return name + + def deref(self, node): + if node is None: + return 'nullptr' + if isinstance(node, slice): + if node == slice(None, None, None): + return 'slice<>()' + else: + raise ValueError('unsupported '+str(node)) + if not isinstance(node, torch.fx.Node): + return str(node) + if node not in self.node_vars: + return node.name + name = self.node_vars[node] + if name in self.tmp_vars: + self.tmp_vars[name][1] -= 1 + assert self.tmp_vars[name][1] >= 0 + return name + + def nested_refstr(self, arg): + if isinstance(arg, (list, tuple)): + nest = [self.nested_refstr(a) for a in arg] + return '{'+', '.join(nest)+'}' + return self.node_vars[arg] + + def alias(self, node, src): + src = self.node_vars[src] + self.node_vars[node] = src + self.tmp_vars[src][1] += len(node.users) - 1 + + + def run_node(self, n): + with self._set_current_node(n): + + args, kwargs = self.fetch_args_kwargs_from_env(n) + val = getattr(self, n.op)(n.target, args, kwargs) + + if n.op == 'placeholder': + self.inputs[n.name] = val_dtype(val) + self.node_vars[n] = n.name + elif n.op == 'get_attr': + self.weights[n.name] = (val_dtype(val, const=True), val) + self.node_vars[n] = n.name + elif n.op == 'call_function' or n.op == 'call_method': + + fname = n.target + if 'fun' in n.op: + fname = fname.__name__ + + if n.target == operator.getitem: + if isinstance(args[0], (list, tuple)): + src = self.node_vars[n.args[0]] + self.node_vars[n] = f'get<{args[1]}>({src})' + return val + + no_ops = [ + 'detach', + 'clone', + torch.nn.functional.dropout, + ] + if n.target in no_ops: + self.alias(n, n.args[0]) + return val + + out_var = self.get_tmp(n, tuple(val.shape)) + flat_arg_nodes = flatten(n.args) + flat_arg_vars = [self.deref(n) for n in flat_arg_nodes] + fargs = ', '.join(flat_arg_vars) + + self.fwds += [f'{out_var} = {fname}({fargs})'] + elif n.op == 'call_module': + raise ValueError('call_module unsupported') + elif n.op == 'output': + self.output_type = val_dtype(val, ref=True) + self.output_ref = self.nested_refstr(n.args[0]) + else: + print(n.name, ':', n.op, n.target, n.args, n.kwargs) + print('->', getattr(val, 'shape', f'{len(val)=}'), len(n.users)) + + return val + + + +def codegen( + model, + out_file, + args=[], + kwargs={}, + tokenizer=None, + autowrap_functions=[], + c_prefix='model', + skip_weights=False, + ): + + tracer = CustomTracer(autowrap_functions=autowrap_functions) + + graph = tracer.trace(model) + + interp = Interpreter(model, graph=graph) + + out = interp.run(*args, **kwargs) + + + if tokenizer is not None: + n_vocab = tokenizer.get_vocab_size() + vocab = tokenizer.decode_batch([[i] for i in range(n_vocab)]) + vocab = [bytes(t, 'utf8') for t in vocab] + token_pack = [struct.pack('B',len(t))+t for t in vocab] + token_pack = [hex(c) for tok in token_pack for c in tok] + token_pack = ','.join(token_pack) + root_tree = {} + n_trees = 1 + + def add_tree(tree, txt, i): + nonlocal n_trees + if len(txt)==0: + tree[''] = i + return + if txt[0] not in tree: + tree[txt[0]] = {} + n_trees += 1 + add_tree(tree[txt[0]], txt[1:], i) + + for i, txt in enumerate(vocab): + add_tree(root_tree, txt, i) + + class_name = 'Model' + blob = [] + ivars = [] + fwds = [] + + for name, info in interp.weights.items(): + dtype, val = info + offset = len(blob) + blob += val.bfloat16().view(torch.uint16).flatten().tolist() + ivars += [f'static {dtype} {name} {{ blob+{offset} }}'] + + fwds += interp.fwds + + class Writer: + def __init__(self, f): + self.f = f + self.indent = 0 + + def __call__(self, s, nl='\n'): + self.f.write(' ' * self.indent) + self.f.write(s + nl) + return self + + def __enter__(self): + self.__call__('{') + self.indent += 4 + return self + + def __exit__(self, *_): + self.indent -= 4 + self.__call__('}', '') + + w = Writer(out_file) + + w('#include "torch2cpp/tensor.h"') + if tokenizer is not None: + w('#include "torch2cpp/tokenizer.h"') + w('\n') + + w('namespace {\n') + + if not skip_weights: + w('// weight initializers"') + with w(f'static const ml::bfloat16 blob[{len(blob)}] = '): + w(','.join([hex(x) for x in blob])) + w(';') + + if tokenizer is not None: + w(f'uint8_t const g_token_pack[] = {{ {token_pack} }};') + + w('// weight tensors') + for i in ivars: + w(i + ';') + w('') + + with w(f'struct {class_name}'): + + w('// inputs') + for name, dtype in interp.inputs.items(): + w(f'{dtype} {name};') + w('') + + w('// tmp vars') + for name, info in interp.tmp_vars.items(): + dtype = shape_dtype(info[0]) + w(f'{dtype} {name};') + w('') + + w(f'{interp.output_type}') + with w('operator()()'): + + w('using std::get;') + w('using namespace ml;') + + for f in fwds: + w(f + ';') + + w(f'return {interp.output_ref};') + + w('') + + w(';\n\n') + + w(f'ml::rng64 g_rng;') + w(f'{class_name} g_model;') + + if tokenizer is not None: + w(f'Tokenizer<{n_vocab}, {n_trees}> g_tokenizer = {{ g_token_pack }};') + w('\n') + w('} // namespace\n') + + w(f''' +extern "C" {{ +void {c_prefix}_reset() +{{ + std::apply([] (auto &&... x) {{ (x.zero_(), ...); }}, g_model.mem); +}} +int {c_prefix}_step(int prevtok, float temperature) +{{ + g_model.x.ptr()[0] = prevtok; + auto outs = g_model(); + g_model.mem = std::get<1>(outs); + + auto & logits = std::get<0>(outs); + return ml::sample_(logits.ptr(), logits.numel(), g_rng, temperature); +}} +''') + + if tokenizer is not None: + w(f''' +int {c_prefix}_encode(char const* str, int str_len, int * out, int out_len) +{{ + return g_tokenizer.encode(str, str_len, out, out_len); +}} +int {c_prefix}_decode(int const* toks, int toks_len, char * out, int out_len) +{{ + return g_tokenizer.decode(toks, toks_len, out, out_len); +}} +''') + + w('} // extern C\n') \ No newline at end of file