From 4c00a26f7f4f6184c0a4951c0fe64114d6186e24 Mon Sep 17 00:00:00 2001 From: Benjamin Bengfort Date: Sun, 24 Nov 2024 09:09:18 -0600 Subject: [PATCH] Replica Server (#34) --- .env.template | 4 + .gitignore | 4 + cmd/honudb/main.go | 46 ++- go.mod | 5 +- go.sum | 14 +- pkg/api/v1/api.go | 37 +++ pkg/api/v1/client.go | 348 +++++++++++++++++++++++ pkg/api/v1/client_test.go | 80 ++++++ pkg/api/v1/context.go | 34 +++ pkg/api/v1/credentials/credentials.go | 14 + pkg/api/v1/credentials/errors.go | 7 + pkg/api/v1/errors.go | 206 ++++++++++++++ pkg/api/v1/errors_test.go | 200 +++++++++++++ pkg/api/v1/options.go | 13 + pkg/api/v1/testdata/statusOK.json | 5 + pkg/config/config.go | 14 +- pkg/config/config_test.go | 2 + pkg/server/maintenance.go | 29 ++ pkg/server/middleware/middleware.go | 20 ++ pkg/server/middleware/middleware_test.go | 144 ++++++++++ pkg/server/render/json.go | 21 ++ pkg/server/render/render.go | 34 +++ pkg/server/routes.go | 30 ++ pkg/server/server.go | 187 ++++++++++++ pkg/server/status.go | 67 +++++ pkg/store/decode.go | 2 +- pkg/store/decode_test.go | 2 +- pkg/store/encode.go | 2 +- pkg/store/encode_test.go | 2 +- pkg/store/object.go | 2 +- pkg/store/object_test.go | 2 +- 31 files changed, 1560 insertions(+), 17 deletions(-) create mode 100644 .env.template create mode 100644 pkg/api/v1/api.go create mode 100644 pkg/api/v1/client.go create mode 100644 pkg/api/v1/client_test.go create mode 100644 pkg/api/v1/context.go create mode 100644 pkg/api/v1/credentials/credentials.go create mode 100644 pkg/api/v1/credentials/errors.go create mode 100644 pkg/api/v1/errors.go create mode 100644 pkg/api/v1/errors_test.go create mode 100644 pkg/api/v1/options.go create mode 100644 pkg/api/v1/testdata/statusOK.json create mode 100644 pkg/server/maintenance.go create mode 100644 pkg/server/middleware/middleware.go create mode 100644 pkg/server/middleware/middleware_test.go create mode 100644 pkg/server/render/json.go create mode 100644 pkg/server/render/render.go create mode 100644 pkg/server/routes.go create mode 100644 pkg/server/server.go create mode 100644 pkg/server/status.go diff --git a/.env.template b/.env.template new file mode 100644 index 0000000..ad9e76e --- /dev/null +++ b/.env.template @@ -0,0 +1,4 @@ +HONU_MAINTENANCE=false +HONU_LOG_LEVEL=info +HONU_CONSOLE_LOG=true +HONU_BIND_ADDR=127.0.0.1:3264 diff --git a/.gitignore b/.gitignore index 66fd13c..78b3838 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,7 @@ # Dependency directories (remove the comment below to include it) # vendor/ + +# Local environment variables and secrets +.env +.secret diff --git a/cmd/honudb/main.go b/cmd/honudb/main.go index 434231f..085812e 100644 --- a/cmd/honudb/main.go +++ b/cmd/honudb/main.go @@ -1,13 +1,16 @@ package main import ( - "fmt" "log" "os" + "text/tabwriter" - "github.com/joho/godotenv" "github.com/rotationalio/honu/pkg" "github.com/rotationalio/honu/pkg/config" + "github.com/rotationalio/honu/pkg/server" + + "github.com/joho/godotenv" + confire "github.com/rotationalio/confire/usage" "github.com/urfave/cli/v2" ) @@ -29,6 +32,19 @@ func main() { Action: serve, Flags: []cli.Flag{}, }, + { + Name: "config", + Usage: "print honu database replica configuration guide", + Category: "server", + Action: usage, + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "list", + Aliases: []string{"l"}, + Usage: "print in list mode instead of table mode", + }, + }, + }, } if err := app.Run(os.Args); err != nil { @@ -41,12 +57,34 @@ func main() { //=========================================================================== func serve(c *cli.Context) (err error) { - // Load the configuration from a file or from the environment. var conf config.Config if conf, err = config.New(); err != nil { return cli.Exit(err, 1) } - fmt.Println(conf) + var honu *server.Server + if honu, err = server.New(conf); err != nil { + return cli.Exit(err, 1) + } + + if err = honu.Serve(); err != nil { + return cli.Exit(err, 1) + } + return nil +} + +func usage(c *cli.Context) error { + tabs := tabwriter.NewWriter(os.Stdout, 1, 0, 4, ' ', 0) + format := confire.DefaultTableFormat + if c.Bool("list") { + format = confire.DefaultListFormat + } + + var conf config.Config + if err := confire.Usagef(config.Prefix, &conf, tabs, format); err != nil { + return cli.Exit(err, 1) + } + + tabs.Flush() return nil } diff --git a/go.mod b/go.mod index 1ad1560..8677962 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,11 @@ module github.com/rotationalio/honu go 1.23.1 require ( + github.com/cenkalti/backoff v2.2.1+incompatible + github.com/google/go-querystring v1.1.0 github.com/joho/godotenv v1.5.1 - github.com/oklog/ulid v1.3.1 + github.com/julienschmidt/httprouter v1.3.0 + github.com/oklog/ulid/v2 v2.1.0 github.com/rotationalio/confire v1.1.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index cc29280..6cdd062 100644 --- a/go.sum +++ b/go.sum @@ -1,18 +1,27 @@ +github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= +github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.5 h1:ZtcqGrnekaHpVLArFSe4HK5DoKx1T0rq2DwVB0alcyc= github.com/cpuguy83/go-md2man/v2 v2.0.5/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= -github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/oklog/ulid/v2 v2.1.0 h1:+9lhoxAP56we25tyYETBBY1YLA2SaoLvUFgrP2miPJU= +github.com/oklog/ulid/v2 v2.1.0/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= +github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -33,6 +42,7 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/api/v1/api.go b/pkg/api/v1/api.go new file mode 100644 index 0000000..22be2d5 --- /dev/null +++ b/pkg/api/v1/api.go @@ -0,0 +1,37 @@ +package api + +import "context" + +//=========================================================================== +// Service Interface +//=========================================================================== + +// Client defines the service interface for interacting with the HonuDB service. +type Client interface { + Status(context.Context) (*StatusReply, error) +} + +//=========================================================================== +// Top Level Requests and Responses +//=========================================================================== + +// Reply contains standard fields that are used for generic API responses and errors. +type Reply struct { + Success bool `json:"success"` + Error string `json:"error,omitempty"` + ErrorDetail ErrorDetail `json:"errors,omitempty"` +} + +// Returned on status requests. +type StatusReply struct { + Status string `json:"status"` + Uptime string `json:"uptime,omitempty"` + Version string `json:"version,omitempty"` +} + +// PageQuery manages paginated list requests. +type PageQuery struct { + PageSize int `json:"page_size,omitempty" url:"page_size,omitempty" form:"page_size"` + NextPageToken string `json:"next_page_token,omitempty" url:"next_page_token,omitempty" form:"next_page_token"` + PrevPageToken string `json:"prev_page_token,omitempty" url:"prev_page_token,omitempty" form:"prev_page_token"` +} diff --git a/pkg/api/v1/client.go b/pkg/api/v1/client.go new file mode 100644 index 0000000..52a5fef --- /dev/null +++ b/pkg/api/v1/client.go @@ -0,0 +1,348 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "net/http/cookiejar" + "net/url" + "time" + + "github.com/cenkalti/backoff" + "github.com/google/go-querystring/query" + "github.com/oklog/ulid/v2" + "github.com/rotationalio/honu/pkg/api/v1/credentials" + "github.com/rs/zerolog/log" +) + +const ( + userAgent = "HonuDB API Client/v1" + accept = "application/json" + acceptLang = "en-US,en" + acceptEncode = "gzip, deflate, br" + contentType = "application/json; charset=utf-8" +) + +// New creates a new APIv1 client that implements the Client interface. +func New(endpoint string, opts ...ClientOption) (_ Client, err error) { + c := &APIv1{} + if c.endpoint, err = url.Parse(endpoint); err != nil { + return nil, fmt.Errorf("could not parse endpoint: %s", err) + } + + // Apply our options + for _, opt := range opts { + if err = opt(c); err != nil { + return nil, err + } + } + + // If an http client isn't specified, create a default client. + if c.client == nil { + c.client = &http.Client{ + Transport: nil, + CheckRedirect: nil, + Timeout: 30 * time.Second, + } + + // Create cookie jar for CSRF + if c.client.Jar, err = cookiejar.New(nil); err != nil { + return nil, fmt.Errorf("could not create cookiejar: %w", err) + } + } + + return c, nil +} + +// APIv1 implements the v1 Client interface for making requests to a honudb replica. +type APIv1 struct { + endpoint *url.URL // the base url for all requests + client *http.Client // used to make http requests to the server + creds credentials.Credentials // used to authenticate requests with the server +} + +// Ensure the APIv1 implements the Client interface +var _ Client = &APIv1{} + +//=========================================================================== +// Client Methods +//=========================================================================== + +const statusEP = "/v1/status" + +func (s *APIv1) Status(ctx context.Context) (out *StatusReply, err error) { + // Make the HTTP request + var req *http.Request + if req, err = s.NewRequest(ctx, http.MethodGet, statusEP, nil, nil); err != nil { + return nil, err + } + + // NOTE: we cannot use s.Do because we want to parse 503 Unavailable errors + var rep *http.Response + if rep, err = s.client.Do(req); err != nil { + return nil, err + } + defer rep.Body.Close() + + // Detect other errors + if rep.StatusCode != http.StatusOK && rep.StatusCode != http.StatusServiceUnavailable { + return nil, fmt.Errorf("%s", rep.Status) + } + + // Deserialize the JSON data from the response + out = &StatusReply{} + if err = json.NewDecoder(rep.Body).Decode(out); err != nil { + return nil, fmt.Errorf("could not deserialize status reply: %s", err) + } + return out, nil +} + +//=========================================================================== +// Client Utility Methods +//=========================================================================== + +// Wait for ready polls the node's status endpoint until it responds with an 200 +// response, retrying with exponential backoff or until the context deadline is expired. +// If the user does not supply a context with a deadline, then a default deadline of +// 5 minutes is used so that this method does not block indefinitely. If the node API +// service is ready (e.g. responds to a status request) then no error is returned, +// otherwise an error is returned if the node never responds. +// +// NOTE: if the node returns a 503 Service Unavailable because it is in maintenance +// mode, this method will continue to wait until the deadline for the node to exit +// from maintenance mode and be ready again. +func (s *APIv1) WaitForReady(ctx context.Context) (err error) { + // If context does not have a deadline, create a context with a default deadline. + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + } + + // Create the status request to send until ready + var req *http.Request + if req, err = s.NewRequest(ctx, http.MethodGet, "/v1/status", nil, nil); err != nil { + return err + } + + // Create a closure to repeatedly call the status endpoint + checkReady := func() (err error) { + var rep *http.Response + if rep, err = s.client.Do(req); err != nil { + return err + } + defer rep.Body.Close() + + if rep.StatusCode < 200 || rep.StatusCode >= 300 { + return &StatusError{StatusCode: rep.StatusCode, Reply: Reply{Success: false, Error: http.StatusText(rep.StatusCode)}} + } + return nil + } + + // Create exponential backoff ticker for retries + ticker := backoff.NewExponentialBackOff() + + // Keep checking if the node is ready until it is ready or until the context expires. + for { + // Execute the status request + if err = checkReady(); err == nil { + // Success - node is ready for requests! + return nil + } + + // Log the error warning that we're still waiting to connect to the node + log.Warn().Err(err).Str("endpoint", s.endpoint.String()).Msg("waiting to connect to TRISA node") + wait := time.After(ticker.NextBackOff()) + + // Wait for the context to be done or for the ticker to move to the next backoff. + select { + case <-ctx.Done(): + return ctx.Err() + case <-wait: + } + } +} + +//=========================================================================== +// REST Resource Methods +//=========================================================================== + +func (s *APIv1) List(ctx context.Context, endpoint string, in *PageQuery, out interface{}) (err error) { + var params url.Values + if params, err = query.Values(in); err != nil { + return fmt.Errorf("could not encode page query: %w", err) + } + + var req *http.Request + if req, err = s.NewRequest(ctx, http.MethodGet, endpoint, nil, ¶ms); err != nil { + return err + } + + if _, err = s.Do(req, &out, true); err != nil { + return err + } + + return nil +} + +func (s *APIv1) Create(ctx context.Context, endpoint string, in, out interface{}) (err error) { + var req *http.Request + if req, err = s.NewRequest(ctx, http.MethodPost, endpoint, in, nil); err != nil { + return err + } + + if _, err = s.Do(req, &out, true); err != nil { + return err + } + return nil +} + +func (s *APIv1) Detail(ctx context.Context, endpoint string, out interface{}) (err error) { + var req *http.Request + if req, err = s.NewRequest(ctx, http.MethodGet, endpoint, nil, nil); err != nil { + return err + } + + if _, err = s.Do(req, &out, true); err != nil { + return err + } + return nil +} + +func (s *APIv1) Update(ctx context.Context, endpoint string, in, out interface{}) (err error) { + var req *http.Request + if req, err = s.NewRequest(ctx, http.MethodPut, endpoint, in, nil); err != nil { + return err + } + + if _, err = s.Do(req, &out, true); err != nil { + return err + } + return nil +} + +func (s *APIv1) Delete(ctx context.Context, endpoint string) (err error) { + var req *http.Request + if req, err = s.NewRequest(ctx, http.MethodDelete, endpoint, nil, nil); err != nil { + return err + } + + if _, err = s.Do(req, nil, true); err != nil { + return err + } + return nil +} + +//=========================================================================== +// Helper Methods +//=========================================================================== + +func (s *APIv1) NewRequest(ctx context.Context, method, path string, data interface{}, params *url.Values) (req *http.Request, err error) { + // Resolve the URL reference from the path + url := s.endpoint.ResolveReference(&url.URL{Path: path}) + if params != nil && len(*params) > 0 { + url.RawQuery = params.Encode() + } + + var body io.ReadWriter + switch { + case data == nil: + body = nil + default: + body = &bytes.Buffer{} + if err = json.NewEncoder(body).Encode(data); err != nil { + return nil, fmt.Errorf("could not serialize request data as json: %s", err) + } + } + + // Create the http request + if req, err = http.NewRequestWithContext(ctx, method, url.String(), body); err != nil { + return nil, fmt.Errorf("could not create request: %s", err) + } + + // Set the headers on the request + req.Header.Add("User-Agent", userAgent) + req.Header.Add("Accept", accept) + req.Header.Add("Accept-Language", acceptLang) + req.Header.Add("Accept-Encoding", acceptEncode) + req.Header.Add("Content-Type", contentType) + + // If there is a request ID on the context, set it on the request, otherwise generate one + var requestID string + if requestID, _ = RequestIDFromContext(ctx); requestID == "" { + requestID = ulid.Make().String() + } + req.Header.Add("X-Request-ID", requestID) + + // Add authentication and authorization header. + if s.creds != nil { + var token string + if token, err = s.creds.AccessToken(); err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + } + + // Add CSRF protection if its available + if s.client.Jar != nil { + cookies := s.client.Jar.Cookies(url) + for _, cookie := range cookies { + if cookie.Name == "csrf_token" { + req.Header.Add("X-CSRF-TOKEN", cookie.Value) + } + } + } + + return req, nil +} + +// Do executes an http request against the server, performs error checking, and +// deserializes the response data into the specified struct. +func (s *APIv1) Do(req *http.Request, data interface{}, checkStatus bool) (rep *http.Response, err error) { + if rep, err = s.client.Do(req); err != nil { + return rep, fmt.Errorf("could not execute request: %s", err) + } + defer rep.Body.Close() + + // Detect http status errors if they've occurred + if checkStatus { + if rep.StatusCode < 200 || rep.StatusCode >= 300 { + // Attempt to read the error response from JSON, if available + serr := &StatusError{ + StatusCode: rep.StatusCode, + } + + if err = json.NewDecoder(rep.Body).Decode(&serr.Reply); err == nil { + return rep, serr + } + + serr.Reply = Unsuccessful + return rep, serr + } + } + + // Deserialize the JSON data from the body + if data != nil && rep.StatusCode >= 200 && rep.StatusCode < 300 && rep.StatusCode != http.StatusNoContent { + ct := rep.Header.Get("Content-Type") + if ct != "" { + mt, _, err := mime.ParseMediaType(ct) + if err != nil { + return nil, fmt.Errorf("malformed content-type header: %w", err) + } + + if mt != accept { + return nil, fmt.Errorf("unexpected content type: %q", mt) + } + } + + if err = json.NewDecoder(rep.Body).Decode(data); err != nil { + return nil, fmt.Errorf("could not deserialize response data: %s", err) + } + } + + return rep, nil +} diff --git a/pkg/api/v1/client_test.go b/pkg/api/v1/client_test.go new file mode 100644 index 0000000..66b1df3 --- /dev/null +++ b/pkg/api/v1/client_test.go @@ -0,0 +1,80 @@ +package api_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/rotationalio/honu/pkg/api/v1" + "github.com/stretchr/testify/require" +) + +var ( + ctx = context.Background() +) + +func TestStatus(t *testing.T) { + fixture := &api.StatusReply{} + err := loadFixture("testdata/statusOK.json", fixture) + require.NoError(t, err, "could not load status ok fixture") + + _, client := testServer(t, &testServerConfig{ + expectedMethod: http.MethodGet, + expectedPath: "/v1/status", + fixture: fixture, + statusCode: http.StatusOK, + }) + + rep, err := client.Status(ctx) + require.NoError(t, err, "could not execute status request") + require.Equal(t, fixture, rep, "expected reply to be equal to fixture") +} + +type testServerConfig struct { + expectedMethod string + expectedPath string + statusCode int + fixture interface{} +} + +func testServer(t *testing.T, conf *testServerConfig) (*httptest.Server, api.Client) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != conf.expectedMethod { + http.Error(w, fmt.Sprintf("expected method %s got %s", conf.expectedMethod, r.Method), http.StatusExpectationFailed) + return + } + + if r.URL.Path != conf.expectedPath { + http.Error(w, fmt.Sprintf("expected path %s got %s", conf.expectedPath, r.URL.Path), http.StatusExpectationFailed) + return + } + + if conf.statusCode == 0 { + conf.statusCode = http.StatusOK + } + + w.Header().Add("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(conf.statusCode) + json.NewEncoder(w).Encode(conf.fixture) + })) + + // Ensure the server is closed when the test is complete + t.Cleanup(ts.Close) + + client, err := api.New(ts.URL) + require.NoError(t, err, "could not create api client") + return ts, client +} + +func loadFixture(path string, v interface{}) (err error) { + var f *os.File + if f, err = os.Open(path); err != nil { + return err + } + defer f.Close() + return json.NewDecoder(f).Decode(v) +} diff --git a/pkg/api/v1/context.go b/pkg/api/v1/context.go new file mode 100644 index 0000000..ea8aa66 --- /dev/null +++ b/pkg/api/v1/context.go @@ -0,0 +1,34 @@ +package api + +import "context" + +// API-specific context keys for passing values to requests via the context. These keys +// are unexported to reduce the size of the public interface an prevent incorrect handling. +type contextKey uint8 + +// Allocate context keys to simplify context key usage in helper functions. +const ( + contextKeyUnknown contextKey = iota + contextKeyRequestID +) + +// Adds a request ID to the context which is sent with the request in the X-Request-ID header. +func ContextWithRequestID(parent context.Context, requestID string) context.Context { + return context.WithValue(parent, contextKeyRequestID, requestID) +} + +// Extracts a request ID from the context. +func RequestIDFromContext(ctx context.Context) (string, bool) { + requestID, ok := ctx.Value(contextKeyRequestID).(string) + return requestID, ok +} + +var contextKeyNames = []string{"unknown", "requestID"} + +// String returns a human readable representation of the context key for easier debugging. +func (c contextKey) String() string { + if int(c) < len(contextKeyNames) { + return contextKeyNames[c] + } + return contextKeyNames[0] +} diff --git a/pkg/api/v1/credentials/credentials.go b/pkg/api/v1/credentials/credentials.go new file mode 100644 index 0000000..3bd6c81 --- /dev/null +++ b/pkg/api/v1/credentials/credentials.go @@ -0,0 +1,14 @@ +package credentials + +type Credentials interface { + AccessToken() (string, error) +} + +type Token string + +func (t Token) AccessToken() (string, error) { + if string(t) == "" { + return "", ErrInvalidCredentials + } + return string(t), nil +} diff --git a/pkg/api/v1/credentials/errors.go b/pkg/api/v1/credentials/errors.go new file mode 100644 index 0000000..2676c08 --- /dev/null +++ b/pkg/api/v1/credentials/errors.go @@ -0,0 +1,7 @@ +package credentials + +import "errors" + +var ( + ErrInvalidCredentials = errors.New("missing, invalid or expired credentials") +) diff --git a/pkg/api/v1/errors.go b/pkg/api/v1/errors.go new file mode 100644 index 0000000..7c3da90 --- /dev/null +++ b/pkg/api/v1/errors.go @@ -0,0 +1,206 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" +) + +//=========================================================================== +// Standard Error Handling +//=========================================================================== + +var ( + Unsuccessful = Reply{Success: false} + NotFound = Reply{Success: false, Error: "resource not found"} + NotAllowed = Reply{Success: false, Error: "method not allowed"} +) + +// Construct a new response for an error or simply return unsuccessful. +func Error(err interface{}) Reply { + if err == nil { + return Unsuccessful + } + + rep := Reply{Success: false} + switch err := err.(type) { + case ValidationErrors: + if len(err) == 1 { + rep.Error = err.Error() + } else { + rep.Error = fmt.Sprintf("%d validation errors occurred", len(err)) + rep.ErrorDetail = make(ErrorDetail, 0, len(err)) + for _, verr := range err { + rep.ErrorDetail = append(rep.ErrorDetail, &DetailError{ + Field: verr.field, + Error: verr.Error(), + }) + } + } + case error: + rep.Error = err.Error() + case string: + rep.Error = err + case fmt.Stringer: + rep.Error = err.String() + case json.Marshaler: + data, e := err.MarshalJSON() + if e != nil { + panic(err) + } + rep.Error = string(data) + default: + rep.Error = "unhandled error response" + } + + return rep +} + +//=========================================================================== +// Status Errors +//=========================================================================== + +// StatusError decodes an error response from the TRISA API. +type StatusError struct { + StatusCode int + Reply Reply +} + +func (e *StatusError) Error() string { + return fmt.Sprintf("[%d] %s", e.StatusCode, e.Reply.Error) +} + +// ErrorStatus returns the HTTP status code from an error or 500 if the error is not a StatusError. +func ErrorStatus(err error) int { + if err == nil { + return http.StatusOK + } + + if e, ok := err.(*StatusError); !ok || e.StatusCode < 100 || e.StatusCode >= 600 { + return http.StatusInternalServerError + } else { + return e.StatusCode + } +} + +//=========================================================================== +// Detail Error +//=========================================================================== + +type ErrorDetail []*DetailError + +type DetailError struct { + Field string `json:"field"` + Error string `json:"error"` +} + +//=========================================================================== +// Field Validation Errors +//=========================================================================== + +func MissingField(field string) *FieldError { + return &FieldError{verb: "missing", field: field, issue: "this field is required"} +} + +func IncorrectField(field, issue string) *FieldError { + return &FieldError{verb: "invalid field", field: field, issue: issue} +} + +func ReadOnlyField(field string) *FieldError { + return &FieldError{verb: "read-only field", field: field, issue: "this field cannot be written by the user"} +} + +func OneOfMissing(fields ...string) *FieldError { + var fieldstr string + switch len(fields) { + case 0: + panic("no fields specified for one of") + case 1: + return MissingField(fields[0]) + default: + fieldstr = fieldList(fields...) + } + + return &FieldError{verb: "missing one of", field: fieldstr, issue: "at most one of these fields is required"} +} + +func OneOfTooMany(fields ...string) *FieldError { + if len(fields) < 2 { + panic("must specify at least two fields for one of too many") + } + return &FieldError{verb: "specify only one of", field: fieldList(fields...), issue: "at most one of these fields may be specified"} +} + +func ValidationError(err error, errs ...*FieldError) error { + var verr ValidationErrors + if err == nil { + verr = make(ValidationErrors, 0, len(errs)) + } else { + var ok bool + if verr, ok = err.(ValidationErrors); !ok { + verr = make(ValidationErrors, 0, len(errs)+1) + verr = append(verr, &FieldError{verb: "invalid", field: "input", issue: err.Error()}) + } + } + + for _, e := range errs { + if e != nil { + verr = append(verr, e) + } + } + + if len(verr) == 0 { + return nil + } + return verr +} + +type FieldError struct { + verb string + field string + issue string +} + +func (e *FieldError) Error() string { + return fmt.Sprintf("%s %s: %s", e.verb, e.field, e.issue) +} + +func (e *FieldError) Subfield(parent string) *FieldError { + e.field = fmt.Sprintf("%s.%s", parent, e.field) + return e +} + +func (e *FieldError) SubfieldArray(parent string, index int) *FieldError { + e.field = fmt.Sprintf("%s[%d].%s", parent, index, e.field) + return e +} + +type ValidationErrors []*FieldError + +func (e ValidationErrors) Error() string { + if len(e) == 1 { + return e[0].Error() + } + + errs := make([]string, 0, len(e)) + for _, err := range e { + errs = append(errs, err.Error()) + } + + return fmt.Sprintf("%d validation errors occurred:\n %s", len(e), strings.Join(errs, "\n ")) +} + +func fieldList(fields ...string) string { + switch len(fields) { + case 0: + return "" + case 1: + return fields[0] + case 2: + return fmt.Sprintf("%s or %s", fields[0], fields[1]) + default: + last := len(fields) - 1 + return fmt.Sprintf("%s, or %s", strings.Join(fields[0:last], ", "), fields[last]) + } +} diff --git a/pkg/api/v1/errors_test.go b/pkg/api/v1/errors_test.go new file mode 100644 index 0000000..5001a3a --- /dev/null +++ b/pkg/api/v1/errors_test.go @@ -0,0 +1,200 @@ +package api_test + +import ( + "fmt" + "testing" + + "github.com/rotationalio/honu/pkg/api/v1" + + "github.com/stretchr/testify/require" +) + +func TestValidationErrors(t *testing.T) { + + t.Run("Nil", func(t *testing.T) { + require.NoError(t, api.ValidationError(nil, nil, nil, nil, nil)) + }) + + t.Run("Single", func(t *testing.T) { + testCases := []struct { + err error + errs []*api.FieldError + expected string + }{ + { + nil, + []*api.FieldError{api.MissingField("foo")}, + "missing foo: this field is required", + }, + { + make(api.ValidationErrors, 0), + []*api.FieldError{api.MissingField("foo")}, + "missing foo: this field is required", + }, + { + nil, + []*api.FieldError{nil, api.MissingField("foo"), nil}, + "missing foo: this field is required", + }, + } + + for i, tc := range testCases { + err := api.ValidationError(tc.err, tc.errs...) + require.EqualError(t, err, tc.expected, "test case %d failed", i) + } + }) + + t.Run("Multi", func(t *testing.T) { + testCases := []struct { + err error + errs []*api.FieldError + expected string + }{ + { + nil, + []*api.FieldError{api.MissingField("foo"), api.MissingField("bar")}, + "2 validation errors occurred:\n missing foo: this field is required\n missing bar: this field is required", + }, + { + nil, + []*api.FieldError{nil, api.MissingField("foo"), nil, api.MissingField("bar"), nil}, + "2 validation errors occurred:\n missing foo: this field is required\n missing bar: this field is required", + }, + { + api.ValidationErrors([]*api.FieldError{api.MissingField("foo")}), + []*api.FieldError{nil, api.MissingField("bar"), nil}, + "2 validation errors occurred:\n missing foo: this field is required\n missing bar: this field is required", + }, + } + + for i, tc := range testCases { + err := api.ValidationError(tc.err, tc.errs...) + require.EqualError(t, err, tc.expected, "test case %d failed", i) + } + }) + + t.Run("OneOfMissing", func(t *testing.T) { + testCases := []struct { + fields []string + expected string + }{ + { + []string{"foo"}, + "missing foo: this field is required", + }, + { + []string{"foo", "bar"}, + "missing one of foo or bar: at most one of these fields is required", + }, + { + []string{"foo", "bar", "zap"}, + "missing one of foo, bar, or zap: at most one of these fields is required", + }, + { + []string{"foo", "bar", "zap", "baz"}, + "missing one of foo, bar, zap, or baz: at most one of these fields is required", + }, + } + + for i, tc := range testCases { + err := api.OneOfMissing(tc.fields...) + require.EqualError(t, err, tc.expected, "test case %d failed", i) + } + }) + + t.Run("OneOfTooMany", func(t *testing.T) { + testCases := []struct { + fields []string + expected string + }{ + { + []string{"foo", "bar"}, + "specify only one of foo or bar: at most one of these fields may be specified", + }, + { + []string{"foo", "bar", "zap"}, + "specify only one of foo, bar, or zap: at most one of these fields may be specified", + }, + { + []string{"foo", "bar", "zap", "baz"}, + "specify only one of foo, bar, zap, or baz: at most one of these fields may be specified", + }, + } + + for i, tc := range testCases { + err := api.OneOfTooMany(tc.fields...) + require.EqualError(t, err, tc.expected, "test case %d failed", i) + } + }) +} + +func ExampleValidationErrors() { + err := api.ValidationError( + nil, + api.MissingField("name"), + api.IncorrectField("ssn", "ssn should be 8 digits only"), + nil, + api.MissingField("date_of_birth"), + nil, + ) + + fmt.Println(err) + // Output: + // 3 validation errors occurred: + // missing name: this field is required + // invalid field ssn: ssn should be 8 digits only + // missing date_of_birth: this field is required +} + +func TestFieldError(t *testing.T) { + t.Run("Subfield", func(t *testing.T) { + tests := []struct { + err *api.FieldError + parent string + expected string + }{ + { + api.MissingField("last_name"), + "person", + "missing person.last_name: this field is required", + }, + { + api.IncorrectField("banner", "banner must have ## prefix"), + "prom.queen", + "invalid field prom.queen.banner: banner must have ## prefix", + }, + } + + for i, tc := range tests { + err := tc.err.Subfield(tc.parent) + require.EqualError(t, err, tc.expected, "test case %d failed", i) + } + }) + + t.Run("SubfieldArray", func(t *testing.T) { + tests := []struct { + err *api.FieldError + parent string + index int + expected string + }{ + { + api.MissingField("last_name"), + "persons", + 0, + "missing persons[0].last_name: this field is required", + }, + { + api.IncorrectField("banner", "banner must have ## prefix"), + "prom.queens", + 14, + "invalid field prom.queens[14].banner: banner must have ## prefix", + }, + } + + for i, tc := range tests { + err := tc.err.SubfieldArray(tc.parent, tc.index) + require.EqualError(t, err, tc.expected, "test case %d failed", i) + } + }) +} diff --git a/pkg/api/v1/options.go b/pkg/api/v1/options.go new file mode 100644 index 0000000..b2457ce --- /dev/null +++ b/pkg/api/v1/options.go @@ -0,0 +1,13 @@ +package api + +import "net/http" + +// ClientOption allows us to configure the APIv1 client when it is created. +type ClientOption func(c *APIv1) error + +func WithClient(client *http.Client) ClientOption { + return func(c *APIv1) error { + c.client = client + return nil + } +} diff --git a/pkg/api/v1/testdata/statusOK.json b/pkg/api/v1/testdata/statusOK.json new file mode 100644 index 0000000..4e46738 --- /dev/null +++ b/pkg/api/v1/testdata/statusOK.json @@ -0,0 +1,5 @@ +{ + "status": "ok", + "uptime": "5s", + "version": "1.0.0" +} \ No newline at end of file diff --git a/pkg/config/config.go b/pkg/config/config.go index 8e92452..5b73c3b 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,6 +1,8 @@ package config import ( + "time" + "github.com/rotationalio/confire" "github.com/rotationalio/honu/pkg/logger" "github.com/rs/zerolog" @@ -16,10 +18,14 @@ const Prefix = "honu" // values that are omitted. The Config should be validated in preparation for running // the honudb instance to ensure that all server operations work as expected. type Config struct { - Maintenance bool `default:"false" desc:"if true, the replica will start in maintenance mode"` - LogLevel logger.LevelDecoder `split_words:"true" default:"info" desc:"specify the verbosity of logging (trace, debug, info, warn, error, fatal panic)"` - ConsoleLog bool `split_words:"true" default:"false" desc:"if true logs colorized human readable output instead of json"` - processed bool + Maintenance bool `default:"false" desc:"if true, the replica will start in maintenance mode"` + LogLevel logger.LevelDecoder `split_words:"true" default:"info" desc:"specify the verbosity of logging (trace, debug, info, warn, error, fatal panic)"` + ConsoleLog bool `split_words:"true" default:"false" desc:"if true logs colorized human readable output instead of json"` + BindAddr string `split_words:"true" default:":3264" desc:"the ip address and port to bind the honu database server on"` + ReadTimeout time.Duration `split_words:"true" default:"20s" desc:"amount of time allowed to read request headers before server decides the request is too slow"` + WriteTimeout time.Duration `split_words:"true" default:"20s" desc:"maximum amount of time before timing out a write to a response"` + IdleTimeout time.Duration `split_words:"true" default:"10m" desc:"maximum amount of time to wait for the next request while keep alives are enabled"` + processed bool } func New() (conf Config, err error) { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index bc5371c..786310f 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -13,6 +13,7 @@ var testEnv = map[string]string{ "HONU_MAINTENANCE": "true", "HONU_LOG_LEVEL": "debug", "HONU_CONSOLE_LOG": "true", + "HONU_BIND_ADDR": "127.0.0.1:443", } func TestConfig(t *testing.T) { @@ -28,6 +29,7 @@ func TestConfig(t *testing.T) { require.True(t, conf.Maintenance) require.Equal(t, zerolog.DebugLevel, conf.GetLogLevel()) require.True(t, conf.ConsoleLog) + require.Equal(t, testEnv["HONU_BIND_ADDR"], conf.BindAddr) } // Returns the current environment for the specified keys, or if no keys are specified diff --git a/pkg/server/maintenance.go b/pkg/server/maintenance.go new file mode 100644 index 0000000..fabbfb5 --- /dev/null +++ b/pkg/server/maintenance.go @@ -0,0 +1,29 @@ +package server + +import ( + "net/http" + "time" + + "github.com/julienschmidt/httprouter" + "github.com/rotationalio/honu/pkg" + "github.com/rotationalio/honu/pkg/api/v1" + "github.com/rotationalio/honu/pkg/server/middleware" + "github.com/rotationalio/honu/pkg/server/render" +) + +// If the server is in maintenance mode, aborts the current request and renders the +// maintenance mode page instead. Returns nil if not in maintenance mode. +func (s *Server) Maintenance() middleware.Middleware { + if s.conf.Maintenance { + return func(next httprouter.Handle) httprouter.Handle { + return func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) { + render.JSON(http.StatusServiceUnavailable, w, &api.StatusReply{ + Status: "maintenance", + Version: pkg.Version(), + Uptime: time.Since(s.started).String(), + }) + } + } + } + return nil +} diff --git a/pkg/server/middleware/middleware.go b/pkg/server/middleware/middleware.go new file mode 100644 index 0000000..c90e7bd --- /dev/null +++ b/pkg/server/middleware/middleware.go @@ -0,0 +1,20 @@ +package middleware + +import "github.com/julienschmidt/httprouter" + +type Middleware func(next httprouter.Handle) httprouter.Handle + +func Chain(h httprouter.Handle, m ...Middleware) httprouter.Handle { + if len(m) < 1 { + return h + } + + // Wrap the handler in the middleware in a reverse loop to preserve order + wrapped := h + for i := len(m) - 1; i >= 0; i-- { + if m[i] != nil { + wrapped = m[i](wrapped) + } + } + return wrapped +} diff --git a/pkg/server/middleware/middleware_test.go b/pkg/server/middleware/middleware_test.go new file mode 100644 index 0000000..0b6feb9 --- /dev/null +++ b/pkg/server/middleware/middleware_test.go @@ -0,0 +1,144 @@ +package middleware_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/julienschmidt/httprouter" + . "github.com/rotationalio/honu/pkg/server/middleware" + "github.com/stretchr/testify/require" +) + +func MakeTestMiddleware(name string, abort bool, calls *Calls) Middleware { + return func(next httprouter.Handle) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + calls.before(name) + if !abort { + next(w, r, p) + calls.after(name) + } + } + } +} + +func MakeTestHandler(calls *Calls) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + calls.call("handler") + fmt.Fprintln(w, "success") + } +} + +func TestChain(t *testing.T) { + calls := &Calls{} + h := Chain( + MakeTestHandler(calls), + MakeTestMiddleware("A", false, calls), + MakeTestMiddleware("B", false, calls), + MakeTestMiddleware("C", false, calls), + MakeTestMiddleware("D", false, calls), + ) + + srv := testServer(t, h) + _, err := srv.Client().Get(srv.URL + "/") + require.NoError(t, err, "expected no error making request") + + expected := []string{ + "A-before", + "B-before", + "C-before", + "D-before", + "handler", + "D-after", + "C-after", + "B-after", + "A-after", + } + require.Equal(t, len(expected), calls.calls, "incorrect number of calls") + require.Equal(t, expected, calls.callers, "middleware not chained correctly") +} + +func TestAbort(t *testing.T) { + calls := &Calls{} + h := Chain( + MakeTestHandler(calls), + MakeTestMiddleware("A", false, calls), + MakeTestMiddleware("B", true, calls), + MakeTestMiddleware("C", false, calls), + MakeTestMiddleware("D", false, calls), + ) + + srv := testServer(t, h) + _, err := srv.Client().Get(srv.URL + "/") + require.NoError(t, err, "expected no error making request") + + expected := []string{ + "A-before", + "B-before", + "A-after", + } + require.Equal(t, len(expected), calls.calls, "incorrect number of calls") + require.Equal(t, expected, calls.callers, "middleware not chained correctly") +} + +func TestChainWithNil(t *testing.T) { + calls := &Calls{} + h := Chain( + MakeTestHandler(calls), + nil, + MakeTestMiddleware("A", false, calls), + nil, nil, nil, + MakeTestMiddleware("B", false, calls), + nil, nil, nil, + ) + + srv := testServer(t, h) + _, err := srv.Client().Get(srv.URL + "/") + require.NoError(t, err, "expected no error making request") + + expected := []string{ + "A-before", + "B-before", + "handler", + "B-after", + "A-after", + } + require.Equal(t, len(expected), calls.calls, "incorrect number of calls") + require.Equal(t, expected, calls.callers, "middleware not chained correctly") +} + +func testServer(t *testing.T, h httprouter.Handle) *httptest.Server { + // Setup the test server and router + router := httprouter.New() + router.GET("/", h) + + srv := httptest.NewServer(router) + + // Ensure the server is closed when the test is complete + t.Cleanup(srv.Close) + + return srv +} + +type Calls struct { + calls int + callers []string +} + +func (c *Calls) call(name string) { + if c.callers == nil { + c.callers = make([]string, 0, 16) + } + + c.callers = append(c.callers, name) + c.calls++ +} + +func (c *Calls) before(name string) { + c.call(fmt.Sprintf("%s-before", name)) +} + +func (c *Calls) after(name string) { + c.call(fmt.Sprintf("%s-after", name)) +} diff --git a/pkg/server/render/json.go b/pkg/server/render/json.go new file mode 100644 index 0000000..b74afe9 --- /dev/null +++ b/pkg/server/render/json.go @@ -0,0 +1,21 @@ +package render + +import ( + "encoding/json" + "net/http" +) + +const ( + jsonContentType = "application/json; charset=utf-8" +) + +// JSON marshals the given interface object and writes it with the correct ContentType. +func JSON(code int, w http.ResponseWriter, obj any) error { + w.Header().Set(ContentType, jsonContentType) + w.WriteHeader(code) + + if err := json.NewEncoder(w).Encode(obj); err != nil { + return err + } + return nil +} diff --git a/pkg/server/render/render.go b/pkg/server/render/render.go new file mode 100644 index 0000000..f47f750 --- /dev/null +++ b/pkg/server/render/render.go @@ -0,0 +1,34 @@ +package render + +import ( + "fmt" + "net/http" +) + +// Header keys for http responses +const ( + ContentType = "Content-Type" +) + +// Content types for plain text responses +const ( + plainContentType = "text/plain; charset=utf-8" +) + +type Renderer func(code int, w http.ResponseWriter, obj any) error + +func Text(code int, w http.ResponseWriter, text string) error { + w.Header().Set(ContentType, plainContentType) + w.WriteHeader(code) + + fmt.Fprintln(w, text) + return nil +} + +func Textf(code int, w http.ResponseWriter, text string, a ...any) error { + w.Header().Set(ContentType, plainContentType) + w.WriteHeader(code) + + fmt.Fprintf(w, text, a...) + return nil +} diff --git a/pkg/server/routes.go b/pkg/server/routes.go new file mode 100644 index 0000000..3ccebf7 --- /dev/null +++ b/pkg/server/routes.go @@ -0,0 +1,30 @@ +package server + +import ( + "net/http" + + "github.com/julienschmidt/httprouter" + "github.com/rotationalio/honu/pkg/server/middleware" +) + +// Sets up the server's middleware and routes. +func (s *Server) setupRoutes() (err error) { + middleware := []middleware.Middleware{ + s.Maintenance(), + } + + // Kubernetes liveness probes added before middleware. + s.router.GET("/healthz", s.Healthz) + s.router.GET("/livez", s.Healthz) + s.router.GET("/readyz", s.Readyz) + + // API Routes + // Status/Heartbeat endpoint + s.addRoute(http.MethodGet, "/v1/status", s.Status, middleware...) + + return nil +} + +func (s *Server) addRoute(method, path string, h httprouter.Handle, m ...middleware.Middleware) { + s.router.Handle(method, path, middleware.Chain(h, m...)) +} diff --git a/pkg/server/server.go b/pkg/server/server.go new file mode 100644 index 0000000..0342d61 --- /dev/null +++ b/pkg/server/server.go @@ -0,0 +1,187 @@ +package server + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "os" + "sync" + "time" + + "github.com/julienschmidt/httprouter" + "github.com/rotationalio/honu/pkg/config" + "github.com/rotationalio/honu/pkg/logger" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +func init() { + // Initialize zerolog with GCP logging requirements + zerolog.TimeFieldFormat = time.RFC3339 + zerolog.TimestampFieldName = logger.GCPFieldKeyTime + zerolog.MessageFieldName = logger.GCPFieldKeyMsg + + // Add the severity hook for GCP logging + var gcpHook logger.SeverityHook + log.Logger = zerolog.New(os.Stdout).Hook(gcpHook).With().Timestamp().Logger() +} + +// Create a new Honu database server/replica instance using the specified configuration. +// This function is the main entry point to initializing a honudb instance and should +// be called rather than constructing a server directly. This method ensures that the +// configuration is correctly loaded from the environment, that the logging defaults +// are set correctly, and that any observability tools are correctly configured. +func New(conf config.Config) (s *Server, err error) { + // Load the default configuration from the environment + if conf.IsZero() { + if conf, err = config.New(); err != nil { + return nil, err + } + } + + // Set the global level + zerolog.SetGlobalLevel(conf.GetLogLevel()) + + // Set human readable logging if specified. + if conf.ConsoleLog { + console := zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339} + log.Logger = zerolog.New(console).With().Timestamp().Logger() + } + + // Create the server and prepare to serve. + s = &Server{ + conf: conf, + errc: make(chan error, 1), + } + + // Create the httprouter + s.router = httprouter.New() + s.router.RedirectFixedPath = true + s.router.HandleMethodNotAllowed = true + s.router.RedirectTrailingSlash = true + if err = s.setupRoutes(); err != nil { + return nil, err + } + + // Create the http server + s.srv = &http.Server{ + Addr: s.conf.BindAddr, + Handler: s.router, + ErrorLog: nil, + ReadHeaderTimeout: s.conf.ReadTimeout, + WriteTimeout: s.conf.WriteTimeout, + IdleTimeout: s.conf.IdleTimeout, + } + + return s, nil +} + +// A Honu Database server implements several services as enabled for interaction with +// the Honu replica network including: +// +// 1. A database api client for users to interact with the database +// 2. An administrative client for managing peers and the replica +// 3. A replication service with auto-adapting anti-entropy replication +// 4. A metrics server for prometheus to scrape data +// +// The server may also implement background services as required. +type Server struct { + sync.RWMutex + conf config.Config + srv *http.Server + router *httprouter.Router + url *url.URL + started time.Time + errc chan error + healthy bool + ready bool +} + +func (s *Server) Serve() (err error) { + // Create a socket to listen on and infer the final URL. + // NOTE: if the bindaddr is 127.0.0.1:0 for testing, a random port will be assigned, + // manually creating the listener will allow us to determine which port. + // When we start listening all incoming requests will be buffered until the server + // actually starts up in its own go routine below. + var sock net.Listener + if sock, err = net.Listen("tcp", s.srv.Addr); err != nil { + return fmt.Errorf("could not listen on bind addr %s: %s", s.srv.Addr, err) + } + + s.setURL(sock.Addr()) + s.SetStatus(true, true) + s.started = time.Now() + + // Listen for HTTP requests and handle them. + go func() { + // Make sure we don't use the external err to avoid data races. + if serr := s.serve(sock); !errors.Is(serr, http.ErrServerClosed) { + s.errc <- serr + } + }() + + log.Info().Str("url", s.URL()).Msg("honu database server started") + return <-s.errc +} + +// ServeTLS if a tls configuration is provided, otherwise Serve. +func (s *Server) serve(sock net.Listener) error { + if s.srv.TLSConfig != nil { + return s.srv.ServeTLS(sock, "", "") + } + return s.srv.Serve(sock) +} + +func (s *Server) Shutdown() (err error) { + log.Info().Msg("gracefully shutting down honu database server") + s.SetStatus(false, false) + + ctx, cancel := context.WithTimeout(context.Background(), 35*time.Second) + defer cancel() + + s.srv.SetKeepAlivesEnabled(false) + if err = s.srv.Shutdown(ctx); err != nil { + return err + } + + return nil +} + +// SetStatus sets the health and ready status on the server, modifying the behavior of +// the kubernetes probe responses. +func (s *Server) SetStatus(health, ready bool) { + s.Lock() + s.healthy = health + s.ready = ready + s.Unlock() + log.Debug().Bool("health", health).Bool("ready", ready).Msg("server status set") +} + +// URL returns the endpoint of the server as determined by the configuration and the +// socket address and port (if specified). +func (s *Server) URL() string { + s.RLock() + defer s.RUnlock() + return s.url.String() +} + +func (s *Server) setURL(addr net.Addr) { + s.Lock() + defer s.Unlock() + + s.url = &url.URL{ + Scheme: "http", + Host: addr.String(), + } + + if s.srv.TLSConfig != nil { + s.url.Scheme = "https" + } + + if tcp, ok := addr.(*net.TCPAddr); ok && tcp.IP.IsUnspecified() { + s.url.Host = fmt.Sprintf("127.0.0.1:%d", tcp.Port) + } +} diff --git a/pkg/server/status.go b/pkg/server/status.go new file mode 100644 index 0000000..c94954f --- /dev/null +++ b/pkg/server/status.go @@ -0,0 +1,67 @@ +package server + +import ( + "net/http" + "time" + + "github.com/julienschmidt/httprouter" + "github.com/rotationalio/honu/pkg" + "github.com/rotationalio/honu/pkg/api/v1" + "github.com/rotationalio/honu/pkg/server/render" +) + +const ( + serverStatusOK = "ok" + serverStatusNotReady = "not ready" + serverStatusUnhealthy = "unhealthy" + serverStatusMaintenance = "maintenance" +) + +// Status reports the version and uptime of the server +func (s *Server) Status(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + var state string + s.RLock() + switch { + case s.healthy && s.ready: + state = serverStatusOK + case s.healthy && !s.ready: + state = serverStatusNotReady + case !s.healthy: + state = serverStatusUnhealthy + } + s.RUnlock() + + render.JSON(http.StatusOK, w, &api.StatusReply{ + Status: state, + Version: pkg.Version(), + Uptime: time.Since(s.started).String(), + }) +} + +// Healthz is used to alert k8s to the health/liveness status of the server. +func (s *Server) Healthz(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + s.RLock() + healthy := s.healthy + s.RUnlock() + + if !healthy { + render.Text(http.StatusServiceUnavailable, w, serverStatusUnhealthy) + return + } + + render.Text(http.StatusOK, w, serverStatusOK) +} + +// Readyz is used to alert k8s to the readiness status of the server. +func (s *Server) Readyz(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + s.RLock() + ready := s.ready + s.RUnlock() + + if !ready { + render.Text(http.StatusServiceUnavailable, w, serverStatusNotReady) + return + } + + render.Text(http.StatusOK, w, serverStatusOK) +} diff --git a/pkg/store/decode.go b/pkg/store/decode.go index 5f44b3b..b2f5eca 100644 --- a/pkg/store/decode.go +++ b/pkg/store/decode.go @@ -5,7 +5,7 @@ import ( "io" "time" - "github.com/oklog/ulid" + "github.com/oklog/ulid/v2" ) // Unmarshal a decodable object from a byte slice for deserialization. diff --git a/pkg/store/decode_test.go b/pkg/store/decode_test.go index 92aa3bb..3fe7a85 100644 --- a/pkg/store/decode_test.go +++ b/pkg/store/decode_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/oklog/ulid" + "github.com/oklog/ulid/v2" . "github.com/rotationalio/honu/pkg/store" "github.com/stretchr/testify/require" ) diff --git a/pkg/store/encode.go b/pkg/store/encode.go index 13a7ede..bbcbdd2 100644 --- a/pkg/store/encode.go +++ b/pkg/store/encode.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - "github.com/oklog/ulid" + "github.com/oklog/ulid/v2" ) const ( diff --git a/pkg/store/encode_test.go b/pkg/store/encode_test.go index a9ee970..62cab9e 100644 --- a/pkg/store/encode_test.go +++ b/pkg/store/encode_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/oklog/ulid" + "github.com/oklog/ulid/v2" . "github.com/rotationalio/honu/pkg/store" "github.com/stretchr/testify/require" ) diff --git a/pkg/store/object.go b/pkg/store/object.go index 569fe19..908dd48 100644 --- a/pkg/store/object.go +++ b/pkg/store/object.go @@ -5,7 +5,7 @@ import ( "net" "time" - "github.com/oklog/ulid" + "github.com/oklog/ulid/v2" ) //=========================================================================== diff --git a/pkg/store/object_test.go b/pkg/store/object_test.go index 562c552..8c813d1 100644 --- a/pkg/store/object_test.go +++ b/pkg/store/object_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/oklog/ulid" + "github.com/oklog/ulid/v2" "github.com/rotationalio/honu/pkg/store" "github.com/stretchr/testify/require" )