diff --git a/pkg/audit/audit.go b/pkg/audit/audit.go index 25e06d2..b32e6f4 100644 --- a/pkg/audit/audit.go +++ b/pkg/audit/audit.go @@ -8,6 +8,7 @@ type QueryData struct { Namespace string Pod string Timestamp int64 + DBName string } type Audit interface { diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index b72e0c0..b3aa5e7 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -92,7 +92,7 @@ func Run(logger *zap.SugaredLogger) error { r := mux.NewRouter() r.Handle("/healthcheck", logHandler(healthLogOutput, handlers.Healthcheck(cfg))).Methods("GET") r.Handle("/query", logHandler(defaultLogOutput, queryHandler)).Methods("POST") - r.Handle("/dbname", logHandler(defaultLogOutput, handlers.GetCurrentDBName(cfg))).Methods("GET") + r.Handle("/dbname", logHandler(defaultLogOutput, handlers.GetDBName(cfg))).Methods("GET") r.Handle("/dbname/switch", logHandler(defaultLogOutput, handlers.SwitchDBName(cfg))).Methods("POST") port := 8080 diff --git a/pkg/handlers/getdbname.go b/pkg/handlers/getdbname.go index 26ac1ef..3e0ede8 100644 --- a/pkg/handlers/getdbname.go +++ b/pkg/handlers/getdbname.go @@ -3,15 +3,26 @@ package handlers import ( "encoding/json" "net/http" + "os" gabi "github.com/app-sre/gabi/pkg" "github.com/app-sre/gabi/pkg/models" ) -func GetCurrentDBName(cfg *gabi.Config) http.Handler { +func GetDBName(cfg *gabi.Config) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { dbName := cfg.DBEnv.GetCurrentDBName() + defaultDBName := os.Getenv("DB_NAME") + + response := models.DBNameResponse{DBName: dbName} + + if dbName != defaultDBName { + warning := "Current database differs from the default" + cfg.Logger.Warnf(warning) + response.Warnings = append(response.Warnings, warning) + } + w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(models.DBNameResponse{DBName: dbName}) + json.NewEncoder(w).Encode(response) }) } diff --git a/pkg/handlers/getdbname_test.go b/pkg/handlers/getdbname_test.go index 1e32548..8207d1f 100644 --- a/pkg/handlers/getdbname_test.go +++ b/pkg/handlers/getdbname_test.go @@ -3,10 +3,10 @@ package handlers import ( "bytes" "context" - "encoding/json" "io" "net/http" "net/http/httptest" + "os" "testing" "github.com/app-sre/gabi/internal/test" @@ -16,53 +16,68 @@ import ( "github.com/stretchr/testify/require" ) -func TestGetCurrentDBName(t *testing.T) { +func TestGetDBName(t *testing.T) { cases := []struct { - description string - dbName string - code int - body map[string]string + description string + dbName string + defaultDBName string + expectedStatus int + expectedBody string + want string }{ { "returns current database name", "test_db", + "test_db", 200, - map[string]string{"db_name": "test_db"}, + `{"db_name":"test_db"}`, + "", }, { "returns empty database name", "", + "", + 200, + `{"db_name":""}`, + "", + }, + { + "returns warning when current db name is different from default", + "test_db", + "default_db", 200, - map[string]string{"db_name": ""}, + `{"db_name":"test_db","warnings":["Current database differs from the default"]}`, + "Current database differs from the default", }, } for _, tc := range cases { tc := tc t.Run(tc.description, func(t *testing.T) { - var body bytes.Buffer + var output bytes.Buffer + + os.Setenv("DB_NAME", tc.defaultDBName) + defer os.Unsetenv("DB_NAME") dbEnv := &db.Env{Name: tc.dbName} w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", &bytes.Buffer{}) - logger := test.DummyLogger(io.Discard).Sugar() + logger := test.DummyLogger(&output).Sugar() expected := &gabi.Config{DBEnv: dbEnv, Logger: logger} - GetCurrentDBName(expected).ServeHTTP(w, r.WithContext(context.TODO())) + GetDBName(expected).ServeHTTP(w, r.WithContext(context.TODO())) actual := w.Result() defer func() { _ = actual.Body.Close() }() - _, _ = io.Copy(&body, actual.Body) - - var responseBody map[string]string - err := json.Unmarshal(body.Bytes(), &responseBody) - + body, err := io.ReadAll(actual.Body) require.NoError(t, err) - assert.Equal(t, tc.code, actual.StatusCode) - assert.Equal(t, tc.body, responseBody) + + assert.Equal(t, tc.expectedStatus, actual.StatusCode) + assert.JSONEq(t, tc.expectedBody, string(body)) + assert.Contains(t, output.String(), tc.want) }) } } diff --git a/pkg/handlers/healthcheck.go b/pkg/handlers/healthcheck.go index 393ecd4..c4b0b05 100644 --- a/pkg/handlers/healthcheck.go +++ b/pkg/handlers/healthcheck.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "os" "time" "github.com/etherlabsio/healthcheck/v2" @@ -14,17 +15,25 @@ import ( const healthcheckTimeout = 5 * time.Second func Healthcheck(cfg *gabi.Config) http.Handler { + defaultDBName := os.Getenv("DB_NAME") return healthcheck.Handler( healthcheck.WithTimeout(healthcheckTimeout), healthcheck.WithChecker( "database", healthcheck.CheckerFunc( func(ctx context.Context) error { + dbName := cfg.DBEnv.GetCurrentDBName() err := cfg.DB.PingContext(ctx) if err != nil { l := "Unable to connect to the database" cfg.Logger.Errorf("%s: %s", l, err) return errors.New(l) } + + if dbName != defaultDBName { + l := "Current database differs from the default" + cfg.Logger.Warnf(l) + } + return nil }, ), diff --git a/pkg/handlers/healthcheck_test.go b/pkg/handlers/healthcheck_test.go index 56dbae4..e97f64e 100644 --- a/pkg/handlers/healthcheck_test.go +++ b/pkg/handlers/healthcheck_test.go @@ -6,11 +6,13 @@ import ( "io" "net/http" "net/http/httptest" + "os" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/app-sre/gabi/internal/test" gabi "github.com/app-sre/gabi/pkg" + "github.com/app-sre/gabi/pkg/env/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -19,16 +21,20 @@ func TestHealthcheck(t *testing.T) { t.Parallel() cases := []struct { - description string - given func(sqlmock.Sqlmock) - code int - body string + description string + given func(sqlmock.Sqlmock) + dbName string + defaultDBName string + code int + body string }{ { "database is accessible and returns ping reply", func(mock sqlmock.Sqlmock) { mock.ExpectPing() }, + "default_db", + "default_db", 200, `{"status":"OK"}`, }, @@ -37,9 +43,21 @@ func TestHealthcheck(t *testing.T) { func(mock sqlmock.Sqlmock) { mock.ExpectPing().WillReturnError(errors.New("test")) }, + "default_db", + "default_db", 503, `{"database":"Unable to connect to the database"}`, }, + { + "database name differs from the default", + func(mock sqlmock.Sqlmock) { + mock.ExpectPing() + }, + "test_db", + "default_db", + 200, + ``, + }, } for _, tc := range cases { @@ -49,6 +67,11 @@ func TestHealthcheck(t *testing.T) { var body bytes.Buffer + os.Setenv("DB_NAME", tc.defaultDBName) + defer os.Unsetenv("DB_NAME") + + dbEnv := &db.Env{Name: tc.dbName} + db, mock, _ := sqlmock.New(sqlmock.MonitorPingsOption(true)) defer func() { _ = db.Close() }() @@ -59,7 +82,7 @@ func TestHealthcheck(t *testing.T) { tc.given(mock) - expected := &gabi.Config{DB: db, Logger: logger} + expected := &gabi.Config{DB: db, Logger: logger, DBEnv: dbEnv} Healthcheck(expected).ServeHTTP(w, r) actual := w.Result() diff --git a/pkg/handlers/query.go b/pkg/handlers/query.go index b54463e..adda757 100644 --- a/pkg/handlers/query.go +++ b/pkg/handlers/query.go @@ -29,8 +29,17 @@ func Query(cfg *gabi.Config) http.HandlerFunc { var ( base64Mode byte request models.QueryRequest + warnings []string ) + defaultDBName := os.Getenv("DB_NAME") + currentDBName := cfg.DBEnv.GetCurrentDBName() + if currentDBName != defaultDBName { + l := "Current database differs from the default" + cfg.Logger.Warnf(l) + warnings = append(warnings, l) + } + if s := r.URL.Query().Get("base64_results"); s != "" { if ok, err := strconv.ParseBool(s); err == nil && ok { base64Mode |= base64EncodeResults @@ -159,7 +168,8 @@ func Query(cfg *gabi.Config) http.HandlerFunc { w.Header().Set("Cache-Control", "private, no-store") w.Header().Set("Content-Type", "application/json; charset=utf-8") _ = json.NewEncoder(w).Encode(&models.QueryResponse{ - Result: result, + Result: result, + Warnings: warnings, }) } } diff --git a/pkg/handlers/query_test.go b/pkg/handlers/query_test.go index 36da059..d103fb1 100644 --- a/pkg/handlers/query_test.go +++ b/pkg/handlers/query_test.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/http/httptest" + "os" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -22,18 +23,19 @@ import ( ) func TestQuery(t *testing.T) { - t.Parallel() cases := []struct { - description string - database func() (*sql.DB, sqlmock.Sqlmock) - mock func(sqlmock.Sqlmock) - context func() context.Context - parameters func(*http.Request) - request func() *bytes.Buffer - code int - body string - want string + description string + database func() (*sql.DB, sqlmock.Sqlmock) + mock func(sqlmock.Sqlmock) + context func() context.Context + parameters func(*http.Request) + request func() *bytes.Buffer + defaultDBName string + dbName string + code int + body string + want string }{ { "valid query", @@ -56,6 +58,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query": "select 1;"}`) }, + "default_db", + "default_db", 200, `{"result":[["?column?"],["1"]],"error":""}`, ``, @@ -82,6 +86,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query": "select 1;"}`) }, + "default_db", + "default_db", 200, `{"result":[["?column?"],["2"]],"error":""}`, ``, @@ -108,6 +114,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query": "select 1;"}`) }, + "default_db", + "default_db", 200, `{"result":[["?column?"],["1"]],"error":""}`, ``, @@ -135,6 +143,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query": "c2VsZWN0IDE7"}`) }, + "default_db", + "default_db", 200, `{"result":[["?column?"],["1"]],"error":""}`, ``, @@ -162,6 +172,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query": "select 1;"}`) }, + "default_db", + "default_db", 200, `{"result":[["?column?"],["1"]],"error":""}`, ``, @@ -189,6 +201,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query": "select 1;"}`) }, + "default_db", + "default_db", 200, `{"result":[["?column?"],["MQ=="]],"error":""}`, ``, @@ -216,6 +230,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query": "select 1;"}`) }, + "default_db", + "default_db", 200, `{"result":[["?column?"],["1"]],"error":""}`, ``, @@ -241,6 +257,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query": ""}`) }, + "default_db", + "default_db", 200, `{"result":[null],"error":""}`, ``, @@ -265,6 +283,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query":"select 1;"}`) }, + "default_db", + "default_db", 400, `{"result":null,"error":"test"}`, ``, @@ -289,6 +309,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query":"select 1;"}`) }, + "default_db", + "default_db", 400, ``, `Unable to query database: test`, @@ -311,6 +333,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query":"select 1;"}`) }, + "default_db", + "default_db", 400, ``, `Unable to start database transaction: test`, @@ -336,6 +360,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query": "select 1;"}`) }, + "default_db", + "default_db", 400, ``, `Unable to commit database changes: test`, @@ -363,6 +389,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query": "select * from test;"}`) }, + "default_db", + "default_db", 400, ``, `Unable to process database rows: test`, @@ -386,6 +414,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query": "select 1;"}`) }, + "default_db", + "default_db", 503, `Unable to connect to the database`, ``, @@ -408,6 +438,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return &bytes.Buffer{} }, + "default_db", + "default_db", 400, `Request body cannot be empty`, ``, @@ -430,6 +462,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query: "select 1;"}`) }, + "default_db", + "default_db", 400, ``, `Unable to decode request body`, @@ -452,6 +486,8 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query": -1}`) }, + "default_db", + "default_db", 400, ``, `Unable to decode request body`, @@ -476,19 +512,52 @@ func TestQuery(t *testing.T) { func() *bytes.Buffer { return bytes.NewBufferString(`{"query": "dGhpcyBpcyBhIHRlc3Q=="}`) }, + "default_db", + "default_db", 400, `Unable to decode Base64-encoded query`, ``, }, + { + "database name differs from the default", + func() (*sql.DB, sqlmock.Sqlmock) { + db, mock, _ := sqlmock.New() + return db, mock + }, + func(mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{"?column?"}).AddRow("1") + mock.ExpectBegin() + mock.ExpectQuery(`select 1;`).WillReturnRows(rows) + mock.ExpectCommit() + }, + func() context.Context { + return context.TODO() + }, + func(r *http.Request) { + // No-op. + }, + func() *bytes.Buffer { + return bytes.NewBufferString(`{"query": "select 1;"}`) + }, + "default_db", + "test_db", + 200, + `{"result":[["?column?"],["1"]],"warnings":["Current database differs from the default"],"error":""}`, + ``, + }, } for _, tc := range cases { tc := tc t.Run(tc.description, func(t *testing.T) { - t.Parallel() var body, output bytes.Buffer + os.Setenv("DB_NAME", tc.defaultDBName) + defer os.Unsetenv("DB_NAME") + + dbEnv := &gabidb.Env{Name: tc.dbName} + w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", tc.request()) @@ -501,7 +570,7 @@ func TestQuery(t *testing.T) { tc.mock(mock) tc.parameters(r) - expected := &gabi.Config{DB: db, DBEnv: &gabidb.Env{}, Logger: logger, Encoder: encoder} + expected := &gabi.Config{DB: db, DBEnv: dbEnv, Logger: logger, Encoder: encoder} Query(expected).ServeHTTP(w, r.WithContext(tc.context())) actual := w.Result() diff --git a/pkg/handlers/switchdbname.go b/pkg/handlers/switchdbname.go index ac2131b..06af766 100644 --- a/pkg/handlers/switchdbname.go +++ b/pkg/handlers/switchdbname.go @@ -1,8 +1,10 @@ package handlers import ( + "context" "encoding/json" "net/http" + "time" gabi "github.com/app-sre/gabi/pkg" "github.com/app-sre/gabi/pkg/models" @@ -16,10 +18,25 @@ func SwitchDBName(cfg *gabi.Config) http.Handler { return } + oldDBName := cfg.DBEnv.GetCurrentDBName() cfg.DBEnv.OverrideDBName(req.DBName) + newDBName := cfg.DBEnv.GetCurrentDBName() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := cfg.DB.PingContext(ctx); err != nil { + cfg.Logger.Errorf("Failed to ping new database %s, falling back to %s: %s", newDBName, oldDBName, err) + cfg.DBEnv.OverrideDBName(oldDBName) + newDBName = oldDBName + } else { + cfg.Logger.Infof("Database name switched from %s to %s", oldDBName, newDBName) + } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{"db_name": cfg.DBEnv.GetCurrentDBName()}) + if err := json.NewEncoder(w).Encode(map[string]string{"db_name": newDBName}); err != nil { + cfg.Logger.Errorf("Failed to encode response: %s", err) + } }) } diff --git a/pkg/handlers/switchdbname_test.go b/pkg/handlers/switchdbname_test.go index 5712b0d..56d0d9e 100644 --- a/pkg/handlers/switchdbname_test.go +++ b/pkg/handlers/switchdbname_test.go @@ -4,11 +4,13 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" "testing" + "github.com/DATA-DOG/go-sqlmock" "github.com/app-sre/gabi/internal/test" gabi "github.com/app-sre/gabi/pkg" "github.com/app-sre/gabi/pkg/env/db" @@ -17,12 +19,15 @@ import ( ) func TestSwitchDBName(t *testing.T) { + t.Parallel() + cases := []struct { description string initialDBName string newDBName string code int body map[string]string + given func(sqlmock.Sqlmock) }{ { "override database name", @@ -30,6 +35,9 @@ func TestSwitchDBName(t *testing.T) { "new_db", 200, map[string]string{"db_name": "new_db"}, + func(mock sqlmock.Sqlmock) { + mock.ExpectPing() + }, }, { "empty database name", @@ -37,6 +45,9 @@ func TestSwitchDBName(t *testing.T) { "", 200, map[string]string{"db_name": ""}, + func(mock sqlmock.Sqlmock) { + mock.ExpectPing() + }, }, { "invalid request payload", @@ -44,16 +55,32 @@ func TestSwitchDBName(t *testing.T) { "", 400, map[string]string{"error": "Invalid request payload"}, + nil, + }, + { + "ping new database fails", + "initial_db", + "new_db", + 200, + map[string]string{"db_name": "initial_db"}, + func(mock sqlmock.Sqlmock) { + mock.ExpectPing().WillReturnError(errors.New("ping failed")) + }, }, } for _, tc := range cases { tc := tc t.Run(tc.description, func(t *testing.T) { + t.Parallel() + var body bytes.Buffer dbEnv := &db.Env{Name: tc.initialDBName} + db, mock, _ := sqlmock.New(sqlmock.MonitorPingsOption(true)) + defer func() { _ = db.Close() }() + w := httptest.NewRecorder() var requestBody []byte if tc.description == "invalid request payload" { @@ -65,7 +92,12 @@ func TestSwitchDBName(t *testing.T) { logger := test.DummyLogger(io.Discard).Sugar() - expected := &gabi.Config{DBEnv: dbEnv, Logger: logger} + expected := &gabi.Config{DBEnv: dbEnv, DB: db, Logger: logger} + + if tc.given != nil { + tc.given(mock) + } + SwitchDBName(expected).ServeHTTP(w, r.WithContext(context.TODO())) actual := w.Result() @@ -84,8 +116,14 @@ func TestSwitchDBName(t *testing.T) { require.NoError(t, err) assert.Equal(t, tc.code, actual.StatusCode) assert.Equal(t, tc.body, responseBody) - assert.Equal(t, tc.newDBName, dbEnv.GetCurrentDBName()) + if tc.description == "ping new database fails" { + assert.Equal(t, tc.initialDBName, dbEnv.GetCurrentDBName()) + } else { + assert.Equal(t, tc.newDBName, dbEnv.GetCurrentDBName()) + } } + + require.NoError(t, mock.ExpectationsWereMet()) }) } } diff --git a/pkg/middleware/audit.go b/pkg/middleware/audit.go index 1a894a9..4ec3db0 100644 --- a/pkg/middleware/audit.go +++ b/pkg/middleware/audit.go @@ -84,6 +84,7 @@ func Audit(cfg *gabi.Config) Middleware { Query: request.Query, User: user, Timestamp: now.Unix(), + DBName: cfg.DBEnv.GetCurrentDBName(), } _ = cfg.LoggerAudit.Write(ctx, query) diff --git a/pkg/middleware/audit_test.go b/pkg/middleware/audit_test.go index 00675b1..4986f7a 100644 --- a/pkg/middleware/audit_test.go +++ b/pkg/middleware/audit_test.go @@ -14,6 +14,7 @@ import ( "github.com/app-sre/gabi/internal/test" gabi "github.com/app-sre/gabi/pkg" "github.com/app-sre/gabi/pkg/audit" + "github.com/app-sre/gabi/pkg/env/db" "github.com/app-sre/gabi/pkg/env/splunk" "github.com/stretchr/testify/assert" ) @@ -33,6 +34,7 @@ func TestAudit(t *testing.T) { response string want *regexp.Regexp query string + dbName string }{ { "valid query", @@ -67,6 +69,7 @@ func TestAudit(t *testing.T) { `{"query":"select 1;","user":"test","namespace":"test","pod":"test"}`, regexp.MustCompile(`AUDIT\s{"Query": "select 1;", "User": "test", "Timestamp": \d{10}}`), `select 1;`, + "test_db", }, { "valid Base64-encoded query", @@ -104,6 +107,7 @@ func TestAudit(t *testing.T) { `{"query":"select 1;","user":"test","namespace":"test","pod":"test"}`, regexp.MustCompile(`AUDIT\s{"Query": "select 1;", "User": "test", "Timestamp": \d{10}}`), `select 1;`, + "test_db", }, { "valid query with user passed via context", @@ -138,6 +142,7 @@ func TestAudit(t *testing.T) { `{"query":"select 1;","user":"test2","namespace":"test","pod":"test"}`, regexp.MustCompile(`AUDIT\s{"Query": "select 1;", "User": "test2", "Timestamp": \d{10}}`), `select 1;`, + "test_db", }, { "valid query with empty HTTP query parameters provided", @@ -175,6 +180,7 @@ func TestAudit(t *testing.T) { `{"query":"select 1;","user":"test","namespace":"test","pod":"test"}`, regexp.MustCompile(`AUDIT\s{"Query": "select 1;", "User": "test", "Timestamp": \d{10}}`), `select 1;`, + "test_db", }, { "valid query with no SQL statements provided", @@ -206,6 +212,7 @@ func TestAudit(t *testing.T) { ``, regexp.MustCompile(`AUDIT\s{"Query": "", "User": "test", "Timestamp": \d{10}}`), ``, + "test_db", }, { "valid query with no Splunk endpoint configured", @@ -234,6 +241,7 @@ func TestAudit(t *testing.T) { ``, regexp.MustCompile(``), ``, + "test_db", }, { "valid query with invalid Splunk endpoint configured", @@ -264,6 +272,7 @@ func TestAudit(t *testing.T) { ``, regexp.MustCompile(``), ``, + "test_db", }, { "valid query with an error in Splunk response", @@ -295,6 +304,7 @@ func TestAudit(t *testing.T) { ``, regexp.MustCompile(``), ``, + "test_db", }, { "valid query with malformed JSON in Splunk response", @@ -326,6 +336,7 @@ func TestAudit(t *testing.T) { ``, regexp.MustCompile(``), ``, + "test_db", }, { "invalid query with empty body", @@ -356,6 +367,7 @@ func TestAudit(t *testing.T) { ``, regexp.MustCompile(`Unable to unmarshal request body`), ``, + "test_db", }, { "invalid query with malformed JSON in the body", @@ -386,6 +398,7 @@ func TestAudit(t *testing.T) { ``, regexp.MustCompile(`Unable to unmarshal request body`), ``, + "test_db", }, { "invalid query with no required headers set", @@ -415,6 +428,7 @@ func TestAudit(t *testing.T) { ``, regexp.MustCompile(``), ``, + "test_db", }, { "invalid query with no required user header set", @@ -444,6 +458,7 @@ func TestAudit(t *testing.T) { ``, regexp.MustCompile(``), ``, + "test_db", }, { "invalid query with malformed Base64-encoded value in the body", @@ -481,6 +496,7 @@ func TestAudit(t *testing.T) { ``, regexp.MustCompile(``), ``, + "test_db", }, } @@ -495,6 +511,8 @@ func TestAudit(t *testing.T) { query string ) + dbEnv := &db.Env{Name: tc.dbName} + w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", tc.request()) @@ -510,7 +528,7 @@ func TestAudit(t *testing.T) { tc.headers(tc.request())(r) - expected := &gabi.Config{LoggerAudit: la, SplunkAudit: sa, Logger: logger, Encoder: encoder} + expected := &gabi.Config{LoggerAudit: la, SplunkAudit: sa, Logger: logger, Encoder: encoder, DBEnv: dbEnv} Audit(expected)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { query, _ = r.Context().Value(ContextKeyQuery).(string) })).ServeHTTP(w, r.WithContext(tc.context())) diff --git a/pkg/models/dbname.go b/pkg/models/dbname.go index bb6d34d..4b86e45 100644 --- a/pkg/models/dbname.go +++ b/pkg/models/dbname.go @@ -5,5 +5,6 @@ type SwitchDBNameRequest struct { } type DBNameResponse struct { - DBName string `json:"db_name"` + DBName string `json:"db_name"` + Warnings []string `json:"warnings,omitempty"` } diff --git a/pkg/models/query.go b/pkg/models/query.go index a6270d7..26932f5 100644 --- a/pkg/models/query.go +++ b/pkg/models/query.go @@ -5,6 +5,7 @@ type QueryRequest struct { } type QueryResponse struct { - Result [][]string `json:"result"` - Error string `json:"error"` + Result [][]string `json:"result"` + Warnings []string `json:"warnings,omitempty"` + Error string `json:"error"` }