Skip to content

Commit

Permalink
KUBE-637: prevent handling the same CSR on every update from watcher …
Browse files Browse the repository at this point in the history
…or restart watcher (#150)

* fix: prevent handling the same CSR on every update from watcher or restart watcher
  • Loading branch information
ValyaB authored Oct 23, 2024
1 parent feb9d3d commit f7ac9e9
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 16 deletions.
45 changes: 31 additions & 14 deletions internal/actions/csr/csr.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"sort"
"strings"
"time"

"github.com/sirupsen/logrus"
certv1 "k8s.io/api/certificates/v1"
Expand All @@ -25,6 +26,7 @@ import (
const (
ReasonApproved = "AutoApproved"
approvedMessage = "This CSR was approved by CAST AI"
csrTTL = time.Hour
)

var ErrNodeCertificateNotFound = errors.New("node certificate not found")
Expand Down Expand Up @@ -313,16 +315,13 @@ func WatchCastAINodeCSRs(ctx context.Context, log logrus.FieldLogger, client kub
return
}

csrResult, name, request := toCertificate(event)
if csrResult == nil {
cert, name, request, err := toCertificate(event)
if err != nil {
log.Warnf("toCertificate: skipping csr event: %v", err)
continue
}
// We are only interested in kubelet-bootstrap csr. SKIP own CSR due to the infinite loop of deleting->creating new->deleting.
if csrResult.RequestingUser != "kubelet-bootstrap" {
log.WithFields(logrus.Fields{
"csr": name,
"node_name": csrResult.RequestingUser,
}).Infof("skipping csr not from kubelet-bootstrap: %v", csrResult.RequestingUser)

if cert == nil {
continue
}

Expand All @@ -334,7 +333,7 @@ func WatchCastAINodeCSRs(ctx context.Context, log logrus.FieldLogger, client kub
}).Infof("skipping csr unable to get common name: %v", err)
continue
}
if csrResult.Approved() {
if cert.Approved() {
continue
}

Expand All @@ -345,8 +344,8 @@ func WatchCastAINodeCSRs(ctx context.Context, log logrus.FieldLogger, client kub
}).Infof("skipping csr not CAST AI node")
continue
}
csrResult.Name = cn
sendCertificate(ctx, c, csrResult)
cert.Name = cn
sendCertificate(ctx, c, cert)
}
}
}
Expand All @@ -362,21 +361,39 @@ func getWatcher(ctx context.Context, client kubernetes.Interface) (watch.Interfa
return w, nil
}

func toCertificate(event watch.Event) (cert *Certificate, name string, request []byte) {
var (
errUnexpectedObjectType = errors.New("unexpected object type")
errCSRTooOld = errors.New("csr is too old")
errOwner = errors.New("owner is not bootstrap")
)

func toCertificate(event watch.Event) (cert *Certificate, name string, request []byte, err error) {
isOutdated := false
switch e := event.Object.(type) {
case *certv1.CertificateSigningRequest:
name = e.Name
request = e.Spec.Request
cert = &Certificate{Name: name, v1: e, RequestingUser: e.Spec.Username}
isOutdated = e.CreationTimestamp.Add(csrTTL).Before(time.Now())
case *certv1beta1.CertificateSigningRequest:
name = e.Name
request = e.Spec.Request
cert = &Certificate{Name: name, v1Beta1: e, RequestingUser: e.Spec.Username}
isOutdated = e.CreationTimestamp.Add(csrTTL).Before(time.Now())
default:
return nil, "", nil
return nil, "", nil, errUnexpectedObjectType
}

if isOutdated {
return nil, "", nil, fmt.Errorf("csr with certificate Name: %v RequestingUser: %v %w", cert.Name, cert.RequestingUser, errCSRTooOld)
}

// We are only interested in kubelet-bootstrap csr. SKIP own CSR due to the infinite loop of deleting->creating new->deleting.
if cert.RequestingUser != "kubelet-bootstrap" {
return nil, "", nil, fmt.Errorf("csr with certificate Name: %v RequestingUser: %v %w", cert.Name, cert.RequestingUser, errOwner)
}

return cert, name, request
return cert, name, request, nil
}

func isCastAINodeCsr(subjectCommonName string) bool {
Expand Down
122 changes: 122 additions & 0 deletions internal/actions/csr/csr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ package csr
import (
"context"
"path/filepath"
"reflect"
"testing"
"time"

"github.com/stretchr/testify/require"
certv1 "k8s.io/api/certificates/v1"
certv1beta1 "k8s.io/api/certificates/v1beta1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/kubernetes"
_ "k8s.io/client-go/plugin/pkg/client/auth/gcp"
"k8s.io/client-go/tools/clientcmd"
Expand Down Expand Up @@ -95,3 +100,120 @@ func Test_isCastAINodeCsr(t *testing.T) {
})
}
}

