Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
okdshin committed Dec 9, 2023
1 parent 266c750 commit 8746635
Showing 1 changed file with 41 additions and 46 deletions.
87 changes: 41 additions & 46 deletions flatline_lsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedModel,
PreTrainedTokenizer,
PretrainedConfig,
Expand All @@ -30,10 +31,10 @@ class LlamaCppCausalLM(PreTrainedModel):
def __init__(
self,
config: LlamaCppConfig,
model_name: str,
backend_server_bin: str,
backend_server_host: str,
backend_server_port: int,
model_name: str,
n_threads: int,
n_gpu_layers: int,
):
Expand Down Expand Up @@ -172,30 +173,14 @@ def __init__(
self,
lang_server: LanguageServer,
max_new_tokens: int,
backend_server_bin: str,
backend_server_host: str,
backend_server_port: int,
tokenizer_name: str,
model_name: str,
n_threads: int,
n_gpu_layers: int,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
):
self.lang_server = lang_server
self.max_new_tokens = max_new_tokens

self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, trust_remote_code=True
)
assert model_name.endswith(".gguf")
self.model = LlamaCppCausalLM(
config=LlamaCppConfig(),
backend_server_bin=backend_server_bin,
backend_server_host=backend_server_host,
backend_server_port=backend_server_port,
model_name=model_name,
n_threads=n_threads,
n_gpu_layers=n_gpu_layers,
)
self.tokenizer = tokenizer
self.model = model

self.latest_completion_id_lock = threading.Lock()
self.computing_resource_lock = threading.Lock()
Expand Down Expand Up @@ -255,7 +240,7 @@ def completions(
ls: LanguageServer, params: Optional[lsp.CompletionParams] = None
) -> lsp.CompletionList:
# assert lm_for_completion is not None
document = server.workspace.get_document(params.text_document.uri)
document = ls.workspace.get_text_document(params.text_document.uri)
line_index = params.position.line
character_index = params.position.character
prompt = "".join(
Expand Down Expand Up @@ -297,6 +282,21 @@ def main() -> None:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--max-new-tokens", type=int, default=256)
parser.add_argument(
"--tokenizer-name",
type=str,
help="tokenizer name or path",
default=resource_path("./flatline/model_data/codegen25-7b-multi"),
)
parser.add_argument(
"--model-name",
type=str,
help="model name or path",
default=resource_path(
"./flatline/model_data/codegen25-7b-multi/ggml-model-Q4_K.gguf"
),
)
parser.add_argument(
"--backend-server-bin",
type=str,
Expand All @@ -315,36 +315,31 @@ def main() -> None:
help="llm inference backend server port number",
default=57045,
)
parser.add_argument(
"--tokenizer-name",
type=str,
help="tokenizer name or path",
default=resource_path("./flatline/model_data/codegen25-7b-multi"),
)
parser.add_argument(
"--model-name",
type=str,
help="model name or path",
default=resource_path(
"./flatline/model_data/codegen25-7b-multi/ggml-model-Q4_K.gguf"
),
)
parser.add_argument("--max-new-tokens", type=int, default=256)
parser.add_argument("--n-threads", type=int, default=-1)
parser.add_argument("--n-gpu-layers", type=int, default=35)
parser.add_argument("--backend-server-n-threads", type=int, default=-1)
parser.add_argument("--backend-server-n-gpu-layers", type=int, default=35)
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, trust_remote_code=True)

if args.model_name.endswith(".gguf"):
model = LlamaCppCausalLM(
config=LlamaCppConfig(),
model_name=args.model_name,
backend_server_bin=args.backend_server_bin,
backend_server_host=args.backend_server_host,
backend_server_port=args.backend_server_port,
n_threads=args.backend_server_n_threads,
n_gpu_layers=args.backend_server_n_gpu_layers,
)
else:
model = AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=True)

global lm_for_completion
lm_for_completion = LanguageModelForCompletion(
lang_server=server,
max_new_tokens=args.max_new_tokens,
backend_server_bin=args.backend_server_bin,
backend_server_host=args.backend_server_host,
backend_server_port=args.backend_server_port,
tokenizer_name=args.tokenizer_name,
model_name=args.model_name,
n_threads=args.n_threads,
n_gpu_layers=args.n_gpu_layers,
tokenizer=tokenizer,
model=model,
)

server.start_io()
Expand Down

0 comments on commit 8746635

Please sign in to comment.