diff --git a/internal/actions/csr/csr.go b/internal/actions/csr/csr.go index 15043074..9986304e 100644 --- a/internal/actions/csr/csr.go +++ b/internal/actions/csr/csr.go @@ -8,6 +8,7 @@ import ( "fmt" "sort" "strings" + "time" "github.com/sirupsen/logrus" certv1 "k8s.io/api/certificates/v1" @@ -25,6 +26,7 @@ import ( const ( ReasonApproved = "AutoApproved" approvedMessage = "This CSR was approved by CAST AI" + csrTTL = time.Hour ) var ErrNodeCertificateNotFound = errors.New("node certificate not found") @@ -313,16 +315,13 @@ func WatchCastAINodeCSRs(ctx context.Context, log logrus.FieldLogger, client kub return } - csrResult, name, request := toCertificate(event) - if csrResult == nil { + cert, name, request, err := toCertificate(event) + if err != nil { + log.Warnf("toCertificate: skipping csr event: %v", err) continue } - // We are only interested in kubelet-bootstrap csr. SKIP own CSR due to the infinite loop of deleting->creating new->deleting. - if csrResult.RequestingUser != "kubelet-bootstrap" { - log.WithFields(logrus.Fields{ - "csr": name, - "node_name": csrResult.RequestingUser, - }).Infof("skipping csr not from kubelet-bootstrap: %v", csrResult.RequestingUser) + + if cert == nil { continue } @@ -334,7 +333,7 @@ func WatchCastAINodeCSRs(ctx context.Context, log logrus.FieldLogger, client kub }).Infof("skipping csr unable to get common name: %v", err) continue } - if csrResult.Approved() { + if cert.Approved() { continue } @@ -345,8 +344,8 @@ func WatchCastAINodeCSRs(ctx context.Context, log logrus.FieldLogger, client kub }).Infof("skipping csr not CAST AI node") continue } - csrResult.Name = cn - sendCertificate(ctx, c, csrResult) + cert.Name = cn + sendCertificate(ctx, c, cert) } } } @@ -362,21 +361,39 @@ func getWatcher(ctx context.Context, client kubernetes.Interface) (watch.Interfa return w, nil } -func toCertificate(event watch.Event) (cert *Certificate, name string, request []byte) { +var ( + errUnexpectedObjectType = errors.New("unexpected object type") + errCSRTooOld = errors.New("csr is too old") + errOwner = errors.New("owner is not bootstrap") +) + +func toCertificate(event watch.Event) (cert *Certificate, name string, request []byte, err error) { + isOutdated := false switch e := event.Object.(type) { case *certv1.CertificateSigningRequest: name = e.Name request = e.Spec.Request cert = &Certificate{Name: name, v1: e, RequestingUser: e.Spec.Username} + isOutdated = e.CreationTimestamp.Add(csrTTL).Before(time.Now()) case *certv1beta1.CertificateSigningRequest: name = e.Name request = e.Spec.Request cert = &Certificate{Name: name, v1Beta1: e, RequestingUser: e.Spec.Username} + isOutdated = e.CreationTimestamp.Add(csrTTL).Before(time.Now()) default: - return nil, "", nil + return nil, "", nil, errUnexpectedObjectType + } + + if isOutdated { + return nil, "", nil, fmt.Errorf("csr with certificate Name: %v RequestingUser: %v %w", cert.Name, cert.RequestingUser, errCSRTooOld) + } + + // We are only interested in kubelet-bootstrap csr. SKIP own CSR due to the infinite loop of deleting->creating new->deleting. + if cert.RequestingUser != "kubelet-bootstrap" { + return nil, "", nil, fmt.Errorf("csr with certificate Name: %v RequestingUser: %v %w", cert.Name, cert.RequestingUser, errOwner) } - return cert, name, request + return cert, name, request, nil } func isCastAINodeCsr(subjectCommonName string) bool { diff --git a/internal/actions/csr/csr_test.go b/internal/actions/csr/csr_test.go index 51313c2e..ed1c156e 100644 --- a/internal/actions/csr/csr_test.go +++ b/internal/actions/csr/csr_test.go @@ -3,10 +3,15 @@ package csr import ( "context" "path/filepath" + "reflect" "testing" "time" "github.com/stretchr/testify/require" + certv1 "k8s.io/api/certificates/v1" + certv1beta1 "k8s.io/api/certificates/v1beta1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/kubernetes" _ "k8s.io/client-go/plugin/pkg/client/auth/gcp" "k8s.io/client-go/tools/clientcmd" @@ -95,3 +100,120 @@ func Test_isCastAINodeCsr(t *testing.T) { }) } } + +func Test_toCertificate(t *testing.T) { + testCSRv1 := &certv1.CertificateSigningRequest{ + Spec: certv1.CertificateSigningRequestSpec{ + Username: "kubelet-bootstrap", + }, + ObjectMeta: metav1.ObjectMeta{ + CreationTimestamp: metav1.Time{Time: time.Now().Add(csrTTL)}, + Name: "test", + }, + } + testCSRv1beta1 := &certv1beta1.CertificateSigningRequest{ + Spec: certv1beta1.CertificateSigningRequestSpec{ + Username: "kubelet-bootstrap", + }, + ObjectMeta: metav1.ObjectMeta{ + CreationTimestamp: metav1.Time{Time: time.Now().Add(csrTTL)}, + Name: "test", + }, + } + type args struct { + event watch.Event + } + tests := []struct { + name string + args args + wantCert *Certificate + wantName string + wantRequest []byte + wantErr bool + }{ + { + name: "empty event", + args: args{ + event: watch.Event{}, + }, + wantErr: true, + }, + { + name: "outdated event", + args: args{ + event: watch.Event{ + Object: &certv1.CertificateSigningRequest{ + ObjectMeta: metav1.ObjectMeta{ + CreationTimestamp: metav1.Time{Time: time.Now().Add(-csrTTL)}, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "bad owner", + args: args{ + event: watch.Event{ + Object: &certv1.CertificateSigningRequest{ + Spec: certv1.CertificateSigningRequestSpec{ + Username: "test", + }, + ObjectMeta: metav1.ObjectMeta{ + CreationTimestamp: metav1.Time{Time: time.Now().Add(csrTTL)}, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ok v1", + args: args{ + event: watch.Event{ + Object: testCSRv1, + }, + }, + wantErr: false, + wantName: "test", + wantCert: &Certificate{ + Name: "test", + RequestingUser: "kubelet-bootstrap", + v1: testCSRv1, + }, + }, + { + name: "ok v1beta1", + args: args{ + event: watch.Event{ + Object: testCSRv1beta1, + }, + }, + wantErr: false, + wantName: "test", + wantCert: &Certificate{ + Name: "test", + RequestingUser: "kubelet-bootstrap", + v1Beta1: testCSRv1beta1, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCert, gotName, gotRequest, err := toCertificate(tt.args.event) + if (err != nil) != tt.wantErr { + t.Errorf("toCertificate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotCert, tt.wantCert) { + t.Errorf("toCertificate() gotCert = %v, want %v", gotCert, tt.wantCert) + } + if gotName != tt.wantName { + t.Errorf("toCertificate() gotName = %v, want %v", gotName, tt.wantName) + } + if !reflect.DeepEqual(gotRequest, tt.wantRequest) { + t.Errorf("toCertificate() gotRequest = %v, want %v", gotRequest, tt.wantRequest) + } + }) + } +} diff --git a/internal/actions/csr/svc.go b/internal/actions/csr/svc.go index a527ff90..1adf806b 100644 --- a/internal/actions/csr/svc.go +++ b/internal/actions/csr/svc.go @@ -30,7 +30,9 @@ type ApprovalManager struct { log logrus.FieldLogger clientset kubernetes.Interface cancelAutoApprove context.CancelFunc - m sync.Mutex // Used to make sure there is just one watcher running. + + inProgress map[string]struct{} // one handler per csr/certificate Name. + m sync.Mutex // Used to make sure there is just one watcher running. } func (h *ApprovalManager) Start(ctx context.Context) { @@ -114,7 +116,13 @@ func (h *ApprovalManager) runAutoApproveForCastAINodes(ctx context.Context) { if cert == nil { continue } + // prevent starting goroutine for the same node certificate + if !h.addInProgress(cert.Name) { + continue + } go func(cert *Certificate) { + defer h.removeInProgress(cert.Name) + log := log.WithField("node_name", cert.Name) log.Info("auto approving csr") err := h.handleWithRetry(ctx, log, cert) @@ -157,3 +165,24 @@ func newApproveCSRExponentialBackoff() wait.Backoff { b.Factor = 2 return b } + +func (h *ApprovalManager) addInProgress(nodeName string) bool { + h.m.Lock() + defer h.m.Unlock() + if h.inProgress == nil { + h.inProgress = make(map[string]struct{}) + } + _, ok := h.inProgress[nodeName] + if ok { + return false + } + h.inProgress[nodeName] = struct{}{} + return true +} + +func (h *ApprovalManager) removeInProgress(nodeName string) { + h.m.Lock() + defer h.m.Unlock() + + delete(h.inProgress, nodeName) +} diff --git a/internal/actions/csr/svc_test.go b/internal/actions/csr/svc_test.go index ad0b7dd3..73678658 100644 --- a/internal/actions/csr/svc_test.go +++ b/internal/actions/csr/svc_test.go @@ -17,7 +17,10 @@ import ( func getCSR(name, username string) *certv1.CertificateSigningRequest { return &certv1.CertificateSigningRequest{ - ObjectMeta: metav1.ObjectMeta{Name: name}, + ObjectMeta: metav1.ObjectMeta{ + Name: name, + CreationTimestamp: metav1.Now(), + }, Spec: certv1.CertificateSigningRequestSpec{ Request: []byte(`-----BEGIN CERTIFICATE REQUEST----- MIIBLTCB0wIBADBPMRUwEwYDVQQKEwxzeXN0ZW06bm9kZXMxNjA0BgNVBAMTLXN5 @@ -69,6 +72,7 @@ func TestCSRApprove(t *testing.T) { csrResult, err := client.CertificatesV1().CertificateSigningRequests().Get(ctx, csrName, metav1.GetOptions{}) r.NoError(err) + r.Equal(csrResult.Status.Conditions[0].Type, certv1.CertificateApproved) })