Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pipeline and model name validation #5872

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions scheduler/pkg/store/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"

"github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/store/utils"
)

type MemoryStore struct {
Expand Down Expand Up @@ -105,6 +106,13 @@ func (m *MemoryStore) UpdateModel(req *pb.LoadModelRequest) error {
m.mu.Lock()
defer m.mu.Unlock()
modelName := req.GetModel().GetMeta().GetName()
validName := utils.CheckName(modelName)
if !validName {
return fmt.Errorf(
"Model %s does not have a valid name - it must be alphanumeric and not contains dots (.)",
modelName,
)
}
model, ok := m.store.models[modelName]
if !ok {
model = &Model{}
Expand Down Expand Up @@ -150,7 +158,7 @@ func (m *MemoryStore) getModelImpl(key string) *ModelSnapshot {
if ok {
return &ModelSnapshot{
Name: key,
Versions: model.versions, //TODO make a copy for safety?
Versions: model.versions, // TODO make a copy for safety?
Deleted: model.IsDeleted(),
}
} else {
Expand Down Expand Up @@ -540,7 +548,8 @@ func (m *MemoryStore) updateModelStateImpl(
}

func (m *MemoryStore) updateReservedMemory(
modelReplicaState ModelReplicaState, serverKey string, replicaIdx int, memBytes uint64) {
modelReplicaState ModelReplicaState, serverKey string, replicaIdx int, memBytes uint64,
) {
// update reserved memory that is being used for sorting replicas
// do we need to lock replica update?
server, ok := m.store.servers[serverKey]
Expand Down Expand Up @@ -610,7 +619,6 @@ func (m *MemoryStore) addServerReplicaImpl(request *agent.AgentSubscribeRequest)
}

func (m *MemoryStore) RemoveServerReplica(serverName string, replicaIdx int) ([]string, error) {

models, evts, err := m.removeServerReplicaImpl(serverName, replicaIdx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -639,7 +647,7 @@ func (m *MemoryStore) removeServerReplicaImpl(serverName string, replicaIdx int)
return nil, nil, fmt.Errorf("Failed to find replica %d for server %s", replicaIdx, serverName)
}
delete(server.replicas, replicaIdx)
//TODO we should not reschedule models on servers with dedicated models, e.g. non shareable servers
// TODO we should not reschedule models on servers with dedicated models, e.g. non shareable servers
if len(server.replicas) == 0 {
delete(m.store.servers, serverName)
}
Expand Down
52 changes: 39 additions & 13 deletions scheduler/pkg/store/memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ the Change License after the Change Date as each is defined in accordance with t
package store

import (
"errors"
"testing"
"time"

Expand All @@ -30,6 +31,7 @@ func TestUpdateModel(t *testing.T) {
store *LocalSchedulerStore
loadModelReq *pb.LoadModelRequest
expectedVersion uint32
err error
}

tests := []test{
Expand Down Expand Up @@ -61,7 +63,8 @@ func TestUpdateModel(t *testing.T) {
},
},
},
}},
},
},
loadModelReq: &pb.LoadModelRequest{
Model: &pb.Model{
Meta: &pb.MetaData{
Expand All @@ -87,7 +90,8 @@ func TestUpdateModel(t *testing.T) {
},
},
},
}},
},
},
loadModelReq: &pb.LoadModelRequest{
Model: &pb.Model{
Meta: &pb.MetaData{
Expand Down Expand Up @@ -122,7 +126,8 @@ func TestUpdateModel(t *testing.T) {
},
},
},
}},
},
},
loadModelReq: &pb.LoadModelRequest{
Model: &pb.Model{
Meta: &pb.MetaData{
Expand Down Expand Up @@ -160,7 +165,8 @@ func TestUpdateModel(t *testing.T) {
},
},
},
}},
},
},
loadModelReq: &pb.LoadModelRequest{
Model: &pb.Model{
Meta: &pb.MetaData{
Expand All @@ -176,6 +182,19 @@ func TestUpdateModel(t *testing.T) {
},
expectedVersion: 2,
},
{
name: "ModelNameIsNotValid",
store: NewLocalSchedulerStore(),
loadModelReq: &pb.LoadModelRequest{
Model: &pb.Model{
Meta: &pb.MetaData{
Name: "this.Name",
},
},
},
expectedVersion: 1,
err: errors.New("Model this.Name does not have a valid name - it must be alphanumeric and not contains dots (.)"),
},
}

for _, test := range tests {
Expand All @@ -185,11 +204,15 @@ func TestUpdateModel(t *testing.T) {
g.Expect(err).To(BeNil())
ms := NewMemoryStore(logger, test.store, eventHub)
err = ms.UpdateModel(test.loadModelReq)
g.Expect(err).To(BeNil())
m := test.store.models[test.loadModelReq.GetModel().GetMeta().GetName()]
latest := m.Latest()
g.Expect(latest.modelDefn).To(Equal(test.loadModelReq.Model))
g.Expect(latest.GetVersion()).To(Equal(test.expectedVersion))
if test.err != nil {
g.Expect(err.Error()).To(BeIdenticalTo(test.err.Error()))
} else {
g.Expect(err).To(BeNil())
m := test.store.models[test.loadModelReq.GetModel().GetMeta().GetName()]
latest := m.Latest()
g.Expect(latest.modelDefn).To(Equal(test.loadModelReq.Model))
g.Expect(latest.GetVersion()).To(Equal(test.expectedVersion))
}
})
}
}
Expand Down Expand Up @@ -228,7 +251,8 @@ func TestGetModel(t *testing.T) {
},
},
},
}},
},
},
key: "model",
versions: 1,
err: nil,
Expand Down Expand Up @@ -285,7 +309,8 @@ func TestRemoveModel(t *testing.T) {
},
},
},
}},
},
},
key: "model",
},
}
Expand Down Expand Up @@ -935,7 +960,7 @@ func TestUpdateModelState(t *testing.T) {
g.Expect(test.store.servers[test.serverKey].replicas[test.replicaIdx].loadedModels[ModelVersionID{Name: test.modelKey, Version: test.version}]).To(Equal(test.modelVersionLoaded))
g.Expect(test.store.servers[test.serverKey].replicas[test.replicaIdx].GetNumLoadedModels()).To(Equal(test.numModelVersionsLoaded))
} else {
//g.Expect(test.store.models[test.modelKey]).To(BeNil())
// g.Expect(test.store.models[test.modelKey]).To(BeNil())
g.Expect(test.store.models[test.modelKey].Latest().state.State).To(Equal(ModelTerminated))
}

