Skip to content

Commit

Permalink
feat(scheduler): account for multiple instances of a model per server…
Browse files Browse the repository at this point in the history
… when scheduling (#6054)

* just checking in whatever I have

* testing all the code

* remove comment

* linting

* document unused param

* changing the proto around

* use parallelWorkers instead of instanceCount for mlserver

* comma

* rename ModelConfig

* use modelWithVersion as param
  • Loading branch information
driev authored Dec 10, 2024
1 parent a7bfb00 commit c1d320e
Show file tree
Hide file tree
Showing 14 changed files with 837 additions and 234 deletions.
607 changes: 469 additions & 138 deletions apis/go/mlops/agent/agent.pb.go

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions apis/mlops/agent/agent.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ message ModelEventMessage {
Event event = 5;
string message = 6;
uint64 availableMemoryBytes = 7;
ModelRuntimeInfo runtimeInfo = 8;
}

message ModelEventResponse {
Expand Down Expand Up @@ -92,8 +93,29 @@ message ModelOperationMessage {
message ModelVersion {
scheduler.Model model = 1;
uint32 version = 2;
ModelRuntimeInfo runtimeInfo = 3;
}

message ModelRuntimeInfo {
oneof modelRuntimeInfo {
MLServerModelSettings mlserver = 1;
TritonModelConfig triton = 2;
}
}

message MLServerModelSettings {
uint32 parallelWorkers = 1;
}

message TritonModelConfig {
repeated TritonCPU cpu = 1;
}

message TritonCPU {
uint32 instanceCount = 1;
}


// [END Messages]

// [START Services]
Expand Down
5 changes: 3 additions & 2 deletions scheduler/pkg/agent/agent_debug_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func setupService(numModels int, modelPrefix string, capacity int) *agentDebug {
}

func TestAgentDebugServiceSmoke(t *testing.T) {
//TODO break this down in proper tests
// TODO break this down in proper tests
g := NewGomegaWithT(t)

service := setupService(10, "dummy", 10)
Expand All @@ -60,6 +60,7 @@ func TestAgentDebugServiceSmoke(t *testing.T) {
MemoryBytes: &mem,
},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
)
g.Expect(err).To(BeNil())
Expand Down Expand Up @@ -87,7 +88,7 @@ func TestAgentDebugServiceSmoke(t *testing.T) {
}

func TestAgentDebugEarlyStop(t *testing.T) {
//TODO break this down in proper tests
// TODO break this down in proper tests
g := NewGomegaWithT(t)

service := setupService(10, "dummy", 10)
Expand Down
17 changes: 14 additions & 3 deletions scheduler/pkg/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,15 +615,23 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage, timestamp int64
}
logger.Infof("Chose path %s for model %s:%d", *chosenVersionPath, modelName, modelVersion)

modelConfig, err := c.ModelRepository.GetModelRuntimeInfo(modelWithVersion)
if err != nil {
logger.Errorf("there was a problem getting the config for model: %s", modelName)
}

// TODO: consider whether we need the actual protos being sent to `LoadModelVersion`?
modifiedModelVersionRequest := getModifiedModelVersion(
modelWithVersion,
pinnedModelVersion,
request.GetModelVersion(),
modelConfig,
)

loaderFn := func() error {
return c.stateManager.LoadModelVersion(modifiedModelVersionRequest)
}

if err := backoffWithMaxNumRetry(loaderFn, c.settings.maxLoadRetryCount, c.settings.maxLoadElapsedTime, logger); err != nil {
c.sendModelEventError(modelName, modelVersion, agent.ModelEventMessage_LOAD_FAILED, err)
c.cleanup(modelWithVersion)
Expand All @@ -641,7 +649,8 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage, timestamp int64
}

logger.Infof("Load model %s:%d success", modelName, modelVersion)
return c.sendAgentEvent(modelName, modelVersion, agent.ModelEventMessage_LOADED)

return c.sendAgentEvent(modelName, modelVersion, modelConfig, agent.ModelEventMessage_LOADED)
}

func (c *Client) UnloadModel(request *agent.ModelOperationMessage, timestamp int64) error {
Expand Down Expand Up @@ -674,7 +683,7 @@ func (c *Client) UnloadModel(request *agent.ModelOperationMessage, timestamp int
defer c.modelTimestamps.Store(modelWithVersion, timestamp)

// we do not care about model versions here
modifiedModelVersionRequest := getModifiedModelVersion(modelWithVersion, pinnedModelVersion, request.GetModelVersion())
modifiedModelVersionRequest := getModifiedModelVersion(modelWithVersion, pinnedModelVersion, request.GetModelVersion(), nil)

unloaderFn := func() error {
return c.stateManager.UnloadModelVersion(modifiedModelVersionRequest)
Expand Down Expand Up @@ -702,7 +711,7 @@ func (c *Client) UnloadModel(request *agent.ModelOperationMessage, timestamp int
}

logger.Infof("Unload model %s:%d success", modelName, modelVersion)
return c.sendAgentEvent(modelName, modelVersion, agent.ModelEventMessage_UNLOADED)
return c.sendAgentEvent(modelName, modelVersion, nil, agent.ModelEventMessage_UNLOADED)
}

func (c *Client) cleanup(modelWithVersion string) {
Expand Down Expand Up @@ -742,6 +751,7 @@ func (c *Client) sendModelEventError(
func (c *Client) sendAgentEvent(
modelName string,
modelVersion uint32,
modelRuntimeInfo *agent.ModelRuntimeInfo,
event agent.ModelEventMessage_Event,
) error {
// if the server is draining and the model load has succeeded, we need to "cancel"
Expand All @@ -765,6 +775,7 @@ func (c *Client) sendAgentEvent(
ModelVersion: modelVersion,
Event: event,
AvailableMemoryBytes: c.stateManager.GetAvailableMemoryBytesWithOverCommit(),
RuntimeInfo: modelRuntimeInfo,
})
return err
}
Expand Down
19 changes: 19 additions & 0 deletions scheduler/pkg/agent/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"

"github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
pbs "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"

Expand All @@ -51,6 +52,7 @@ type mockAgentV2Server struct {
unloadFailedEvents int
otherEvents int
errors int
events []*pb.ModelEventMessage
}

type FakeModelRepository struct {
Expand All @@ -64,6 +66,10 @@ func (f *FakeModelRepository) RemoveModelVersion(modelName string) error {
return nil
}

func (f *FakeModelRepository) GetModelRuntimeInfo(modelName string) (*pb.ModelRuntimeInfo, error) {
return &pb.ModelRuntimeInfo{ModelRuntimeInfo: &pb.ModelRuntimeInfo_Mlserver{Mlserver: &agent.MLServerModelSettings{ParallelWorkers: uint32(1)}}}, nil
}

func (f *FakeModelRepository) DownloadModelVersion(modelName string, version uint32, modelSpec *pbs.ModelSpec, config []byte) (*string, error) {
f.modelDownloads++
if f.err != nil {
Expand Down Expand Up @@ -147,6 +153,7 @@ func (m *mockAgentV2Server) AgentEvent(ctx context.Context, message *pb.ModelEve
default:
m.otherEvents++
}
m.events = append(m.events, message)
return &pb.ModelEventResponse{}, nil
}

Expand Down Expand Up @@ -247,6 +254,7 @@ func TestLoadModel(t *testing.T) {
models []string
replicaConfig *pb.ReplicaConfig
op *pb.ModelOperationMessage
modelConfig *pb.ModelRuntimeInfo
expectedAvailableMemory uint64
v2Status int
modelRepoErr error
Expand All @@ -270,9 +278,11 @@ func TestLoadModel(t *testing.T) {
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelRuntimeInfo(1),
expectedAvailableMemory: 500,
v2Status: 200,
success: true,
Expand All @@ -289,10 +299,12 @@ func TestLoadModel(t *testing.T) {
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
AutoscalingEnabled: true,
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelRuntimeInfo(1),
expectedAvailableMemory: 500,
v2Status: 200,
success: true,
Expand All @@ -310,9 +322,11 @@ func TestLoadModel(t *testing.T) {
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelRuntimeInfo(1),
expectedAvailableMemory: 1000,
v2Status: 400,
success: false,
Expand All @@ -329,9 +343,11 @@ func TestLoadModel(t *testing.T) {
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &largeMemory},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelRuntimeInfo(1),
expectedAvailableMemory: 1000,
v2Status: 200,
success: false,
Expand Down Expand Up @@ -399,6 +415,9 @@ func TestLoadModel(t *testing.T) {
g.Expect(err).To(BeNil())
g.Expect(mockAgentV2Server.loadedEvents).To(Equal(1))
g.Expect(mockAgentV2Server.loadFailedEvents).To(Equal(0))
g.Expect(len(mockAgentV2Server.events)).To(Equal(1))
g.Expect(mockAgentV2Server.events[0].RuntimeInfo).ToNot(BeNil())
g.Expect(mockAgentV2Server.events[0].RuntimeInfo.GetMlserver().ParallelWorkers).To(Equal(uint32(1)))
g.Expect(client.stateManager.GetAvailableMemoryBytes()).To(Equal(test.expectedAvailableMemory))
g.Expect(modelRepository.modelRemovals).To(Equal(0))
loadedVersions := client.stateManager.modelVersions.getVersionsForAllModels()
Expand Down
3 changes: 2 additions & 1 deletion scheduler/pkg/agent/client_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ func isReady(service interfaces.DependencyServiceInterface, logger *log.Entry, m
return backoff.RetryNotify(readyToError, backoffWithMax, logFailure)
}

func getModifiedModelVersion(modelId string, version uint32, originalModelVersion *agent.ModelVersion) *agent.ModelVersion {
func getModifiedModelVersion(modelId string, version uint32, originalModelVersion *agent.ModelVersion, modelRuntimeInfo *agent.ModelRuntimeInfo) *agent.ModelVersion {
mv := proto.Clone(originalModelVersion)
mv.(*agent.ModelVersion).Model.Meta.Name = modelId
mv.(*agent.ModelVersion).Version = version
mv.(*agent.ModelVersion).RuntimeInfo = modelRuntimeInfo
return mv.(*agent.ModelVersion)
}

Expand Down
19 changes: 15 additions & 4 deletions scheduler/pkg/agent/model_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ func (modelState *ModelState) addModelVersionImpl(modelVersionDetails *agent.Mod
modelName, versionId, exsistingVersion.getVersion())
}
}

}

// Remove model version and return true if no versions left (in which case we remove from map)
Expand All @@ -70,7 +69,6 @@ func (modelState *ModelState) removeModelVersion(modelVersionDetails *agent.Mode
}

func (modelState *ModelState) removeModelVersionImpl(modelVersionDetails *agent.ModelVersion) (bool, error) {

modelName := modelVersionDetails.GetModel().GetMeta().GetName()
versionId := modelVersionDetails.GetVersion()

Expand Down Expand Up @@ -143,7 +141,8 @@ func (modelState *ModelState) getVersionsForAllModels() []*agent.ModelVersion {
mv := version.get()
versionedModelName := mv.Model.GetMeta().Name
originalModelName, originalModelVersion, _ := util.GetOrignalModelNameAndVersion(versionedModelName)
loadedModels = append(loadedModels, getModifiedModelVersion(originalModelName, originalModelVersion, mv))
modelRuntimeInfo := mv.RuntimeInfo
loadedModels = append(loadedModels, getModifiedModelVersion(originalModelName, originalModelVersion, mv, modelRuntimeInfo))
}
return loadedModels
}
Expand All @@ -153,7 +152,19 @@ type modelVersion struct {
}

func (version *modelVersion) getVersionMemory() uint64 {
return version.versionInfo.GetModel().GetModelSpec().GetMemoryBytes()
instanceCount := getInstanceCount(version)
return version.versionInfo.GetModel().GetModelSpec().GetMemoryBytes() * instanceCount
}

func getInstanceCount(version *modelVersion) uint64 {
switch version.versionInfo.RuntimeInfo.ModelRuntimeInfo.(type) {
case *agent.ModelRuntimeInfo_Mlserver:
return uint64(version.versionInfo.GetRuntimeInfo().GetMlserver().ParallelWorkers)
case *agent.ModelRuntimeInfo_Triton:
return uint64(version.versionInfo.GetRuntimeInfo().GetTriton().Cpu[0].InstanceCount)
default:
return 1
}
}

func (version *modelVersion) getVersion() uint32 {
Expand Down
Loading

0 comments on commit c1d320e

Please sign in to comment.