Skip to content

Commit

Permalink
Add first draft of Data Access Layer, with test
Browse files Browse the repository at this point in the history
Co-authored-by: Andrea Lamparelli <[email protected]>
  • Loading branch information
tarilabs and lampajr committed Sep 26, 2023
1 parent 0ac30c6 commit 2181a5a
Show file tree
Hide file tree
Showing 2 changed files with 299 additions and 0 deletions.
150 changes: 150 additions & 0 deletions internal/service/db_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package service

import (
"context"
"fmt"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gorm.io/gorm"

"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
"github.com/opendatahub-io/model-registry/internal/model/db"
"github.com/opendatahub-io/model-registry/internal/server"
)

type ArtifactTypeHandler interface {
CreateArtifactType(name string, version *string) (result *db.Type, err error)
CreateArtifactTypeAll(name string,
version *string,
description *string,
inputType *string,
outputType *string,
externalID *string,
properties map[string]proto.PropertyType) (result *db.Type, err error)
GetArtifactTypes(id *int64, name *string, version *string) ([]db.Type, error)
GetArtifactType(id *int64, name *string, version *string) (*db.Type, error)

DeleteArtifactTypesByName(name string) ([]db.Type, error)
DeleteArtifactTypeByNameVersion(name string, version string) (*db.Type, error)
UpdateArtifactType(artifactType db.Type) (*db.Type, error)
}

type ArtifactHandler interface {
CreateArtifact(name string, version string, artifactType string) (result *db.Artifact, err error)
CreateArtifactAll(name string, // TODO: make this optional and auto uuid the name?
version *string, // TODO: has no mapping on RDBMS?
artifactType string,
uri *string,
state *int64, // TODO: relation with grpc?
externalID *string,
properties map[string]*proto.Value,
customProperties map[string]*proto.Value,
) (result *db.Artifact, err error)
GetArtifacts(id *int64, name *string, artifactType *string, version *string) ([]db.Artifact, error)
GetArtifact(id *int64, name *string, artifactType *string, version *string) (*db.Artifact, error)

// ...Delete, Update, etc.
}

// ...entityHandler(s) to be repeated for other Entities

type TypeKind int32

// TODO: move from grpc to just here
const (
EXECUTION_TYPE TypeKind = iota
ARTIFACT_TYPE
CONTEXT_TYPE
)

type Handle struct {
db *gorm.DB
}

// create if not existing (or update (by name) if already existing)
func (h *Handle) CreateArtifactType(name string, version *string) (result *db.Type, err error) {
return h.CreateArtifactTypeAll(name, version, nil, nil, nil, nil, nil)
}

func (h *Handle) CreateArtifactTypeAll(name string,
version *string,
description *string,
inputType *string,
outputType *string,
externalID *string,
properties map[string]proto.PropertyType) (result *db.Type, err error) {
ctx, _ := server.Begin(context.Background(), h.db)
defer handleTransaction(ctx, &err)

value := &db.Type{
Name: name,
Version: version,
TypeKind: int32(ARTIFACT_TYPE),
}
if err := h.db.Where("name = ?", value.Name).Assign(value).FirstOrCreate(value).Error; err != nil {
err = fmt.Errorf("error creating type %s: %v", value.Name, err)
return nil, err
}
// TODO handle remaining attributes/properties
return value, nil
}

func (h *Handle) GetArtifactTypes(id *int64, name *string, version *string) ([]db.Type, error) {
by := db.Type{TypeKind: int32(ARTIFACT_TYPE), Version: version}
if id != nil {
by.ID = *id
}
if name != nil {
by.Name = *name
}
var results []db.Type
rx := h.db.Find(&results, by)
if rx.Error != nil {
return nil, rx.Error
}
return results, nil
}

func (h *Handle) GetArtifactType(id *int64, name *string, version *string) (*db.Type, error) {
return nil, nil
}

func (h *Handle) DeleteArtifactType() (*int64, error) {
return nil, nil
}

func (h *Handle) UpdateArtifactType() (*db.Type, error) {
return nil, nil
}

func (h *Handle) CreateArtifactAll(name string,
version *string,
artifactType string,
uri *string,
state *int64,
externalID *string,
properties map[string]*proto.Value,
customProperties map[string]*proto.Value,
) (result *db.Artifact, err error) {
// CreateTimeSinceEpoch int64
// LastUpdateTimeSinceEpoch int64
return nil, nil
}

