diff --git a/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py b/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py index 0e904a369f..89db2cbf47 100644 --- a/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py +++ b/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py @@ -16,11 +16,13 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import asyncio +from __future__ import annotations + import enum import inspect import io import warnings +from dataclasses import dataclass from json import JSONDecodeError from typing import ( Any, @@ -34,6 +36,7 @@ Optional, Tuple, Type, + TypeVar, Union, ) @@ -95,6 +98,11 @@ ) from nucliadb_sdk.v2 import exceptions +# Generics +OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[BaseModel, None]) + +INPUT_TYPE = TypeVar("INPUT_TYPE", BaseModel, List[InputMessage], None) + class Region(enum.Enum): EUROPE1 = "europe-1" @@ -102,24 +110,307 @@ class Region(enum.Enum): AWS_US_EAST_2_1 = "aws-us-east-2-1" -RawRequestContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] +RawRequestContent = Union[ + str, bytes, Iterable[bytes], AsyncIterable[bytes], dict[str, Any] +] ASK_STATUS_CODE_ERROR = "-1" -def json_response_parser(response: httpx.Response) -> Any: - return orjson.loads(response.content.decode()) +@dataclass +class SDKDefinition: + method: str + path_template: str + path_params: Tuple[str, ...] + + +SDK_DEFINITION = { + # Knowledge Box Endpoints + "create_knowledge_box": SDKDefinition( + path_template="/v1/kbs", + method="POST", + path_params=(), + ), + "delete_knowledge_box": SDKDefinition( + path_template="/v1/kb/{kbid}", + method="DELETE", + path_params=("kbid",), + ), + "get_knowledge_box": SDKDefinition( + path_template="/v1/kb/{kbid}", + method="GET", + path_params=("kbid",), + ), + "get_knowledge_box_by_slug": SDKDefinition( + path_template="/v1/kb/s/{slug}", + method="GET", + path_params=("slug",), + ), + "list_knowledge_boxes": SDKDefinition( + path_template="/v1/kbs", + method="GET", + path_params=(), + ), + # Resource Endpoints + "create_resource": SDKDefinition( + path_template="/v1/kb/{kbid}/resources", + method="POST", + path_params=("kbid",), + ), + "update_resource": SDKDefinition( + path_template="/v1/kb/{kbid}/resource/{rid}", + method="PATCH", + path_params=("kbid", "rid"), + ), + "update_resource_by_slug": SDKDefinition( + path_template="/v1/kb/{kbid}/slug/{rslug}", + method="PATCH", + path_params=("kbid", "rslug"), + ), + "delete_resource": SDKDefinition( + path_template="/v1/kb/{kbid}/resource/{rid}", + method="DELETE", + path_params=("kbid", "rid"), + ), + "delete_resource_by_slug": SDKDefinition( + path_template="/v1/kb/{kbid}/slug/{rslug}", + method="DELETE", + path_params=("kbid", "rslug"), + ), + "get_resource_by_slug": SDKDefinition( + path_template="/v1/kb/{kbid}/slug/{slug}", + method="GET", + path_params=("kbid", "slug"), + ), + "get_resource_by_id": SDKDefinition( + path_template="/v1/kb/{kbid}/resource/{rid}", + method="GET", + path_params=("kbid", "rid"), + ), + "list_resources": SDKDefinition( + path_template="/v1/kb/{kbid}/resources", + method="GET", + path_params=("kbid",), + ), + # reindex/reprocess + "reindex_resource": SDKDefinition( + path_template="/v1/kb/{kbid}/resource/{rid}/reindex", + method="POST", + path_params=("kbid", "rid"), + ), + "reindex_resource_by_slug": SDKDefinition( + path_template="/v1/kb/{kbid}/slug/{slug}/reindex", + method="POST", + path_params=("kbid", "slug"), + ), + "reprocess_resource": SDKDefinition( + path_template="/v1/kb/{kbid}/resource/{rid}/reprocess", + method="POST", + path_params=("kbid", "rid"), + ), + "reprocess_resource_by_slug": SDKDefinition( + path_template="/v1/kb/{kbid}/slug/{slug}/reprocess", + method="POST", + path_params=("kbid", "slug"), + ), + # Field endpoints + "delete_field_by_id": SDKDefinition( + path_template="/v1/kb/{kbid}/resource/{rid}/{field_type}/{field_id}", + method="DELETE", + path_params=("kbid", "rid", "field_type", "field_id"), + ), + # Conversation endpoints + "add_conversation_message": SDKDefinition( + path_template="/v1/kb/{kbid}/resource/{rid}/conversation/{field_id}/messages", + method="PUT", + path_params=("kbid", "rid", "field_id"), + ), + "get_resource_field": SDKDefinition( + path_template="/v1/kb/{kbid}/resource/{rid}/{field_type}/{field_id}", + method="GET", + path_params=("kbid", "rid", "field_type", "field_id"), + ), + "get_resource_field_by_slug": SDKDefinition( + path_template="/v1/kb/{kbid}/slug/{slug}/{field_type}/{field_id}", + method="GET", + path_params=("kbid", "slug", "field_type", "field_id"), + ), + # Labels + "set_labelset": SDKDefinition( + path_template="/v1/kb/{kbid}/labelset/{labelset}", + method="POST", + path_params=("kbid", "labelset"), + ), + "delete_labelset": SDKDefinition( + path_template="/v1/kb/{kbid}/labelset/{labelset}", + method="DELETE", + path_params=("kbid", "labelset"), + ), + "get_labelsets": SDKDefinition( + path_template="/v1/kb/{kbid}/labelsets", + method="GET", + path_params=("kbid",), + ), + "get_labelset": SDKDefinition( + path_template="/v1/kb/{kbid}/labelset/{labelset}", + method="GET", + path_params=("kbid", "labelset"), + ), + # Entity Groups + "create_entitygroup": SDKDefinition( + path_template="/v1/kb/{kbid}/entitiesgroups", + method="POST", + path_params=("kbid",), + ), + "update_entitygroup": SDKDefinition( + path_template="/v1/kb/{kbid}/entitiesgroup/{group}", + method="PATCH", + path_params=("kbid", "group"), + ), + "delete_entitygroup": SDKDefinition( + path_template="/v1/kb/{kbid}/entitiesgroup/{group}", + method="DELETE", + path_params=("kbid", "group"), + ), + "get_entitygroups": SDKDefinition( + path_template="/v1/kb/{kbid}/entitiesgroups", + method="GET", + path_params=("kbid",), + ), + "get_entitygroup": SDKDefinition( + path_template="/v1/kb/{kbid}/entitiesgroup/{group}", + method="GET", + path_params=("kbid", "group"), + ), + # Search / Find Endpoints + "find": SDKDefinition( + path_template="/v1/kb/{kbid}/find", + method="POST", + path_params=("kbid",), + ), + "search": SDKDefinition( + path_template="/v1/kb/{kbid}/search", + method="POST", + path_params=("kbid",), + ), + "ask": SDKDefinition( + path_template="/v1/kb/{kbid}/ask", + method="POST", + path_params=("kbid",), + ), + "ask_on_resource": SDKDefinition( + path_template="/v1/kb/{kbid}/resource/{rid}/ask", + method="POST", + path_params=("kbid", "rid"), + ), + "ask_on_resource_by_slug": SDKDefinition( + path_template="/v1/kb/{kbid}/slug/{slug}/ask", + method="POST", + path_params=("kbid", "slug"), + ), + "summarize": SDKDefinition( + path_template="/v1/kb/{kbid}/summarize", + method="POST", + path_params=("kbid",), + ), + "feedback": SDKDefinition( + path_template="/v1/kb/{kbid}/feedback", + method="POST", + path_params=("kbid",), + ), + "start_export": SDKDefinition( + path_template="/v1/kb/{kbid}/export", + method="POST", + path_params=("kbid",), + ), + "export_status": SDKDefinition( + path_template="/v1/kb/{kbid}/export/{export_id}/status", + method="GET", + path_params=("kbid", "export_id"), + ), + "download_export": SDKDefinition( + path_template="/v1/kb/{kbid}/export/{export_id}", + method="GET", + path_params=("kbid", "export_id"), + ), + "create_kb_from_import": SDKDefinition( + path_template="/v1/kbs/import", + method="POST", + path_params=(), + ), + "start_import": SDKDefinition( + path_template="/v1/kb/{kbid}/import", + method="POST", + path_params=("kbid",), + ), + "import_status": SDKDefinition( + path_template="/v1/kb/{kbid}/import/{import_id}/status", + method="GET", + path_params=("kbid", "import_id"), + ), + "trainset": SDKDefinition( + path_template="/v1/kb/{kbid}/trainset", + method="GET", + path_params=("kbid",), + ), + # Learning Configuration + "get_configuration": SDKDefinition( + path_template="/v1/kb/{kbid}/configuration", + method="GET", + path_params=("kbid",), + ), + "set_configuration": SDKDefinition( + path_template="/v1/kb/{kbid}/configuration", + method="POST", + path_params=("kbid",), + ), + # Learning models + "download_model": SDKDefinition( + path_template="/v1/kb/{kbid}/models/{model_id}/{filename}", + method="GET", + path_params=("kbid", "model_id", "filename"), + ), + "get_models": SDKDefinition( + path_template="/v1/kb/{kbid}/models", + method="GET", + path_params=("kbid",), + ), + "get_model": SDKDefinition( + path_template="/v1/kb/{kbid}/model/{model_id}", + method="GET", + path_params=("kbid", "model_id"), + ), + # Learning config schema + "get_configuration_schema": SDKDefinition( + path_template="/v1/kb/{kbid}/schema", + method="GET", + path_params=("kbid",), + ), + # Custom synonyms + "set_custom_synonyms": SDKDefinition( + path_template="/v1/kb/{kbid}/custom-synonyms", + method="PUT", + path_params=("kbid",), + ), + "get_custom_synonyms": SDKDefinition( + path_template="/v1/kb/{kbid}/custom-synonyms", + method="GET", + path_params=("kbid",), + ), +} -def ask_response_parser(response: httpx.Response) -> SyncAskResponse: +def ask_response_parser( + response_type: Type[BaseModel], response: httpx.Response +) -> BaseModel: content_type = response.headers.get("Content-Type") if content_type not in ("application/json", "application/x-ndjson"): raise ValueError(f"Unknown content type in response: {content_type}") if content_type == "application/json": # This comes from a request with the X-Synchronous header set to true - return SyncAskResponse.model_validate_json(response.content) + return response_type.model_validate_json(response.content) answer = "" answer_json = None @@ -166,15 +457,17 @@ def ask_response_parser(response: httpx.Response) -> SyncAskResponse: warnings.warn("No retrieval results found in ask response") retrieval_results = KnowledgeboxFindResults(resources={}) - return SyncAskResponse( - answer=answer, - answer_json=answer_json, - status=status, - retrieval_results=retrieval_results, - relations=relations, - learning_id=learning_id, - citations=citations, - metadata=SyncAskMetadata(tokens=tokens, timings=timings), + return response_type.model_validate( + { + "answer": answer, + "answer_json": answer_json, + "status": status, + "retrieval_results": retrieval_results, + "relations": relations, + "learning_id": learning_id, + "citations": citations, + "metadata": SyncAskMetadata(tokens=tokens, timings=timings), + } ) @@ -190,19 +483,6 @@ def _parse_list_of_pydantic( return orjson.dumps(output).decode("utf-8") -def _parse_response( - response_type: Optional[Union[Type[BaseModel], Callable[[httpx.Response], Any]]], - resp: httpx.Response, -) -> Any: - if response_type is not None: - if isinstance(response_type, type) and issubclass(response_type, BaseModel): - return response_type.model_validate_json(resp.content) - else: - return response_type(resp) - else: - return resp.content - - def is_raw_request_content(content: Any) -> bool: return ( isinstance(content, str) @@ -214,78 +494,219 @@ def is_raw_request_content(content: Any) -> bool: ) -def _request_builder( - *, - name: str, - method: str, +def prepare_request_base( + path_template: str, + path_params: Tuple[str, ...], + **kwargs, +): + path_data = {} + for param in path_params: + if param not in kwargs: + raise TypeError(f"Missing required parameter {param}") + path_data[param] = kwargs.pop(param) + + path = path_template.format(**path_data) + + query_params = kwargs.pop("query_params", None) + if len(kwargs) > 0: + raise TypeError(f"Invalid arguments provided: {kwargs}") + return path, query_params + + +def prepare_request( path_template: str, path_params: Tuple[str, ...], - request_type: Optional[Union[Type[BaseModel], List[Any]]], - response_type: Optional[ - Union[ - Type[BaseModel], - Callable[[httpx.Response], BaseModel], - Callable[[httpx.Response], Iterator[bytes]], - ] - ], - stream_response: bool = False, + request_type: Optional[Type[INPUT_TYPE]], + content: Optional[INPUT_TYPE] = None, + **kwargs, ): - def _func(self: "NucliaDB | NucliaDBAsync", content: Optional[Any] = None, **kwargs): - path_data = {} - for param in path_params: - if param not in kwargs: - raise TypeError(f"Missing required parameter {param}") - path_data[param] = kwargs.pop(param) - - path = path_template.format(**path_data) - data = None - raw_content: Optional[RawRequestContent] = None - if request_type is not None: - if content is not None: - try: - if not isinstance(content, request_type): # type: ignore - raise TypeError(f"Expected {request_type}, got {type(content)}") - else: - data = content.model_dump_json(by_alias=True, exclude_unset=True) - except TypeError: - if not isinstance(content, list): - raise - data = _parse_list_of_pydantic(content) + path, query_params = prepare_request_base(path_template, path_params, **kwargs) + data = None + if request_type is not None: + if content is not None: + if not isinstance(content, request_type): + raise TypeError(f"Expected {request_type}, got {type(content)}") + elif isinstance(content, BaseModel): + data = content.model_dump_json(by_alias=True, exclude_unset=True) + elif isinstance(content, list): + data = _parse_list_of_pydantic(content) else: - # pull properties out of kwargs now - content_data = {} + raise TypeError(f"Unknown type {type(content)}") + else: + # pull properties out of kwargs now + content_data: Dict[str, str] = {} + if isinstance(request_type, BaseModel): for key in list(kwargs.keys()): - if key in request_type.model_fields: # type: ignore + if key in request_type.model_fields: content_data[key] = kwargs.pop(key) - data = request_type.model_validate(content_data).model_dump_json( # type: ignore + data = request_type.model_validate(content_data).model_dump_json( by_alias=True, exclude_unset=True ) - elif is_raw_request_content(content): - raw_content = content - query_params = kwargs.pop("query_params", None) - if len(kwargs) > 0: - raise TypeError(f"Invalid arguments provided: {kwargs}") + return path, data, query_params - if not stream_response: - resp = self._request(path, method, data=data, query_params=query_params, content=raw_content) - if asyncio.iscoroutine(resp): - async def _wrapped_resp(): - real_resp = await resp - return _parse_response(response_type, real_resp) +def _request_sync_builder( + name: str, + request_type: Type[INPUT_TYPE], + response_type: Type[OUTPUT_TYPE], +): + sdk_def = SDK_DEFINITION[name] + method = sdk_def.method + path_template = sdk_def.path_template + path_params = sdk_def.path_params + + def _func( + self: NucliaDB, content: Optional[INPUT_TYPE] = None, **kwargs + ) -> OUTPUT_TYPE: + path, data, query_params = prepare_request( + path_template=path_template, + path_params=path_params, + request_type=request_type, + content=content, + **kwargs, + ) + resp = self._request(path, method, data=data, query_params=query_params) + if response_type is not None: + if issubclass(response_type, SyncAskResponse): + return ask_response_parser(response_type, resp) # type: ignore + elif issubclass(response_type, BaseModel): + return response_type.model_validate_json(resp.content) # type: ignore + return None # type: ignore + + return _func - return _wrapped_resp() - else: - return _parse_response(response_type, resp) # type: ignore - else: - resp = self._stream_request(path, method, data=data, query_params=query_params) - return resp + +def _request_json_sync_builder( + name: str, +): + sdk_def = SDK_DEFINITION[name] + method = sdk_def.method + path_template = sdk_def.path_template + path_params = sdk_def.path_params + + def _func( + self: NucliaDB, content: Optional[Dict[str, Any]] = None, **kwargs + ) -> Optional[Dict[str, Any]]: + path, query_params = prepare_request_base( + path_template=path_template, + path_params=path_params, + **kwargs, + ) + resp = self._request(path, method, query_params=query_params, content=content) + try: + return orjson.loads(resp.content.decode()) + except orjson.JSONDecodeError: + return None + + return _func + + +def _request_iterator_sync_builder( + name: str, +): + sdk_def = SDK_DEFINITION[name] + method = sdk_def.method + path_template = sdk_def.path_template + path_params = sdk_def.path_params + + def _func(self: NucliaDB, **kwargs) -> Callable[[Optional[int]], Iterator[bytes]]: + path, query_params = prepare_request_base( + path_template=path_template, + path_params=path_params, + **kwargs, + ) + return self._stream_request(path, method, query_params=query_params) + + return _func + + +def _request_async_builder( + name: str, + request_type: Type[INPUT_TYPE], + response_type: Type[OUTPUT_TYPE], +): + sdk_def = SDK_DEFINITION[name] + method = sdk_def.method + path_template = sdk_def.path_template + path_params = sdk_def.path_params + + async def _func( + self: NucliaDBAsync, content: Optional[INPUT_TYPE] = None, **kwargs + ) -> OUTPUT_TYPE: + path, data, query_params = prepare_request( + path_template=path_template, + path_params=path_params, + request_type=request_type, + content=content, + **kwargs, + ) + resp = await self._request(path, method, data=data, query_params=query_params) + if response_type is not None: + if isinstance(response_type, type) and issubclass( + response_type, SyncAskResponse + ): + return ask_response_parser(response_type, resp) # type: ignore + elif isinstance(response_type, type) and issubclass( + response_type, BaseModel + ): + return response_type.model_validate_json(resp.content) # type: ignore + return None # type: ignore + + return _func + + +def _request_json_async_builder( + name: str, +): + sdk_def = SDK_DEFINITION[name] + method = sdk_def.method + path_template = sdk_def.path_template + path_params = sdk_def.path_params + + async def _func( + self: NucliaDBAsync, content: Optional[Dict[str, Any]] = None, **kwargs + ) -> Optional[Dict[str, Any]]: + path, query_params = prepare_request_base( + path_template=path_template, + path_params=path_params, + **kwargs, + ) + resp = await self._request( + path, method, query_params=query_params, content=content + ) + try: + return orjson.loads(resp.content.decode()) + except orjson.JSONDecodeError: + return None + + return _func + + +def _request_iterator_async_builder( + name: str, +): + sdk_def = SDK_DEFINITION[name] + method = sdk_def.method + path_template = sdk_def.path_template + path_params = sdk_def.path_params + + async def _func( + self: NucliaDBAsync, **kwargs + ) -> Callable[[Optional[int]], AsyncGenerator[bytes, None]]: + path, query_params = prepare_request_base( + path_template=path_template, + path_params=path_params, + **kwargs, + ) + return self._stream_request(path, method, query_params=query_params) return _func class _NucliaDBBase: + sync: bool = True + def __init__( self, *, @@ -345,7 +766,9 @@ def _check_response(self, response: httpx.Response): if response.status_code < 300: return response elif response.status_code in (401, 403): - raise exceptions.AuthError(f"Auth error {response.status_code}: {response.text}") + raise exceptions.AuthError( + f"Auth error {response.status_code}: {response.text}" + ) elif response.status_code == 402: raise exceptions.AccountLimitError( f"Account limits exceeded error {response.status_code}: {response.text}" @@ -364,471 +787,14 @@ def _check_response(self, response: httpx.Response): ): # 419 is a custom error code for kb creation conflict raise exceptions.ConflictError(response.text) elif response.status_code == 404: - raise exceptions.NotFoundError(f"Resource not found at url {response.url}: {response.text}") + raise exceptions.NotFoundError( + f"Resource not found at url {response.url}: {response.text}" + ) else: raise exceptions.UnknownError( f"Unknown error connecting to API: {response.status_code}: {response.text}" ) - # Knowledge Box Endpoints - create_knowledge_box = _request_builder( - name="create_knowledge_box", - path_template="/v1/kbs", - method="POST", - path_params=(), - request_type=KnowledgeBoxConfig, - response_type=KnowledgeBoxObj, - ) - delete_knowledge_box = _request_builder( - name="delete_knowledge_box", - path_template="/v1/kb/{kbid}", - method="DELETE", - path_params=("kbid",), - request_type=None, - response_type=KnowledgeBoxObj, - ) - get_knowledge_box = _request_builder( - name="get_knowledge_box", - path_template="/v1/kb/{kbid}", - method="GET", - path_params=("kbid",), - request_type=None, - response_type=KnowledgeBoxObj, - ) - get_knowledge_box_by_slug = _request_builder( - name="get_knowledge_box_by_slug", - path_template="/v1/kb/s/{slug}", - method="GET", - path_params=("slug",), - request_type=None, - response_type=KnowledgeBoxObj, - ) - list_knowledge_boxes = _request_builder( - name="list_knowledge_boxes", - path_template="/v1/kbs", - method="GET", - path_params=(), - request_type=None, - response_type=KnowledgeBoxList, - ) - - # Resource Endpoints - create_resource = _request_builder( - name="create_resource", - path_template="/v1/kb/{kbid}/resources", - method="POST", - path_params=("kbid",), - request_type=CreateResourcePayload, - response_type=ResourceCreated, - ) - update_resource = _request_builder( - name="update_resource", - path_template="/v1/kb/{kbid}/resource/{rid}", - method="PATCH", - path_params=("kbid", "rid"), - request_type=UpdateResourcePayload, - response_type=ResourceUpdated, - ) - update_resource_by_slug = _request_builder( - name="update_resource_by_slug", - path_template="/v1/kb/{kbid}/slug/{rslug}", - method="PATCH", - path_params=("kbid", "rslug"), - request_type=UpdateResourcePayload, - response_type=ResourceUpdated, - ) - delete_resource = _request_builder( - name="delete_resource", - path_template="/v1/kb/{kbid}/resource/{rid}", - method="DELETE", - path_params=("kbid", "rid"), - request_type=None, - response_type=None, - ) - delete_resource_by_slug = _request_builder( - name="delete_resource_by_slug", - path_template="/v1/kb/{kbid}/slug/{rslug}", - method="DELETE", - path_params=("kbid", "rslug"), - request_type=None, - response_type=None, - ) - get_resource_by_slug = _request_builder( - name="get_resource_by_slug", - path_template="/v1/kb/{kbid}/slug/{slug}", - method="GET", - path_params=("kbid", "slug"), - request_type=None, - response_type=Resource, - ) - get_resource_by_id = _request_builder( - name="get_resource_by_id", - path_template="/v1/kb/{kbid}/resource/{rid}", - method="GET", - path_params=("kbid", "rid"), - request_type=None, - response_type=Resource, - ) - list_resources = _request_builder( - name="list_resources", - path_template="/v1/kb/{kbid}/resources", - method="GET", - path_params=("kbid",), - request_type=None, - response_type=ResourceList, - ) - - # reindex/reprocess - reindex_resource = _request_builder( - name="reindex_resource", - path_template="/v1/kb/{kbid}/resource/{rid}/reindex", - method="POST", - path_params=("kbid", "rid"), - request_type=None, - response_type=None, - ) - reindex_resource_by_slug = _request_builder( - name="reindex_resource_by_slug", - path_template="/v1/kb/{kbid}/slug/{slug}/reindex", - method="POST", - path_params=("kbid", "slug"), - request_type=None, - response_type=None, - ) - reprocess_resource = _request_builder( - name="reprocess_resource", - path_template="/v1/kb/{kbid}/resource/{rid}/reprocess", - method="POST", - path_params=("kbid", "rid"), - request_type=None, - response_type=None, - ) - reprocess_resource_by_slug = _request_builder( - name="reprocess_resource_by_slug", - path_template="/v1/kb/{kbid}/slug/{slug}/reprocess", - method="POST", - path_params=("kbid", "slug"), - request_type=None, - response_type=None, - ) - - # Field endpoints - delete_field_by_id = _request_builder( - name="delete_field_by_id", - path_template="/v1/kb/{kbid}/resource/{rid}/{field_type}/{field_id}", - method="DELETE", - path_params=("kbid", "rid", "field_type", "field_id"), - request_type=None, - response_type=None, - ) - - # Conversation endpoints - add_conversation_message = _request_builder( - name="add_conversation_message", - path_template="/v1/kb/{kbid}/resource/{rid}/conversation/{field_id}/messages", - method="PUT", - path_params=("kbid", "rid", "field_id"), - request_type=List[InputMessage], # type: ignore - response_type=ResourceFieldAdded, - ) - - get_resource_field = _request_builder( - name="get_resource_field", - path_template="/v1/kb/{kbid}/resource/{rid}/{field_type}/{field_id}", - method="GET", - path_params=("kbid", "rid", "field_type", "field_id"), - request_type=None, - response_type=ResourceField, - ) - - get_resource_field_by_slug = _request_builder( - name="get_resource_field_by_slug", - path_template="/v1/kb/{kbid}/slug/{slug}/{field_type}/{field_id}", - method="GET", - path_params=("kbid", "slug", "field_type", "field_id"), - request_type=None, - response_type=ResourceField, - ) - - # Labels - set_labelset = _request_builder( - name="set_labelset", - path_template="/v1/kb/{kbid}/labelset/{labelset}", - method="POST", - path_params=("kbid", "labelset"), - request_type=LabelSet, - response_type=None, - ) - delete_labelset = _request_builder( - name="delete_labelset", - path_template="/v1/kb/{kbid}/labelset/{labelset}", - method="DELETE", - path_params=("kbid", "labelset"), - request_type=None, - response_type=None, - ) - get_labelsets = _request_builder( - name="get_labelsets", - path_template="/v1/kb/{kbid}/labelsets", - method="GET", - path_params=("kbid",), - request_type=None, - response_type=KnowledgeBoxLabels, - ) - get_labelset = _request_builder( - name="get_labelset", - path_template="/v1/kb/{kbid}/labelset/{labelset}", - method="GET", - path_params=("kbid", "labelset"), - request_type=None, - response_type=LabelSet, - ) - - # Entity Groups - create_entitygroup = _request_builder( - name="create_entitygroup", - path_template="/v1/kb/{kbid}/entitiesgroups", - method="POST", - path_params=("kbid",), - request_type=CreateEntitiesGroupPayload, - response_type=None, - ) - update_entitygroup = _request_builder( - name="update_entitygroup", - path_template="/v1/kb/{kbid}/entitiesgroup/{group}", - method="PATCH", - path_params=("kbid", "group"), - request_type=UpdateEntitiesGroupPayload, - response_type=None, - ) - delete_entitygroup = _request_builder( - name="delete_entitygroup", - path_template="/v1/kb/{kbid}/entitiesgroup/{group}", - method="DELETE", - path_params=("kbid", "group"), - request_type=None, - response_type=None, - ) - get_entitygroups = _request_builder( - name="get_entitygroups", - path_template="/v1/kb/{kbid}/entitiesgroups", - method="GET", - path_params=("kbid",), - request_type=None, - response_type=KnowledgeBoxEntities, - ) - get_entitygroup = _request_builder( - name="get_entitygroup", - path_template="/v1/kb/{kbid}/entitiesgroup/{group}", - method="GET", - path_params=("kbid", "group"), - request_type=None, - response_type=EntitiesGroup, - ) - - # Search / Find Endpoints - find = _request_builder( - name="find", - path_template="/v1/kb/{kbid}/find", - method="POST", - path_params=("kbid",), - request_type=FindRequest, - response_type=KnowledgeboxFindResults, - ) - search = _request_builder( - name="search", - path_template="/v1/kb/{kbid}/search", - method="POST", - path_params=("kbid",), - request_type=SearchRequest, - response_type=KnowledgeboxSearchResults, - ) - - ask = _request_builder( - name="ask", - path_template="/v1/kb/{kbid}/ask", - method="POST", - path_params=("kbid",), - request_type=AskRequest, - response_type=ask_response_parser, - ) - - ask_on_resource = _request_builder( - name="ask_on_resource", - path_template="/v1/kb/{kbid}/resource/{rid}/ask", - method="POST", - path_params=("kbid", "rid"), - request_type=AskRequest, - response_type=ask_response_parser, - ) - - ask_on_resource_by_slug = _request_builder( - name="ask_on_resource_by_slug", - path_template="/v1/kb/{kbid}/slug/{slug}/ask", - method="POST", - path_params=("kbid", "slug"), - request_type=AskRequest, - response_type=ask_response_parser, - ) - - summarize = _request_builder( - name="summarize", - path_template="/v1/kb/{kbid}/summarize", - method="POST", - path_params=("kbid",), - request_type=SummarizeRequest, - response_type=SummarizedResponse, - ) - - feedback = _request_builder( - name="feedback", - path_template="/v1/kb/{kbid}/feedback", - method="POST", - path_params=("kbid",), - request_type=FeedbackRequest, - response_type=None, - ) - - start_export = _request_builder( - name="start_export", - path_template="/v1/kb/{kbid}/export", - method="POST", - path_params=("kbid",), - request_type=None, - response_type=CreateExportResponse, - ) - - export_status = _request_builder( - name="export_status", - path_template="/v1/kb/{kbid}/export/{export_id}/status", - method="GET", - path_params=("kbid", "export_id"), - request_type=None, - response_type=StatusResponse, - ) - - download_export = _request_builder( - name="download_export", - path_template="/v1/kb/{kbid}/export/{export_id}", - method="GET", - path_params=("kbid", "export_id"), - request_type=None, - response_type=None, - stream_response=True, - ) - - create_kb_from_import = _request_builder( - name="create_kb_from_import", - path_template="/v1/kbs/import", - method="POST", - path_params=(), - request_type=None, - response_type=NewImportedKbResponse, - ) - - start_import = _request_builder( - name="start_import", - path_template="/v1/kb/{kbid}/import", - method="POST", - path_params=("kbid",), - request_type=None, - response_type=CreateImportResponse, - ) - - import_status = _request_builder( - name="import_status", - path_template="/v1/kb/{kbid}/import/{import_id}/status", - method="GET", - path_params=("kbid", "import_id"), - request_type=None, - response_type=StatusResponse, - ) - - trainset = _request_builder( - name="trainset", - path_template="/v1/kb/{kbid}/trainset", - method="GET", - path_params=("kbid",), - request_type=None, - response_type=TrainSetPartitions, - ) - - # Learning Configuration - get_configuration = _request_builder( - name="get_configuration", - path_template="/v1/kb/{kbid}/configuration", - method="GET", - path_params=("kbid",), - request_type=None, - response_type=json_response_parser, - ) - set_configuration = _request_builder( - name="set_configuration", - path_template="/v1/kb/{kbid}/configuration", - method="POST", - path_params=("kbid",), - request_type=None, - response_type=None, - ) - - # Learning models - download_model = _request_builder( - name="download_model", - path_template="/v1/kb/{kbid}/models/{model_id}/{filename}", - method="GET", - path_params=("kbid", "model_id", "filename"), - stream_response=True, - request_type=None, - response_type=None, - ) - - get_models = _request_builder( - name="get_models", - path_template="/v1/kb/{kbid}/models", - method="GET", - path_params=("kbid",), - request_type=None, - response_type=json_response_parser, - ) - - get_model = _request_builder( - name="get_model", - path_template="/v1/kb/{kbid}/model/{model_id}", - method="GET", - path_params=("kbid", "model_id"), - request_type=None, - response_type=json_response_parser, - ) - - # Learning config schema - get_configuration_schema = _request_builder( - name="get_configuration_schema", - path_template="/v1/kb/{kbid}/schema", - method="GET", - path_params=("kbid",), - request_type=None, - response_type=json_response_parser, - ) - - # Custom synonyms - set_custom_synonyms = _request_builder( - name="set_custom_synonyms", - path_template="/v1/kb/{kbid}/custom-synonyms", - method="PUT", - path_params=("kbid",), - request_type=KnowledgeBoxSynonyms, - response_type=None, - ) - - get_custom_synonyms = _request_builder( - name="get_custom_synonyms", - path_template="/v1/kb/{kbid}/custom-synonyms", - method="GET", - path_params=("kbid",), - request_type=None, - response_type=KnowledgeBoxSynonyms, - ) - class NucliaDB(_NucliaDBBase): """ @@ -878,7 +844,9 @@ def __init__( >>> sdk = NucliaDB(api_key="api-key", region=Region.ON_PREM, url=\"http://localhost:8080\") """ # noqa super().__init__(region=region, api_key=api_key, url=url, headers=headers) - self.session = httpx.Client(headers=self.headers, base_url=self.base_url, timeout=timeout) + self.session = httpx.Client( + headers=self.headers, base_url=self.base_url, timeout=timeout + ) def _request( self, @@ -918,13 +886,143 @@ def _stream_request( opts["params"] = query_params def iter_bytes(chunk_size=None) -> Iterator[bytes]: - with self.session.stream(method.lower(), url=url, **opts, timeout=30.0) as response: + with self.session.stream( + method.lower(), url=url, **opts, timeout=30.0 + ) as response: self._check_response(response) for chunk in response.iter_raw(chunk_size=chunk_size): yield chunk return iter_bytes + create_knowledge_box = _request_sync_builder( + "create_knowledge_box", KnowledgeBoxConfig, KnowledgeBoxObj + ) + delete_knowledge_box = _request_sync_builder( + "delete_knowledge_box", type(None), KnowledgeBoxObj + ) + get_knowledge_box = _request_sync_builder( + "get_knowledge_box", type(None), KnowledgeBoxObj + ) + get_knowledge_box_by_slug = _request_sync_builder( + "get_knowledge_box_by_slug", type(None), KnowledgeBoxObj + ) + list_knowledge_boxes = _request_sync_builder( + "list_knowledge_boxes", type(None), KnowledgeBoxList + ) + # Resource Endpoints + create_resource = _request_sync_builder( + "create_resource", CreateResourcePayload, ResourceCreated + ) + update_resource = _request_sync_builder( + "update_resource", UpdateResourcePayload, ResourceUpdated + ) + update_resource_by_slug = _request_sync_builder( + "update_resource_by_slug", UpdateResourcePayload, ResourceUpdated + ) + delete_resource = _request_sync_builder("delete_resource", type(None), type(None)) + delete_resource_by_slug = _request_sync_builder( + "delete_resource_by_slug", type(None), type(None) + ) + get_resource_by_slug = _request_sync_builder( + "get_resource_by_slug", type(None), Resource + ) + get_resource_by_id = _request_sync_builder( + "get_resource_by_id", type(None), Resource + ) + list_resources = _request_sync_builder("list_resources", type(None), ResourceList) + # reindex/reprocess + reindex_resource = _request_sync_builder("reindex_resource", type(None), type(None)) + reindex_resource_by_slug = _request_sync_builder( + "reindex_resource_by_slug", type(None), type(None) + ) + reprocess_resource = _request_sync_builder( + "reprocess_resource", type(None), type(None) + ) + reprocess_resource_by_slug = _request_sync_builder( + "reprocess_resource_by_slug", type(None), type(None) + ) + # Field endpoints + delete_field_by_id = _request_sync_builder( + "delete_field_by_id", type(None), type(None) + ) + # Conversation endpoints + add_conversation_message = _request_sync_builder( + "add_conversation_message", List[InputMessage], ResourceFieldAdded + ) + get_resource_field = _request_sync_builder( + "get_resource_field", type(None), ResourceField + ) + get_resource_field_by_slug = _request_sync_builder( + "get_resource_field_by_slug", type(None), ResourceField + ) + # Labels + set_labelset = _request_sync_builder("set_labelset", LabelSet, type(None)) + delete_labelset = _request_sync_builder("delete_labelset", type(None), type(None)) + get_labelsets = _request_sync_builder( + "get_labelsets", type(None), KnowledgeBoxLabels + ) + get_labelset = _request_sync_builder("get_labelset", type(None), LabelSet) + # Entity Groups + create_entitygroup = _request_sync_builder( + "create_entitygroup", CreateEntitiesGroupPayload, type(None) + ) + update_entitygroup = _request_sync_builder( + "update_entitygroup", UpdateEntitiesGroupPayload, type(None) + ) + delete_entitygroup = _request_sync_builder( + "delete_entitygroup", type(None), type(None) + ) + get_entitygroups = _request_sync_builder( + "get_entitygroups", type(None), KnowledgeBoxEntities + ) + get_entitygroup = _request_sync_builder( + "get_entitygroup", type(None), EntitiesGroup + ) + # Search / Find Endpoints + find = _request_sync_builder("find", FindRequest, KnowledgeboxFindResults) + search = _request_sync_builder("search", SearchRequest, KnowledgeboxSearchResults) + ask = _request_sync_builder("ask", AskRequest, SyncAskResponse) + ask_on_resource = _request_sync_builder( + "ask_on_resource", AskRequest, SyncAskResponse + ) + ask_on_resource_by_slug = _request_sync_builder( + "ask_on_resource_by_slug", AskRequest, SyncAskResponse + ) + summarize = _request_sync_builder("summarize", SummarizeRequest, SummarizedResponse) + feedback = _request_sync_builder("feedback", FeedbackRequest, type(None)) + start_export = _request_sync_builder( + "start_export", type(None), CreateExportResponse + ) + export_status = _request_sync_builder("export_status", type(None), StatusResponse) + download_export = _request_iterator_sync_builder("download_export") + create_kb_from_import = _request_sync_builder( + "create_kb_from_import", type(None), NewImportedKbResponse + ) + start_import = _request_sync_builder( + "start_import", type(None), CreateImportResponse + ) + import_status = _request_sync_builder("import_status", type(None), StatusResponse) + trainset = _request_sync_builder("trainset", type(None), TrainSetPartitions) + # Learning Configuration + get_configuration = _request_json_sync_builder("get_configuration") + set_configuration = _request_json_sync_builder("set_configuration") + + # Learning models + download_model = _request_iterator_sync_builder("download_model") + get_models = _request_json_sync_builder("get_models") + get_model = _request_json_sync_builder("get_model") + + # Learning config schema + get_configuration_schema = _request_json_sync_builder("get_configuration_schema") + # Custom synonyms + set_custom_synonyms = _request_sync_builder( + "set_custom_synonyms", KnowledgeBoxSynonyms, type(None) + ) + get_custom_synonyms = _request_sync_builder( + "get_custom_synonyms", type(None), KnowledgeBoxSynonyms + ) + class NucliaDBAsync(_NucliaDBBase): """ @@ -972,7 +1070,9 @@ def __init__( >>> sdk = NucliaDBAsync(api_key="api-key", region=Region.ON_PREM, url="https://mycompany.api.com/api/nucliadb") """ # noqa super().__init__(region=region, api_key=api_key, url=url, headers=headers) - self.session = httpx.AsyncClient(headers=self.headers, base_url=self.base_url, timeout=timeout) + self.session = httpx.AsyncClient( + headers=self.headers, base_url=self.base_url, timeout=timeout + ) async def _request( self, @@ -994,7 +1094,9 @@ async def _request( opts["content"] = content if query_params is not None: opts["params"] = query_params - response: httpx.Response = await getattr(self.session, method.lower())(url, **opts) + response: httpx.Response = await getattr(self.session, method.lower())( + url, **opts + ) return self._check_response(response) def _stream_request( @@ -1018,3 +1120,135 @@ async def iter_bytes(chunk_size=None) -> AsyncGenerator[bytes, None]: yield chunk return iter_bytes + + create_knowledge_box = _request_async_builder( + "create_knowledge_box", KnowledgeBoxConfig, KnowledgeBoxObj + ) + delete_knowledge_box = _request_async_builder( + "delete_knowledge_box", type(None), KnowledgeBoxObj + ) + get_knowledge_box = _request_async_builder( + "get_knowledge_box", type(None), KnowledgeBoxObj + ) + get_knowledge_box_by_slug = _request_async_builder( + "get_knowledge_box_by_slug", type(None), KnowledgeBoxObj + ) + list_knowledge_boxes = _request_async_builder( + "list_knowledge_boxes", type(None), KnowledgeBoxList + ) + # Resource Endpoints + create_resource = _request_async_builder( + "create_resource", CreateResourcePayload, ResourceCreated + ) + update_resource = _request_async_builder( + "update_resource", UpdateResourcePayload, ResourceUpdated + ) + update_resource_by_slug = _request_async_builder( + "update_resource_by_slug", UpdateResourcePayload, ResourceUpdated + ) + delete_resource = _request_async_builder("delete_resource", type(None), type(None)) + delete_resource_by_slug = _request_async_builder( + "delete_resource_by_slug", type(None), type(None) + ) + get_resource_by_slug = _request_async_builder( + "get_resource_by_slug", type(None), Resource + ) + get_resource_by_id = _request_async_builder( + "get_resource_by_id", type(None), Resource + ) + list_resources = _request_async_builder("list_resources", type(None), ResourceList) + # reindex/reprocess + reindex_resource = _request_async_builder( + "reindex_resource", type(None), type(None) + ) + reindex_resource_by_slug = _request_async_builder( + "reindex_resource_by_slug", type(None), type(None) + ) + reprocess_resource = _request_async_builder( + "reprocess_resource", type(None), type(None) + ) + reprocess_resource_by_slug = _request_async_builder( + "reprocess_resource_by_slug", type(None), type(None) + ) + # Field endpoints + delete_field_by_id = _request_async_builder( + "delete_field_by_id", type(None), type(None) + ) + # Conversation endpoints + add_conversation_message = _request_async_builder( + "add_conversation_message", List[InputMessage], ResourceFieldAdded + ) + get_resource_field = _request_async_builder( + "get_resource_field", type(None), ResourceField + ) + get_resource_field_by_slug = _request_async_builder( + "get_resource_field_by_slug", type(None), ResourceField + ) + # Labels + set_labelset = _request_async_builder("set_labelset", LabelSet, type(None)) + delete_labelset = _request_async_builder("delete_labelset", type(None), type(None)) + get_labelsets = _request_async_builder( + "get_labelsets", type(None), KnowledgeBoxLabels + ) + get_labelset = _request_async_builder("get_labelset", type(None), LabelSet) + # Entity Groups + create_entitygroup = _request_async_builder( + "create_entitygroup", CreateEntitiesGroupPayload, type(None) + ) + update_entitygroup = _request_async_builder( + "update_entitygroup", UpdateEntitiesGroupPayload, type(None) + ) + delete_entitygroup = _request_async_builder( + "delete_entitygroup", type(None), type(None) + ) + get_entitygroups = _request_async_builder( + "get_entitygroups", type(None), KnowledgeBoxEntities + ) + get_entitygroup = _request_async_builder( + "get_entitygroup", type(None), EntitiesGroup + ) + # Search / Find Endpoints + find = _request_async_builder("find", FindRequest, KnowledgeboxFindResults) + search = _request_async_builder("search", SearchRequest, KnowledgeboxSearchResults) + ask = _request_async_builder("ask", AskRequest, SyncAskResponse) + ask_on_resource = _request_async_builder( + "ask_on_resource", AskRequest, SyncAskResponse + ) + ask_on_resource_by_slug = _request_async_builder( + "ask_on_resource_by_slug", AskRequest, SyncAskResponse + ) + summarize = _request_async_builder( + "summarize", SummarizeRequest, SummarizedResponse + ) + feedback = _request_async_builder("feedback", FeedbackRequest, type(None)) + start_export = _request_async_builder( + "start_export", type(None), CreateExportResponse + ) + export_status = _request_async_builder("export_status", type(None), StatusResponse) + download_export = _request_iterator_async_builder("download_export") + create_kb_from_import = _request_async_builder( + "create_kb_from_import", type(None), NewImportedKbResponse + ) + start_import = _request_async_builder( + "start_import", type(None), CreateImportResponse + ) + import_status = _request_async_builder("import_status", type(None), StatusResponse) + trainset = _request_async_builder("trainset", type(None), TrainSetPartitions) + # Learning Configuration + get_configuration = _request_json_async_builder("get_configuration") + set_configuration = _request_json_async_builder("set_configuration") + + # Learning models + download_model = _request_iterator_async_builder("download_model") + get_models = _request_json_async_builder("get_models") + get_model = _request_json_async_builder("get_model") + + # Learning config schema + get_configuration_schema = _request_json_async_builder("get_configuration_schema") + # Custom synonyms + set_custom_synonyms = _request_async_builder( + "set_custom_synonyms", KnowledgeBoxSynonyms, type(None) + ) + get_custom_synonyms = _request_async_builder( + "get_custom_synonyms", type(None), KnowledgeBoxSynonyms + )