Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work in progress: Ideomatic interface helm #112

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 51 additions & 31 deletions actions/actions.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:generate mockgen -destination ./mock_actions/helm_client_mock.go github.com/castai/cluster-controller/actions HelmClient
//go:generate mockgen -destination ./mock_actions/client_mock.go github.com/castai/cluster-controller/actions Client
package actions

import (
Expand All @@ -12,18 +14,22 @@ import (

"github.com/cenkalti/backoff/v4"
"github.com/sirupsen/logrus"
"helm.sh/helm/v3/pkg/release"
"k8s.io/client-go/dynamic"
"k8s.io/client-go/kubernetes"

"github.com/castai/cluster-controller/castai"
"github.com/castai/cluster-controller/actions/types"
"github.com/castai/cluster-controller/health"
"github.com/castai/cluster-controller/helm"
)

const (
// actionIDLogField is the log field name for action ID.
// This field is used in backend to detect actions ID in logs.
actionIDLogField = "id"
actionIDLogField = "id"
labelNodeID = "provisioner.cast.ai/node-id"
actionCheckNodeStatus_READY types.ActionCheckNodeStatus_Status = "NodeStatus_READY"
actionCheckNodeStatus_DELETED types.ActionCheckNodeStatus_Status = "NodeStatus_DELETED"
)

func newUnexpectedTypeErr(value interface{}, expectedType interface{}) error {
Expand All @@ -46,7 +52,21 @@ type Service interface {
}

type ActionHandler interface {
Handle(ctx context.Context, action *castai.ClusterAction) error
Handle(ctx context.Context, action *types.ClusterAction) error
}

type Client interface {
GetActions(ctx context.Context, k8sVersion string) ([]*types.ClusterAction, error)
AckAction(ctx context.Context, actionID string, req *types.AckClusterActionRequest) error
SendAKSInitData(ctx context.Context, req *types.AKSInitDataRequest) error
}

type HelmClient interface {
Install(ctx context.Context, opts helm.InstallOptions) (*release.Release, error)
Uninstall(opts helm.UninstallOptions) (*release.UninstallReleaseResponse, error)
Upgrade(ctx context.Context, opts helm.UpgradeOptions) (*release.Release, error)
Rollback(opts helm.RollbackOptions) error
GetRelease(opts helm.GetReleaseOptions) (*release.Release, error)
}

func NewService(
Expand All @@ -55,41 +75,41 @@ func NewService(
k8sVersion string,
clientset *kubernetes.Clientset,
dynamicClient dynamic.Interface,
castaiClient castai.Client,
helmClient helm.Client,
client Client,
helmClient HelmClient,
healthCheck *health.HealthzProvider,
) Service {
return &service{
log: log,
cfg: cfg,
k8sVersion: k8sVersion,
castAIClient: castaiClient,
client: client,
startedActions: map[string]struct{}{},
actionHandlers: map[reflect.Type]ActionHandler{
reflect.TypeOf(&castai.ActionDeleteNode{}): newDeleteNodeHandler(log, clientset),
reflect.TypeOf(&castai.ActionDrainNode{}): newDrainNodeHandler(log, clientset, cfg.Namespace),
reflect.TypeOf(&castai.ActionPatchNode{}): newPatchNodeHandler(log, clientset),
reflect.TypeOf(&castai.ActionCreateEvent{}): newCreateEventHandler(log, clientset),
reflect.TypeOf(&castai.ActionApproveCSR{}): newApproveCSRHandler(log, clientset),
reflect.TypeOf(&castai.ActionChartUpsert{}): newChartUpsertHandler(log, helmClient),
reflect.TypeOf(&castai.ActionChartUninstall{}): newChartUninstallHandler(log, helmClient),
reflect.TypeOf(&castai.ActionChartRollback{}): newChartRollbackHandler(log, helmClient, cfg.Version),
reflect.TypeOf(&castai.ActionDisconnectCluster{}): newDisconnectClusterHandler(log, clientset),
reflect.TypeOf(&castai.ActionSendAKSInitData{}): newSendAKSInitDataHandler(log, castaiClient),
reflect.TypeOf(&castai.ActionCheckNodeDeleted{}): newCheckNodeDeletedHandler(log, clientset),
reflect.TypeOf(&castai.ActionCheckNodeStatus{}): newCheckNodeStatusHandler(log, clientset),
reflect.TypeOf(&castai.ActionPatch{}): newPatchHandler(log, dynamicClient),
reflect.TypeOf(&castai.ActionCreate{}): newCreateHandler(log, dynamicClient),
reflect.TypeOf(&castai.ActionDelete{}): newDeleteHandler(log, dynamicClient),
reflect.TypeOf(&types.ActionDeleteNode{}): newDeleteNodeHandler(log, clientset),
reflect.TypeOf(&types.ActionDrainNode{}): newDrainNodeHandler(log, clientset, cfg.Namespace),
reflect.TypeOf(&types.ActionPatchNode{}): newPatchNodeHandler(log, clientset),
reflect.TypeOf(&types.ActionCreateEvent{}): newCreateEventHandler(log, clientset),
reflect.TypeOf(&types.ActionApproveCSR{}): newApproveCSRHandler(log, clientset),
reflect.TypeOf(&types.ActionChartUpsert{}): newChartUpsertHandler(log, helmClient),
reflect.TypeOf(&types.ActionChartUninstall{}): newChartUninstallHandler(log, helmClient),
reflect.TypeOf(&types.ActionChartRollback{}): newChartRollbackHandler(log, helmClient, cfg.Version),
reflect.TypeOf(&types.ActionDisconnectCluster{}): newDisconnectClusterHandler(log, clientset),
reflect.TypeOf(&types.ActionSendAKSInitData{}): newSendAKSInitDataHandler(log, client),
reflect.TypeOf(&types.ActionCheckNodeDeleted{}): newCheckNodeDeletedHandler(log, clientset),
reflect.TypeOf(&types.ActionCheckNodeStatus{}): newCheckNodeStatusHandler(log, clientset),
reflect.TypeOf(&types.ActionPatch{}): newPatchHandler(log, dynamicClient),
reflect.TypeOf(&types.ActionCreate{}): newCreateHandler(log, dynamicClient),
reflect.TypeOf(&types.ActionDelete{}): newDeleteHandler(log, dynamicClient),
},
healthCheck: healthCheck,
}
}

type service struct {
log logrus.FieldLogger
cfg Config
castAIClient castai.Client
log logrus.FieldLogger
cfg Config
client Client

k8sVersion string

Expand Down Expand Up @@ -127,15 +147,15 @@ func (s *service) doWork(ctx context.Context) error {
s.log.Info("polling actions")
start := time.Now()
var (
actions []*castai.ClusterAction
actions []*types.ClusterAction
err error
iteration int
)

b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(5*time.Second), 3), ctx)
errR := backoff.Retry(func() error {
iteration++
actions, err = s.castAIClient.GetActions(ctx, s.k8sVersion)
actions, err = s.client.GetActions(ctx, s.k8sVersion)
if err != nil {
s.log.Errorf("polling actions: get action request failed: iteration: %v %v", iteration, err)
return err
Expand All @@ -158,13 +178,13 @@ func (s *service) doWork(ctx context.Context) error {
return nil
}

func (s *service) handleActions(ctx context.Context, actions []*castai.ClusterAction) {
func (s *service) handleActions(ctx context.Context, actions []*types.ClusterAction) {
for _, action := range actions {
if !s.startProcessing(action.ID) {
continue
}

go func(action *castai.ClusterAction) {
go func(action *types.ClusterAction) {
defer s.finishProcessing(action.ID)

var err error
Expand Down Expand Up @@ -211,7 +231,7 @@ func (s *service) startProcessing(actionID string) bool {
return true
}

func (s *service) handleAction(ctx context.Context, action *castai.ClusterAction) (err error) {
func (s *service) handleAction(ctx context.Context, action *types.ClusterAction) (err error) {
actionType := reflect.TypeOf(action.Data())

defer func() {
Expand All @@ -235,7 +255,7 @@ func (s *service) handleAction(ctx context.Context, action *castai.ClusterAction
return nil
}

func (s *service) ackAction(ctx context.Context, action *castai.ClusterAction, handleErr error) error {
func (s *service) ackAction(ctx context.Context, action *types.ClusterAction, handleErr error) error {
actionType := reflect.TypeOf(action.Data())
s.log.WithFields(logrus.Fields{
actionIDLogField: action.ID,
Expand All @@ -245,7 +265,7 @@ func (s *service) ackAction(ctx context.Context, action *castai.ClusterAction, h
return backoff.RetryNotify(func() error {
ctx, cancel := context.WithTimeout(ctx, s.cfg.AckTimeout)
defer cancel()
return s.castAIClient.AckAction(ctx, action.ID, &castai.AckClusterActionRequest{
return s.client.AckAction(ctx, action.ID, &types.AckClusterActionRequest{
Error: getHandlerError(handleErr),
})
}, backoff.WithContext(
Expand Down
41 changes: 20 additions & 21 deletions actions/actions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import (
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"github.com/castai/cluster-controller/castai"
"github.com/castai/cluster-controller/castai/mock"
"github.com/castai/cluster-controller/actions/types"
"github.com/castai/cluster-controller/health"
)

Expand All @@ -33,7 +32,7 @@ func TestActions(t *testing.T) {
ClusterID: uuid.New().String(),
}

newTestService := func(handler ActionHandler, client castai.Client) *service {
newTestService := func(handler ActionHandler, client Client) *service {
svc := NewService(
log,
cfg,
Expand All @@ -55,30 +54,30 @@ func TestActions(t *testing.T) {
t.Run("poll handle and ack", func(t *testing.T) {
r := require.New(t)

apiActions := []*castai.ClusterAction{
apiActions := []*types.ClusterAction{
{
ID: "a1",
CreatedAt: time.Now(),
ActionDeleteNode: &castai.ActionDeleteNode{
ActionDeleteNode: &types.ActionDeleteNode{
NodeName: "n1",
},
},
{
ID: "a2",
CreatedAt: time.Now(),
ActionDrainNode: &castai.ActionDrainNode{
ActionDrainNode: &types.ActionDrainNode{
NodeName: "n1",
},
},
{
ID: "a3",
CreatedAt: time.Now(),
ActionPatchNode: &castai.ActionPatchNode{
ActionPatchNode: &types.ActionPatchNode{
NodeName: "n1",
},
},
}
client := mock.NewMockAPIClient(apiActions)
client := newMockAPIClient(apiActions)
handler := &mockAgentActionHandler{handleDelay: 2 * time.Millisecond}
svc := newTestService(handler, client)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
Expand All @@ -102,7 +101,7 @@ func TestActions(t *testing.T) {
t.Run("continue polling on api error", func(t *testing.T) {
r := require.New(t)

client := mock.NewMockAPIClient([]*castai.ClusterAction{})
client := newMockAPIClient([]*types.ClusterAction{})
client.GetActionsErr = errors.New("ups")
handler := &mockAgentActionHandler{err: errors.New("ups")}
svc := newTestService(handler, client)
Expand All @@ -119,16 +118,16 @@ func TestActions(t *testing.T) {
t.Run("do not ack action on context canceled error", func(t *testing.T) {
r := require.New(t)

apiActions := []*castai.ClusterAction{
apiActions := []*types.ClusterAction{
{
ID: "a1",
CreatedAt: time.Now(),
ActionPatchNode: &castai.ActionPatchNode{
ActionPatchNode: &types.ActionPatchNode{
NodeName: "n1",
},
},
}
client := mock.NewMockAPIClient(apiActions)
client := newMockAPIClient(apiActions)
handler := &mockAgentActionHandler{err: context.Canceled}
svc := newTestService(handler, client)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
Expand All @@ -145,16 +144,16 @@ func TestActions(t *testing.T) {
t.Run("ack with error when action handler failed", func(t *testing.T) {
r := require.New(t)

apiActions := []*castai.ClusterAction{
apiActions := []*types.ClusterAction{
{
ID: "a1",
CreatedAt: time.Now(),
ActionPatchNode: &castai.ActionPatchNode{
ActionPatchNode: &types.ActionPatchNode{
NodeName: "n1",
},
},
}
client := mock.NewMockAPIClient(apiActions)
client := newMockAPIClient(apiActions)
handler := &mockAgentActionHandler{err: errors.New("ups")}
svc := newTestService(handler, client)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
Expand All @@ -165,24 +164,24 @@ func TestActions(t *testing.T) {
r.Empty(client.Actions)
r.Len(client.Acks, 1)
r.Equal("a1", client.Acks[0].ActionID)
r.Equal("handling action *castai.ActionPatchNode: ups", *client.Acks[0].Err)
r.Equal("handling action *ActionPatchNode: ups", *client.Acks[0].Err)
}()
svc.Run(ctx)
})

t.Run("ack with error when action handler panic occurred", func(t *testing.T) {
r := require.New(t)

apiActions := []*castai.ClusterAction{
apiActions := []*types.ClusterAction{
{
ID: "a1",
CreatedAt: time.Now(),
ActionPatchNode: &castai.ActionPatchNode{
ActionPatchNode: &types.ActionPatchNode{
NodeName: "n1",
},
},
}
client := mock.NewMockAPIClient(apiActions)
client := newMockAPIClient(apiActions)
handler := &mockAgentActionHandler{panicErr: errors.New("ups")}
svc := newTestService(handler, client)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
Expand All @@ -193,7 +192,7 @@ func TestActions(t *testing.T) {
r.Empty(client.Actions)
r.Len(client.Acks, 1)
r.Equal("a1", client.Acks[0].ActionID)
r.Contains(*client.Acks[0].Err, "panic: handling action *castai.ActionPatchNode: ups: goroutine")
r.Contains(*client.Acks[0].Err, "panic: handling action *ActionPatchNode: ups: goroutine")
}()
svc.Run(ctx)
})
Expand All @@ -205,7 +204,7 @@ type mockAgentActionHandler struct {
handleDelay time.Duration
}

func (m *mockAgentActionHandler) Handle(ctx context.Context, action *castai.ClusterAction) error {
func (m *mockAgentActionHandler) Handle(ctx context.Context, action *types.ClusterAction) error {
time.Sleep(m.handleDelay)
if m.panicErr != nil {
panic(m.panicErr)
Expand Down
8 changes: 4 additions & 4 deletions actions/approve_csr_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/sirupsen/logrus"
"k8s.io/client-go/kubernetes"

"github.com/castai/cluster-controller/castai"
"github.com/castai/cluster-controller/actions/types"
"github.com/castai/cluster-controller/csr"
)

Expand All @@ -31,15 +31,15 @@ type approveCSRHandler struct {
csrFetchInterval time.Duration
}

func (h *approveCSRHandler) Handle(ctx context.Context, action *castai.ClusterAction) error {
req, ok := action.Data().(*castai.ActionApproveCSR)
func (h *approveCSRHandler) Handle(ctx context.Context, action *types.ClusterAction) error {
req, ok := action.Data().(*types.ActionApproveCSR)
if !ok {
return fmt.Errorf("unexpected type %T for approve csr handler", action.Data())
}
log := h.log.WithFields(logrus.Fields{
"node_name": req.NodeName,
"node_id": req.NodeID,
"type": reflect.TypeOf(action.Data().(*castai.ActionApproveCSR)).String(),
"type": reflect.TypeOf(action.Data().(*types.ActionApproveCSR)).String(),
actionIDLogField: action.ID,
})

Expand Down
Loading
Loading