Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jenting committed Nov 24, 2024
1 parent ad9b3c2 commit a4bc5aa
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 164 deletions.
296 changes: 132 additions & 164 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,210 +20,119 @@ import (

func main() {
config := ctrl.GetConfigOrDie()
namespace := "argocd"
configMapName := "argocd-rbac-cm"

clientset, err := kubernetes.NewForConfig(config)
if err != nil {
log.Fatalf("Failed to create Kubernetes client: %v", err)
}

userToObjectPatternMapping, groupToObjectPatternMapping := loadRBACPolicyFromConfigMap(clientset, "argocd", "argocd-rbac-cm")

redisClient := initializeRedis("localhost:16379", "", 1)

proxy := createReverseProxy("http://localhost:8443")

http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
handleRequest(w, r, proxy, redisClient, userToObjectPatternMapping, groupToObjectPatternMapping)
})

log.Println("Proxy server running on :8081")
log.Fatal(http.ListenAndServe(":8081", nil))
}

func loadRBACPolicyFromConfigMap(clientset *kubernetes.Clientset, namespace, configMapName string) (map[string][]string, map[string][]string) {
cm, err := clientset.CoreV1().ConfigMaps(namespace).Get(context.Background(), configMapName, metav1.GetOptions{})
if err != nil {
log.Fatalf("Failed to fetch ConfigMap %s: %v", configMapName, err)
fmt.Printf("Failed to fetch ConfigMap %s: %v", configMapName, err)
return nil, nil
}

// Extract policy data from ConfigMap
var userToObjectPatternMapping map[string][]string = make(map[string][]string)
var groupToObjectPatternMapping map[string][]string = make(map[string][]string)

policyCSV, ok := cm.Data["policy.csv"]
if !ok {
fmt.Printf("policy.csv not found in ConfigMap %s\n", configMapName)
} else {
fmt.Printf("Policy csv data: %s\n", policyCSV)

// Parse the policy.csv content and build the map
userToObjectPatternMapping, groupToObjectPatternMapping = parsePolicyCSV(policyCSV)

// Print the user permissions
fmt.Println("User Permissions:")
for user, objectPattern := range userToObjectPatternMapping {
fmt.Printf("User: %s, Permissions: %v\n", user, objectPattern)
}

// Print the group permissions
fmt.Println("Group Permissions:")
for group, objectPattern := range groupToObjectPatternMapping {
fmt.Printf("Group: %s, Permissions: %v\n", group, objectPattern)
}
fmt.Printf("policy.csv not found in ConfigMap %s", configMapName)
return nil, nil
}

// Redis configuration
redisAddr := "localhost:16379" // Redis service DNS
redisPassword := "" // Set the password if Redis authentication is enabled
return parsePolicyCSV(policyCSV)
}

// Initialize Redis client
redisClient := redis.NewClient(&redis.Options{
Addr: redisAddr,
Password: redisPassword,
DB: 1,
func initializeRedis(addr, password string, db int) *redis.Client {
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
DB: db,
DialTimeout: 5 * time.Second,
})

// Test connection
pong, err := redisClient.Ping().Result()
if err != nil {
if _, err := client.Ping().Result(); err != nil {
log.Fatalf("Failed to connect to Redis: %v", err)
}
fmt.Println("Connected to Redis successfully")
return client
}

fmt.Printf("Connected to Redis: %s\n", pong)

// ArgoCD server URL
argocdServerURL := "http://localhost:8443" // Update this to your actual ArgoCD server URL
func createReverseProxy(target string) *httputil.ReverseProxy {
parsedURL, err := url.Parse(target)
if err != nil {
log.Fatalf("Invalid ArgoCD server URL: %v", err)
}

// Proxy handler
proxy := &httputil.ReverseProxy{
return &httputil.ReverseProxy{
Director: func(req *http.Request) {
argoURL, _ := url.Parse(argocdServerURL)
req.URL.Scheme = argoURL.Scheme
req.URL.Host = argoURL.Host
req.Host = argoURL.Host
req.URL.Scheme = parsedURL.Scheme
req.URL.Host = parsedURL.Host
req.Host = parsedURL.Host
},
}
}

http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
token := func() string {
if authHeader := r.Header.Get("Authorization"); strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimPrefix(authHeader, "Bearer ")
}
if cookie, err := r.Cookie("argocd.token"); err == nil {
return cookie.Value
}
return ""
}()
if token == "" {
proxy.ServeHTTP(w, r)
return
}

// Capture GET requests to /api/v1/applications
if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/api/v1/applications") {
fmt.Printf("Request: %s %s\n", r.Method, r.URL.Path)

payload, err := decodeJWTPayload(token)
if err != nil {
proxy.ServeHTTP(w, r)
return
}

// Extract the "email" and "groups" from the payload
email, _ := payload["email"].(string)
groups, _ := payload["groups"].([]string)
fmt.Printf("Email: %s, Groups: %v\n", email, groups)

// Collect all unique object patterns for the user and groups
objectPatterns := make(map[string]struct{})
if objectPattern, ok := userToObjectPatternMapping[email]; ok {
for _, pattern := range objectPattern {
objectPatterns[pattern] = struct{}{}
}
}
for _, group := range groups {
if objectPattern, ok := groupToObjectPatternMapping[group]; ok {
for _, pattern := range objectPattern {
objectPatterns[pattern] = struct{}{}
}
}
}

// Prepare a list of keys to fetch from Redis
keyPatterns := make([]string, 0, len(objectPatterns))
for pattern := range objectPatterns {
keyPatterns = append(keyPatterns, fmt.Sprintf("%s|*", pattern))
}

// Batch process keys using Redis MGET
allKeys := make([]string, 0)
for _, keyPattern := range keyPatterns {
keys, err := redisClient.Keys(keyPattern).Result()
if err != nil {
log.Printf("Failed to fetch keys for pattern %s: %v", keyPattern, err)
continue
}
allKeys = append(allKeys, keys...)
}

// Fetch all values for the keys in a single batch
keyValuePairs := make(map[string]string)
if len(allKeys) > 0 {
pipe := redisClient.Pipeline()
cmds := make([]*redis.StringCmd, len(allKeys))
for i, key := range allKeys {
cmds[i] = pipe.Get(key)
}
_, err := pipe.Exec()
if err != nil && err != redis.Nil {
log.Fatalf("Failed to fetch values for keys: %v", err)
proxy.ServeHTTP(w, r)
return
}
func handleRequest(w http.ResponseWriter, r *http.Request, proxy *httputil.ReverseProxy, redisClient *redis.Client, userToObjectPatternMapping, groupToObjectPatternMapping map[string][]string) {
token := extractToken(r)
if token == "" || (r.Method != http.MethodGet || !strings.HasPrefix(r.URL.Path, "/api/v1/applications")) {
proxy.ServeHTTP(w, r)
return
}

for i, cmd := range cmds {
if cmd.Err() == nil {
keyValuePairs[allKeys[i]] = cmd.Val()
} else {
log.Printf("Failed to fetch value for key %s: %v", allKeys[i], cmd.Err())
}
}
}
payload, err := decodeJWTPayload(token)
if err != nil {
proxy.ServeHTTP(w, r)
return
}

// Unmarshal values and build the response
var resp struct {
Items []interface{} `json:"items"`
}
resp.Items = make([]interface{}, 0, len(keyValuePairs))
for key, value := range keyValuePairs {
var rawJson interface{}
if err := json.Unmarshal([]byte(value), &rawJson); err != nil {
log.Printf("Failed to unmarshal value for key %s: %v", key, err)
continue
}
resp.Items = append(resp.Items, rawJson)
}
email, _ := payload["email"].(string)
groups, _ := payload["groups"].([]string)

if len(resp.Items) == 0 {
proxy.ServeHTTP(w, r)
return
}
objectPatterns := resolveObjectPatterns(email, groups, userToObjectPatternMapping, groupToObjectPatternMapping)

// Serialize the key-value pairs as JSON
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Failed to write response: %v", err)
proxy.ServeHTTP(w, r)
return
}
return
}
resp := fetchApplicationsFromRedis(redisClient, objectPatterns)
if len(resp.Items) == 0 {
proxy.ServeHTTP(w, r)
return
}

