Skip to content

Commit

Permalink
fix(agent): Skip agent out of order control messages (#5969)
Browse files Browse the repository at this point in the history
* add logic for dealing with out of order messages

* fix exsiting test

* add test for utils

* add test for load/unload order of messages

* tidy up comments

* extend server draing grace period to 5s

* fix test after merge

* use milli ticks instead

* improve logging

* revert logger change

* reduce server drrain grace period to 3s

* fix flacky test

* use static ticks in tests

* use monotonic ticks

* add comment
  • Loading branch information
sakoush authored Oct 11, 2024
1 parent 86c7c02 commit e278790
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 16 deletions.
36 changes: 29 additions & 7 deletions scheduler/pkg/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"context"
"fmt"
"math"
"sync"
"sync/atomic"
"time"

Expand Down Expand Up @@ -54,6 +55,8 @@ type Client struct {
stop atomic.Bool
modelScalingClientStream agent.AgentService_ModelScalingTriggerClient
settings *ClientSettings
modelTimestamps sync.Map
startTime time.Time
ClientServices
SchedulerGrpcClientOptions
KubernetesOptions
Expand Down Expand Up @@ -189,9 +192,11 @@ func NewClient(
KubernetesOptions: KubernetesOptions{
namespace: namespace,
},
isDraining: atomic.Bool{},
stop: atomic.Bool{},
settings: settings,
isDraining: atomic.Bool{},
stop: atomic.Bool{},
settings: settings,
modelTimestamps: sync.Map{},
startTime: time.Now(),
}
}

Expand Down Expand Up @@ -474,12 +479,15 @@ func (c *Client) StartService() error {

c.logger.Infof("Received operation")

// Get the time since the start of the agent, this is monotonic as time.Now contains a monotonic clock
ticksSinceStart := time.Since(c.startTime).Milliseconds()

switch operation.Operation {
case agent.ModelOperationMessage_LOAD_MODEL:
c.logger.Infof("calling load model")

go func() {
err := c.LoadModel(operation)
err := c.LoadModel(operation, ticksSinceStart)
if err != nil {
c.logger.WithError(err).Errorf(
"Failed to handle load model %s:%d",
Expand All @@ -493,7 +501,7 @@ func (c *Client) StartService() error {
c.logger.Infof("calling unload model")

go func() {
err := c.UnloadModel(operation)
err := c.UnloadModel(operation, ticksSinceStart)
if err != nil {
c.logger.WithError(err).Errorf(
"Failed to handle unload model %s:%d",
Expand Down Expand Up @@ -560,7 +568,7 @@ func (c *Client) getArtifactConfig(request *agent.ModelOperationMessage) ([]byte
return nil, nil
}

func (c *Client) LoadModel(request *agent.ModelOperationMessage) error {
func (c *Client) LoadModel(request *agent.ModelOperationMessage, timestamp int64) error {
if request == nil || request.ModelVersion == nil {
return fmt.Errorf("empty request received for load model")
}
Expand All @@ -576,6 +584,13 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage) error {
defer c.stateManager.cache.Unlock(modelWithVersion)

logger.Infof("Load model %s:%d", modelName, modelVersion)
// if it is out of order message, ignore it
ignore := ignoreIfOutOfOrder(modelWithVersion, timestamp, &c.modelTimestamps)
if ignore {
logger.Warnf("Ignoring out of order message for model %s:%d", modelName, modelVersion)
return nil
}
defer c.modelTimestamps.Store(modelWithVersion, timestamp)

// Get Rclone configuration
config, err := c.getArtifactConfig(request)
Expand Down Expand Up @@ -627,7 +642,7 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage) error {
return c.sendAgentEvent(modelName, modelVersion, agent.ModelEventMessage_LOADED)
}

func (c *Client) UnloadModel(request *agent.ModelOperationMessage) error {
func (c *Client) UnloadModel(request *agent.ModelOperationMessage, timestamp int64) error {
if request == nil || request.GetModelVersion() == nil {
return fmt.Errorf("Empty request received for unload model")
}
Expand All @@ -648,6 +663,13 @@ func (c *Client) UnloadModel(request *agent.ModelOperationMessage) error {
defer c.stateManager.cache.Unlock(modelWithVersion)

logger.Infof("Unload model %s:%d", modelName, modelVersion)
// if it is out of order message, ignore it
ignore := ignoreIfOutOfOrder(modelWithVersion, timestamp, &c.modelTimestamps)
if ignore {
logger.Warnf("Ignoring out of order message for model %s:%d", modelName, modelVersion)
return nil
}
defer c.modelTimestamps.Store(modelWithVersion, timestamp)

// we do not care about model versions here
modifiedModelVersionRequest := getModifiedModelVersion(modelWithVersion, pinnedModelVersion, request.GetModelVersion())
Expand Down
151 changes: 145 additions & 6 deletions scheduler/pkg/agent/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ func TestLoadModel(t *testing.T) {
time.Sleep(50 * time.Millisecond)

// Do the actual function call that is being tested
err := client.LoadModel(test.op)
err := client.LoadModel(test.op, 1)

if test.success {
g.Expect(err).To(BeNil())
Expand Down Expand Up @@ -543,8 +543,10 @@ parameters:
go func() {
_ = client.Start()
}()
// Give the client time to start (?)
time.Sleep(50 * time.Millisecond)
err := client.LoadModel(test.op)

err := client.LoadModel(test.op, 1)
if test.success {
g.Expect(err).To(BeNil())
g.Expect(mockAgentV2Server.loadedEvents).To(Equal(1))
Expand Down Expand Up @@ -595,7 +597,7 @@ func TestUnloadModel(t *testing.T) {
},
},
unloadOp: &pb.ModelOperationMessage{
Operation: pb.ModelOperationMessage_LOAD_MODEL,
Operation: pb.ModelOperationMessage_UNLOAD_MODEL,
ModelVersion: &pb.ModelVersion{
Model: &pbs.Model{
Meta: &pbs.MetaData{
Expand Down Expand Up @@ -625,7 +627,7 @@ func TestUnloadModel(t *testing.T) {
},
},
unloadOp: &pb.ModelOperationMessage{
Operation: pb.ModelOperationMessage_LOAD_MODEL,
Operation: pb.ModelOperationMessage_UNLOAD_MODEL,
ModelVersion: &pb.ModelVersion{
Model: &pbs.Model{
Meta: &pbs.MetaData{
Expand Down Expand Up @@ -680,10 +682,12 @@ func TestUnloadModel(t *testing.T) {
go func() {
_ = client.Start()
}()
// Give the client time to start (?)
time.Sleep(50 * time.Millisecond)
err := client.LoadModel(test.loadOp)

err := client.LoadModel(test.loadOp, 1)
g.Expect(err).To(BeNil())
err = client.UnloadModel(test.unloadOp)
err = client.UnloadModel(test.unloadOp, 2)
if test.success {
g.Expect(err).To(BeNil())
g.Expect(mockAgentV2Server.loadedEvents).To(Equal(1))
Expand Down Expand Up @@ -864,3 +868,138 @@ func TestAgentStopOnSubServicesFailure(t *testing.T) {
})
}
}

func TestUnloadModelOutOfOrder(t *testing.T) {
t.Logf("Started")
logger := log.New()
log.SetLevel(log.DebugLevel)
g := NewGomegaWithT(t)

type test struct {
name string
models []string
loadOp *pb.ModelOperationMessage
loadTicks int64
unloadOp *pb.ModelOperationMessage
unloadTicks int64
success bool
}
smallMemory := uint64(500)
tests := []test{
{
name: "in-order",
models: []string{"iris"},
loadOp: &pb.ModelOperationMessage{
Operation: pb.ModelOperationMessage_LOAD_MODEL,
ModelVersion: &pb.ModelVersion{
Model: &pbs.Model{
Meta: &pbs.MetaData{
Name: "iris",
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
},
},
},
loadTicks: 1,
unloadOp: &pb.ModelOperationMessage{
Operation: pb.ModelOperationMessage_UNLOAD_MODEL,
ModelVersion: &pb.ModelVersion{
Model: &pbs.Model{
Meta: &pbs.MetaData{
Name: "iris",
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
},
},
},
unloadTicks: 2,
success: true,
},
{
name: "out-of-order",
models: []string{"iris"},
loadOp: &pb.ModelOperationMessage{
Operation: pb.ModelOperationMessage_LOAD_MODEL,
ModelVersion: &pb.ModelVersion{
Model: &pbs.Model{
Meta: &pbs.MetaData{
Name: "iris",
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
},
},
},
loadTicks: 2,
unloadOp: &pb.ModelOperationMessage{
Operation: pb.ModelOperationMessage_LOAD_MODEL,
ModelVersion: &pb.ModelVersion{
Model: &pbs.Model{
Meta: &pbs.MetaData{
Name: "iris",
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
},
},
},
unloadTicks: 1,
success: false,
},
}

for tidx, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Logf("Test #%d", tidx)
v2Client := createTestV2Client(addVerionToModels(test.models, 0), 200)
httpmock.ActivateNonDefault(v2Client.(*testing_utils.V2RestClientForTest).HttpClient)
modelRepository := &FakeModelRepository{}
rpHTTP := FakeDependencyService{err: nil}
rpGRPC := FakeDependencyService{err: nil}
agentDebug := FakeDependencyService{err: nil}
lags := modelscaling.ModelScalingStatsWrapper{
Stats: modelscaling.NewModelReplicaLagsKeeper(),
Operator: interfaces.Gte,
Threshold: 10,
Reset: true,
EventType: modelscaling.ScaleUpEvent,
}
lastUsed := modelscaling.ModelScalingStatsWrapper{
Stats: modelscaling.NewModelReplicaLastUsedKeeper(),
Operator: interfaces.Gte,
Threshold: 10,
Reset: false,
EventType: modelscaling.ScaleDownEvent,
}
modelScalingService := modelscaling.NewStatsAnalyserService(
[]modelscaling.ModelScalingStatsWrapper{lags, lastUsed}, logger, 10)
drainerServicePort, _ := testing_utils2.GetFreePortForTest()
drainerService := drainservice.NewDrainerService(logger, uint(drainerServicePort))
client := NewClient(
NewClientSettings("mlserver", 1, "scheduler", 9002, 9055, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1*time.Minute, 1, 1, 1),
logger, modelRepository, v2Client, &pb.ReplicaConfig{MemoryBytes: 1000}, "default",
rpHTTP, rpGRPC, agentDebug, modelScalingService, drainerService, newFakeMetricsHandler())
mockAgentV2Server := &mockAgentV2Server{models: []string{}}
conn, cerr := grpc.NewClient("passthrough://", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(dialerv2(mockAgentV2Server)))
g.Expect(cerr).To(BeNil())
client.conn = conn
go func() {
_ = client.Start()
}()
// Give the client time to start (?)
time.Sleep(50 * time.Millisecond)

err := client.LoadModel(test.loadOp, test.loadTicks)
g.Expect(err).To(BeNil())
err = client.UnloadModel(test.unloadOp, test.unloadTicks)
g.Expect(err).To(BeNil())
if test.success {
g.Expect(mockAgentV2Server.loadedEvents).To(Equal(1))
g.Expect(mockAgentV2Server.unloadedEvents).To(Equal(1))
} else {
g.Expect(mockAgentV2Server.loadedEvents).To(Equal(1))
g.Expect(mockAgentV2Server.unloadedEvents).To(Equal(0))
}
client.Stop()
httpmock.DeactivateAndReset()
})
}
}
13 changes: 13 additions & 0 deletions scheduler/pkg/agent/client_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package agent

import (
"fmt"
"sync"
"time"

backoff "github.com/cenkalti/backoff/v4"
Expand Down Expand Up @@ -118,3 +119,15 @@ func (b *backOffWithMaxCount) NextBackOff() time.Duration {
return b.backoffPolicy.NextBackOff()
}
}

func ignoreIfOutOfOrder(key string, timestamp int64, timestamps *sync.Map) bool {
tick, ok := timestamps.Load(key)
if !ok {
timestamps.Store(key, timestamp)
} else {
if timestamp < tick.(int64) {
return true
}
}
return false
}
45 changes: 45 additions & 0 deletions scheduler/pkg/agent/client_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package agent

import (
"fmt"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -114,3 +115,47 @@ func TestFnWrapperWithMax(t *testing.T) {
})
}
}

func TestOutOfOrderUtil(t *testing.T) {
ticks := sync.Map{}
ticks.Store("key", int64(1))

type test struct {
name string
ticks *sync.Map
key string
timestamp int64
isOutOfOrder bool
}
tests := []test{
{
name: "empty",
ticks: &sync.Map{},
key: "key",
timestamp: 2, // dummy
isOutOfOrder: false,
},
{
name: "in order",
ticks: &ticks,
key: "key",
timestamp: 3,
isOutOfOrder: false,
},
{
name: "out of order",
ticks: &ticks,
key: "key",
timestamp: 0,
isOutOfOrder: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
g := NewGomegaWithT(t)
outOfOrder := ignoreIfOutOfOrder(test.key, test.timestamp, test.ticks)
g.Expect(outOfOrder).To(Equal(test.isOutOfOrder))
})
}
}
2 changes: 1 addition & 1 deletion scheduler/pkg/agent/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ const (
pendingSyncsQueueSize int = 1000
modelEventHandlerName = "agent.server.models"
modelScalingCoolingDownSeconds = 60 // this is currently used in scale down events
serverDrainingExtraWaitMillis = 2000
serverDrainingExtraWaitMillis = 3000
)

type modelRelocatedWaiter struct {
Expand Down
4 changes: 2 additions & 2 deletions scheduler/pkg/synchroniser/sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ func TestSimpleSynchroniser(t *testing.T) {
signal: true,
},
{
name: "No timer",
timeout: 0 * time.Millisecond,
name: "Small timer",
timeout: 1 * time.Millisecond,
signal: true,
},
{
Expand Down

0 comments on commit e278790

Please sign in to comment.