func handleTransaction(ctx context.Context, err *error) {
// handle panic
if perr := recover(); perr != nil {
_ = server.Rollback(ctx)
*err = status.Errorf(codes.Internal, "server panic: %v", perr)
return
}
if err == nil || *err == nil {
*err = server.Commit(ctx)
} else {
_ = server.Rollback(ctx)
if _, ok := status.FromError(*err); !ok {
*err = status.Errorf(codes.Internal, "internal error: %v", *err)
}
}
}
149 changes: 149 additions & 0 deletions internal/service/db_service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package service

import (
"fmt"
"log"
"os"
"testing"

"github.com/opendatahub-io/model-registry/internal/model/db"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)

func migrateDatabase(dbConn *gorm.DB) error {
// using only needed RDBMS type for the scope under test
err := dbConn.AutoMigrate(
db.Type{},
db.TypeProperty{},
// TODO: add as needed.
)
if err != nil {
return fmt.Errorf("db migration failed: %w", err)
}
return nil
}

func setup(tmpFile *os.File) (*gorm.DB, error) {
db, err := gorm.Open(sqlite.Open(tmpFile.Name()), &gorm.Config{})
if err != nil {
return nil, err
}
err = migrateDatabase(db)
if err != nil {
return nil, err
}
return db, nil
}

// Bare minimal test of PutArtifactType with a given Name, and Get.
func TestPutArtifactTypeThenGet(t *testing.T) {
f, err := os.CreateTemp("", "model-registry-db")
if err != nil {
log.Fatal(err)
}
defer os.Remove(f.Name())
db, err := setup(f)
if err != nil {
t.Errorf("Should expect DB connection: %v", err)
}
dal := Handle{
db: db,
}
artifactName := "John Doe"
at, err := dal.CreateArtifactType(artifactName, nil)
if err != nil {
t.Errorf("Should create ArtifactType: %v", err)
}
if at.ID < 0 {
t.Errorf("Should have ID for ArtifactType: %v", at.ID)
}
if at.Name != artifactName {
t.Errorf("Should have Name for ArtifactType per constant: %v", at.Name)
}

ats, err2 := dal.GetArtifactTypes(nil, &artifactName, nil)
if err2 != nil {
t.Errorf("Should get ArtifactType: %v", err2)
}
if len(ats) != 1 { // TODO if temp file is okay, this is superfluos
t.Errorf("The test is running under different assumption")
}
at0 := ats[0]
t.Logf("at0: %v", at0)
if at0.ID != at.ID {
t.Errorf("Should have same ID")
}
if at0.Name != at.Name {
t.Errorf("Should have same Name")
}

}

func TestGetArtifactTypesByCommonCriteria(t *testing.T) {
f, err := os.CreateTemp("", "model-registry-db")
if err != nil {
log.Fatal(err)
}
defer os.Remove(f.Name())
db, err := setup(f)
if err != nil {
t.Errorf("Should expect DB connection: %v", err)
}
dal := Handle{
db: db,
}

fixVersion := "version"
if _, err := dal.CreateArtifactType("at0", &fixVersion); err != nil {
t.Errorf("Should create ArtifactType: %v", err)
}
if _, err := dal.CreateArtifactType("at1", &fixVersion); err != nil {
t.Errorf("Should create ArtifactType: %v", err)
}

// TODO here only demonstrating criteria using "version", but likely more meaningful to use property as criteria
results, err := dal.GetArtifactTypes(nil, nil, &fixVersion)
t.Logf("results: %v", results)
if err != nil {
t.Errorf("Should get ArtifactTypes: %v", err)
}
if len(results) != 2 {
t.Errorf("Should have retrieved 2 artifactTypes")
}
}

func TestPutArtifactTypeSameNameDiffVersion(t *testing.T) {
f, err := os.CreateTemp("", "model-registry-db")
if err != nil {
log.Fatal(err)
}
defer os.Remove(f.Name())
db, err := setup(f)
if err != nil {
t.Errorf("Should expect DB connection: %v", err)
}
dal := Handle{
db: db,
}

artifactName := "John Doe"
v0 := "v0"
v1 := "v1"
at0, err := dal.CreateArtifactType(artifactName, &v0)
if err != nil {
t.Errorf("Should create ArtifactType: %v", err)
}
at1, err := dal.CreateArtifactType(artifactName, &v1)
if err != nil {
t.Errorf("Should create ArtifactType: %v", err)
}
if at0.ID > at1.ID {
t.Errorf("ID invariant does not hold")
}

// TODO implement validation logic or RDBMS constraint/key
// if _, err := dal.CreateArtifactType(artifactName, &v1); err == nil {
// t.Errorf("Created multiple artifact with the same version")
// }
}

0 comments on commit 2181a5a

Please sign in to comment.