Skip to content

Commit

Permalink
feat: 支持并完善服务提供商默认配置模板接口
Browse files Browse the repository at this point in the history
  • Loading branch information
Soulter committed Jan 12, 2025
1 parent bb05927 commit 7c68181
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 10 deletions.
11 changes: 8 additions & 3 deletions astrbot/core/platform/platform_metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from dataclasses import dataclass
@dataclass
class PlatformMetadata():
name: str # 平台的名称
description: str # 平台的描述
name: str
'''平台的名称'''
description: str
'''平台的描述'''

default_config_tmpl: dict = None # 平台的默认配置模板
default_config_tmpl: dict = None
'''平台的默认配置模板'''
adapter_display_name: str = None
'''显示在 WebUI 配置页中的平台名称,如空则是 name'''
10 changes: 8 additions & 2 deletions astrbot/core/platform/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
platform_cls_map: Dict[str, Type] = {}
'''维护了平台适配器名称和适配器类的映射'''

def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl: dict = None):
def register_platform_adapter(
adapter_name: str,
desc: str,
default_config_tmpl: dict = None,
adapter_display_name: str = None
):
'''用于注册平台适配器的带参装饰器。
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
Expand All @@ -26,7 +31,8 @@ def decorator(cls):
pm = PlatformMetadata(
name=adapter_name,
description=desc,
default_config_tmpl=default_config_tmpl
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls
Expand Down
5 changes: 5 additions & 0 deletions astrbot/core/provider/entites.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ class ProviderMetaData():
'''提供商适配器描述.'''
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Type = None

default_config_tmpl: dict = None
'''平台的默认配置模板'''
provider_display_name: str = None
'''显示在 WebUI 配置页中的提供商名称,如空则是 type'''

@dataclass
class ProviderRequest():
Expand Down
17 changes: 14 additions & 3 deletions astrbot/core/provider/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,33 @@
def register_provider_adapter(
provider_type_name: str,
desc: str,
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
default_config_tmpl: dict = None,
provider_display_name: str = None
):
'''用于注册平台适配器的带参装饰器'''
def decorator(cls):
if provider_type_name in provider_cls_map:
raise ValueError(f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。")

# 添加必备选项
if default_config_tmpl:
if 'type' not in default_config_tmpl:
default_config_tmpl['type'] = provider_type_name
if 'enable' not in default_config_tmpl:
default_config_tmpl['enable'] = False

pm = ProviderMetaData(
type=provider_type_name,
desc=desc,
provider_type=provider_type,
cls_type=cls
cls_type=cls,
default_config_tmpl=default_config_tmpl,
provider_display_name=provider_display_name
)
provider_registry.append(pm)
provider_cls_map[provider_type_name] = pm
logger.debug(f"Provider {provider_type_name} 已注册")
logger.debug(f"服务提供商 Provider {provider_type_name} 已注册")
return cls

return decorator
8 changes: 8 additions & 0 deletions astrbot/dashboard/routes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from astrbot.core.star.config import update_config
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_registry
from astrbot.core.provider.register import provider_registry

def try_cast(value: str, type_: str):
if type_ == "int" and value.isdigit():
Expand Down Expand Up @@ -123,11 +124,18 @@ async def post_extension_configs(self):
async def _get_astrbot_config(self):
config = self.config

# 平台适配器的默认配置模板注入
platform_default_tmpl = CONFIG_METADATA_2['platform_group']['metadata']['platform']['config_template']
for platform in platform_registry:
if platform.default_config_tmpl:
platform_default_tmpl[platform.name] = platform.default_config_tmpl

# 服务提供商的默认配置模板注入
provider_default_tmpl = CONFIG_METADATA_2['provider_group']['metadata']['provider']['config_template']
for provider in provider_registry:
if provider.default_config_tmpl:
provider_default_tmpl[provider.type] = provider.default_config_tmpl

return {
"metadata": CONFIG_METADATA_2,
"config": config
Expand Down
4 changes: 2 additions & 2 deletions packages/python_interpreter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def get_image_name(self) -> str:
return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}"
return self.config["sandbox"]["image"]

async def _save_config(self):
def _save_config(self):
with open(PATH, "w") as f:
json.dump(self.config, f)

Expand Down Expand Up @@ -207,7 +207,7 @@ async def pi_mirror(self, event: AstrMessageEvent, url: str = ""):
""")
else:
self.config["sandbox"]["docker_mirror"] = url
await self._save_config()
self._save_config()
yield event.plain_result("设置 Docker 镜像地址成功。")

@pi.command("repull")
Expand Down

0 comments on commit 7c68181

Please sign in to comment.