Skip to content

Commit

Permalink
Improvements to debug logging and CORS handling
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Creasy <[email protected]>
  • Loading branch information
alexcreasy committed Feb 4, 2025
1 parent 5cbf43c commit 03b8e39
Show file tree
Hide file tree
Showing 13 changed files with 277 additions and 145 deletions.
4 changes: 2 additions & 2 deletions clients/ui/bff/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ test: fmt vet envtest ## Runs the full test suite.

.PHONY: build
build: fmt vet test ## Builds the project to produce a binary executable.
go build -o bin/bff cmd/main.go
go build -o bin/bff cmd/*.go

.PHONY: run
run: fmt vet envtest ## Runs the project.
ENVTEST_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" \
go run ./cmd/main.go --port=$(PORT) --static-assets-dir=$(STATIC_ASSETS_DIR) --mock-k8s-client=$(MOCK_K8S_CLIENT) --mock-mr-client=$(MOCK_MR_CLIENT) --dev-mode=$(DEV_MODE) --dev-mode-port=$(DEV_MODE_PORT) --standalone-mode=$(STANDALONE_MODE) --log-level=$(LOG_LEVEL) --allowed-origins=$(ALLOWED_ORIGINS)
go run ./cmd/main.go ./cmd/helpers.go --port=$(PORT) --static-assets-dir=$(STATIC_ASSETS_DIR) --mock-k8s-client=$(MOCK_K8S_CLIENT) --mock-mr-client=$(MOCK_MR_CLIENT) --dev-mode=$(DEV_MODE) --dev-mode-port=$(DEV_MODE_PORT) --standalone-mode=$(STANDALONE_MODE) --log-level=$(LOG_LEVEL) --allowed-origins=$(ALLOWED_ORIGINS)

##@ Dependencies

Expand Down
55 changes: 55 additions & 0 deletions clients/ui/bff/cmd/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package main

import (
"fmt"
"log/slog"
"os"
"strconv"
"strings"
)

func getEnvAsInt(name string, defaultVal int) int {
if value, exists := os.LookupEnv(name); exists {
if intValue, err := strconv.Atoi(value); err == nil {
return intValue
}
}
return defaultVal
}

func getEnvAsString(name string, defaultVal string) string {
if value, exists := os.LookupEnv(name); exists {
return value
}
return defaultVal
}

func parseLevel(s string) slog.Level {
var level slog.Level
err := level.UnmarshalText([]byte(s))
if err != nil {
panic(fmt.Errorf("invalid log level: %s, valid levels are: error, warn, info, debug", s))
}
return level
}

func newOriginParser(allowList *[]string, defaultVal string) func(s string) error {
return func(s string) error {
value := defaultVal

if s != "" {
value = s
}

if value == "" {
allowList = nil
return nil
}

for _, str := range strings.Split(s, ",") {
*allowList = append(*allowList, strings.TrimSpace(str))
}

return nil
}
}
66 changes: 66 additions & 0 deletions clients/ui/bff/cmd/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package main

import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"testing"
)

var _ = Describe("newOriginParser helper function", func() {
var originParser func(s string) error
var allowList []string

BeforeEach(func() {
allowList = []string{}
originParser = newOriginParser(&allowList, "")
})

It("should parse a valid string list with 1 item", func() {
expected := []string{"https://test.com"}

err := originParser("https://test.com")

Expect(err).NotTo(HaveOccurred())
Expect(allowList).To(Equal(expected))
})

It("should parse a valid string list with 2 items", func() {
expected := []string{"https://test.com", "https://test2.com"}

err := originParser("https://test.com,https://test2.com")

Expect(err).NotTo(HaveOccurred())
Expect(allowList).To(Equal(expected))
})

It("should parse a valid string list with 2 items and extra spaces", func() {
expected := []string{"https://test.com", "https://test2.com"}

err := originParser("https://test.com, https://test2.com")

Expect(err).NotTo(HaveOccurred())
Expect(allowList).To(Equal(expected))
})

It("should parse an empty string", func() {
err := originParser("")

Expect(err).NotTo(HaveOccurred())
Expect(allowList).To(BeEmpty())
})

It("should parse the wildcard string", func() {
expected := []string{"*"}

err := originParser("*")

Expect(err).NotTo(HaveOccurred())
Expect(allowList).To(Equal(expected))
})

})

func TestMainHelpers(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Main helpers suite")
}
47 changes: 7 additions & 40 deletions clients/ui/bff/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"flag"
"fmt"
"os/signal"
"strings"
"syscall"

"github.com/kubeflow/model-registry/ui/bff/internal/api"
Expand All @@ -14,7 +13,6 @@ import (
"log/slog"
"net/http"
"os"
"strconv"
"time"
)

Expand All @@ -27,15 +25,18 @@ func main() {
flag.IntVar(&cfg.DevModePort, "dev-mode-port", getEnvAsInt("DEV_MODE_PORT", 8080), "Use port when in development mode")
flag.BoolVar(&cfg.StandaloneMode, "standalone-mode", false, "Use standalone mode for enabling endpoints in standalone mode")
flag.StringVar(&cfg.StaticAssetsDir, "static-assets-dir", "./static", "Configure frontend static assets root directory")
flag.StringVar(&cfg.LogLevel, "log-level", getEnvAsString("LOG_LEVEL", "info"), "Sets server log level, possible values: debug, info, warn, error, fatal")
flag.StringVar(&cfg.AllowedOrigins, "allowed-origins", getEnvAsString("ALLOWED_ORIGINS", ""), "Sets allowed origins for CORS purposes, accepts a comma separated list of origins or * to allow all, default none")
flag.TextVar(&cfg.LogLevel, "log-level", parseLevel(getEnvAsString("LOG_LEVEL", "INFO")), "Sets server log level, possible values: error, warn, info, debug")
flag.Func("allowed-origins", "Sets allowed origins for CORS purposes, accepts a comma separated list of origins or * to allow all, default none", newOriginParser(&cfg.AllowedOrigins, getEnvAsString("ALLOWED_ORIGINS", "")))
flag.Parse()

logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: getLogLevelFromString(cfg.LogLevel),
Level: cfg.LogLevel,
}))

app, err := api.NewApp(cfg, logger)
// Only use for logging errors about logging configuration.
slog.SetDefault(logger)

app, err := api.NewApp(cfg, slog.New(logger.Handler()))
if err != nil {
logger.Error(err.Error())
os.Exit(1)
Expand Down Expand Up @@ -82,38 +83,4 @@ func main() {

logger.Info("server stopped")
os.Exit(0)

}

func getEnvAsInt(name string, defaultVal int) int {
if value, exists := os.LookupEnv(name); exists {
if intValue, err := strconv.Atoi(value); err == nil {
return intValue
}
}
return defaultVal
}

func getEnvAsString(name string, defaultVal string) string {
if value, exists := os.LookupEnv(name); exists {
return value
}
return defaultVal
}

func getLogLevelFromString(level string) slog.Level {
switch strings.ToLower(level) {
case "debug":
return slog.LevelDebug
case "info":
return slog.LevelInfo
case "warn":
return slog.LevelWarn
case "error":
return slog.LevelError
case "fatal":
return slog.LevelError

}
return slog.LevelInfo
}
5 changes: 5 additions & 0 deletions clients/ui/bff/internal/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"context"
"fmt"
helper "github.com/kubeflow/model-registry/ui/bff/internal/helpers"
"log/slog"
"net/http"
"path"
Expand Down Expand Up @@ -50,6 +51,7 @@ type App struct {
}

func NewApp(cfg config.EnvConfig, logger *slog.Logger) (*App, error) {
logger.Debug("Initializing app with config", slog.Any("config", cfg))
var k8sClient integrations.KubernetesClientInterface
var err error
if cfg.MockK8Client {
Expand Down Expand Up @@ -136,14 +138,17 @@ func (app *App) Routes() http.Handler {
staticDir := http.Dir(app.config.StaticAssetsDir)
fileServer := http.FileServer(staticDir)
appMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
ctxLogger := helper.GetContextLoggerFromReq(r)
// Check if the requested file exists
if _, err := staticDir.Open(r.URL.Path); err == nil {
ctxLogger.Debug("Serving static file", slog.String("path", r.URL.Path))
// Serve the file if it exists
fileServer.ServeHTTP(w, r)
return
}

// Fallback to index.html for SPA routes
ctxLogger.Debug("Static asset not found, serving index.html", slog.String("path", r.URL.Path))
http.ServeFile(w, r, path.Join(app.config.StaticAssetsDir, "index.html"))
})

Expand Down
12 changes: 0 additions & 12 deletions clients/ui/bff/internal/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,3 @@ func ParseURLTemplate(tmpl string, params map[string]string) string {

return r.Replace(tmpl)
}

func ParseOriginList(origins string) ([]string, bool) {
if origins == "" {
return []string{}, false
}

if origins == "*" {
return []string{"*"}, true
}
originsTrimmed := strings.ReplaceAll(origins, " ", "")
return strings.Split(originsTrimmed, ","), true
}
39 changes: 0 additions & 39 deletions clients/ui/bff/internal/api/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,3 @@ func TestParseURLTemplateWhenEmpty(t *testing.T) {
actual := ParseURLTemplate("", nil)
assert.Empty(t, actual)
}

func TestParseOriginListAllowAll(t *testing.T) {
expected := []string{"*"}

actual, ok := ParseOriginList("*")

assert.True(t, ok)
assert.Equal(t, expected, actual)
}

func TestParseOriginListEmpty(t *testing.T) {
actual, ok := ParseOriginList("")

assert.False(t, ok)
assert.Empty(t, actual)
}

func TestParseOriginListSingle(t *testing.T) {
expected := []string{"http://test.com"}

actual, ok := ParseOriginList("http://test.com")

assert.True(t, ok)
assert.Equal(t, expected, actual)
}

func TestParseOriginListMultiple(t *testing.T) {
expected := []string{"http://test.com", "http://test2.com"}
actual, ok := ParseOriginList("http://test.com,http://test2.com")
assert.True(t, ok)
assert.Equal(t, expected, actual)
}

func TestParseOriginListMultipleAndSpaces(t *testing.T) {
expected := []string{"http://test.com", "http://test2.com"}
actual, ok := ParseOriginList("http://test.com, http://test2.com")
assert.True(t, ok)
assert.Equal(t, expected, actual)
}
36 changes: 20 additions & 16 deletions clients/ui/bff/internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/ui/bff/internal/config"
"github.com/kubeflow/model-registry/ui/bff/internal/constants"
helper "github.com/kubeflow/model-registry/ui/bff/internal/helpers"
"github.com/kubeflow/model-registry/ui/bff/internal/integrations"
"github.com/rs/cors"
"log/slog"
Expand All @@ -22,7 +23,7 @@ func (app *App) RecoverPanic(next http.Handler) http.Handler {
if err := recover(); err != nil {
w.Header().Set("Connection", "close")
app.serverErrorResponse(w, r, fmt.Errorf("%s", err))
app.logger.Error("Recover from panic: " + string(debug.Stack()))
app.logger.Error("Recovered from panic", slog.String("stack_trace", string(debug.Stack())))
}
}()

Expand All @@ -32,7 +33,6 @@ func (app *App) RecoverPanic(next http.Handler) http.Handler {

func (app *App) InjectUserHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

//skip use headers check if we are not on /api/v1
if !strings.HasPrefix(r.URL.Path, PathPrefix) {
next.ServeHTTP(w, r)
Expand Down Expand Up @@ -68,18 +68,17 @@ func (app *App) InjectUserHeaders(next http.Handler) http.Handler {
}

func (app *App) EnableCORS(next http.Handler) http.Handler {
allowedOrigins, ok := ParseOriginList(app.config.AllowedOrigins)

if !ok {
if len(app.config.AllowedOrigins) == 0 {
// CORS is disabled, this middleware becomes a noop.
return next
}

c := cors.New(cors.Options{
AllowedOrigins: allowedOrigins,
AllowedOrigins: app.config.AllowedOrigins,
AllowCredentials: true,
AllowedMethods: []string{"GET", "PUT", "POST", "PATCH", "DELETE"},
AllowedHeaders: []string{constants.KubeflowUserIDHeader, constants.KubeflowUserGroupsIdHeader},
Debug: strings.ToLower(app.config.LogLevel) == "debug",
Debug: app.config.LogLevel == slog.LevelDebug,
OptionsPassthrough: false,
})

Expand All @@ -97,16 +96,21 @@ func (app *App) EnableTelemetry(next http.Handler) http.Handler {
traceLogger := app.logger.With(slog.String("trace_id", traceId))
ctx = context.WithValue(ctx, constants.TraceLoggerKey, traceLogger)

if traceLogger.Enabled(ctx, slog.LevelDebug) {
cloneBody, err := integrations.CloneBody(r)
if err != nil {
traceLogger.Debug("Error reading request body for debug logging", "error", err)
}
////TODO (Alex) Log headers, BUT we must ensure we don't log confidential data like tokens etc.
traceLogger.Debug("Incoming HTTP request", "method", r.Method, "url", r.URL.String(), "body", cloneBody)
}
traceLogger.Debug("Incoming HTTP request", slog.Any("request", helper.RequestLogValuer{Request: r}))

//if traceLogger.Enabled(ctx, slog.LevelDebug) {
// cloneBody, err := helper.CloneBody(r)
// if err != nil {
// traceLogger.Debug("Error reading request body for debug logging", "error", err)
// }
// traceLogger.Debug("Incoming HTTP request",
// slog.Group("request",
// slog.String("method", r.Method),
// slog.String("url", r.URL.String()),
// slog.String("body", string(cloneBody)),
// slog.Any("headers", helper.HeaderLogValuer{Header: r.Header})))
//}
}

next.ServeHTTP(w, r.WithContext(ctx))
})
}
Expand Down
Loading

0 comments on commit 03b8e39

Please sign in to comment.