Skip to content

Commit

Permalink
add "skip_authentication" and "model_key" options
Browse files Browse the repository at this point in the history
  • Loading branch information
jekalmin committed Dec 24, 2023
1 parent e3aebbd commit 9e0b346
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 96 deletions.
20 changes: 13 additions & 7 deletions custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""The OpenAI Conversation integration."""
from __future__ import annotations

import re
import logging
from typing import Literal
import json
Expand Down Expand Up @@ -41,6 +40,8 @@
CONF_FUNCTIONS,
CONF_BASE_URL,
CONF_API_VERSION,
CONF_SKIP_AUTHENTICATION,
CONF_MODEL_KEY,
DEFAULT_ATTACH_USERNAME,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
Expand All @@ -49,6 +50,8 @@
DEFAULT_TOP_P,
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_CONF_FUNCTIONS,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_MODEL_KEY,
DOMAIN,
)

Expand All @@ -75,6 +78,7 @@
convert_to_template,
validate_authentication,
get_function_executor,
get_api_type,
)


Expand All @@ -97,6 +101,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
api_key=entry.data[CONF_API_KEY],
base_url=entry.data.get(CONF_BASE_URL),
api_version=entry.data.get(CONF_API_VERSION),
skip_authentication=entry.data.get(
CONF_SKIP_AUTHENTICATION, DEFAULT_SKIP_AUTHENTICATION
),
)
except error.AuthenticationError as err:
_LOGGER.error("Invalid API key: %s", err)
Expand Down Expand Up @@ -264,9 +271,7 @@ async def query(
"""Process a sentence."""
api_base = self.entry.data.get(CONF_BASE_URL)
api_key = self.entry.data[CONF_API_KEY]
api_type = None
if api_base and re.search(AZURE_DOMAIN_PATTERN, api_base):
api_type = "azure"
api_type = get_api_type(api_base)
api_version = self.entry.data.get(CONF_API_VERSION)
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
Expand All @@ -279,7 +284,9 @@ async def query(
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
):
function_call = "none"
response_format = {"type": "text"}
model_kwargs = {
self.entry.options.get(CONF_MODEL_KEY, DEFAULT_MODEL_KEY): model
}

_LOGGER.info("Prompt for %s: %s", model, messages)

Expand All @@ -288,15 +295,14 @@ async def query(
api_key=api_key,
api_type=api_type,
api_version=api_version,
model=model,
messages=messages,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
user=user_input.conversation_id,
functions=functions,
function_call=function_call,
response_format=response_format,
**model_kwargs,
)

_LOGGER.info("Response %s", response)
Expand Down
145 changes: 88 additions & 57 deletions custom_components/extended_openai_conversation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@
NumberSelectorConfig,
TemplateSelector,
AttributeSelector,
SelectSelector,
SelectSelectorConfig,
SelectOptionDict,
SelectSelectorMode,
)

from .helpers import validate_authentication
from .helpers import validate_authentication, get_api_type

from .const import (
CONF_ATTACH_USERNAME,
Expand All @@ -35,6 +39,9 @@
CONF_FUNCTIONS,
CONF_BASE_URL,
CONF_API_VERSION,
CONF_SKIP_AUTHENTICATION,
CONF_MODEL_KEY,
MODEL_KEYS,
DEFAULT_ATTACH_USERNAME,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
Expand All @@ -44,6 +51,8 @@
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_CONF_FUNCTIONS,
DEFAULT_CONF_BASE_URL,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_MODEL_KEY,
DOMAIN,
DEFAULT_NAME,
)
Expand All @@ -56,6 +65,9 @@
vol.Required(CONF_API_KEY): str,
vol.Optional(CONF_BASE_URL, default=DEFAULT_CONF_BASE_URL): str,
vol.Optional(CONF_API_VERSION): str,
vol.Optional(
CONF_SKIP_AUTHENTICATION, default=DEFAULT_SKIP_AUTHENTICATION
): bool,
}
)

Expand Down Expand Up @@ -83,14 +95,19 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
api_key = data[CONF_API_KEY]
base_url = data.get(CONF_BASE_URL)
api_version = data.get(CONF_API_VERSION)
skip_authentication = data.get(CONF_SKIP_AUTHENTICATION)

if base_url == DEFAULT_CONF_BASE_URL:
# Do not set base_url if using OpenAI for case of OpenAI's base_url change
base_url = None
data.pop(CONF_BASE_URL)

await validate_authentication(
hass=hass, api_key=api_key, base_url=base_url, api_version=api_version
hass=hass,
api_key=api_key,
base_url=base_url,
api_version=api_version,
skip_authentication=skip_authentication,
)


Expand Down Expand Up @@ -151,63 +168,77 @@ async def async_step_init(
return self.async_create_entry(
title=user_input.get(CONF_NAME, DEFAULT_NAME), data=user_input
)
schema = openai_config_option_schema(self.config_entry.options)
schema = self.openai_config_option_schema(self.config_entry.options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
)


def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
"""Return a schema for OpenAI completion options."""
if not options:
options = DEFAULT_OPTIONS
return {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options[CONF_PROMPT]},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional(
CONF_CHAT_MODEL,
description={
# New key in HA 2023.4
"suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
},
default=DEFAULT_CHAT_MODEL,
): str,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options[CONF_MAX_TOKENS]},
default=DEFAULT_MAX_TOKENS,
): int,
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options[CONF_TOP_P]},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options[CONF_TEMPERATURE]},
default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
description={
"suggested_value": options[CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION]
},
default=DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
): int,
vol.Optional(
CONF_FUNCTIONS,
description={"suggested_value": options.get(CONF_FUNCTIONS)},
default=DEFAULT_CONF_FUNCTIONS_STR,
): TemplateSelector(),
vol.Optional(
CONF_ATTACH_USERNAME,
description={
"suggested_value": options.get(CONF_ATTACH_USERNAME)
},
default=DEFAULT_ATTACH_USERNAME,
): BooleanSelector(),
}
def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> dict:
"""Return a schema for OpenAI completion options."""
if not options:
options = DEFAULT_OPTIONS

is_azure = get_api_type(self.config_entry.data.get(CONF_BASE_URL)) == "azure"

return {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options[CONF_PROMPT]},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional(
CONF_CHAT_MODEL,
description={
# New key in HA 2023.4
"suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
},
default=DEFAULT_CHAT_MODEL,
): str,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options[CONF_MAX_TOKENS]},
default=DEFAULT_MAX_TOKENS,
): int,
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options[CONF_TOP_P]},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options[CONF_TEMPERATURE]},
default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
description={
"suggested_value": options[CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION]
},
default=DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
): int,
vol.Optional(
CONF_FUNCTIONS,
description={"suggested_value": options.get(CONF_FUNCTIONS)},
default=DEFAULT_CONF_FUNCTIONS_STR,
): TemplateSelector(),
vol.Optional(
CONF_ATTACH_USERNAME,
description={
"suggested_value": options.get(CONF_ATTACH_USERNAME)
},
default=DEFAULT_ATTACH_USERNAME,
): BooleanSelector(),
vol.Optional(
CONF_MODEL_KEY,
description={"suggested_value": options.get(CONF_MODEL_KEY)},
default="engine" if is_azure else DEFAULT_MODEL_KEY,
): SelectSelector(
SelectSelectorConfig(
options=[
SelectOptionDict(value=key, label=key) for key in MODEL_KEYS
],
mode=SelectSelectorMode.DROPDOWN,
)
),
}
13 changes: 10 additions & 3 deletions custom_components/extended_openai_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

DOMAIN = "extended_openai_conversation"
DEFAULT_NAME = "Extended OpenAI Conversation"
CONF_BASE_URL = "base_url"
DEFAULT_CONF_BASE_URL = "https://api.openai.com/v1"
CONF_API_VERSION = "api_version"
CONF_SKIP_AUTHENTICATION = "skip_authentication"
DEFAULT_SKIP_AUTHENTICATION = False

EVENT_AUTOMATION_REGISTERED = "automation_registered_via_extended_openai_conversation"

CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """I want you to act as smart home manager of Home Assistant.
I will provide information of smart home along with a question, you will truthfully make correction or answer using information provided in one sentence in everyday language.
Expand Down Expand Up @@ -75,8 +82,8 @@
"function": {"type": "native", "name": "execute_service"},
}
]
CONF_BASE_URL = "base_url"
DEFAULT_CONF_BASE_URL = "https://api.openai.com/v1"
CONF_API_VERSION = "api_version"
CONF_ATTACH_USERNAME = "attach_username"
DEFAULT_ATTACH_USERNAME = False
CONF_MODEL_KEY = "model_key"
DEFAULT_MODEL_KEY = "model"
MODEL_KEYS = ["model", "engine"]
48 changes: 29 additions & 19 deletions custom_components/extended_openai_conversation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import yaml
import time
import sqlite3
import openai
import re
import voluptuous as vol
from functools import partial
from bs4 import BeautifulSoup
from typing import Any
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from openai.error import AuthenticationError
from urllib import parse

