diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 765960d4..126df503 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -188,7 +188,7 @@ def run( data: Union[Text, Dict], name: Text = "model_process", timeout: float = 300, - parameters: Optional[Dict] = {}, + parameters: Optional[Dict] = None, wait_time: float = 0.5, ) -> Dict: """Runs a model call. @@ -197,7 +197,7 @@ def run( data (Union[Text, Dict]): link to the input data name (Text, optional): ID given to a call. Defaults to "model_process". timeout (float, optional): total polling time. Defaults to 300. - parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + parameters (Dict, optional): optional parameters to the model. Defaults to None. wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. Returns: @@ -220,13 +220,13 @@ def run( response = {"status": "FAILED", "error": msg, "elapsed_time": end - start} return response - def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = {}) -> Dict: + def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = None) -> Dict: """Runs asynchronously a model call. Args: data (Union[Text, Dict]): link to the input data name (Text, optional): ID given to a call. Defaults to "model_process". - parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + parameters (Dict, optional): optional parameters to the model. Defaults to None. Returns: dict: polling URL in response diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index f48a3068..f0b4cef6 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -102,7 +102,7 @@ def run( top_p: float = 1.0, name: Text = "model_process", timeout: float = 300, - parameters: Optional[Dict] = {}, + parameters: Optional[Dict] = None, wait_time: float = 0.5, ) -> Dict: """Synchronously running a Large Language Model (LLM) model. @@ -117,21 +117,23 @@ def run( top_p (float, optional): Top P. Defaults to 1.0. name (Text, optional): ID given to a call. Defaults to "model_process". timeout (float, optional): total polling time. Defaults to 300. - parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + parameters (Dict, optional): optional parameters to the model. Defaults to None. wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. Returns: Dict: parsed output from model """ start = time.time() + if parameters is None: + parameters = {} parameters.update( { - "context": parameters["context"] if "context" in parameters else context, - "prompt": parameters["prompt"] if "prompt" in parameters else prompt, - "history": parameters["history"] if "history" in parameters else history, - "temperature": parameters["temperature"] if "temperature" in parameters else temperature, - "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, - "top_p": parameters["top_p"] if "top_p" in parameters else top_p, + "context": parameters.get("context", context), + "prompt": parameters.get("prompt", prompt), + "history": parameters.get("history", history), + "temperature": parameters.get("temperature", temperature), + "max_tokens": parameters.get("max_tokens", max_tokens), + "top_p": parameters.get("top_p", top_p), } ) payload = build_payload(data=data, parameters=parameters) @@ -160,7 +162,7 @@ def run_async( max_tokens: int = 128, top_p: float = 1.0, name: Text = "model_process", - parameters: Optional[Dict] = {}, + parameters: Optional[Dict] = None, ) -> Dict: """Runs asynchronously a model call. @@ -173,21 +175,23 @@ def run_async( max_tokens (int, optional): Maximum Generation Tokens. Defaults to 128. top_p (float, optional): Top P. Defaults to 1.0. name (Text, optional): ID given to a call. Defaults to "model_process". - parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + parameters (Dict, optional): optional parameters to the model. Defaults to None. Returns: dict: polling URL in response """ url = f"{self.url}/{self.id}" logging.debug(f"Model Run Async: Start service for {name} - {url}") + if parameters is None: + parameters = {} parameters.update( { - "context": parameters["context"] if "context" in parameters else context, - "prompt": parameters["prompt"] if "prompt" in parameters else prompt, - "history": parameters["history"] if "history" in parameters else history, - "temperature": parameters["temperature"] if "temperature" in parameters else temperature, - "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, - "top_p": parameters["top_p"] if "top_p" in parameters else top_p, + "context": parameters.get("context", context), + "prompt": parameters.get("prompt", prompt), + "history": parameters.get("history", history), + "temperature": parameters.get("temperature", temperature), + "max_tokens": parameters.get("max_tokens", max_tokens), + "top_p": parameters.get("top_p", top_p), } ) payload = build_payload(data=data, parameters=parameters) diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index 2235b35a..13cc1f7c 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -3,12 +3,15 @@ import json import logging from aixplain.utils.file_utils import _request_with_retry -from typing import Dict, Text, Union +from typing import Dict, Text, Union, Optional -def build_payload(data: Union[Text, Dict], parameters: Dict = {}): +def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None): from aixplain.factories import FileFactory + if parameters is None: + parameters = {} + data = FileFactory.to_link(data) if isinstance(data, dict): payload = data