diff --git a/internal/server/grpc_server.go b/internal/server/grpc_server.go index f2617659..00b401b2 100644 --- a/internal/server/grpc_server.go +++ b/internal/server/grpc_server.go @@ -266,9 +266,38 @@ func (g grpcServer) PutParentContexts(ctx context.Context, request *proto.PutPar panic("implement me") } -func (g grpcServer) GetArtifactType(ctx context.Context, request *proto.GetArtifactTypeRequest) (*proto.GetArtifactTypeResponse, error) { - //TODO implement me - panic("implement me") +func (g grpcServer) GetArtifactType(ctx context.Context, request *proto.GetArtifactTypeRequest) (resp *proto.GetArtifactTypeResponse, err error) { + ctx, dbConn := Begin(ctx, g.dbConnection) + defer handleTransaction(ctx, &err) + + err = requiredFields(REQUIRED_TYPE_FIELDS, request.TypeName) + response := &proto.GetArtifactTypeResponse{} + + var results []db.Type + rx := dbConn.Find(&results, db.Type{Name: *request.TypeName, TypeKind: int32(ARTIFACT_TYPE), Version: request.TypeVersion}) + if rx.Error != nil { + return nil, rx.Error + } + if len(results) > 1 { + return nil, fmt.Errorf("more than one type found: %v", len(results)) + } + if len(results) == 0 { + return response, nil + } + + r0 := results[0] + artifactType := proto.ArtifactType{ + Id: &r0.ID, + Name: &r0.Name, + Version: r0.Version, + Description: r0.Description, + ExternalId: r0.ExternalID, + } + for _, v := range r0.Properties { + artifactType.Properties[v.Name] = proto.PropertyType(v.DataType) + } + response.ArtifactType = &artifactType + return response, nil } func (g grpcServer) GetArtifactTypesByID(ctx context.Context, request *proto.GetArtifactTypesByIDRequest) (*proto.GetArtifactTypesByIDResponse, error) { diff --git a/test/python/test_mlmetadata.py b/test/python/test_mlmetadata.py index aa018b3c..d01bc3ba 100644 --- a/test/python/test_mlmetadata.py +++ b/test/python/test_mlmetadata.py @@ -45,6 +45,12 @@ def main(): response = store.PutArtifactType(request) model_type_id = response.type_id + request = metadata_store_service_pb2.GetArtifactTypeRequest() + request.type_name = "SavedModel" + response = store.GetArtifactType(request) + assert response.artifact_type.id == 2 + assert response.artifact_type.name == "SavedModel" + # Query all registered Artifact types. # artifact_types = store.GetArtifactTypes()