-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 7a99786
Showing
4 changed files
with
325 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
cmake_minimum_required(VERSION 3.17) | ||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) | ||
project(local_llm C CXX) | ||
set(CMAKE_CXX_STANDARD 17) | ||
set(CMAKE_CXX_STANDARD_REQUIRED ON) | ||
set(CMAKE_CXX_EXTENSIONS OFF) | ||
|
||
list(INSERT CMAKE_MODULE_PATH 0 ${CMAKE_CURRENT_SOURCE_DIR}/cmake) | ||
|
||
find_package(Python COMPONENTS Interpreter Development) | ||
|
||
include(FetchContent) | ||
|
||
FetchContent_Declare(pybind11 URL "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.zip") | ||
FetchContent_Declare( | ||
llama_cpp | ||
GIT_REPOSITORY https://github.com/okdshin/llama.cpp.git | ||
GIT_TAG add_pfnet_plamo_13b | ||
) | ||
set(BUILD_SHARED_LIBS ON) | ||
FetchContent_MakeAvailable(pybind11 llama_cpp) | ||
|
||
|
||
pybind11_add_module(infer infer.cpp) | ||
target_link_libraries(infer PRIVATE llama ggml) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import argparse | ||
from typing import Any, Dict, Optional, List | ||
import threading | ||
|
||
from lsprotocol import types as lsp | ||
|
||
from pygls.server import LanguageServer | ||
|
||
import torch | ||
from transformers import (AutoTokenizer, PreTrainedModel, | ||
PreTrainedTokenizer, PretrainedConfig, StoppingCriteria) | ||
from transformers.modeling_outputs import CausalLMOutput | ||
|
||
import infer | ||
|
||
|
||
class LlamaCppConfig(PretrainedConfig): # type: ignore | ||
model_type: str = "llama_cpp" | ||
|
||
|
||
class LlamaCppCausalLM(PreTrainedModel): | ||
def __init__(self, model_name, vocab_size, config: LlamaCppConfig, n_threads: int): | ||
super().__init__(config) | ||
self.vocab_size = vocab_size | ||
|
||
# self.model = AutoModelForCausalLM.from_pretrained("gpt2") | ||
self.plamo_cpp_model = infer.load_model_from_file(model_name, n_threads) | ||
|
||
@property | ||
def device(self) -> torch.device: | ||
return torch.device("cpu") | ||
|
||
@property | ||
def dtype(self) -> torch.dtype: | ||
return torch.float32 | ||
|
||
def forward( # type: ignore | ||
self, | ||
input_ids: torch.LongTensor, | ||
**kwargs, | ||
) -> CausalLMOutput: | ||
logits = torch.from_numpy(self.plamo_cpp_model.calc_next_token_logits( | ||
input_ids.numpy(), self.vocab_size)) | ||
return CausalLMOutput( | ||
loss=None, | ||
logits=logits, | ||
hidden_states=None, | ||
attentions=None, | ||
) | ||
|
||
def prepare_inputs_for_generation( | ||
self, | ||
input_ids: torch.Tensor, | ||
**kwargs, | ||
) -> Dict[str, Any]: | ||
model_inputs = {"input_ids": input_ids} | ||
return model_inputs | ||
|
||
|
||
class StopWord(StoppingCriteria): | ||
def __init__(self, stop_word: str, tokenizer: PreTrainedTokenizer): | ||
super().__init__() | ||
self.tokenizer = tokenizer | ||
self.stop_word = stop_word | ||
self.stop_tokens_len = len(tokenizer(stop_word).input_ids) | ||
|
||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> bool: | ||
suffix_text = self.tokenizer.decode(input_ids[0][-self.stop_tokens_len:]) | ||
return suffix_text.endswith(self.stop_word) | ||
|
||
|
||
class StopCutoffCompletion(StoppingCriteria): | ||
def __init__(self, latest_completion_id: List[int], latest_completion_id_lock: threading.Lock, completion_id: int, lang_server: LanguageServer): | ||
super().__init__() | ||
self.latest_completion_id = latest_completion_id | ||
self.latest_completion_id_lock = latest_completion_id_lock | ||
self.completion_id = completion_id | ||
self.lang_server = lang_server | ||
|
||
def __call__(self, *args, **kwargs) -> bool: | ||
with self.latest_completion_id_lock: | ||
if self.latest_completion_id[0] != self.completion_id: | ||
self.lang_server.show_message(f"stop-cutoff-completion {self.completion_id}", lsp.MessageType.Info) | ||
return True | ||
else: | ||
return False | ||
|
||
|
||
class LanguageModelForCompletion: | ||
def __init__(self, lang_server: LanguageServer, model_name: str, max_new_tokens: int, n_threads: int): | ||
self.lang_server = lang_server | ||
|
||
assert model_name.endswith(".gguf") | ||
self.tokenizer = AutoTokenizer.from_pretrained( | ||
"Salesforce/codegen25-7b-multi", trust_remote_code=True) | ||
self.model = LlamaCppCausalLM(model_name=model_name, vocab_size=len(self.tokenizer), | ||
config=LlamaCppConfig(), n_threads=n_threads) | ||
self.max_new_tokens = max_new_tokens | ||
|
||
self.latest_completion_id_lock = threading.Lock() | ||
self.computing_resource_lock = threading.Lock() | ||
|
||
self.latest_completion_id = [0] | ||
|
||
self.stop_word = StopWord("\n", tokenizer=self.tokenizer) | ||
|
||
def generate_completion(self, text: str) -> str: | ||
with self.latest_completion_id_lock: | ||
self.latest_completion_id[0] += 1 | ||
stop_cutoff_completion = StopCutoffCompletion( | ||
latest_completion_id=self.latest_completion_id, | ||
latest_completion_id_lock=self.latest_completion_id_lock, | ||
completion_id=self.latest_completion_id[0], | ||
lang_server=self.lang_server, | ||
) | ||
with self.computing_resource_lock: | ||
if stop_cutoff_completion(): | ||
return "<canceled>" | ||
tokenized_prompt = self.tokenizer(text).input_ids | ||
generated_tokens = self.model.generate(inputs=torch.LongTensor( | ||
[tokenized_prompt]), max_new_tokens=self.max_new_tokens, do_sample=False, | ||
#stopping_criteria=[stop_cutoff_completion, self.stop_word])[0] | ||
stopping_criteria=[stop_cutoff_completion])[0] | ||
generated_text = self.tokenizer.decode(generated_tokens[len(tokenized_prompt):]) | ||
return generated_text | ||
|
||
|
||
lm_for_completion: Optional[LanguageModelForCompletion] = None | ||
|
||
|
||
server = LanguageServer("flatline-lsp", "v0.0") | ||
|
||
|
||
@server.thread() | ||
@server.feature( | ||
lsp.TEXT_DOCUMENT_COMPLETION, | ||
lsp.CompletionOptions(trigger_characters=[ | ||
".", ",", " ", "(", ")", "[", "]", "{", "}"]), | ||
) | ||
def completions( | ||
ls: LanguageServer, params: Optional[lsp.CompletionParams] = None) -> lsp.CompletionList: | ||
# assert lm_for_completion is not None | ||
document = server.workspace.get_document(params.text_document.uri) | ||
line_index = params.position.line | ||
character_index = params.position.character | ||
prompt = "".join(list(document.lines[max(0, line_index-15):line_index]) + | ||
[document.lines[line_index][:character_index]]) | ||
|
||
if lm_for_completion is None: | ||
completed_text = "<flatline_lsp_lm_for_completion is not initialized>" | ||
else: | ||
completed_text = lm_for_completion.generate_completion(text=prompt) | ||
|
||
return lsp.CompletionList( | ||
is_incomplete=True, | ||
items=[lsp.CompletionItem(label="(FL)"+completed_text, | ||
insert_text=completed_text, | ||
insert_text_mode=lsp.InsertTextMode.AdjustIndentation, | ||
documentation=completed_text)], | ||
) | ||
|
||
|
||
def main() -> None: | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model-name", type=str, | ||
default="/home/okada/flatline2/codegen25-7b-multi/ggml-model-Q4_K.gguf") | ||
parser.add_argument("--max-new-tokens", type=int, default=256) | ||
parser.add_argument("--n-threads", type=int, default=8) | ||
args = parser.parse_args() | ||
|
||
global lm_for_completion | ||
lm_for_completion = LanguageModelForCompletion( | ||
lang_server=server, | ||
model_name=args.model_name, max_new_tokens=args.max_new_tokens, n_threads=args.n_threads) | ||
|
||
# server.start_tcp("127.0.0.1", 8080) | ||
server.start_io() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
#include <iostream> | ||
#include <cassert> | ||
#include <pybind11/numpy.h> | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include <llama.h> | ||
|
||
namespace { | ||
|
||
namespace py = pybind11; | ||
|
||
class llama_cpp_model { | ||
public: | ||
static std::unique_ptr<llama_cpp_model> | ||
load_from_file(std::string const &model_file_path, size_t n_threads) { | ||
|
||
llama_model_params model_params = llama_model_default_params(); | ||
model_params.n_gpu_layers = 20; | ||
llama_model *model = | ||
llama_load_model_from_file(model_file_path.c_str(), model_params); | ||
|
||
llama_context_params ctx_params = llama_context_default_params(); | ||
|
||
ctx_params.seed = 1234; | ||
ctx_params.n_ctx = 2048; // TODO | ||
ctx_params.n_threads = n_threads; | ||
ctx_params.n_threads_batch = | ||
n_threads; // params.n_threads_batch == -1 ? params.n_threads : | ||
// params.n_threads_batch; | ||
llama_context *ctx = llama_new_context_with_model(model, ctx_params); | ||
|
||
return std::make_unique<llama_cpp_model>( | ||
llama_cpp_model(std::move(model), ctx)); | ||
} | ||
|
||
py::array_t<float> calc_next_token_logits(py::array_t<int> const &input_ids, | ||
size_t vocab_size) { | ||
assert(input_ids.shape(0) == 1); // batch_size must be 1 | ||
llama_batch batch = llama_batch_init(2048, 0); // TODO | ||
if (is_first(input_ids)) { | ||
//py::print("FIRST, input_ids = ", input_ids); | ||
llama_kv_cache_tokens_rm(ctx_, -1, -1); | ||
batch.n_tokens = input_ids.shape(1); | ||
for (size_t i = 0; i < batch.n_tokens; ++i) { | ||
batch.token[i] = *input_ids.data(0, i); | ||
batch.pos[i] = i; | ||
batch.seq_id[i] = 0; | ||
batch.logits[i] = false; | ||
} | ||
batch.logits[batch.n_tokens - 1] = true; | ||
} else { | ||
//py::print("input_ids = ", input_ids); | ||
batch.token[0] = *input_ids.data(0, input_ids.shape(1) - 1); | ||
batch.pos[0] = input_ids.shape(1) - 1; | ||
batch.seq_id[0] = 0; | ||
batch.logits[0] = true; | ||
batch.n_tokens = 1; | ||
} | ||
// if (auto result = llama_decode(ctx_, batch); result != 0) { | ||
if (auto result = llama_decode(ctx_, batch); result < 0) { | ||
throw std::runtime_error("llama_decode failed " + std::to_string(result)); | ||
} | ||
auto *logits_data = llama_get_logits_ith(ctx_, batch.n_tokens - 1); | ||
py::array_t<float> logits( | ||
std::vector<size_t>{static_cast<size_t>(input_ids.shape(0)), 1u, | ||
vocab_size}, | ||
logits_data); | ||
//py::print("logits = ", logits); | ||
return logits; | ||
} | ||
|
||
private: | ||
llama_cpp_model(llama_model *model, llama_context *ctx) | ||
: model_(model), ctx_(ctx) {} | ||
|
||
bool is_first(py::array_t<int> const &input_ids) { | ||
static py::array_t<int> input_ids_before_backup = py::array_t<int>(); | ||
py::array_t<int> input_ids_before = input_ids_before_backup; | ||
input_ids_before_backup = input_ids; | ||
if(input_ids_before.ndim() != input_ids.ndim()) { | ||
return true; | ||
} | ||
if(input_ids_before.shape(0) != input_ids.shape(0)) { | ||
return true; | ||
} | ||
if(input_ids_before.shape(1) > input_ids.shape(1)) { | ||
return true; | ||
} | ||
for(size_t i = 0; i < input_ids_before.shape(0); ++i) { | ||
for(size_t j = 0; j < input_ids_before.shape(1); ++j) { | ||
if(*input_ids_before.data(i, j) != *input_ids.data(i, j)) { | ||
return true; | ||
} | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
llama_model *model_; | ||
llama_context *ctx_; | ||
}; | ||
|
||
} // namespace | ||
|
||
PYBIND11_MODULE(infer, m) { | ||
m.doc() = "infer module"; | ||
|
||
m.def("load_model_from_file", &llama_cpp_model::load_from_file, "", | ||
py::arg("model_file_path"), py::arg("n_threads")); | ||
|
||
py::class_<llama_cpp_model, std::unique_ptr<llama_cpp_model>>( | ||
m, "llama_cpp_model") | ||
//.def(py::init<>()) // use load_model_from_file | ||
.def("calc_next_token_logits", &llama_cpp_model::calc_next_token_logits, | ||
py::arg("input_ids"), py::arg("vocab_size")); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
import sentencepiece as spm | ||
spm.SentencePieceTrainer.train(input="dummy_file", model_prefix='dummy_tokenizer/tokenizer', vocab_size=51200, byte_fallback=True) |