From 9e0b3468e160c18e05cea9ca320836dddb9b52ae Mon Sep 17 00:00:00 2001 From: jekalmin Date: Sun, 24 Dec 2023 17:08:25 +0900 Subject: [PATCH] add "skip_authentication" and "model_key" options --- .../extended_openai_conversation/__init__.py | 20 ++- .../config_flow.py | 145 +++++++++++------- .../extended_openai_conversation/const.py | 13 +- .../extended_openai_conversation/helpers.py | 48 +++--- .../extended_openai_conversation/strings.json | 6 +- .../translations/de.json | 7 +- .../translations/en.json | 6 +- .../translations/ko.json | 6 +- .../translations/nl.json | 7 +- 9 files changed, 162 insertions(+), 96 deletions(-) diff --git a/custom_components/extended_openai_conversation/__init__.py b/custom_components/extended_openai_conversation/__init__.py index d39d42b..53d89bc 100644 --- a/custom_components/extended_openai_conversation/__init__.py +++ b/custom_components/extended_openai_conversation/__init__.py @@ -1,7 +1,6 @@ """The OpenAI Conversation integration.""" from __future__ import annotations -import re import logging from typing import Literal import json @@ -41,6 +40,8 @@ CONF_FUNCTIONS, CONF_BASE_URL, CONF_API_VERSION, + CONF_SKIP_AUTHENTICATION, + CONF_MODEL_KEY, DEFAULT_ATTACH_USERNAME, DEFAULT_CHAT_MODEL, DEFAULT_MAX_TOKENS, @@ -49,6 +50,8 @@ DEFAULT_TOP_P, DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, DEFAULT_CONF_FUNCTIONS, + DEFAULT_SKIP_AUTHENTICATION, + DEFAULT_MODEL_KEY, DOMAIN, ) @@ -75,6 +78,7 @@ convert_to_template, validate_authentication, get_function_executor, + get_api_type, ) @@ -97,6 +101,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: api_key=entry.data[CONF_API_KEY], base_url=entry.data.get(CONF_BASE_URL), api_version=entry.data.get(CONF_API_VERSION), + skip_authentication=entry.data.get( + CONF_SKIP_AUTHENTICATION, DEFAULT_SKIP_AUTHENTICATION + ), ) except error.AuthenticationError as err: _LOGGER.error("Invalid API key: %s", err) @@ -264,9 +271,7 @@ async def query( """Process a sentence.""" api_base = self.entry.data.get(CONF_BASE_URL) api_key = self.entry.data[CONF_API_KEY] - api_type = None - if api_base and re.search(AZURE_DOMAIN_PATTERN, api_base): - api_type = "azure" + api_type = get_api_type(api_base) api_version = self.entry.data.get(CONF_API_VERSION) model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) @@ -279,7 +284,9 @@ async def query( DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, ): function_call = "none" - response_format = {"type": "text"} + model_kwargs = { + self.entry.options.get(CONF_MODEL_KEY, DEFAULT_MODEL_KEY): model + } _LOGGER.info("Prompt for %s: %s", model, messages) @@ -288,7 +295,6 @@ async def query( api_key=api_key, api_type=api_type, api_version=api_version, - model=model, messages=messages, max_tokens=max_tokens, top_p=top_p, @@ -296,7 +302,7 @@ async def query( user=user_input.conversation_id, functions=functions, function_call=function_call, - response_format=response_format, + **model_kwargs, ) _LOGGER.info("Response %s", response) diff --git a/custom_components/extended_openai_conversation/config_flow.py b/custom_components/extended_openai_conversation/config_flow.py index a632e3e..3e2184d 100644 --- a/custom_components/extended_openai_conversation/config_flow.py +++ b/custom_components/extended_openai_conversation/config_flow.py @@ -20,9 +20,13 @@ NumberSelectorConfig, TemplateSelector, AttributeSelector, + SelectSelector, + SelectSelectorConfig, + SelectOptionDict, + SelectSelectorMode, ) -from .helpers import validate_authentication +from .helpers import validate_authentication, get_api_type from .const import ( CONF_ATTACH_USERNAME, @@ -35,6 +39,9 @@ CONF_FUNCTIONS, CONF_BASE_URL, CONF_API_VERSION, + CONF_SKIP_AUTHENTICATION, + CONF_MODEL_KEY, + MODEL_KEYS, DEFAULT_ATTACH_USERNAME, DEFAULT_CHAT_MODEL, DEFAULT_MAX_TOKENS, @@ -44,6 +51,8 @@ DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, DEFAULT_CONF_FUNCTIONS, DEFAULT_CONF_BASE_URL, + DEFAULT_SKIP_AUTHENTICATION, + DEFAULT_MODEL_KEY, DOMAIN, DEFAULT_NAME, ) @@ -56,6 +65,9 @@ vol.Required(CONF_API_KEY): str, vol.Optional(CONF_BASE_URL, default=DEFAULT_CONF_BASE_URL): str, vol.Optional(CONF_API_VERSION): str, + vol.Optional( + CONF_SKIP_AUTHENTICATION, default=DEFAULT_SKIP_AUTHENTICATION + ): bool, } ) @@ -83,6 +95,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: api_key = data[CONF_API_KEY] base_url = data.get(CONF_BASE_URL) api_version = data.get(CONF_API_VERSION) + skip_authentication = data.get(CONF_SKIP_AUTHENTICATION) if base_url == DEFAULT_CONF_BASE_URL: # Do not set base_url if using OpenAI for case of OpenAI's base_url change @@ -90,7 +103,11 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: data.pop(CONF_BASE_URL) await validate_authentication( - hass=hass, api_key=api_key, base_url=base_url, api_version=api_version + hass=hass, + api_key=api_key, + base_url=base_url, + api_version=api_version, + skip_authentication=skip_authentication, ) @@ -151,63 +168,77 @@ async def async_step_init( return self.async_create_entry( title=user_input.get(CONF_NAME, DEFAULT_NAME), data=user_input ) - schema = openai_config_option_schema(self.config_entry.options) + schema = self.openai_config_option_schema(self.config_entry.options) return self.async_show_form( step_id="init", data_schema=vol.Schema(schema), ) - -def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict: - """Return a schema for OpenAI completion options.""" - if not options: - options = DEFAULT_OPTIONS - return { - vol.Optional( - CONF_PROMPT, - description={"suggested_value": options[CONF_PROMPT]}, - default=DEFAULT_PROMPT, - ): TemplateSelector(), - vol.Optional( - CONF_CHAT_MODEL, - description={ - # New key in HA 2023.4 - "suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) - }, - default=DEFAULT_CHAT_MODEL, - ): str, - vol.Optional( - CONF_MAX_TOKENS, - description={"suggested_value": options[CONF_MAX_TOKENS]}, - default=DEFAULT_MAX_TOKENS, - ): int, - vol.Optional( - CONF_TOP_P, - description={"suggested_value": options[CONF_TOP_P]}, - default=DEFAULT_TOP_P, - ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), - vol.Optional( - CONF_TEMPERATURE, - description={"suggested_value": options[CONF_TEMPERATURE]}, - default=DEFAULT_TEMPERATURE, - ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), - vol.Optional( - CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION, - description={ - "suggested_value": options[CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION] - }, - default=DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, - ): int, - vol.Optional( - CONF_FUNCTIONS, - description={"suggested_value": options.get(CONF_FUNCTIONS)}, - default=DEFAULT_CONF_FUNCTIONS_STR, - ): TemplateSelector(), - vol.Optional( - CONF_ATTACH_USERNAME, - description={ - "suggested_value": options.get(CONF_ATTACH_USERNAME) - }, - default=DEFAULT_ATTACH_USERNAME, - ): BooleanSelector(), - } + def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> dict: + """Return a schema for OpenAI completion options.""" + if not options: + options = DEFAULT_OPTIONS + + is_azure = get_api_type(self.config_entry.data.get(CONF_BASE_URL)) == "azure" + + return { + vol.Optional( + CONF_PROMPT, + description={"suggested_value": options[CONF_PROMPT]}, + default=DEFAULT_PROMPT, + ): TemplateSelector(), + vol.Optional( + CONF_CHAT_MODEL, + description={ + # New key in HA 2023.4 + "suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) + }, + default=DEFAULT_CHAT_MODEL, + ): str, + vol.Optional( + CONF_MAX_TOKENS, + description={"suggested_value": options[CONF_MAX_TOKENS]}, + default=DEFAULT_MAX_TOKENS, + ): int, + vol.Optional( + CONF_TOP_P, + description={"suggested_value": options[CONF_TOP_P]}, + default=DEFAULT_TOP_P, + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + vol.Optional( + CONF_TEMPERATURE, + description={"suggested_value": options[CONF_TEMPERATURE]}, + default=DEFAULT_TEMPERATURE, + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + vol.Optional( + CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION, + description={ + "suggested_value": options[CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION] + }, + default=DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, + ): int, + vol.Optional( + CONF_FUNCTIONS, + description={"suggested_value": options.get(CONF_FUNCTIONS)}, + default=DEFAULT_CONF_FUNCTIONS_STR, + ): TemplateSelector(), + vol.Optional( + CONF_ATTACH_USERNAME, + description={ + "suggested_value": options.get(CONF_ATTACH_USERNAME) + }, + default=DEFAULT_ATTACH_USERNAME, + ): BooleanSelector(), + vol.Optional( + CONF_MODEL_KEY, + description={"suggested_value": options.get(CONF_MODEL_KEY)}, + default="engine" if is_azure else DEFAULT_MODEL_KEY, + ): SelectSelector( + SelectSelectorConfig( + options=[ + SelectOptionDict(value=key, label=key) for key in MODEL_KEYS + ], + mode=SelectSelectorMode.DROPDOWN, + ) + ), + } diff --git a/custom_components/extended_openai_conversation/const.py b/custom_components/extended_openai_conversation/const.py index 778d8fc..56a2233 100644 --- a/custom_components/extended_openai_conversation/const.py +++ b/custom_components/extended_openai_conversation/const.py @@ -2,7 +2,14 @@ DOMAIN = "extended_openai_conversation" DEFAULT_NAME = "Extended OpenAI Conversation" +CONF_BASE_URL = "base_url" +DEFAULT_CONF_BASE_URL = "https://api.openai.com/v1" +CONF_API_VERSION = "api_version" +CONF_SKIP_AUTHENTICATION = "skip_authentication" +DEFAULT_SKIP_AUTHENTICATION = False + EVENT_AUTOMATION_REGISTERED = "automation_registered_via_extended_openai_conversation" + CONF_PROMPT = "prompt" DEFAULT_PROMPT = """I want you to act as smart home manager of Home Assistant. I will provide information of smart home along with a question, you will truthfully make correction or answer using information provided in one sentence in everyday language. @@ -75,8 +82,8 @@ "function": {"type": "native", "name": "execute_service"}, } ] -CONF_BASE_URL = "base_url" -DEFAULT_CONF_BASE_URL = "https://api.openai.com/v1" -CONF_API_VERSION = "api_version" CONF_ATTACH_USERNAME = "attach_username" DEFAULT_ATTACH_USERNAME = False +CONF_MODEL_KEY = "model_key" +DEFAULT_MODEL_KEY = "model" +MODEL_KEYS = ["model", "engine"] diff --git a/custom_components/extended_openai_conversation/helpers.py b/custom_components/extended_openai_conversation/helpers.py index 86f6e58..a500ff1 100644 --- a/custom_components/extended_openai_conversation/helpers.py +++ b/custom_components/extended_openai_conversation/helpers.py @@ -4,11 +4,12 @@ import yaml import time import sqlite3 +import openai +import re import voluptuous as vol +from functools import partial from bs4 import BeautifulSoup from typing import Any -from homeassistant.helpers.aiohttp_client import async_get_clientsession -from openai.error import AuthenticationError from urllib import parse from datetime import timedelta @@ -63,6 +64,9 @@ _LOGGER = logging.getLogger(__name__) +AZURE_DOMAIN_PATTERN = r"\.openai\.azure\.com" + + def get_function_executor(value: str): function_executor = FUNCTION_EXECUTORS.get(value) if function_executor is None: @@ -70,6 +74,12 @@ def get_function_executor(value: str): return function_executor +def get_api_type(base_url: str): + if base_url and re.search(AZURE_DOMAIN_PATTERN, base_url): + return "azure" + return None + + def convert_to_template( settings, template_keys=["data", "event_data", "target", "service"], @@ -122,25 +132,25 @@ def _get_rest_data(hass, rest_config, arguments): async def validate_authentication( - hass: HomeAssistant, api_key: str, base_url: str, api_version: str + hass: HomeAssistant, + api_key: str, + base_url: str, + api_version: str, + skip_authentication=False, ) -> None: - if not base_url: - base_url = DEFAULT_CONF_BASE_URL - - url = f"{base_url}/models" - if api_version: - url = f"{url}?api-version={api_version}" - - session = async_get_clientsession(hass) - response = await session.get( - url, - headers={"Authorization": f"Bearer {api_key}"}, - timeout=10, + if skip_authentication: + return + + await hass.async_add_executor_job( + partial( + openai.Model.list, + api_type=get_api_type(base_url), + api_key=api_key, + api_version=api_version, + api_base=base_url, + request_timeout=10, + ) ) - if response.status == 401: - raise AuthenticationError() - - response.raise_for_status() class FunctionExecutor(ABC): diff --git a/custom_components/extended_openai_conversation/strings.json b/custom_components/extended_openai_conversation/strings.json index 06b5512..752e35f 100644 --- a/custom_components/extended_openai_conversation/strings.json +++ b/custom_components/extended_openai_conversation/strings.json @@ -6,7 +6,8 @@ "name": "[%key:common::config_flow::data::name%]", "api_key": "[%key:common::config_flow::data::api_key%]", "base_url": "[%key:common::config_flow::data::base_url%]", - "api_version": "[%key:common::config_flow::data::api_version%]" + "api_version": "[%key:common::config_flow::data::api_version%]", + "skip_authentication": "[%key:common::config_flow::data::skip_authentication%]" } } }, @@ -27,7 +28,8 @@ "top_p": "Top P", "max_function_calls_per_conversation": "Maximum function calls per conversation", "functions": "Functions", - "attach_username": "Attach Username to Message" + "attach_username": "Attach Username to Message", + "model_key": "Model Key" } } } diff --git a/custom_components/extended_openai_conversation/translations/de.json b/custom_components/extended_openai_conversation/translations/de.json index 5aa8899..9c27154 100644 --- a/custom_components/extended_openai_conversation/translations/de.json +++ b/custom_components/extended_openai_conversation/translations/de.json @@ -11,7 +11,8 @@ "name": "Name", "api_key": "API Key", "base_url": "Base Url", - "api_version": "Api Version" + "api_version": "Api Version", + "skip_authentication": "Skip Authentication" } } } @@ -26,7 +27,9 @@ "temperature": "Temperatur", "top_p": "Top P", "max_function_calls_per_conversation": "Maximale Anzahl an Funktionsaufrufen pro Konversation", - "functions": "Funktionen" + "functions": "Funktionen", + "attach_username": "Attach Username to Message", + "model_key": "Model Key" } } } diff --git a/custom_components/extended_openai_conversation/translations/en.json b/custom_components/extended_openai_conversation/translations/en.json index 61644a3..0a46be2 100644 --- a/custom_components/extended_openai_conversation/translations/en.json +++ b/custom_components/extended_openai_conversation/translations/en.json @@ -11,7 +11,8 @@ "name": "Name", "api_key": "API Key", "base_url": "Base Url", - "api_version": "Api Version" + "api_version": "Api Version", + "skip_authentication": "Skip Authentication" } } } @@ -27,7 +28,8 @@ "top_p": "Top P", "max_function_calls_per_conversation": "Maximum function calls per conversation", "functions": "Functions", - "attach_username": "Attach Username to Message" + "attach_username": "Attach Username to Message", + "model_key": "Model Key" } } } diff --git a/custom_components/extended_openai_conversation/translations/ko.json b/custom_components/extended_openai_conversation/translations/ko.json index 61644a3..0a46be2 100644 --- a/custom_components/extended_openai_conversation/translations/ko.json +++ b/custom_components/extended_openai_conversation/translations/ko.json @@ -11,7 +11,8 @@ "name": "Name", "api_key": "API Key", "base_url": "Base Url", - "api_version": "Api Version" + "api_version": "Api Version", + "skip_authentication": "Skip Authentication" } } } @@ -27,7 +28,8 @@ "top_p": "Top P", "max_function_calls_per_conversation": "Maximum function calls per conversation", "functions": "Functions", - "attach_username": "Attach Username to Message" + "attach_username": "Attach Username to Message", + "model_key": "Model Key" } } } diff --git a/custom_components/extended_openai_conversation/translations/nl.json b/custom_components/extended_openai_conversation/translations/nl.json index 5e0c42a..3fa6114 100644 --- a/custom_components/extended_openai_conversation/translations/nl.json +++ b/custom_components/extended_openai_conversation/translations/nl.json @@ -11,7 +11,8 @@ "name": "Naam", "api_key": "API-sleutel", "base_url": "Basis-URL", - "api_version": "Api Version" + "api_version": "Api Version", + "skip_authentication": "Skip Authentication" } } } @@ -26,7 +27,9 @@ "temperature": "Temperatuur", "top_p": "Top P", "max_function_calls_per_conversation": "Maximale keren functies mogen worden aangeroepen per conversatie", - "functions": "Functies" + "functions": "Functies", + "attach_username": "Attach Username to Message", + "model_key": "Model Key" } } }