Skip to content

Commit

Permalink
fix: patch requests to artifact endpoint make mr panic (#718)
Browse files Browse the repository at this point in the history
* fix: patch requests to artifact endpoint make mr panic

Signed-off-by: Alessio Pragliola <[email protected]>

* chore: remove commented code

Signed-off-by: Alessio Pragliola <[email protected]>

---------

Signed-off-by: Alessio Pragliola <[email protected]>
  • Loading branch information
Al-Pragliola authored Jan 20, 2025
1 parent 35079eb commit 8a39ad5
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
33 changes: 33 additions & 0 deletions clients/python/tests/regression_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import requests

from model_registry import ModelRegistry
from model_registry.types.artifacts import ModelArtifact
Expand Down Expand Up @@ -99,3 +100,35 @@ async def test_create_standalone_model_artifact(client: ModelRegistry):
assert mv.id
mv_ma = await client._api.upsert_model_version_artifact(new_ma, mv.id)
assert mv_ma.id == new_ma.id

@pytest.mark.e2e
async def test_patch_model_artifacts_artifact_type(client: ModelRegistry):
"""Patching Artifacts makes the model registry server panic.
reported with https://issues.redhat.com/browse/RHOAIENG-16932
"""
name = "test_model"
version = "1.0.0"
rm = client.register_model(
name,
"s3",
model_format_name="test_format",
model_format_version="test_version",
version=version,
)
assert rm.id
mv = client.get_model_version(name, version)
assert mv
assert mv.id
ma = client.get_model_artifact(name, version)
assert ma
assert ma.id

payload = { "modelFormatName": "foo", "artifactType": "model-artifact" }
from .conftest import REGISTRY_HOST, REGISTRY_PORT
response = requests.patch(url=f"{REGISTRY_HOST}:{REGISTRY_PORT}/api/model_registry/v1alpha3/artifacts/{ma.id}", json=payload, timeout=10, headers={"Content-Type": "application/json"})
assert response.status_code == 200
ma = client.get_model_artifact(name, version)
assert ma
assert ma.id
assert ma.model_format_name == "foo"
26 changes: 18 additions & 8 deletions internal/converter/openapi_reconciler_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,28 @@ import (

func UpdateExistingArtifact(genc OpenAPIReconciler, source OpenapiUpdateWrapper[openapi.Artifact]) (openapi.Artifact, error) {
art := InitWithExisting(source)

if source.Update == nil {
return art, nil
}
ma, err := genc.UpdateExistingModelArtifact(OpenapiUpdateWrapper[openapi.ModelArtifact]{Existing: art.ModelArtifact, Update: source.Update.ModelArtifact})
if err != nil {
return art, err

if source.Update.ModelArtifact != nil {
ma, err := genc.UpdateExistingModelArtifact(OpenapiUpdateWrapper[openapi.ModelArtifact]{Existing: art.ModelArtifact, Update: source.Update.ModelArtifact})
if err != nil {
return art, err
}

art.ModelArtifact = &ma
}
da, err := genc.UpdateExistingDocArtifact(OpenapiUpdateWrapper[openapi.DocArtifact]{Existing: art.DocArtifact, Update: source.Update.DocArtifact})
if err != nil {
return art, err

if source.Update.DocArtifact != nil {
da, err := genc.UpdateExistingDocArtifact(OpenapiUpdateWrapper[openapi.DocArtifact]{Existing: art.DocArtifact, Update: source.Update.DocArtifact})
if err != nil {
return art, err
}

art.DocArtifact = &da
}
art.DocArtifact = &da
art.ModelArtifact = &ma

return art, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ func (s *ModelRegistryServiceAPIService) UpdateArtifact(ctx context.Context, art
}
if artifactUpdate.DocArtifactUpdate != nil {
entity.DocArtifact.Id = &artifactId
} else {
}
if artifactUpdate.ModelArtifactUpdate != nil {
entity.ModelArtifact.Id = &artifactId
}
existing, err := s.coreApi.GetArtifactById(artifactId)
Expand Down

0 comments on commit 8a39ad5

Please sign in to comment.