func Test_toCertificate(t *testing.T) {
testCSRv1 := &certv1.CertificateSigningRequest{
Spec: certv1.CertificateSigningRequestSpec{
Username: "kubelet-bootstrap",
},
ObjectMeta: metav1.ObjectMeta{
CreationTimestamp: metav1.Time{Time: time.Now().Add(csrTTL)},
Name: "test",
},
}
testCSRv1beta1 := &certv1beta1.CertificateSigningRequest{
Spec: certv1beta1.CertificateSigningRequestSpec{
Username: "kubelet-bootstrap",
},
ObjectMeta: metav1.ObjectMeta{
CreationTimestamp: metav1.Time{Time: time.Now().Add(csrTTL)},
Name: "test",
},
}
type args struct {
event watch.Event
}
tests := []struct {
name string
args args
wantCert *Certificate
wantName string
wantRequest []byte
wantErr bool
}{
{
name: "empty event",
args: args{
event: watch.Event{},
},
wantErr: true,
},
{
name: "outdated event",
args: args{
event: watch.Event{
Object: &certv1.CertificateSigningRequest{
ObjectMeta: metav1.ObjectMeta{
CreationTimestamp: metav1.Time{Time: time.Now().Add(-csrTTL)},
},
},
},
},
wantErr: true,
},
{
name: "bad owner",
args: args{
event: watch.Event{
Object: &certv1.CertificateSigningRequest{
Spec: certv1.CertificateSigningRequestSpec{
Username: "test",
},
ObjectMeta: metav1.ObjectMeta{
CreationTimestamp: metav1.Time{Time: time.Now().Add(csrTTL)},
},
},
},
},
wantErr: true,
},
{
name: "ok v1",
args: args{
event: watch.Event{
Object: testCSRv1,
},
},
wantErr: false,
wantName: "test",
wantCert: &Certificate{
Name: "test",
RequestingUser: "kubelet-bootstrap",
v1: testCSRv1,
},
},
{
name: "ok v1beta1",
args: args{
event: watch.Event{
Object: testCSRv1beta1,
},
},
wantErr: false,
wantName: "test",
wantCert: &Certificate{
Name: "test",
RequestingUser: "kubelet-bootstrap",
v1Beta1: testCSRv1beta1,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotCert, gotName, gotRequest, err := toCertificate(tt.args.event)
if (err != nil) != tt.wantErr {
t.Errorf("toCertificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(gotCert, tt.wantCert) {
t.Errorf("toCertificate() gotCert = %v, want %v", gotCert, tt.wantCert)
}
if gotName != tt.wantName {
t.Errorf("toCertificate() gotName = %v, want %v", gotName, tt.wantName)
}
if !reflect.DeepEqual(gotRequest, tt.wantRequest) {
t.Errorf("toCertificate() gotRequest = %v, want %v", gotRequest, tt.wantRequest)
}
})
}
}
31 changes: 30 additions & 1 deletion internal/actions/csr/svc.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ type ApprovalManager struct {
log logrus.FieldLogger
clientset kubernetes.Interface
cancelAutoApprove context.CancelFunc
m sync.Mutex // Used to make sure there is just one watcher running.

inProgress map[string]struct{} // one handler per csr/certificate Name.
m sync.Mutex // Used to make sure there is just one watcher running.
}

func (h *ApprovalManager) Start(ctx context.Context) {
Expand Down Expand Up @@ -114,7 +116,13 @@ func (h *ApprovalManager) runAutoApproveForCastAINodes(ctx context.Context) {
if cert == nil {
continue
}
// prevent starting goroutine for the same node certificate
if !h.addInProgress(cert.Name) {
continue
}
go func(cert *Certificate) {
defer h.removeInProgress(cert.Name)

log := log.WithField("node_name", cert.Name)
log.Info("auto approving csr")
err := h.handleWithRetry(ctx, log, cert)
Expand Down Expand Up @@ -157,3 +165,24 @@ func newApproveCSRExponentialBackoff() wait.Backoff {
b.Factor = 2
return b
}

func (h *ApprovalManager) addInProgress(nodeName string) bool {
h.m.Lock()
defer h.m.Unlock()
if h.inProgress == nil {
h.inProgress = make(map[string]struct{})
}
_, ok := h.inProgress[nodeName]
if ok {
return false
}
h.inProgress[nodeName] = struct{}{}
return true
}

func (h *ApprovalManager) removeInProgress(nodeName string) {
h.m.Lock()
defer h.m.Unlock()

delete(h.inProgress, nodeName)
}
6 changes: 5 additions & 1 deletion internal/actions/csr/svc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ import (

func getCSR(name, username string) *certv1.CertificateSigningRequest {
return &certv1.CertificateSigningRequest{
ObjectMeta: metav1.ObjectMeta{Name: name},
ObjectMeta: metav1.ObjectMeta{
Name: name,
CreationTimestamp: metav1.Now(),
},
Spec: certv1.CertificateSigningRequestSpec{
Request: []byte(`-----BEGIN CERTIFICATE REQUEST-----
MIIBLTCB0wIBADBPMRUwEwYDVQQKEwxzeXN0ZW06bm9kZXMxNjA0BgNVBAMTLXN5
Expand Down Expand Up @@ -69,6 +72,7 @@ func TestCSRApprove(t *testing.T) {

csrResult, err := client.CertificatesV1().CertificateSigningRequests().Get(ctx, csrName, metav1.GetOptions{})
r.NoError(err)

r.Equal(csrResult.Status.Conditions[0].Type, certv1.CertificateApproved)
})

Expand Down

0 comments on commit f7ac9e9

Please sign in to comment.