Skip to content

Commit

Permalink
Vectorset API improvements (#2836)
Browse files Browse the repository at this point in the history
* Move vectorset API logic to a single module

* Learning should never throw binary errors

* Return a structured response, not a proxied string

* Return better status codes and response

* Use raise HTTPException instead of return (for typing)

* Do not magically proxy content types, we return JSON
  • Loading branch information
jotare authored Feb 4, 2025
1 parent 84d4c14 commit e38eedd
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 165 deletions.
10 changes: 6 additions & 4 deletions nucliadb/src/nucliadb/learning_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,23 @@ def into_semantic_model_metadata(self) -> knowledgebox_pb2.SemanticModelMetadata


class ProxiedLearningConfigError(Exception):
def __init__(self, status_code: int, content: bytes, content_type: str):
def __init__(self, status_code: int, content: Union[str, dict[str, Any]]):
self.status_code = status_code
self.content = content
self.content_type = content_type


def raise_for_status(response: httpx.Response) -> None:
try:
response.raise_for_status()
except httpx.HTTPStatusError as err:
content_type = err.response.headers.get("Content-Type", "application/json")
if content_type == "application/json":
content = err.response.json()
else:
content = err.response.text
raise ProxiedLearningConfigError(
status_code=err.response.status_code,
content=err.response.content,
content_type=content_type,
content=content,
)


Expand Down
136 changes: 117 additions & 19 deletions nucliadb/src/nucliadb/writer/api/v1/vectorsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,62 +18,160 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#

from fastapi import Response
from fastapi import HTTPException, Response
from fastapi_versioning import version
from starlette.requests import Request

from nucliadb import learning_proxy
from nucliadb.common import datamanagers
from nucliadb.ingest.orm.exceptions import VectorSetConflict
from nucliadb.models.responses import HTTPConflict
from nucliadb.writer import vectorsets
from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
from nucliadb.writer import logger
from nucliadb.writer.api.v1.router import KB_PREFIX, api
from nucliadb_models.resource import (
NucliaDBRoles,
)
from nucliadb_models.vectorsets import CreatedVectorSet
from nucliadb_protos import knowledgebox_pb2
from nucliadb_telemetry import errors
from nucliadb_utils.authentication import requires_one
from nucliadb_utils.utilities import get_storage


@api.post(
f"/{KB_PREFIX}/{{kbid}}/vectorsets/{{vectorset_id}}",
status_code=200,
status_code=201,
summary="Add a vectorset to Knowledge Box",
tags=["Knowledge Boxes"],
tags=["VectorSets"],
# TODO: remove when the feature is mature
include_in_schema=False,
)
@requires_one([NucliaDBRoles.MANAGER, NucliaDBRoles.WRITER])
@version(1)
async def add_vectorset(request: Request, kbid: str, vectorset_id: str) -> Response:
async def add_vectorset(request: Request, kbid: str, vectorset_id: str) -> CreatedVectorSet:
try:
await vectorsets.add(kbid, vectorset_id)
await _add_vectorset(kbid, vectorset_id)

except learning_proxy.ProxiedLearningConfigError as err:
return Response(
raise HTTPException(
status_code=err.status_code,
content=err.content,
media_type=err.content_type,
detail=err.content,
)

except VectorSetConflict:
raise HTTPException(
status_code=409,
detail="A vectorset with this embedding model already exists in your KB",
)

return CreatedVectorSet(id=vectorset_id)


async def _add_vectorset(kbid: str, vectorset_id: str) -> None:
# First off, add the vectorset to the learning configuration if it's not already there
lconfig = await learning_proxy.get_configuration(kbid)
assert lconfig is not None
semantic_models = lconfig.model_dump()["semantic_models"]
if vectorset_id not in semantic_models:
semantic_models.append(vectorset_id)
await learning_proxy.update_configuration(kbid, {"semantic_models": semantic_models})
lconfig = await learning_proxy.get_configuration(kbid)
assert lconfig is not None

# Then, add the vectorset to the index if it's not already there
storage = await get_storage()
vectorset_config = get_vectorset_config(lconfig, vectorset_id)
async with datamanagers.with_rw_transaction() as txn:
kbobj = KnowledgeBox(txn, storage, kbid)
await kbobj.create_vectorset(vectorset_config)
await txn.commit()


def get_vectorset_config(
learning_config: learning_proxy.LearningConfiguration, vectorset_id: str
) -> knowledgebox_pb2.VectorSetConfig:
"""
Create a VectorSetConfig from a LearningConfiguration for a given vectorset_id
"""
vectorset_config = knowledgebox_pb2.VectorSetConfig(vectorset_id=vectorset_id)
vectorset_index_config = knowledgebox_pb2.VectorIndexConfig(
vector_type=knowledgebox_pb2.VectorType.DENSE_F32,
)
model_config = learning_config.semantic_model_configs[vectorset_id]

# Parse similarity function
parsed_similarity = learning_proxy.SimilarityFunction(model_config.similarity)
if parsed_similarity == learning_proxy.SimilarityFunction.COSINE.value:
vectorset_index_config.similarity = knowledgebox_pb2.VectorSimilarity.COSINE
elif parsed_similarity == learning_proxy.SimilarityFunction.DOT.value:
vectorset_index_config.similarity = knowledgebox_pb2.VectorSimilarity.DOT
else:
raise ValueError(
f"Unknown similarity function {model_config.similarity}, parsed as {parsed_similarity}"
)
return Response(status_code=200)

# Parse vector dimension
vectorset_index_config.vector_dimension = model_config.size

# Parse matryoshka dimensions
if len(model_config.matryoshka_dims) > 0:
vectorset_index_config.normalize_vectors = True
vectorset_config.matryoshka_dimensions.extend(model_config.matryoshka_dims)
else:
vectorset_index_config.normalize_vectors = False
vectorset_config.vectorset_index_config.CopyFrom(vectorset_index_config)
return vectorset_config


@api.delete(
f"/{KB_PREFIX}/{{kbid}}/vectorsets/{{vectorset_id}}",
status_code=200,
status_code=204,
summary="Delete vectorset from Knowledge Box",
tags=["Knowledge Boxes"],
tags=["VectorSets"],
# TODO: remove when the feature is mature
include_in_schema=False,
)
@requires_one([NucliaDBRoles.MANAGER, NucliaDBRoles.WRITER])
@version(1)
async def delete_vectorset(request: Request, kbid: str, vectorset_id: str) -> Response:
try:
await vectorsets.delete(kbid, vectorset_id)
await _delete_vectorset(kbid, vectorset_id)

except VectorSetConflict as exc:
return HTTPConflict(detail=str(exc))
raise HTTPException(
status_code=409,
detail=str(exc),
)

except learning_proxy.ProxiedLearningConfigError as err:
return Response(
raise HTTPException(
status_code=err.status_code,
content=err.content,
media_type=err.content_type,
detail=err.content,
)

return Response(status_code=204)


async def _delete_vectorset(kbid: str, vectorset_id: str) -> None:
lconfig = await learning_proxy.get_configuration(kbid)
if lconfig is not None:
semantic_models = lconfig.model_dump()["semantic_models"]
if vectorset_id in semantic_models:
semantic_models.remove(vectorset_id)
await learning_proxy.update_configuration(kbid, {"semantic_models": semantic_models})

storage = await get_storage()
try:
async with datamanagers.with_rw_transaction() as txn:
kbobj = KnowledgeBox(txn, storage, kbid)
await kbobj.delete_vectorset(vectorset_id=vectorset_id)
await txn.commit()

except VectorSetConflict:
# caller should handle this error
raise
except Exception as ex:
errors.capture_exception(ex)
logger.exception(
"Could not delete vectorset from index", extra={"kbid": kbid, "vectorset_id": vectorset_id}
)
return Response(status_code=200)
132 changes: 0 additions & 132 deletions nucliadb/src/nucliadb/writer/vectorsets.py

This file was deleted.

14 changes: 7 additions & 7 deletions nucliadb/tests/nucliadb/integration/test_vectorsets_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from tests.utils.dirty_index import mark_dirty, wait_for_sync
from tests.utils.vectorsets import add_vectorset

MODULE = "nucliadb.writer.vectorsets"
MODULE = "nucliadb.writer.api.v1.vectorsets"


async def test_add_delete_vectorsets(
Expand Down Expand Up @@ -100,7 +100,7 @@ async def test_add_delete_vectorsets(
with patch(f"{MODULE}.learning_proxy.update_configuration", return_value=None):
# Add the vectorset
resp = await nucliadb_manager.post(f"/kb/{kbid}/vectorsets/{vectorset_id}")
assert resp.status_code == 200, resp.text
assert resp.status_code == 201, resp.text

# Check that the vectorset has been created with the correct configuration
async with datamanagers.with_ro_transaction() as txn:
Expand All @@ -117,7 +117,7 @@ async def test_add_delete_vectorsets(

# Delete the vectorset
resp = await nucliadb_manager.delete(f"/kb/{kbid}/vectorsets/{vectorset_id}")
assert resp.status_code == 200, resp.text
assert resp.status_code == 204, resp.text

# Check that the vectorset has been deleted
async with datamanagers.with_ro_transaction() as txn:
Expand All @@ -144,16 +144,16 @@ async def test_learning_config_errors_are_proxied_correctly(
with patch(
f"{MODULE}.learning_proxy.get_configuration",
side_effect=ProxiedLearningConfigError(
status_code=500, content=b"Learning Internal Server Error", content_type="text/plain"
status_code=500, content="Learning Internal Server Error"
),
):
resp = await nucliadb_manager.post(f"/kb/{kbid}/vectorsets/foo")
assert resp.status_code == 500
assert resp.text == "Learning Internal Server Error"
assert resp.json() == {"detail": "Learning Internal Server Error"}

resp = await nucliadb_manager.delete(f"/kb/{kbid}/vectorsets/foo")
assert resp.status_code == 500
assert resp.text == "Learning Internal Server Error"
assert resp.json() == {"detail": "Learning Internal Server Error"}


@pytest.mark.parametrize("bwc_with_default_vectorset", [True, False])
Expand Down Expand Up @@ -245,7 +245,7 @@ async def test_vectorset_migration(
resp = await add_vectorset(
nucliadb_manager, kbid, vectorset_id, similarity=SimilarityFunction.COSINE, vector_dimension=1024
)
assert resp.status_code == 200
assert resp.status_code == 201

# Ingest a new broker message as if it was coming from the migration
bm2 = BrokerMessage(
Expand Down
Loading

0 comments on commit e38eedd

Please sign in to comment.