Skip to content

Commit

Permalink
chore: adjust code after merge
Browse files Browse the repository at this point in the history
Refs: #50
Signed-off-by: Markus Schuettler <[email protected]>
  • Loading branch information
mschuettlerTNG committed Feb 23, 2025
1 parent 4d3cc4a commit a012611
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 27 deletions.
5 changes: 1 addition & 4 deletions OpenVINO/openvino_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
import json
import traceback
from typing import Dict, List, Callable
#from model_downloader import NotEnoughDiskSpaceException, DownloadException
#from psutil._common import bytes2human
from openvino_interface import LLMInterface
from openvino_params import LLMParams


RAG_PROMPT_FORMAT = "Answer the questions based on the information below. \n{context}\n\nQuestion: {prompt}"

class LLM_SSE_Adapter:
Expand Down Expand Up @@ -110,7 +107,7 @@ def text_conversation_run(

prompt = params.prompt
full_prompt = convert_prompt(prompt)
self.llm_interface.create_chat_completion(full_prompt, self.stream_function, params.generation_parameters)
self.llm_interface.create_chat_completion(full_prompt, self.stream_function, params.max_tokens)

except Exception as ex:
traceback.print_exc()
Expand Down
9 changes: 2 additions & 7 deletions OpenVINO/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,9 @@ def load_model(self, params: LLMParams, callback: Callable[[str], None] = None):
callback("finish")



def create_chat_completion(self, messages: List[Dict[str, str]], streamer: Callable[[str], None], generation_parameters: Dict[str, Any]):
def create_chat_completion(self, messages: List[Dict[str, str]], streamer: Callable[[str], None], max_tokens: int = 1024):
config = openvino_genai.GenerationConfig()
if generation_parameters.get("max_new_tokens"):
config.max_new_tokens = generation_parameters["max_new_tokens"]
else:
# TODO: set default
config.max_new_tokens = 1024
config.max_new_tokens = max_tokens

full_prompt = self._tokenizer.apply_chat_template(messages, add_generation_prompt=True)
return self._model.generate(full_prompt, config, streamer)
Expand Down
4 changes: 3 additions & 1 deletion OpenVINO/openvino_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ class LLMParams:
device: int
enable_rag: bool
model_repo_id: str
max_tokens: int
generation_parameters: Dict[str, Any]

def __init__(
self, prompt: list, device: int, enable_rag: bool, model_repo_id: str, **kwargs
self, prompt: list, device: int, enable_rag: bool, model_repo_id: str, max_tokens: int, **kwargs
) -> None:
self.prompt = prompt
self.device = device
self.enable_rag = enable_rag
self.model_repo_id = model_repo_id
self.max_tokens = max_tokens
self.generation_parameters = kwargs
8 changes: 0 additions & 8 deletions WebUI/src/App.vue
Original file line number Diff line number Diff line change
Expand Up @@ -368,14 +368,6 @@ function switchTab(index: number) {
}
}
watch(textInference, (newSetting, _oldSetting) => {
if (newSetting.backend === 'llamaCPP') {
answer.value!.disableRag()
} else {
answer.value!.restoreRagState()
}
})
function miniWindow() {
window.electronAPI.miniWindow()
}
Expand Down
17 changes: 10 additions & 7 deletions WebUI/src/views/Answer.vue
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,7 @@
class="bg-gray-400 text-black font-sans rounded-md px-1 py-1"
:class="textInference.nameSizeClass"
>
{{
textInference.backend === 'IPEX-LLM'
? globalSetup.modelSettings.llm_model
: globalSetup.modelSettings.ggufLLM_model
}}
{{ textInference.activeModel }}
</span>
</div>
<div
Expand Down Expand Up @@ -553,10 +549,9 @@ async function updateTitle(conversation: ChatItem[]) {
device: globalSetup.modelSettings.graphics,
prompt: chatContext,
enable_rag: false,
max_tokens: textInference.maxTokens,
max_tokens: 8,
model_repo_id: textInference.activeModel,
print_metrics: false,
max_new_tokens: 7,
}
const response = await fetch(`${textInference.currentBackendUrl}/api/llm/chat`, {
method: 'POST',
Expand Down Expand Up @@ -869,6 +864,14 @@ async function disableRag() {
}
}

watch(() => textInference.backend, (newBackend, _oldBackend) => {
if (newBackend === 'ipexLLM') {
restoreRagState()
} else {
disableRag()
}
})

async function restoreRagState() {
ragData.processEnable = true
if (ragData.enable) {
Expand Down

0 comments on commit a012611

Please sign in to comment.