// Proxy other requests to the ArgoCD server
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Failed to write response: %v", err)
proxy.ServeHTTP(w, r)
})
}
}

log.Println("Proxy server running on :8081")
log.Fatal(http.ListenAndServe(":8081", nil))
func extractToken(r *http.Request) string {
if authHeader := r.Header.Get("Authorization"); strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimPrefix(authHeader, "Bearer ")
}
if cookie, err := r.Cookie("argocd.token"); err == nil {
return cookie.Value
}
return ""
}

// decodeJWTPayload decodes the payload of a JWT token without validating it
func decodeJWTPayload(token string) (map[string]interface{}, error) {
// Split the token into its parts (header, payload, signature)
parts := strings.Split(token, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("invalid token format")
}

// Decode the payload (second part of the JWT)
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to decode payload: %v", err)
Expand All @@ -233,10 +142,69 @@ func decodeJWTPayload(token string) (map[string]interface{}, error) {
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
return nil, fmt.Errorf("failed to unmarshal payload: %v", err)
}

return payload, nil
}

func resolveObjectPatterns(email string, groups []string, userToObjectPatternMapping, groupToObjectPatternMapping map[string][]string) map[string]struct{} {
objectPatterns := make(map[string]struct{})

for _, pattern := range userToObjectPatternMapping[email] {
objectPatterns[pattern] = struct{}{}
}

for _, group := range groups {
for _, pattern := range groupToObjectPatternMapping[group] {
objectPatterns[pattern] = struct{}{}
}
}

return objectPatterns
}

func fetchApplicationsFromRedis(redisClient *redis.Client, objectPatterns map[string]struct{}) struct {
Items []interface{} `json:"items"`
} {
resp := struct {
Items []interface{} `json:"items"`
}{Items: []interface{}{}}

var allKeys []string
for pattern := range objectPatterns {
keys, err := redisClient.Keys(fmt.Sprintf("%s|*", pattern)).Result()
if err != nil {
log.Printf("Failed to fetch keys for pattern %s: %v", pattern, err)
continue
}
allKeys = append(allKeys, keys...)
}

if len(allKeys) > 0 {
pipe := redisClient.Pipeline()
cmds := make([]*redis.StringCmd, len(allKeys))
for i, key := range allKeys {
cmds[i] = pipe.Get(key)
}
_, err := pipe.Exec()
if err != nil && err != redis.Nil {
log.Printf("Failed to fetch values for keys: %v", err)
}

for i, cmd := range cmds {
if cmd.Err() == nil {
var rawJson interface{}
if err := json.Unmarshal([]byte(cmd.Val()), &rawJson); err == nil {
resp.Items = append(resp.Items, rawJson)
} else {
log.Printf("Failed to unmarshal value for key %s: %v", allKeys[i], err)
}
} else {
log.Printf("Failed to fetch value for key %s: %v", allKeys[i], cmd.Err())
}
}
}
return resp
}

func parsePolicyCSV(policyCSV string) (map[string][]string, map[string][]string) {
userToRoleMapping := make(map[string][]string)
groupToRoleMapping := make(map[string][]string)
Expand Down
Loading

0 comments on commit a4bc5aa

Please sign in to comment.