Expand Down Expand Up @@ -1223,7 +1248,8 @@ func TestAddModelVersionIfNotExists(t *testing.T) {
{
name: "Add new version when none exist",
store: &LocalSchedulerStore{
models: map[string]*Model{}},
models: map[string]*Model{},
},
modelVersion: &agent.ModelVersion{
Version: 1,
Model: &pb.Model{
Expand Down
8 changes: 8 additions & 0 deletions scheduler/pkg/store/pipeline/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,11 @@ type PipelineInputErr struct {
func (pie *PipelineInputErr) Error() string {
return fmt.Sprintf("pipeline %s input %s is invalid. %s", pie.pipeline, pie.input, pie.reason)
}

type PipelineNameValidationErr struct {
pipeline string
}

func (pnve *PipelineNameValidationErr) Error() string {
return fmt.Sprintf("pipeline %s does not have a valid name - it must be alphanmumeric and cannot contain dots (.)", pnve.pipeline)
}
12 changes: 12 additions & 0 deletions scheduler/pkg/store/pipeline/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ package pipeline

import (
"strings"

"github.com/seldonio/seldon-core/scheduler/v2/pkg/store/utils"
)

// Step inputs can be reference a previous step name and tensor output/input
Expand All @@ -27,6 +29,9 @@ const (
)

func validate(pv *PipelineVersion) error {
if err := checkName(pv); err != nil {
return err
}
if err := checkStepsExist(pv); err != nil {
return err
}
Expand Down Expand Up @@ -57,6 +62,13 @@ func validate(pv *PipelineVersion) error {
return nil
}

func checkName(pv *PipelineVersion) error {
if ok := utils.CheckName(pv.Name); !ok {
return &PipelineNameValidationErr{pipeline: pv.Name}
}
return nil
}

func checkStepsExist(pv *PipelineVersion) error {
if len(pv.Steps) == 0 {
return &PipelineStepsEmptyErr{pipeline: pv.Name}
Expand Down
55 changes: 55 additions & 0 deletions scheduler/pkg/store/pipeline/validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -845,3 +845,58 @@ func TestCheckStepOutputs(t *testing.T) {
})
}
}

func TestCheckName(t *testing.T) {
g := NewGomegaWithT(t)
tests := []validateTest{
{
name: "a valid name",
pipelineVersion: &PipelineVersion{
Name: "1-name-that-isva1id0",
Steps: map[string]*PipelineStep{
"a": {
Name: "a",
},
"b": {
Name: "b",
Inputs: []string{"a.outputs.t1", "a.inputs", "a.outputs"},
},
"c": {
Name: "c",
Inputs: []string{"a.outputs.t1"},
},
},
},
},
{
name: "a invalid name with dots",
pipelineVersion: &PipelineVersion{
Name: "a-name-that-is-not-valid.10.1",
},
err: &PipelineNameValidationErr{pipeline: "a-name-that-is-not-valid.10.1"},
},
{
name: "a invalid name with a special character",
pipelineVersion: &PipelineVersion{
Name: "aNameThatIs%notValid",
},
err: &PipelineNameValidationErr{pipeline: "aNameThatIs%notValid"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
err := checkName(test.pipelineVersion)
if test.err == nil {
g.Expect(err).To(BeNil())
} else {
g.Expect(err.Error()).To(Equal(test.err.Error()))
}
err = validate(test.pipelineVersion)
if test.err == nil {
g.Expect(err).To(BeNil())
} else {
g.Expect(err.Error()).To(Equal(test.err.Error()))
}
})
}
}
14 changes: 14 additions & 0 deletions scheduler/pkg/store/utils/validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package utils

import (
"regexp"
"strings"
)

func CheckName(name string) bool {
ok, err := regexp.MatchString("^[a-z0-9]([-a-z0-9]*[a-z0-9])?$", name)
if !ok || err != nil || strings.Contains(name, ".") {
return false
}
return true
}
62 changes: 62 additions & 0 deletions scheduler/pkg/store/utils/validate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package utils

import (
"testing"

. "github.com/onsi/gomega"
)

func TestCheckName(t *testing.T) {
g := NewGomegaWithT(t)
tests := []struct {
name string
value string
expectedResult bool
}{
{
"a valid name",
"this-name-is-valid1234",
true,
},
{
"a valid numerical name",
"111111111111111111",
true,
},
{
"a valid name that begins and ends with something alphanumeric",
"a-a-a",
true,
},
{
"an invalid name that doesn't begin and end with something alphanumeric",
"--",
false,
},
{
"an invalid name that doesn't end with something alphanumeric",
"1--",
false,
},
{
"an invalid name that doesn't begin with something alphanumeric",
"--a",
false,
},
{
"an invalid name with an uppercase letter",
"this-name-is-not-Valid1234",
false,
},
{
"an invalid name with a dot",
"not.valid",
false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
g.Expect(CheckName(test.value)).To(BeIdenticalTo(test.expectedResult))
})
}
}
Loading