From a4bc5aa973d475742810f4caa7e87757fc4de1f3 Mon Sep 17 00:00:00 2001 From: JenTing Hsiao Date: Sun, 24 Nov 2024 13:54:38 +0800 Subject: [PATCH] add unit tests --- main.go | 296 +++++++++++++++++++++++---------------------------- main_test.go | 124 +++++++++++++++++++++ 2 files changed, 256 insertions(+), 164 deletions(-) diff --git a/main.go b/main.go index 3537e54..5cc04f5 100644 --- a/main.go +++ b/main.go @@ -20,210 +20,119 @@ import ( func main() { config := ctrl.GetConfigOrDie() - namespace := "argocd" - configMapName := "argocd-rbac-cm" clientset, err := kubernetes.NewForConfig(config) if err != nil { log.Fatalf("Failed to create Kubernetes client: %v", err) } + userToObjectPatternMapping, groupToObjectPatternMapping := loadRBACPolicyFromConfigMap(clientset, "argocd", "argocd-rbac-cm") + + redisClient := initializeRedis("localhost:16379", "", 1) + + proxy := createReverseProxy("http://localhost:8443") + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + handleRequest(w, r, proxy, redisClient, userToObjectPatternMapping, groupToObjectPatternMapping) + }) + + log.Println("Proxy server running on :8081") + log.Fatal(http.ListenAndServe(":8081", nil)) +} + +func loadRBACPolicyFromConfigMap(clientset *kubernetes.Clientset, namespace, configMapName string) (map[string][]string, map[string][]string) { cm, err := clientset.CoreV1().ConfigMaps(namespace).Get(context.Background(), configMapName, metav1.GetOptions{}) if err != nil { - log.Fatalf("Failed to fetch ConfigMap %s: %v", configMapName, err) + fmt.Printf("Failed to fetch ConfigMap %s: %v", configMapName, err) + return nil, nil } - // Extract policy data from ConfigMap - var userToObjectPatternMapping map[string][]string = make(map[string][]string) - var groupToObjectPatternMapping map[string][]string = make(map[string][]string) - policyCSV, ok := cm.Data["policy.csv"] if !ok { - fmt.Printf("policy.csv not found in ConfigMap %s\n", configMapName) - } else { - fmt.Printf("Policy csv data: %s\n", policyCSV) - - // Parse the policy.csv content and build the map - userToObjectPatternMapping, groupToObjectPatternMapping = parsePolicyCSV(policyCSV) - - // Print the user permissions - fmt.Println("User Permissions:") - for user, objectPattern := range userToObjectPatternMapping { - fmt.Printf("User: %s, Permissions: %v\n", user, objectPattern) - } - - // Print the group permissions - fmt.Println("Group Permissions:") - for group, objectPattern := range groupToObjectPatternMapping { - fmt.Printf("Group: %s, Permissions: %v\n", group, objectPattern) - } + fmt.Printf("policy.csv not found in ConfigMap %s", configMapName) + return nil, nil } - // Redis configuration - redisAddr := "localhost:16379" // Redis service DNS - redisPassword := "" // Set the password if Redis authentication is enabled + return parsePolicyCSV(policyCSV) +} - // Initialize Redis client - redisClient := redis.NewClient(&redis.Options{ - Addr: redisAddr, - Password: redisPassword, - DB: 1, +func initializeRedis(addr, password string, db int) *redis.Client { + client := redis.NewClient(&redis.Options{ + Addr: addr, + Password: password, + DB: db, DialTimeout: 5 * time.Second, }) - // Test connection - pong, err := redisClient.Ping().Result() - if err != nil { + if _, err := client.Ping().Result(); err != nil { log.Fatalf("Failed to connect to Redis: %v", err) } + fmt.Println("Connected to Redis successfully") + return client +} - fmt.Printf("Connected to Redis: %s\n", pong) - - // ArgoCD server URL - argocdServerURL := "http://localhost:8443" // Update this to your actual ArgoCD server URL +func createReverseProxy(target string) *httputil.ReverseProxy { + parsedURL, err := url.Parse(target) + if err != nil { + log.Fatalf("Invalid ArgoCD server URL: %v", err) + } - // Proxy handler - proxy := &httputil.ReverseProxy{ + return &httputil.ReverseProxy{ Director: func(req *http.Request) { - argoURL, _ := url.Parse(argocdServerURL) - req.URL.Scheme = argoURL.Scheme - req.URL.Host = argoURL.Host - req.Host = argoURL.Host + req.URL.Scheme = parsedURL.Scheme + req.URL.Host = parsedURL.Host + req.Host = parsedURL.Host }, } +} - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - token := func() string { - if authHeader := r.Header.Get("Authorization"); strings.HasPrefix(authHeader, "Bearer ") { - return strings.TrimPrefix(authHeader, "Bearer ") - } - if cookie, err := r.Cookie("argocd.token"); err == nil { - return cookie.Value - } - return "" - }() - if token == "" { - proxy.ServeHTTP(w, r) - return - } - - // Capture GET requests to /api/v1/applications - if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/api/v1/applications") { - fmt.Printf("Request: %s %s\n", r.Method, r.URL.Path) - - payload, err := decodeJWTPayload(token) - if err != nil { - proxy.ServeHTTP(w, r) - return - } - - // Extract the "email" and "groups" from the payload - email, _ := payload["email"].(string) - groups, _ := payload["groups"].([]string) - fmt.Printf("Email: %s, Groups: %v\n", email, groups) - - // Collect all unique object patterns for the user and groups - objectPatterns := make(map[string]struct{}) - if objectPattern, ok := userToObjectPatternMapping[email]; ok { - for _, pattern := range objectPattern { - objectPatterns[pattern] = struct{}{} - } - } - for _, group := range groups { - if objectPattern, ok := groupToObjectPatternMapping[group]; ok { - for _, pattern := range objectPattern { - objectPatterns[pattern] = struct{}{} - } - } - } - - // Prepare a list of keys to fetch from Redis - keyPatterns := make([]string, 0, len(objectPatterns)) - for pattern := range objectPatterns { - keyPatterns = append(keyPatterns, fmt.Sprintf("%s|*", pattern)) - } - - // Batch process keys using Redis MGET - allKeys := make([]string, 0) - for _, keyPattern := range keyPatterns { - keys, err := redisClient.Keys(keyPattern).Result() - if err != nil { - log.Printf("Failed to fetch keys for pattern %s: %v", keyPattern, err) - continue - } - allKeys = append(allKeys, keys...) - } - - // Fetch all values for the keys in a single batch - keyValuePairs := make(map[string]string) - if len(allKeys) > 0 { - pipe := redisClient.Pipeline() - cmds := make([]*redis.StringCmd, len(allKeys)) - for i, key := range allKeys { - cmds[i] = pipe.Get(key) - } - _, err := pipe.Exec() - if err != nil && err != redis.Nil { - log.Fatalf("Failed to fetch values for keys: %v", err) - proxy.ServeHTTP(w, r) - return - } +func handleRequest(w http.ResponseWriter, r *http.Request, proxy *httputil.ReverseProxy, redisClient *redis.Client, userToObjectPatternMapping, groupToObjectPatternMapping map[string][]string) { + token := extractToken(r) + if token == "" || (r.Method != http.MethodGet || !strings.HasPrefix(r.URL.Path, "/api/v1/applications")) { + proxy.ServeHTTP(w, r) + return + } - for i, cmd := range cmds { - if cmd.Err() == nil { - keyValuePairs[allKeys[i]] = cmd.Val() - } else { - log.Printf("Failed to fetch value for key %s: %v", allKeys[i], cmd.Err()) - } - } - } + payload, err := decodeJWTPayload(token) + if err != nil { + proxy.ServeHTTP(w, r) + return + } - // Unmarshal values and build the response - var resp struct { - Items []interface{} `json:"items"` - } - resp.Items = make([]interface{}, 0, len(keyValuePairs)) - for key, value := range keyValuePairs { - var rawJson interface{} - if err := json.Unmarshal([]byte(value), &rawJson); err != nil { - log.Printf("Failed to unmarshal value for key %s: %v", key, err) - continue - } - resp.Items = append(resp.Items, rawJson) - } + email, _ := payload["email"].(string) + groups, _ := payload["groups"].([]string) - if len(resp.Items) == 0 { - proxy.ServeHTTP(w, r) - return - } + objectPatterns := resolveObjectPatterns(email, groups, userToObjectPatternMapping, groupToObjectPatternMapping) - // Serialize the key-value pairs as JSON - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(resp); err != nil { - log.Printf("Failed to write response: %v", err) - proxy.ServeHTTP(w, r) - return - } - return - } + resp := fetchApplicationsFromRedis(redisClient, objectPatterns) + if len(resp.Items) == 0 { + proxy.ServeHTTP(w, r) + return + } - // Proxy other requests to the ArgoCD server + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + log.Printf("Failed to write response: %v", err) proxy.ServeHTTP(w, r) - }) + } +} - log.Println("Proxy server running on :8081") - log.Fatal(http.ListenAndServe(":8081", nil)) +func extractToken(r *http.Request) string { + if authHeader := r.Header.Get("Authorization"); strings.HasPrefix(authHeader, "Bearer ") { + return strings.TrimPrefix(authHeader, "Bearer ") + } + if cookie, err := r.Cookie("argocd.token"); err == nil { + return cookie.Value + } + return "" } -// decodeJWTPayload decodes the payload of a JWT token without validating it func decodeJWTPayload(token string) (map[string]interface{}, error) { - // Split the token into its parts (header, payload, signature) parts := strings.Split(token, ".") if len(parts) < 2 { return nil, fmt.Errorf("invalid token format") } - // Decode the payload (second part of the JWT) payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return nil, fmt.Errorf("failed to decode payload: %v", err) @@ -233,10 +142,69 @@ func decodeJWTPayload(token string) (map[string]interface{}, error) { if err := json.Unmarshal(payloadBytes, &payload); err != nil { return nil, fmt.Errorf("failed to unmarshal payload: %v", err) } - return payload, nil } +func resolveObjectPatterns(email string, groups []string, userToObjectPatternMapping, groupToObjectPatternMapping map[string][]string) map[string]struct{} { + objectPatterns := make(map[string]struct{}) + + for _, pattern := range userToObjectPatternMapping[email] { + objectPatterns[pattern] = struct{}{} + } + + for _, group := range groups { + for _, pattern := range groupToObjectPatternMapping[group] { + objectPatterns[pattern] = struct{}{} + } + } + + return objectPatterns +} + +func fetchApplicationsFromRedis(redisClient *redis.Client, objectPatterns map[string]struct{}) struct { + Items []interface{} `json:"items"` +} { + resp := struct { + Items []interface{} `json:"items"` + }{Items: []interface{}{}} + + var allKeys []string + for pattern := range objectPatterns { + keys, err := redisClient.Keys(fmt.Sprintf("%s|*", pattern)).Result() + if err != nil { + log.Printf("Failed to fetch keys for pattern %s: %v", pattern, err) + continue + } + allKeys = append(allKeys, keys...) + } + + if len(allKeys) > 0 { + pipe := redisClient.Pipeline() + cmds := make([]*redis.StringCmd, len(allKeys)) + for i, key := range allKeys { + cmds[i] = pipe.Get(key) + } + _, err := pipe.Exec() + if err != nil && err != redis.Nil { + log.Printf("Failed to fetch values for keys: %v", err) + } + + for i, cmd := range cmds { + if cmd.Err() == nil { + var rawJson interface{} + if err := json.Unmarshal([]byte(cmd.Val()), &rawJson); err == nil { + resp.Items = append(resp.Items, rawJson) + } else { + log.Printf("Failed to unmarshal value for key %s: %v", allKeys[i], err) + } + } else { + log.Printf("Failed to fetch value for key %s: %v", allKeys[i], cmd.Err()) + } + } + } + return resp +} + func parsePolicyCSV(policyCSV string) (map[string][]string, map[string][]string) { userToRoleMapping := make(map[string][]string) groupToRoleMapping := make(map[string][]string) diff --git a/main_test.go b/main_test.go index 70dac42..2d49ff7 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,10 @@ package main import ( + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" "reflect" "testing" ) @@ -131,3 +135,123 @@ func TestParsePolicyCSV(t *testing.T) { }) } } + +func TestExtractToken(t *testing.T) { + tests := []struct { + name string + setupRequest func() *http.Request + expectedToken string + }{ + { + name: "Valid Bearer Token in Authorization Header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer valid_token") + return req + }, + expectedToken: "valid_token", + }, + { + name: "Token in Cookie", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "argocd.token", Value: "cookie_token"}) + return req + }, + expectedToken: "cookie_token", + }, + { + name: "No Token in Header or Cookie", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + return req + }, + expectedToken: "", + }, + { + name: "Invalid Authorization Header", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "InvalidAuth valid_token") + return req + }, + expectedToken: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := tt.setupRequest() + token := extractToken(req) + + if token != tt.expectedToken { + t.Errorf("extractToken() = %q, want %q", token, tt.expectedToken) + } + }) + } +} + +func TestDecodeJWTPayload(t *testing.T) { + tests := []struct { + name string + token string + expected map[string]interface{} + expectingError bool + }{ + { + name: "Valid JWT Token", + token: createTestJWT(map[string]interface{}{"email": "test@example.com", "role": "admin"}), + expected: map[string]interface{}{"email": "test@example.com", "role": "admin"}, + expectingError: false, + }, + { + name: "Invalid JWT Token Format", + token: "invalid.token", + expected: nil, + expectingError: true, + }, + { + name: "Invalid Payload Encoding", + token: "header." + base64.RawURLEncoding.EncodeToString([]byte("invalid payload")) + ".signature", + expected: nil, + expectingError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, err := decodeJWTPayload(tt.token) + + if tt.expectingError && err == nil { + t.Errorf("Expected an error but got none") + } + + if !tt.expectingError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if !tt.expectingError && !compareMaps(payload, tt.expected) { + t.Errorf("Payload mismatch. Expected: %v, Got: %v", tt.expected, payload) + } + }) + } +} + +func createTestJWT(payload map[string]interface{}) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`)) + payloadBytes, _ := json.Marshal(payload) + encodedPayload := base64.RawURLEncoding.EncodeToString(payloadBytes) + return header + "." + encodedPayload + ".signature" +} + +func compareMaps(a, b map[string]interface{}) bool { + if len(a) != len(b) { + return false + } + for key, valueA := range a { + if valueB, exists := b[key]; !exists || valueA != valueB { + return false + } + } + return true +}