diff --git a/clients/python/tests/regression_test.py b/clients/python/tests/regression_test.py index 0310f3879..4645833ea 100644 --- a/clients/python/tests/regression_test.py +++ b/clients/python/tests/regression_test.py @@ -1,4 +1,5 @@ import pytest +import requests from model_registry import ModelRegistry from model_registry.types.artifacts import ModelArtifact @@ -99,3 +100,35 @@ async def test_create_standalone_model_artifact(client: ModelRegistry): assert mv.id mv_ma = await client._api.upsert_model_version_artifact(new_ma, mv.id) assert mv_ma.id == new_ma.id + +@pytest.mark.e2e +async def test_patch_model_artifacts_artifact_type(client: ModelRegistry): + """Patching Artifacts makes the model registry server panic. + + reported with https://issues.redhat.com/browse/RHOAIENG-16932 + """ + name = "test_model" + version = "1.0.0" + rm = client.register_model( + name, + "s3", + model_format_name="test_format", + model_format_version="test_version", + version=version, + ) + assert rm.id + mv = client.get_model_version(name, version) + assert mv + assert mv.id + ma = client.get_model_artifact(name, version) + assert ma + assert ma.id + + payload = { "modelFormatName": "foo", "artifactType": "model-artifact" } + from .conftest import REGISTRY_HOST, REGISTRY_PORT + response = requests.patch(url=f"{REGISTRY_HOST}:{REGISTRY_PORT}/api/model_registry/v1alpha3/artifacts/{ma.id}", json=payload, timeout=10, headers={"Content-Type": "application/json"}) + assert response.status_code == 200 + ma = client.get_model_artifact(name, version) + assert ma + assert ma.id + assert ma.model_format_name == "foo" diff --git a/internal/converter/openapi_reconciler_util.go b/internal/converter/openapi_reconciler_util.go index e2b8544c8..071912a7f 100644 --- a/internal/converter/openapi_reconciler_util.go +++ b/internal/converter/openapi_reconciler_util.go @@ -6,18 +6,28 @@ import ( func UpdateExistingArtifact(genc OpenAPIReconciler, source OpenapiUpdateWrapper[openapi.Artifact]) (openapi.Artifact, error) { art := InitWithExisting(source) + if source.Update == nil { return art, nil } - ma, err := genc.UpdateExistingModelArtifact(OpenapiUpdateWrapper[openapi.ModelArtifact]{Existing: art.ModelArtifact, Update: source.Update.ModelArtifact}) - if err != nil { - return art, err + + if source.Update.ModelArtifact != nil { + ma, err := genc.UpdateExistingModelArtifact(OpenapiUpdateWrapper[openapi.ModelArtifact]{Existing: art.ModelArtifact, Update: source.Update.ModelArtifact}) + if err != nil { + return art, err + } + + art.ModelArtifact = &ma } - da, err := genc.UpdateExistingDocArtifact(OpenapiUpdateWrapper[openapi.DocArtifact]{Existing: art.DocArtifact, Update: source.Update.DocArtifact}) - if err != nil { - return art, err + + if source.Update.DocArtifact != nil { + da, err := genc.UpdateExistingDocArtifact(OpenapiUpdateWrapper[openapi.DocArtifact]{Existing: art.DocArtifact, Update: source.Update.DocArtifact}) + if err != nil { + return art, err + } + + art.DocArtifact = &da } - art.DocArtifact = &da - art.ModelArtifact = &ma + return art, nil } diff --git a/internal/server/openapi/api_model_registry_service_service.go b/internal/server/openapi/api_model_registry_service_service.go index d22b38c25..3d5211b4a 100644 --- a/internal/server/openapi/api_model_registry_service_service.go +++ b/internal/server/openapi/api_model_registry_service_service.go @@ -492,7 +492,8 @@ func (s *ModelRegistryServiceAPIService) UpdateArtifact(ctx context.Context, art } if artifactUpdate.DocArtifactUpdate != nil { entity.DocArtifact.Id = &artifactId - } else { + } + if artifactUpdate.ModelArtifactUpdate != nil { entity.ModelArtifact.Id = &artifactId } existing, err := s.coreApi.GetArtifactById(artifactId)