From 6ed4ca4032a4a1dfc7e19882bd79c3f7d7b3ed37 Mon Sep 17 00:00:00 2001 From: Shubham Sureka Date: Fri, 29 Mar 2024 13:24:53 +0530 Subject: [PATCH] code-refactor --- spacy_llm/models/rest/mistral/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/spacy_llm/models/rest/mistral/model.py b/spacy_llm/models/rest/mistral/model.py index 3bef2bab..83918171 100644 --- a/spacy_llm/models/rest/mistral/model.py +++ b/spacy_llm/models/rest/mistral/model.py @@ -3,6 +3,9 @@ from typing import Iterable, Optional, Any, Dict from ..base import REST +from mistralai.client import MistralClient +from mistralai.models.chat_completion import ChatMessage + class AzureMistral(REST): def __init__( @@ -29,13 +32,10 @@ def __init__( def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: all_resps = [] + api_key = self._credentials.get("api-key") for prompts_doc in prompts: doc_resps = [] for prompt in prompts_doc: - from mistralai.client import MistralClient - from mistralai.models.chat_completion import ChatMessage - - api_key = self._credentials.get("api-key") client = MistralClient(endpoint=self._endpoint, api_key=api_key) chat_response = client.chat(