Skip to content

Commit

Permalink
feat: add spot rebalance recommendation handling for aws (#13)
Browse files Browse the repository at this point in the history
feat: add spot rebalance recommendation handling for aws
  • Loading branch information
r0kas authored Feb 21, 2023
1 parent 24fde4a commit 739943d
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 30 deletions.
15 changes: 13 additions & 2 deletions handler/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/aws/aws-node-termination-handler/pkg/ec2metadata"
)

func NewAWSInterruptChecker() InterruptChecker {
func NewAWSInterruptChecker() MetadataChecker {
return &awsInterruptChecker{
imds: ec2metadata.New("http://169.254.169.254", 3),
}
Expand All @@ -16,7 +16,18 @@ type awsInterruptChecker struct {
imds *ec2metadata.Service
}

func (c *awsInterruptChecker) Check(_ context.Context) (bool, error) {
func (c *awsInterruptChecker) CheckRebalanceRecommendation(_ context.Context) (bool, error) {
rebalanceRecommendation, err := c.imds.GetRebalanceRecommendationEvent()
if err != nil {
return false, err
}
if rebalanceRecommendation == nil {
return false, nil
}
return true, nil
}

func (c *awsInterruptChecker) CheckInterrupt(_ context.Context) (bool, error) {
instanceAction, err := c.imds.GetSpotITNEvent()
if instanceAction == nil && err == nil {
// if there are no spot itns and no errors
Expand Down
2 changes: 1 addition & 1 deletion handler/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestAwsInterruptChecker(t *testing.T) {
imds: ec2metadata.New(s.URL, 3),
}

interrupted, err := checker.Check(context.Background())
interrupted, err := checker.CheckInterrupt(context.Background())
require.NoError(t, err)
require.True(t, interrupted)
}
19 changes: 12 additions & 7 deletions handler/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

// NewAzureInterruptChecker checks for azure spot interrupt event from metadata server.
// See https://docs.microsoft.com/en-us/azure/virtual-machines/linux/scheduled-events#endpoint-discovery
func NewAzureInterruptChecker() InterruptChecker {
func NewAzureInterruptChecker() MetadataChecker {
client := resty.New()
// Times out if set to 1 second, after 2 we will try again soon anyway
client.SetTimeout(time.Second * 2)
Expand All @@ -26,7 +26,14 @@ type azureInterruptChecker struct {
metadataServerURL string
}

func (c *azureInterruptChecker) Check(ctx context.Context) (bool, error) {
type azureSpotScheduledEvent struct {
EventType string
}
type azureSpotScheduledEvents struct {
Events []azureSpotScheduledEvent
}

func (c *azureInterruptChecker) CheckInterrupt(ctx context.Context) (bool, error) {
responseBody := azureSpotScheduledEvents{}

req := c.client.NewRequest().SetContext(ctx).SetResult(&responseBody)
Expand All @@ -49,9 +56,7 @@ func (c *azureInterruptChecker) Check(ctx context.Context) (bool, error) {
return false, nil
}

type azureSpotScheduledEvent struct {
EventType string
}
type azureSpotScheduledEvents struct {
Events []azureSpotScheduledEvent
func (c *azureInterruptChecker) CheckRebalanceRecommendation(ctx context.Context) (bool, error) {
// Applicable only for AWS for now.
return false, nil
}
2 changes: 1 addition & 1 deletion handler/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestAzureInterruptChecker(t *testing.T) {
metadataServerURL: s.URL,
}

interrupted, err := checker.Check(context.Background())
interrupted, err := checker.CheckInterrupt(context.Background())
require.NoError(t, err)
require.True(t, interrupted)
}
9 changes: 7 additions & 2 deletions handler/gcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type metadataGetter interface {
}

// NewGCPChecker checks for gcp spot interrupt event from metadata server.
func NewGCPChecker() InterruptChecker {
func NewGCPChecker() MetadataChecker {
return &gcpInterruptChecker{
metadata: metadata.NewClient(nil),
}
Expand All @@ -29,7 +29,7 @@ type gcpInterruptChecker struct {
metadata metadataGetter
}

func (c *gcpInterruptChecker) Check(ctx context.Context) (bool, error) {
func (c *gcpInterruptChecker) CheckInterrupt(ctx context.Context) (bool, error) {
m, err := c.metadata.Get(maintenanceSuffix)
if err != nil {
return false, err
Expand All @@ -41,3 +41,8 @@ func (c *gcpInterruptChecker) Check(ctx context.Context) (bool, error) {

return m == maintenanceEventTerminate || p == preemptionEventTrue, nil
}

func (c *gcpInterruptChecker) CheckRebalanceRecommendation(ctx context.Context) (bool, error) {
// Applicable only for AWS for now.
return false, nil
}
2 changes: 1 addition & 1 deletion handler/gcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func TestGCPInterruptChecker(t *testing.T) {
metadata: mockMetadata{},
}

interrupted, err := checker.Check(context.Background())
interrupted, err := checker.CheckInterrupt(context.Background())
require.NoError(t, err)
require.True(t, interrupted)
}
Expand Down
57 changes: 47 additions & 10 deletions handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@ import (

const CastNodeIDLabel = "provisioner.cast.ai/node-id"

type InterruptChecker interface {
Check(ctx context.Context) (bool, error)
const (
cloudEventInterrupted = "interrupted"
cloudEventRebalanceRecommendation = "rebalanceRecommendation"
)

type MetadataChecker interface {
CheckInterrupt(ctx context.Context) (bool, error)
CheckRebalanceRecommendation(ctx context.Context) (bool, error)
}

type SpotHandler struct {
castClient castai.Client
clientset kubernetes.Interface
interruptChecker InterruptChecker
metadataChecker MetadataChecker
nodeName string
pollWaitInterval time.Duration
log logrus.FieldLogger
Expand All @@ -38,14 +44,14 @@ func NewSpotHandler(
log logrus.FieldLogger,
castClient castai.Client,
clientset kubernetes.Interface,
interruptChecker InterruptChecker,
metadataChecker MetadataChecker,
pollWaitInterval time.Duration,
nodeName string,
) *SpotHandler {
return &SpotHandler{
castClient: castClient,
clientset: clientset,
interruptChecker: interruptChecker,
metadataChecker: metadataChecker,
log: log,
nodeName: nodeName,
pollWaitInterval: pollWaitInterval,
Expand All @@ -60,15 +66,17 @@ func (g *SpotHandler) Run(ctx context.Context) error {
var once sync.Once
deadline := time.NewTimer(24 * 365 * time.Hour)

// Once rebalance recommendation is set by cloud it stays there permanently. It needs to be sent only once.
var rebalanceRecommendationSent bool

for {
select {
case <-t.C:
// Check interruption.
err := func() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

interrupted, err := g.interruptChecker.Check(ctx)
interrupted, err := g.metadataChecker.CheckInterrupt(ctx)
if err != nil {
return err
}
Expand All @@ -80,11 +88,26 @@ func (g *SpotHandler) Run(ctx context.Context) error {
// Stop after ACK.
t.Stop()
}

if !rebalanceRecommendationSent {
rebalanceRecommendation, err := g.metadataChecker.CheckRebalanceRecommendation(ctx)
if err != nil {
return err
}
if rebalanceRecommendation {
g.log.Infof("rebalance recommendation notice received")
if err := g.handleRebalanceRecommendation(ctx); err != nil {
return err
}
rebalanceRecommendationSent = true
}
}

return nil
}()

if err != nil {
g.log.Errorf("checking for interruption: %v", err)
g.log.Errorf("checking for cloud events: %v", err)
}
case <-deadline.C:
return nil
Expand All @@ -104,7 +127,7 @@ func (g *SpotHandler) handleInterruption(ctx context.Context) error {
}

req := &castai.CloudEventRequest{
EventType: "interrupted",
EventType: cloudEventInterrupted,
NodeID: node.Labels[CastNodeIDLabel],
}
if err = g.castClient.SendCloudEvent(ctx, req); err != nil {
Expand Down Expand Up @@ -163,3 +186,17 @@ func (g *SpotHandler) patchNode(ctx context.Context, node *v1.Node, changeFn fun
func defaultBackoff(ctx context.Context) backoff.BackOffContext {
return backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(1*time.Second), 5), ctx)
}

func (g *SpotHandler) handleRebalanceRecommendation(ctx context.Context) error {
node, err := g.clientset.CoreV1().Nodes().Get(ctx, g.nodeName, metav1.GetOptions{})
if err != nil {
return err
}

req := &castai.CloudEventRequest{
EventType: cloudEventRebalanceRecommendation,
NodeID: node.Labels[CastNodeIDLabel],
}

return g.castClient.SendCloudEvent(ctx, req)
}
48 changes: 43 additions & 5 deletions handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func TestRunLoop(t *testing.T) {
mockInterrupt := &mockInterruptChecker{interrupted: true}
handler := SpotHandler{
pollWaitInterval: 100 * time.Millisecond,
interruptChecker: mockInterrupt,
metadataChecker: mockInterrupt,
castClient: mockCastClient,
nodeName: nodeName,
clientset: fakeApi,
Expand Down Expand Up @@ -93,7 +93,7 @@ func TestRunLoop(t *testing.T) {
mockInterrupt := &mockInterruptChecker{interrupted: true}
handler := SpotHandler{
pollWaitInterval: 1 * time.Second,
interruptChecker: mockInterrupt,
metadataChecker: mockInterrupt,
castClient: mockCastClient,
nodeName: nodeName,
clientset: fakeApi,
Expand Down Expand Up @@ -144,7 +144,7 @@ func TestRunLoop(t *testing.T) {
mockInterrupt := &mockInterruptChecker{interrupted: true}
handler := SpotHandler{
pollWaitInterval: time.Millisecond * 100,
interruptChecker: mockInterrupt,
metadataChecker: mockInterrupt,
castClient: mockCastClient,
nodeName: nodeName,
clientset: fakeApi,
Expand All @@ -160,12 +160,50 @@ func TestRunLoop(t *testing.T) {
r.Equal(3, mothershipCalls)
}()
})

t.Run("handle successful mock rebalance recommendation", func(t *testing.T) {
mothershipCalls := 0
castS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, re *http.Request) {
mothershipCalls++
var req castai.CloudEventRequest
r.NoError(json.NewDecoder(re.Body).Decode(&req))
r.Equal(req.NodeID, castNodeID)
w.WriteHeader(http.StatusOK)
}))
defer castS.Close()

fakeApi := fake.NewSimpleClientset(node)
castHttp := castai.NewDefaultClient(castS.URL, "test", log.Level, 100*time.Millisecond, "0.0.0")
mockCastClient := castai.NewClient(log, castHttp, "test1")

mockRecommendation := &mockInterruptChecker{rebalanceRecommendation: true}
handler := SpotHandler{
pollWaitInterval: 100 * time.Millisecond,
metadataChecker: mockRecommendation,
castClient: mockCastClient,
nodeName: nodeName,
clientset: fakeApi,
log: log,
}

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

err := handler.Run(ctx)
require.NoError(t, err)
r.Equal(1, mothershipCalls)
})
}

type mockInterruptChecker struct {
interrupted bool
interrupted bool
rebalanceRecommendation bool
}

func (m *mockInterruptChecker) Check(ctx context.Context) (bool, error) {
func (m *mockInterruptChecker) CheckInterrupt(ctx context.Context) (bool, error) {
return m.interrupted, nil
}

func (m *mockInterruptChecker) CheckRebalanceRecommendation(ctx context.Context) (bool, error) {
return m.rebalanceRecommendation, nil
}
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func main() {
}
}

func buildInterruptChecker(provider string) (handler.InterruptChecker, error) {
func buildInterruptChecker(provider string) (handler.MetadataChecker, error) {
switch provider {
case "azure":
return handler.NewAzureInterruptChecker(), nil
Expand Down

0 comments on commit 739943d

Please sign in to comment.