Skip to content

Commit

Permalink
Parameters as None by default (#299)
Browse files Browse the repository at this point in the history
  • Loading branch information
thiago-aixplain authored Nov 4, 2024
1 parent e7eff8b commit 6bbf8e9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 22 deletions.
8 changes: 4 additions & 4 deletions aixplain/modules/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def run(
data: Union[Text, Dict],
name: Text = "model_process",
timeout: float = 300,
parameters: Optional[Dict] = {},
parameters: Optional[Dict] = None,
wait_time: float = 0.5,
) -> Dict:
"""Runs a model call.
Expand All @@ -197,7 +197,7 @@ def run(
data (Union[Text, Dict]): link to the input data
name (Text, optional): ID given to a call. Defaults to "model_process".
timeout (float, optional): total polling time. Defaults to 300.
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
parameters (Dict, optional): optional parameters to the model. Defaults to None.
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.
Returns:
Expand All @@ -220,13 +220,13 @@ def run(
response = {"status": "FAILED", "error": msg, "elapsed_time": end - start}
return response

def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = {}) -> Dict:
def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = None) -> Dict:
"""Runs asynchronously a model call.
Args:
data (Union[Text, Dict]): link to the input data
name (Text, optional): ID given to a call. Defaults to "model_process".
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
parameters (Dict, optional): optional parameters to the model. Defaults to None.
Returns:
dict: polling URL in response
Expand Down
36 changes: 20 additions & 16 deletions aixplain/modules/model/llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def run(
top_p: float = 1.0,
name: Text = "model_process",
timeout: float = 300,
parameters: Optional[Dict] = {},
parameters: Optional[Dict] = None,
wait_time: float = 0.5,
) -> Dict:
"""Synchronously running a Large Language Model (LLM) model.
Expand All @@ -117,21 +117,23 @@ def run(
top_p (float, optional): Top P. Defaults to 1.0.
name (Text, optional): ID given to a call. Defaults to "model_process".
timeout (float, optional): total polling time. Defaults to 300.
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
parameters (Dict, optional): optional parameters to the model. Defaults to None.
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.
Returns:
Dict: parsed output from model
"""
start = time.time()
if parameters is None:
parameters = {}
parameters.update(
{
"context": parameters["context"] if "context" in parameters else context,
"prompt": parameters["prompt"] if "prompt" in parameters else prompt,
"history": parameters["history"] if "history" in parameters else history,
"temperature": parameters["temperature"] if "temperature" in parameters else temperature,
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"top_p": parameters["top_p"] if "top_p" in parameters else top_p,
"context": parameters.get("context", context),
"prompt": parameters.get("prompt", prompt),
"history": parameters.get("history", history),
"temperature": parameters.get("temperature", temperature),
"max_tokens": parameters.get("max_tokens", max_tokens),
"top_p": parameters.get("top_p", top_p),
}
)
payload = build_payload(data=data, parameters=parameters)
Expand Down Expand Up @@ -160,7 +162,7 @@ def run_async(
max_tokens: int = 128,
top_p: float = 1.0,
name: Text = "model_process",
parameters: Optional[Dict] = {},
parameters: Optional[Dict] = None,
) -> Dict:
"""Runs asynchronously a model call.
Expand All @@ -173,21 +175,23 @@ def run_async(
max_tokens (int, optional): Maximum Generation Tokens. Defaults to 128.
top_p (float, optional): Top P. Defaults to 1.0.
name (Text, optional): ID given to a call. Defaults to "model_process".
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
parameters (Dict, optional): optional parameters to the model. Defaults to None.
Returns:
dict: polling URL in response
"""
url = f"{self.url}/{self.id}"
logging.debug(f"Model Run Async: Start service for {name} - {url}")
if parameters is None:
parameters = {}
parameters.update(
{
"context": parameters["context"] if "context" in parameters else context,
"prompt": parameters["prompt"] if "prompt" in parameters else prompt,
"history": parameters["history"] if "history" in parameters else history,
"temperature": parameters["temperature"] if "temperature" in parameters else temperature,
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"top_p": parameters["top_p"] if "top_p" in parameters else top_p,
"context": parameters.get("context", context),
"prompt": parameters.get("prompt", prompt),
"history": parameters.get("history", history),
"temperature": parameters.get("temperature", temperature),
"max_tokens": parameters.get("max_tokens", max_tokens),
"top_p": parameters.get("top_p", top_p),
}
)
payload = build_payload(data=data, parameters=parameters)
Expand Down
7 changes: 5 additions & 2 deletions aixplain/modules/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import json
import logging
from aixplain.utils.file_utils import _request_with_retry
from typing import Dict, Text, Union
from typing import Dict, Text, Union, Optional


def build_payload(data: Union[Text, Dict], parameters: Dict = {}):
def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None):
from aixplain.factories import FileFactory

if parameters is None:
parameters = {}

data = FileFactory.to_link(data)
if isinstance(data, dict):
payload = data
Expand Down

0 comments on commit 6bbf8e9

Please sign in to comment.