Skip to content

Commit

Permalink
[#94] add energy, statistics function
Browse files Browse the repository at this point in the history
  • Loading branch information
jekalmin authored and jekalmin committed Feb 19, 2024
1 parent aadf0fe commit 752786c
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 6 deletions.
4 changes: 3 additions & 1 deletion custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
FunctionNotFound,
InvalidFunction,
ParseArgumentsFailed,
TokenLengthExceededError,
)
from .helpers import (
get_function_executor,
Expand Down Expand Up @@ -383,9 +384,10 @@ async def query(
return await self.execute_tool_calls(
user_input, messages, message, exposed_entities, n_requests + 1
)
if choice.finish_reason == "length":
raise TokenLengthExceededError(response.usage.completion_tokens)

return OpenAIQueryResponse(response=response, message=message)
# return message

async def execute_function_call(
self,
Expand Down
16 changes: 16 additions & 0 deletions custom_components/extended_openai_conversation/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,22 @@ def __str__(self) -> str:
return f"failed to parse arguments `{self.arguments}`. Increase maximum token to avoid the issue."


class TokenLengthExceededError(HomeAssistantError):
"""When openai return 'length' as 'finish_reason'."""

def __init__(self, token: int) -> None:
"""Initialize error."""
super().__init__(
self,
f"token length(`{token}`) exceeded. Increase maximum token to avoid the issue.",
)
self.token = token

def __str__(self) -> str:
"""Return string representation."""
return f"token length(`{self.token}`) exceeded. Increase maximum token to avoid the issue."


class InvalidFunction(HomeAssistantError):
"""When function validation failed."""

Expand Down
66 changes: 61 additions & 5 deletions custom_components/extended_openai_conversation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,24 @@
import voluptuous as vol
import yaml

from homeassistant.components import automation, conversation, recorder, rest, scrape
from homeassistant.components import (
automation,
conversation,
energy,
recorder,
rest,
scrape,
)
from homeassistant.components.automation.config import _async_validate_config_item
from homeassistant.components.script.config import SCRIPT_ENTITY_SCHEMA
from homeassistant.config import AUTOMATION_CONFIG_PATH
from homeassistant.const import (
CONF_ATTRIBUTE,
CONF_METHOD,
CONF_NAME,
CONF_PAYLOAD,
CONF_RESOURCE,
CONF_RESOURCE_TEMPLATE,
CONF_PAYLOAD,
CONF_TIMEOUT,
CONF_VALUE_TEMPLATE,
CONF_VERIFY_SSL,
Expand All @@ -36,7 +43,7 @@
from homeassistant.helpers.template import Template
import homeassistant.util.dt as dt_util

from .const import DOMAIN, EVENT_AUTOMATION_REGISTERED, CONF_PAYLOAD_TEMPLATE
from .const import CONF_PAYLOAD_TEMPLATE, DOMAIN, EVENT_AUTOMATION_REGISTERED
from .exceptions import (
CallServiceError,
EntityNotExposed,
Expand Down Expand Up @@ -208,6 +215,14 @@ async def execute(
return await self.get_history(
hass, function, arguments, user_input, exposed_entities
)
if name == "get_energy":
return await self.get_energy(
hass, function, arguments, user_input, exposed_entities
)
if name == "get_statistics":
return await self.get_statistics(
hass, function, arguments, user_input, exposed_entities
)

raise NativeNotFound(name)

Expand Down Expand Up @@ -346,6 +361,40 @@ async def get_history(

return [[self.as_dict(item) for item in sublist] for sublist in result.values()]

async def get_energy(
self,
hass: HomeAssistant,
function,
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
):
energy_manager: energy.data.EnergyManager = await energy.async_get_manager(hass)
return energy_manager.data

async def get_statistics(
self,
hass: HomeAssistant,
function,
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
):
statistic_ids = arguments.get("statistic_ids", [])
start_time = dt_util.as_utc(dt_util.parse_datetime(arguments["start_time"]))
end_time = dt_util.as_utc(dt_util.parse_datetime(arguments["end_time"]))

return await recorder.get_instance(hass).async_add_executor_job(
recorder.statistics.statistics_during_period,
hass,
start_time,
end_time,
statistic_ids,
arguments.get("period", "day"),
arguments.get("units"),
arguments.get("types", {"change"}),
)

def as_utc(self, value: str, default_value, parse_error_message: str):
if value is None:
return default_value
Expand Down Expand Up @@ -393,7 +442,14 @@ async def execute(
class TemplateFunctionExecutor(FunctionExecutor):
def __init__(self) -> None:
"""initialize template function"""
super().__init__(vol.Schema({vol.Required("value_template"): cv.template}))
super().__init__(
vol.Schema(
{
vol.Required("value_template"): cv.template,
vol.Optional("parse_result"): bool,
}
)
)

async def execute(
self,
Expand All @@ -405,7 +461,7 @@ async def execute(
):
return function["value_template"].async_render(
arguments,
parse_result=False,
parse_result=function.get("parse_result", False),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"config_flow": true,
"dependencies": [
"conversation",
"energy",
"history",
"recorder",
"rest",
Expand Down

0 comments on commit 752786c

Please sign in to comment.