Skip to content

Commit

Permalink
Merge pull request #265 from jekalmin/v1.0.4
Browse files Browse the repository at this point in the history
1.0.4
  • Loading branch information
jekalmin authored Nov 12, 2024
2 parents 1b20b56 + 94d293f commit a7ac05d
Show file tree
Hide file tree
Showing 10 changed files with 332 additions and 27 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Derived from [OpenAI Conversation](https://www.home-assistant.io/integrations/op
## How it works
Extended OpenAI Conversation uses OpenAI API's feature of [function calling](https://platform.openai.com/docs/guides/function-calling) to call service of Home Assistant.

Since "gpt-3.5-turbo" model already knows how to call service of Home Assistant in general, you just have to let model know what devices you have by [exposing entities](https://github.com/jekalmin/extended_openai_conversation#preparation)
Since OpenAI models already know how to call service of Home Assistant in general, you just have to let model know what devices you have by [exposing entities](https://github.com/jekalmin/extended_openai_conversation#preparation)

## Installation
1. Install via registering as a custom repository of HACS or by copying `extended_openai_conversation` folder into `<config directory>/custom_components`
Expand All @@ -22,7 +22,7 @@ Since "gpt-3.5-turbo" model already knows how to call service of Home Assistant
4. In the bottom right corner, select the Add Integration button.
5. Follow the instructions on screen to complete the setup (API Key is required).
- [Generating an API Key](https://www.home-assistant.io/integrations/openai_conversation/#generate-an-api-key)
- Specify "Base Url" if using OpenAI compatible servers like LocalAI, otherwise leave as it is.
- Specify "Base Url" if using OpenAI compatible servers like Azure OpenAI (also with APIM), LocalAI, otherwise leave as it is.
6. Go to Settings > [Voice Assistants](https://my.home-assistant.io/redirect/voice_assistants/).
7. Click to edit Assistant (named "Home Assistant" by default).
8. Select "Extended OpenAI Conversation" from "Conversation agent" tab.
Expand Down Expand Up @@ -245,12 +245,14 @@ In order to pass result of calling service to OpenAI, set response variable to `
function:
type: script
sequence:
- service: calendar.list_events
- service: calendar.get_events
data:
start_date_time: "{{start_date_time}}"
end_date_time: "{{end_date_time}}"
target:
entity_id: calendar.test
entity_id:
- calendar.[YourCalendarHere]
- calendar.[MoreCalendarsArePossible]
response_variable: _function_result
```

Expand Down Expand Up @@ -513,7 +515,7 @@ When using [ytube_music_player](https://github.com/KoljaWindeler/ytube_music_pla
#### 7-1. Let model generate a query
- Without examples, a query tries to fetch data only from "states" table like below
> Question: When did bedroom light turn on? <br/>
Query(generated by gpt-3.5): SELECT * FROM states WHERE entity_id = 'input_boolean.livingroom_light_2' AND state = 'on' ORDER BY last_changed DESC LIMIT 1
Query(generated by gpt): SELECT * FROM states WHERE entity_id = 'input_boolean.livingroom_light_2' AND state = 'on' ORDER BY last_changed DESC LIMIT 1
- Since "entity_id" is stored in "states_meta" table, we need to give examples of question and query.
- Not secured, but flexible way

Expand Down
13 changes: 8 additions & 5 deletions custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
intent,
template,
)
from homeassistant.helpers.httpx_client import get_async_client
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import ulid

Expand Down Expand Up @@ -145,12 +146,14 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
azure_endpoint=base_url,
api_version=entry.data.get(CONF_API_VERSION),
organization=entry.data.get(CONF_ORGANIZATION),
http_client=get_async_client(hass),
)
else:
self.client = AsyncOpenAI(
api_key=entry.data[CONF_API_KEY],
base_url=base_url,
organization=entry.data.get(CONF_ORGANIZATION),
http_client=get_async_client(hass),
)

@property
Expand Down Expand Up @@ -186,9 +189,9 @@ async def async_process(
messages = [system_message]
user_message = {"role": "user", "content": user_input.text}
if self.entry.options.get(CONF_ATTACH_USERNAME, DEFAULT_ATTACH_USERNAME):
user = await self.hass.auth.async_get_user(user_input.context.user_id)
if user is not None and user.name is not None:
user_message[ATTR_NAME] = user.name
user = user_input.context.user_id
if user is not None:
user_message[ATTR_NAME] = user

messages.append(user_message)

Expand Down Expand Up @@ -356,7 +359,7 @@ async def query(
if len(functions) == 0:
tool_kwargs = {}

_LOGGER.info("Prompt for %s: %s", model, messages)
_LOGGER.info("Prompt for %s: %s", model, json.dumps(messages))

response: ChatCompletion = await self.client.chat.completions.create(
model=model,
Expand All @@ -368,7 +371,7 @@ async def query(
**tool_kwargs,
)

_LOGGER.info("Response %s", response.model_dump(exclude_none=True))
_LOGGER.info("Response %s", json.dumps(response.model_dump(exclude_none=True)))

if response.usage.total_tokens > context_threshold:
await self.truncate_message_history(messages, exposed_entities, user_input)
Expand Down
2 changes: 1 addition & 1 deletion custom_components/extended_openai_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
Do not restate or appreciate what user says, rather make a quick inquiry.
"""
CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "gpt-3.5-turbo-1106"
DEFAULT_CHAT_MODEL = "gpt-4o-mini"
CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 150
CONF_TOP_P = "top_p"
Expand Down
27 changes: 24 additions & 3 deletions custom_components/extended_openai_conversation/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from datetime import timedelta
from functools import partial
import logging
import os
import re
Expand Down Expand Up @@ -39,6 +40,7 @@
from homeassistant.core import HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError, ServiceNotFound
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.httpx_client import get_async_client
from homeassistant.helpers.script import Script
from homeassistant.helpers.template import Template
import homeassistant.util.dt as dt_util
Expand All @@ -56,7 +58,7 @@
_LOGGER = logging.getLogger(__name__)


AZURE_DOMAIN_PATTERN = r"\.openai\.azure\.com"
AZURE_DOMAIN_PATTERN = r"\.(openai\.azure\.com|azure-api\.net)"


def get_function_executor(value: str):
Expand Down Expand Up @@ -141,13 +143,17 @@ async def validate_authentication(
azure_endpoint=base_url,
api_version=api_version,
organization=organization,
http_client=get_async_client(hass),
)
else:
client = AsyncOpenAI(
api_key=api_key, base_url=base_url, organization=organization
api_key=api_key,
base_url=base_url,
organization=organization,
http_client=get_async_client(hass),
)

await client.models.list(timeout=10)
await hass.async_add_executor_job(partial(client.models.list, timeout=10))


class FunctionExecutor(ABC):
Expand Down Expand Up @@ -223,6 +229,10 @@ async def execute(
return await self.get_statistics(
hass, function, arguments, user_input, exposed_entities
)
if name == "get_user_from_user_id":
return await self.get_user_from_user_id(
hass, function, arguments, user_input, exposed_entities
)

raise NativeNotFound(name)

Expand Down Expand Up @@ -372,6 +382,17 @@ async def get_energy(
energy_manager: energy.data.EnergyManager = await energy.async_get_manager(hass)
return energy_manager.data

async def get_user_from_user_id(
self,
hass: HomeAssistant,
function,
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
):
user = await hass.auth.async_get_user(user_input.context.user_id)
return {'name': user.name if user and hasattr(user, 'name') else 'Unknown'}

async def get_statistics(
self,
hass: HomeAssistant,
Expand Down
44 changes: 40 additions & 4 deletions custom_components/extended_openai_conversation/services.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import base64
import logging
import mimetypes
from pathlib import Path
from urllib.parse import urlparse

import voluptuous as vol
from openai import AsyncOpenAI
from openai._exceptions import OpenAIError
from openai.types.chat.chat_completion_content_part_image_param import (
ChatCompletionContentPartImageParam,
)
import voluptuous as vol

from homeassistant.core import (
HomeAssistant,
Expand All @@ -11,8 +18,8 @@
SupportsResponse,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, selector
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers import selector, config_validation as cv

from .const import DOMAIN, SERVICE_QUERY_IMAGE

Expand All @@ -25,7 +32,7 @@
),
vol.Required("model", default="gpt-4-vision-preview"): cv.string,
vol.Required("prompt"): cv.string,
vol.Required("images"): vol.All(cv.ensure_list, [{"url": cv.url}]),
vol.Required("images"): vol.All(cv.ensure_list, [{"url": cv.string}]),
vol.Optional("max_tokens", default=300): cv.positive_int,
}
)
Expand All @@ -41,7 +48,7 @@ async def query_image(call: ServiceCall) -> ServiceResponse:
try:
model = call.data["model"]
images = [
{"type": "image_url", "image_url": image}
{"type": "image_url", "image_url": to_image_param(hass, image)}
for image in call.data["images"]
]

Expand Down Expand Up @@ -74,3 +81,32 @@ async def query_image(call: ServiceCall) -> ServiceResponse:
schema=QUERY_IMAGE_SCHEMA,
supports_response=SupportsResponse.ONLY,
)


def to_image_param(hass: HomeAssistant, image) -> ChatCompletionContentPartImageParam:
"""Convert url to base64 encoded image if local."""
url = image["url"]

if urlparse(url).scheme in cv.EXTERNAL_URL_PROTOCOL_SCHEMA_LIST:
return image

if not hass.config.is_allowed_path(url):
raise HomeAssistantError(
f"Cannot read `{url}`, no access to path; "
"`allowlist_external_dirs` may need to be adjusted in "
"`configuration.yaml`"
)
if not Path(url).exists():
raise HomeAssistantError(f"`{url}` does not exist")
mime_type, _ = mimetypes.guess_type(url)
if mime_type is None or not mime_type.startswith("image"):
raise HomeAssistantError(f"`{url}` is not an image")

image["url"] = f"data:{mime_type};base64,{encode_image(url)}"
return image


def encode_image(image_path):
"""Convert to base64 encoded image."""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
{
"config": {
"error": {
"cannot_connect": "Não é possível conectar",
"invalid_auth": "Autenticação inválida",
"unknown": "Erro desconhecido"
},
"step": {
"user": {
"data": {
"name": "Nome",
"api_key": "Chave API",
"base_url": "Base Url",
"api_version": "Versão da API",
"organization": "Organização",
"skip_authentication": "Pular autenticação"
}
}
}
},
"options": {
"step": {
"init": {
"data": {
"max_tokens": "Número máximo de tokens da resposta",
"model": "Modelo da Conclusão",
"prompt": "Template do Prompt",
"temperature": "Temperatura",
"top_p": "Top P",
"max_function_calls_per_conversation": "Quantidade máxima de chamadas por conversação",
"functions": "Funções",
"attach_username": "Anexar nome do usuário na mensagem",
"use_tools": "Use ferramentas",
"context_threshold": "Limite do contexto",
"context_truncate_strategy": "Estratégia de truncamento de contexto quando o limite é excedido"
}
}
}
},
"services": {
"query_image": {
"name": "Consultar imagem",
"description": "Receba imagens e responda perguntas sobre elas",
"fields": {
"config_entry": {
"name": "Registro de configuração",
"description": "O registro de configuração para utilizar neste serviço"
},
"model": {
"name": "Modelo",
"description": "Especificar modelo",
"example": "gpt-4-vision-preview"
},
"prompt": {
"name": "Prompt",
"description": "O texto para fazer a pergunta sobre a imagem",
"example": "O que tem nesta imagem?"
},
"images": {
"name": "Imagens",
"description": "Uma lista de imagens que serão analisadas",
"example": "{\"url\": \"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg\"}"
},
"max_tokens": {
"name": "Max Tokens",
"description": "Quantidade máxima de tokens",
"example": "300"
}
}
}
}
}
Loading

0 comments on commit a7ac05d

Please sign in to comment.