Skip to content

Commit

Permalink
add composite function
Browse files Browse the repository at this point in the history
  • Loading branch information
jekalmin committed Nov 1, 2023
1 parent 89db578 commit 41e0edc
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 16 deletions.
13 changes: 3 additions & 10 deletions custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@
)

from .helpers import (
FUNCTION_EXECUTORS,
FunctionExecutor,
NativeFunctionExecutor,
ScriptFunctionExecutor,
TemplateFunctionExecutor,
RestFunctionExecutor,
ScrapeFunctionExecutor,
CompositeFunctionExecutor,
convert_to_template,
)

Expand All @@ -72,15 +74,6 @@
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)


FUNCTION_EXECUTORS: dict[str, FunctionExecutor] = {
"predefined": NativeFunctionExecutor(),
"native": NativeFunctionExecutor(),
"script": ScriptFunctionExecutor(),
"template": TemplateFunctionExecutor(),
"rest": RestFunctionExecutor(),
"scrape": ScrapeFunctionExecutor(),
}

# hass.data key for agent.
DATA_AGENT = "agent"

Expand Down Expand Up @@ -326,7 +319,7 @@ async def execute_function(
arguments = json.loads(message["function_call"]["arguments"])

result = await function_executor.execute(
self.hass, function, arguments, user_input, exposed_entities
self.hass, function["function"], arguments, user_input, exposed_entities
)

messages.append(
Expand Down
51 changes: 45 additions & 6 deletions custom_components/extended_openai_conversation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def execute(
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
name = function["function"]["name"]
name = function["name"]
if name == "execute_service":
return await self.execute_service(
hass, function, arguments, user_input, exposed_entities
Expand Down Expand Up @@ -236,10 +236,10 @@ async def execute(
) -> str:
script = Script(
hass,
function["function"]["sequence"],
function["sequence"],
"extended_openai_conversation",
DOMAIN,
running_description=f"""[extended_openai_conversation] function {function.get("spec", {}).get("name")}""",
running_description="[extended_openai_conversation] function",
logger=_LOGGER,
)

Expand All @@ -261,7 +261,7 @@ async def execute(
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
return Template(function["function"]["value_template"], hass).async_render(
return Template(function["value_template"], hass).async_render(
arguments,
parse_result=False,
)
Expand All @@ -279,7 +279,7 @@ async def execute(
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
config = function["function"]
config = function
rest_data = _get_rest_data(hass, config, arguments)

await rest_data.async_update()
Expand All @@ -306,7 +306,7 @@ async def execute(
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
config = function["function"]
config = function
rest_data = _get_rest_data(hass, config, arguments)
coordinator = scrape.coordinator.ScrapeCoordinator(
hass,
Expand Down Expand Up @@ -376,3 +376,42 @@ def _extract_value(self, data: BeautifulSoup, sensor_config: dict[str, Any]) ->
value = None
_LOGGER.debug("Parsed value: %s", value)
return value


class CompositeFunctionExecutor(FunctionExecutor):
def __init__(self) -> None:
"""initialize composite function"""

async def execute(
self,
hass: HomeAssistant,
function,
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
config = function
sequence = config["sequence"]

for executor_config in sequence:
function_executor = FUNCTION_EXECUTORS[executor_config["type"]]
result = await function_executor.execute(
hass, executor_config, arguments, user_input, exposed_entities
)

response_variable = executor_config.get("response_variable")
if response_variable:
arguments[response_variable] = str(result)

return result


FUNCTION_EXECUTORS: dict[str, FunctionExecutor] = {
"predefined": NativeFunctionExecutor(),
"native": NativeFunctionExecutor(),
"script": ScriptFunctionExecutor(),
"template": TemplateFunctionExecutor(),
"rest": RestFunctionExecutor(),
"scrape": ScrapeFunctionExecutor(),
"composite": CompositeFunctionExecutor(),
}

0 comments on commit 41e0edc

Please sign in to comment.