Skip to content

Commit

Permalink
[#100] Add "organization" on setup
Browse files Browse the repository at this point in the history
  • Loading branch information
jekalmin committed Feb 8, 2024
1 parent d2abd50 commit 263d87a
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 54 deletions.
63 changes: 32 additions & 31 deletions custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,82 @@
"""The OpenAI Conversation integration."""
from __future__ import annotations

import json
import logging
from typing import Literal
import json
import yaml

from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai._exceptions import AuthenticationError, OpenAIError
from openai.types.chat.chat_completion import (
Choice,
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai._exceptions import OpenAIError, AuthenticationError
import yaml

from homeassistant.components import conversation
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, MATCH_ALL, ATTR_NAME
from homeassistant.const import ATTR_NAME, CONF_API_KEY, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import ulid
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.exceptions import (
ConfigEntryNotReady,
HomeAssistantError,
TemplateError,
)

from homeassistant.helpers import (
config_validation as cv,
entity_registry as er,
intent,
template,
entity_registry as er,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import ulid

from .const import (
CONF_API_VERSION,
CONF_ATTACH_USERNAME,
CONF_BASE_URL,
CONF_CHAT_MODEL,
CONF_CONTEXT_THRESHOLD,
CONF_CONTEXT_TRUNCATE_STRATEGY,
CONF_FUNCTIONS,
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
CONF_MAX_TOKENS,
CONF_ORGANIZATION,
CONF_PROMPT,
CONF_SKIP_AUTHENTICATION,
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
CONF_FUNCTIONS,
CONF_BASE_URL,
CONF_API_VERSION,
CONF_SKIP_AUTHENTICATION,
CONF_USE_TOOLS,
CONF_CONTEXT_THRESHOLD,
CONF_CONTEXT_TRUNCATE_STRATEGY,
DEFAULT_ATTACH_USERNAME,
DEFAULT_CHAT_MODEL,
DEFAULT_CONF_FUNCTIONS,
DEFAULT_CONTEXT_THRESHOLD,
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_CONF_FUNCTIONS,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_USE_TOOLS,
DEFAULT_CONTEXT_THRESHOLD,
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
DOMAIN,
)

from .exceptions import (
FunctionNotFound,
FunctionLoadFailed,
ParseArgumentsFailed,
FunctionNotFound,
InvalidFunction,
ParseArgumentsFailed,
)

from .helpers import (
validate_authentication,
get_function_executor,
is_azure,
is_exposed,
validate_authentication,
)

from .services import async_setup_services


_LOGGER = logging.getLogger(__name__)

CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
Expand All @@ -104,6 +101,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
api_key=entry.data[CONF_API_KEY],
base_url=entry.data.get(CONF_BASE_URL),
api_version=entry.data.get(CONF_API_VERSION),
organization=entry.data.get(CONF_ORGANIZATION),
skip_authentication=entry.data.get(
CONF_SKIP_AUTHENTICATION, DEFAULT_SKIP_AUTHENTICATION
),
Expand Down Expand Up @@ -145,10 +143,13 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
api_key=entry.data[CONF_API_KEY],
azure_endpoint=base_url,
api_version=entry.data.get(CONF_API_VERSION),
organization=entry.data.get(CONF_ORGANIZATION),
)
else:
self.client = AsyncOpenAI(
api_key=entry.data[CONF_API_KEY], base_url=base_url
api_key=entry.data[CONF_API_KEY],
base_url=base_url,
organization=entry.data.get(CONF_ORGANIZATION),
)

@property
Expand Down
45 changes: 24 additions & 21 deletions custom_components/extended_openai_conversation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,62 @@

import logging
import types
import yaml
from types import MappingProxyType
from typing import Any

from openai._exceptions import APIConnectionError, AuthenticationError
import voluptuous as vol
import yaml

from homeassistant import config_entries
from homeassistant.const import CONF_NAME, CONF_API_KEY
from homeassistant.const import CONF_API_KEY, CONF_NAME
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.selector import (
BooleanSelector,
NumberSelector,
NumberSelectorConfig,
TemplateSelector,
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
SelectOptionDict,
SelectSelectorMode,
TemplateSelector,
)

from .helpers import validate_authentication

from .const import (
CONF_API_VERSION,
CONF_ATTACH_USERNAME,
CONF_BASE_URL,
CONF_CHAT_MODEL,
CONF_CONTEXT_THRESHOLD,
CONF_CONTEXT_TRUNCATE_STRATEGY,
CONF_FUNCTIONS,
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
CONF_MAX_TOKENS,
CONF_ORGANIZATION,
CONF_PROMPT,
CONF_SKIP_AUTHENTICATION,
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
CONF_FUNCTIONS,
CONF_BASE_URL,
CONF_API_VERSION,
CONF_SKIP_AUTHENTICATION,
CONF_USE_TOOLS,
CONF_CONTEXT_THRESHOLD,
CONF_CONTEXT_TRUNCATE_STRATEGY,
CONTEXT_TRUNCATE_STRATEGIES,
DEFAULT_ATTACH_USERNAME,
DEFAULT_CHAT_MODEL,
DEFAULT_CONF_BASE_URL,
DEFAULT_CONF_FUNCTIONS,
DEFAULT_CONTEXT_THRESHOLD,
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_MAX_TOKENS,
DEFAULT_NAME,
DEFAULT_PROMPT,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_CONF_FUNCTIONS,
DEFAULT_CONF_BASE_URL,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_USE_TOOLS,
DEFAULT_CONTEXT_THRESHOLD,
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
CONTEXT_TRUNCATE_STRATEGIES,
DOMAIN,
DEFAULT_NAME,
)
from .helpers import validate_authentication

_LOGGER = logging.getLogger(__name__)

Expand All @@ -68,6 +68,7 @@
vol.Required(CONF_API_KEY): str,
vol.Optional(CONF_BASE_URL, default=DEFAULT_CONF_BASE_URL): str,
vol.Optional(CONF_API_VERSION): str,
vol.Optional(CONF_ORGANIZATION): str,
vol.Optional(
CONF_SKIP_AUTHENTICATION, default=DEFAULT_SKIP_AUTHENTICATION
): bool,
Expand Down Expand Up @@ -101,6 +102,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
api_key = data[CONF_API_KEY]
base_url = data.get(CONF_BASE_URL)
api_version = data.get(CONF_API_VERSION)
organization = data.get(CONF_ORGANIZATION)
skip_authentication = data.get(CONF_SKIP_AUTHENTICATION)

if base_url == DEFAULT_CONF_BASE_URL:
Expand All @@ -113,6 +115,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
api_key=api_key,
base_url=base_url,
api_version=api_version,
organization=organization,
skip_authentication=skip_authentication,
)

Expand Down
1 change: 1 addition & 0 deletions custom_components/extended_openai_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

DOMAIN = "extended_openai_conversation"
DEFAULT_NAME = "Extended OpenAI Conversation"
CONF_ORGANIZATION = "organization"
CONF_BASE_URL = "base_url"
DEFAULT_CONF_BASE_URL = "https://api.openai.com/v1"
CONF_API_VERSION = "api_version"
Expand Down
10 changes: 8 additions & 2 deletions custom_components/extended_openai_conversation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,23 @@ async def validate_authentication(
api_key: str,
base_url: str,
api_version: str,
organization: str = None,
skip_authentication=False,
) -> None:
if skip_authentication:
return

if is_azure(base_url):
client = AsyncAzureOpenAI(
api_key=api_key, azure_endpoint=base_url, api_version=api_version
api_key=api_key,
azure_endpoint=base_url,
api_version=api_version,
organization=organization,
)
else:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
client = AsyncOpenAI(
api_key=api_key, base_url=base_url, organization=organization
)

await client.models.list(timeout=10)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"api_key": "[%key:common::config_flow::data::api_key%]",
"base_url": "[%key:common::config_flow::data::base_url%]",
"api_version": "[%key:common::config_flow::data::api_version%]",
"organization": "[%key:common::config_flow::data::organization%]",
"skip_authentication": "[%key:common::config_flow::data::skip_authentication%]"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"api_key": "API Key",
"base_url": "Base Url",
"api_version": "Api Version",
"organization": "Organization",
"skip_authentication": "Authentifizierung überspringen"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"api_key": "API Key",
"base_url": "Base Url",
"api_version": "Api Version",
"organization": "Organization",
"skip_authentication": "Skip Authentication"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"api_key": "Clé d'API",
"base_url": "Base de l'URL",
"api_version": "Version de l'API",
"organization": "Organization",
"skip_authentication": "Ignorer l'authentification"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"api_key": "API Kulcs",
"base_url": "Base Url",
"api_version": "API Verzió",
"organization": "Organization",
"skip_authentication": "Azonosítás átugrása"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"api_key": "API Key",
"base_url": "Base Url",
"api_version": "Api Version",
"organization": "Organization",
"skip_authentication": "Skip Authentication"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"api_key": "API Sleutel",
"base_url": "Basis URL",
"api_version": "API Version",
"organization": "Organization",
"skip_authentication": "Authenticatie overslaan"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"api_key": "Klucz API",
"base_url": "Bazowy URL",
"api_version": "Wersja API",
"organization": "Organization",
"skip_authentication": "Pomiń authentykację"
}
}
Expand Down

0 comments on commit 263d87a

Please sign in to comment.