from datetime import timedelta
Expand Down Expand Up @@ -63,13 +64,22 @@
_LOGGER = logging.getLogger(__name__)


AZURE_DOMAIN_PATTERN = r"\.openai\.azure\.com"


def get_function_executor(value: str):
function_executor = FUNCTION_EXECUTORS.get(value)
if function_executor is None:
raise FunctionNotFound(value)
return function_executor


def get_api_type(base_url: str):
if base_url and re.search(AZURE_DOMAIN_PATTERN, base_url):
return "azure"
return None


def convert_to_template(
settings,
template_keys=["data", "event_data", "target", "service"],
Expand Down Expand Up @@ -122,25 +132,25 @@ def _get_rest_data(hass, rest_config, arguments):


async def validate_authentication(
hass: HomeAssistant, api_key: str, base_url: str, api_version: str
hass: HomeAssistant,
api_key: str,
base_url: str,
api_version: str,
skip_authentication=False,
) -> None:
if not base_url:
base_url = DEFAULT_CONF_BASE_URL

url = f"{base_url}/models"
if api_version:
url = f"{url}?api-version={api_version}"

session = async_get_clientsession(hass)
response = await session.get(
url,
headers={"Authorization": f"Bearer {api_key}"},
timeout=10,
if skip_authentication:
return

await hass.async_add_executor_job(
partial(
openai.Model.list,
api_type=get_api_type(base_url),
api_key=api_key,
api_version=api_version,
api_base=base_url,
request_timeout=10,
)
)
if response.status == 401:
raise AuthenticationError()

response.raise_for_status()


class FunctionExecutor(ABC):
Expand Down
Loading

0 comments on commit 9e0b346

Please sign in to comment.