Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.1.x (Use tools instead of functions) #31

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 36 additions & 35 deletions custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,13 @@ async def query(
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
functions = list(map(lambda s: s["spec"], self.get_functions()))
function_call = "auto"
functions = list(map(lambda s: {'type': 'function', 'function': s["spec"]}, self.get_functions()))
tool_choice = "auto"
if n_requests == self.entry.options.get(
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
):
function_call = "none"
tool_choice = "none"

_LOGGER.info("Prompt for %s: %s", model, messages)

Expand All @@ -270,63 +270,64 @@ async def query(
top_p=top_p,
temperature=temperature,
user=user_input.conversation_id,
functions=functions,
function_call=function_call,
tools=functions,
tool_choice=tool_choice
)

_LOGGER.info("Response %s", response)
message = response["choices"][0]["message"]
if message.get("function_call"):
message = await self.execute_function_call(
if message.get("tool_calls"):
message = await self.execute_tool_calls(
user_input, messages, message, exposed_entities, n_requests + 1
)
return message

def execute_function_call(
async def execute_tool_calls(
self,
user_input: conversation.ConversationInput,
messages,
message,
exposed_entities,
n_requests,
):
function_name = message["function_call"]["name"]
function = next(
(s for s in self.get_functions() if s["spec"]["name"] == function_name),
None,
)
if function is not None:
return self.execute_function(
user_input,
messages,
message,
exposed_entities,
n_requests,
function,
messages.append(message)
for tool in message['tool_calls']:
function_name = tool["function"]["name"]
function = next(
(s for s in self.get_functions() if s["spec"]["name"] == function_name),
None,
)
raise FunctionNotFound(message["function_call"]["name"])
if function is not None:
result = await self.execute_function(
user_input,
tool,
exposed_entities,
function,
)

messages.append(
{
"tool_call_id": tool['id'],
"role": "tool",
"name": function_name,
"content": str(result),
}
)
else:
raise FunctionNotFound(function_name)
return await self.query(user_input, messages, exposed_entities, n_requests)

async def execute_function(
self,
user_input: conversation.ConversationInput,
messages,
message,
tool,
exposed_entities,
n_requests,
function,
):
function_executor = FUNCTION_EXECUTORS[function["function"]["type"]]
arguments = json.loads(message["function_call"]["arguments"])
arguments = json.loads(tool["function"]["arguments"])

result = await function_executor.execute(
self.hass, function["function"], arguments, user_input, exposed_entities
)

messages.append(
{
"role": "function",
"name": message["function_call"]["name"],
"content": str(result),
}
)
return await self.query(user_input, messages, exposed_entities, n_requests)
return result
54 changes: 23 additions & 31 deletions custom_components/extended_openai_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
```

The current state of devices is provided in available devices.
Use execute_services function only for requested action, not for current states.
Use the execute_service function only for requested action, not for current states.
Do not execute service without user's confirmation.
Do not restate or appreciate what user says, rather make a quick inquiry.
"""
Expand All @@ -36,43 +36,35 @@
DEFAULT_CONF_FUNCTIONS = [
{
"spec": {
"name": "execute_services",
"description": "Use this function to execute service of devices in Home Assistant.",
"name": "execute_service",
"description": "Use this function to execute a service of devices in Home Assistant.",
"parameters": {
"type": "object",
"properties": {
"list": {
"type": "array",
"items": {
"type": "object",
"properties": {
"domain": {
"type": "string",
"description": "The domain of the service",
},
"service": {
"type": "string",
"description": "The service to be called",
},
"service_data": {
"type": "object",
"description": "The service data object to indicate what to control.",
"properties": {
"entity_id": {
"type": "string",
"description": "The entity_id retrieved from available devices. It must start with domain, followed by dot character.",
}
},
"required": ["entity_id"],
},
},
"required": ["domain", "service", "service_data"],
"domain": {
"type": "string",
"description": "The domain of the service",
},
"service": {
"type": "string",
"description": "The service to be called",
},
"service_data": {
"type": "object",
"description": "The service data object to indicate what to control.",
"properties": {
"entity_id": {
"type": "string",
"description": "The entity_id retrieved from available devices. It must start with domain, followed by dot character.",
}
},
}
"required": ["entity_id"],
},
},
"required": ["domain", "service", "service_data"],
},
},
"function": {"type": "native", "name": "execute_service"},
"function": {"type": "native", "name": "execute_service_single"},
}
]
CONF_BASE_URL = "base_url"
Expand Down
85 changes: 50 additions & 35 deletions custom_components/extended_openai_conversation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import yaml
import time
import voluptuous
import json
from bs4 import BeautifulSoup
from typing import Any
from homeassistant.helpers.aiohttp_client import async_get_clientsession
Expand Down Expand Up @@ -146,16 +148,60 @@ async def execute(
) -> str:
name = function["name"]
if name == "execute_service":
return await self.execute_service(
return json.dumps(await self.execute_service(
hass, function, arguments, user_input, exposed_entities
)
))
if name == "execute_service_single":
return json.dumps(await self.execute_service_single(
hass, function, arguments, user_input, exposed_entities
))
if name == "add_automation":
return await self.add_automation(
hass, function, arguments, user_input, exposed_entities
)

raise NativeNotFound(name)

async def execute_service_single(
self,
hass: HomeAssistant,
function,
service_argument,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
domain = service_argument["domain"]
service = service_argument["service"]
service_data = service_argument.get(
"service_data", service_argument.get("data", {})
)
entity_id = service_data.get("entity_id", service_argument.get("entity_id"))

if isinstance(entity_id, str):
entity_id = [e.strip() for e in entity_id.split(",")]
service_data["entity_id"] = entity_id

if entity_id is None:
raise CallServiceError(domain, service, service_data)
if not hass.services.has_service(domain, service):
raise ServiceNotFound(domain, service)
if any(hass.states.get(entity) is None for entity in entity_id):
raise EntityNotFound(entity_id)
exposed_entity_ids = map(lambda e: e["entity_id"], exposed_entities)
if not set(entity_id).issubset(exposed_entity_ids):
raise EntityNotExposed(entity_id)

try:
await hass.services.async_call(
domain=domain,
service=service,
service_data=service_data,
)
return {'success': True}
except (HomeAssistantError, voluptuous.Error) as e:
_LOGGER.error(e)
return {'error': str(e)}

async def execute_service(
self,
hass: HomeAssistant,
Expand All @@ -166,39 +212,8 @@ async def execute_service(
) -> str:
result = []
for service_argument in arguments.get("list", []):
domain = service_argument["domain"]
service = service_argument["service"]
service_data = service_argument.get(
"service_data", service_argument.get("data", {})
)
entity_id = service_data.get("entity_id", service_argument.get("entity_id"))

if isinstance(entity_id, str):
entity_id = [e.strip() for e in entity_id.split(",")]
service_data["entity_id"] = entity_id

if entity_id is None:
raise CallServiceError(domain, service, service_data)
if not hass.services.has_service(domain, service):
raise ServiceNotFound(domain, service)
if any(hass.states.get(entity) is None for entity in entity_id):
raise EntityNotFound(entity_id)
exposed_entity_ids = map(lambda e: e["entity_id"], exposed_entities)
if not set(entity_id).issubset(exposed_entity_ids):
raise EntityNotExposed(entity_id)

try:
await hass.services.async_call(
domain=domain,
service=service,
service_data=service_data,
)
result.append(True)
except HomeAssistantError:
_LOGGER.error(e)
result.append(False)

return str(result)
result.append(await self.execute_service_single(hass, function, service_argument, user_input, exposed_entities))
return result

async def add_automation(
self,
Expand Down