From f4f06900e4fe603be9458496c3236295b40286a7 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Thu, 1 Feb 2024 11:39:17 -0800 Subject: [PATCH 1/7] http config refactor --- cli/start.go | 33 ++- http/handler.go | 10 +- http/handler_ccip_test.go | 12 +- http/middleware.go | 12 +- http/server.go | 373 +++++++++++----------------------- http/server_test.go | 273 +++++++++---------------- tests/clients/cli/wrapper.go | 2 +- tests/clients/http/wrapper.go | 2 +- tests/gen/cli/gendocs.go | 10 + tests/gen/cli/gendocs_test.go | 63 +++--- tests/gen/cli/util_test.go | 45 +--- 11 files changed, 290 insertions(+), 545 deletions(-) diff --git a/cli/start.go b/cli/start.go index 55b168205f..0c344ea94c 100644 --- a/cli/start.go +++ b/cli/start.go @@ -175,7 +175,7 @@ func (di *defraInstance) close(ctx context.Context) { } else { di.db.Close() } - if err := di.server.Close(); err != nil { + if err := di.server.Shutdown(ctx); err != nil { log.FeedbackInfo( ctx, "The server could not be closed successfully", @@ -259,40 +259,31 @@ func start(ctx context.Context, cfg *config.Config) (*defraInstance, error) { } } - sOpt := []func(*httpapi.Server){ + serverOpts := []httpapi.ServerOpt{ httpapi.WithAddress(cfg.API.Address), - httpapi.WithRootDir(cfg.Rootdir), httpapi.WithAllowedOrigins(cfg.API.AllowedOrigins...), + httpapi.WithTLSCertPath(cfg.API.PubKeyPath), + httpapi.WithTLSKeyPath(cfg.API.PrivKeyPath), } - if cfg.API.TLS { - sOpt = append( - sOpt, - httpapi.WithTLS(), - httpapi.WithSelfSignedCert(cfg.API.PubKeyPath, cfg.API.PrivKeyPath), - httpapi.WithCAEmail(cfg.API.Email), - ) - } - - var server *httpapi.Server + var handler *httpapi.Handler if node != nil { - server, err = httpapi.NewServer(node, sOpt...) + handler, err = httpapi.NewHandler(node) } else { - server, err = httpapi.NewServer(db, sOpt...) + handler, err = httpapi.NewHandler(db) } if err != nil { - return nil, errors.Wrap("failed to create http server", err) + return nil, errors.Wrap("failed to create http handler", err) } - if err := server.Listen(ctx); err != nil { - return nil, errors.Wrap(fmt.Sprintf("failed to listen on TCP address %v", server.Addr), err) + server, err := httpapi.NewServer(handler, serverOpts...) + if err != nil { + return nil, errors.Wrap("failed to create http server", err) } - // save the address on the config in case the port number was set to random - cfg.API.Address = server.AssignedAddr() // run the server in a separate goroutine go func() { log.FeedbackInfo(ctx, fmt.Sprintf("Providing HTTP API at %s.", cfg.API.AddressToURL())) - if err := server.Run(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { log.FeedbackErrorE(ctx, "Failed to run the HTTP server", err) if node != nil { node.Close() diff --git a/http/handler.go b/http/handler.go index 328ea8fab9..b06ef06cb6 100644 --- a/http/handler.go +++ b/http/handler.go @@ -20,7 +20,6 @@ import ( "github.com/sourcenetwork/defradb/datastore" "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" ) // Version is the identifier for the current API version. @@ -69,7 +68,7 @@ type Handler struct { txs *sync.Map } -func NewHandler(db client.DB, opts ServerOptions) (*Handler, error) { +func NewHandler(db client.DB) (*Handler, error) { router, err := NewApiRouter() if err != nil { return nil, err @@ -77,14 +76,9 @@ func NewHandler(db client.DB, opts ServerOptions) (*Handler, error) { txs := &sync.Map{} mux := chi.NewMux() - mux.Use( - middleware.RequestLogger(&logFormatter{}), - middleware.Recoverer, - CorsMiddleware(opts), - ) mux.Route("/api/"+Version, func(r chi.Router) { r.Use( - ApiMiddleware(db, txs, opts), + ApiMiddleware(db, txs), TransactionMiddleware, StoreMiddleware, ) diff --git a/http/handler_ccip_test.go b/http/handler_ccip_test.go index c0df7e6a26..2a2cc4f077 100644 --- a/http/handler_ccip_test.go +++ b/http/handler_ccip_test.go @@ -49,7 +49,7 @@ func TestCCIPGet_WithValidData(t *testing.T) { req := httptest.NewRequest(http.MethodGet, url, nil) rec := httptest.NewRecorder() - handler, err := NewHandler(cdb, ServerOptions{}) + handler, err := NewHandler(cdb) require.NoError(t, err) handler.ServeHTTP(rec, req) @@ -88,7 +88,7 @@ func TestCCIPGet_WithSubscription(t *testing.T) { req := httptest.NewRequest(http.MethodGet, url, nil) rec := httptest.NewRecorder() - handler, err := NewHandler(cdb, ServerOptions{}) + handler, err := NewHandler(cdb) require.NoError(t, err) handler.ServeHTTP(rec, req) @@ -106,7 +106,7 @@ func TestCCIPGet_WithInvalidData(t *testing.T) { req := httptest.NewRequest(http.MethodGet, url, nil) rec := httptest.NewRecorder() - handler, err := NewHandler(cdb, ServerOptions{}) + handler, err := NewHandler(cdb) require.NoError(t, err) handler.ServeHTTP(rec, req) @@ -135,7 +135,7 @@ func TestCCIPPost_WithValidData(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "http://localhost:9181/api/v0/ccip", bytes.NewBuffer(body)) rec := httptest.NewRecorder() - handler, err := NewHandler(cdb, ServerOptions{}) + handler, err := NewHandler(cdb) require.NoError(t, err) handler.ServeHTTP(rec, req) @@ -167,7 +167,7 @@ func TestCCIPPost_WithInvalidGraphQLRequest(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "http://localhost:9181/api/v0/ccip", bytes.NewBuffer(body)) rec := httptest.NewRecorder() - handler, err := NewHandler(cdb, ServerOptions{}) + handler, err := NewHandler(cdb) require.NoError(t, err) handler.ServeHTTP(rec, req) @@ -181,7 +181,7 @@ func TestCCIPPost_WithInvalidBody(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "http://localhost:9181/api/v0/ccip", nil) rec := httptest.NewRecorder() - handler, err := NewHandler(cdb, ServerOptions{}) + handler, err := NewHandler(cdb) require.NoError(t, err) handler.ServeHTTP(rec, req) diff --git a/http/middleware.go b/http/middleware.go index d33cbfb5ff..f18ba8bf60 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -56,13 +56,13 @@ var ( ) // CorsMiddleware handles cross origin request -func CorsMiddleware(opts ServerOptions) func(http.Handler) http.Handler { +func CorsMiddleware(allowedOrigins []string) func(http.Handler) http.Handler { return cors.Handler(cors.Options{ AllowOriginFunc: func(r *http.Request, origin string) bool { - if slices.Contains(opts.AllowedOrigins, "*") { + if slices.Contains(allowedOrigins, "*") { return true } - return slices.Contains(opts.AllowedOrigins, strings.ToLower(origin)) + return slices.Contains(allowedOrigins, strings.ToLower(origin)) }, AllowedMethods: []string{"GET", "HEAD", "POST", "PATCH", "DELETE"}, AllowedHeaders: []string{"Content-Type"}, @@ -71,13 +71,9 @@ func CorsMiddleware(opts ServerOptions) func(http.Handler) http.Handler { } // ApiMiddleware sets the required context values for all API requests. -func ApiMiddleware(db client.DB, txs *sync.Map, opts ServerOptions) func(http.Handler) http.Handler { +func ApiMiddleware(db client.DB, txs *sync.Map) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if opts.TLS.HasValue() { - rw.Header().Add("Strict-Transport-Security", "max-age=63072000; includeSubDomains") - } - ctx := req.Context() ctx = context.WithValue(ctx, dbContextKey, db) ctx = context.WithValue(ctx, txsContextKey, txs) diff --git a/http/server.go b/http/server.go index 768542c68d..ddbbf7f73b 100644 --- a/http/server.go +++ b/http/server.go @@ -13,304 +13,171 @@ package http import ( "context" "crypto/tls" - "fmt" "net" "net/http" - "path" - "strings" + "sync/atomic" + "time" - "github.com/sourcenetwork/immutable" - "golang.org/x/crypto/acme/autocert" - - "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/config" - "github.com/sourcenetwork/defradb/errors" - "github.com/sourcenetwork/defradb/logging" -) - -const ( - // These constants are best effort durations that fit our current API - // and possibly prevent from running out of file descriptors. - // readTimeout = 5 * time.Second - // writeTimeout = 10 * time.Second - // idleTimeout = 120 * time.Second - - // Temparily disabling timeouts until [this proposal](https://github.com/golang/go/issues/54136) is merged. - // https://github.com/sourcenetwork/defradb/issues/927 - readTimeout = 0 - writeTimeout = 0 - idleTimeout = 0 + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" ) -const ( - httpPort = ":80" - httpsPort = ":443" -) - -// Server struct holds the Handler for the HTTP API. -type Server struct { - options ServerOptions - listener net.Listener - certManager *autocert.Manager - // address that is assigned to the server on listen - address string - - http.Server +// tlsConfig contains the default tls config settings +var tlsConfig = &tls.Config{ + ServerName: "DefraDB", + MinVersion: tls.VersionTLS12, + // We only allow cipher suites that are marked secure + // by ssllabs + CipherSuites: []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, } type ServerOptions struct { + // Address is the bind address the server listens on. + Address string // AllowedOrigins is the list of allowed origins for CORS. AllowedOrigins []string - // TLS enables https when the value is present. - TLS immutable.Option[TLSOptions] - // RootDirectory is the directory for the node config. - RootDir string - // Domain is the domain for the API (optional). - Domain immutable.Option[string] -} - -type TLSOptions struct { - // PublicKey is the public key for TLS. Ignored if domain is set. - PublicKey string - // PrivateKey is the private key for TLS. Ignored if domain is set. - PrivateKey string - // Email is the address for the CA to send problem notifications (optional) - Email string - // Port is the tls port - Port string -} - -// NewServer instantiates a new server with the given http.Handler. -func NewServer(db client.DB, options ...func(*Server)) (*Server, error) { - srv := &Server{ - Server: http.Server{ - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, - IdleTimeout: idleTimeout, - }, - } - - for _, opt := range append(options, DefaultOpts()) { - opt(srv) - } - - handler, err := NewHandler(db, srv.options) - if err != nil { - return nil, err - } - srv.Handler = handler - return srv, nil -} - -func newHTTPRedirServer(m *autocert.Manager) *Server { - srv := &Server{ - Server: http.Server{ - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, - IdleTimeout: idleTimeout, - }, - } - - srv.Addr = httpPort - srv.Handler = m.HTTPHandler(nil) - - return srv + // TLSCertPath is the path to the TLS certificate. + TLSCertPath string + // TLSKeyPath is the path to the TLS key. + TLSKeyPath string + // ReadTimeout is the read timeout for connections. + ReadTimeout time.Duration + // WriteTimeout is the write timeout for connections. + WriteTimeout time.Duration + // IdleTimeout is the idle timeout for connections. + IdleTimeout time.Duration } // DefaultOpts returns the default options for the server. -func DefaultOpts() func(*Server) { - return func(s *Server) { - if s.Addr == "" { - s.Addr = "localhost:9181" - } - } -} - -// WithAllowedOrigins returns an option to set the allowed origins for CORS. -func WithAllowedOrigins(origins ...string) func(*Server) { - return func(s *Server) { - s.options.AllowedOrigins = append(s.options.AllowedOrigins, origins...) +func DefaultServerOptions() *ServerOptions { + return &ServerOptions{ + Address: "127.0.0.1:9181", } } -// WithAddress returns an option to set the address for the server. -func WithAddress(addr string) func(*Server) { - return func(s *Server) { - s.Addr = addr +// ServerOpt is a function that configures server options. +type ServerOpt func(*ServerOptions) - // If the address is not localhost, we check to see if it's a valid IP address. - // If it's not a valid IP, we assume that it's a domain name to be used with TLS. - if !strings.HasPrefix(addr, "localhost:") && !strings.HasPrefix(addr, ":") { - host, _, err := net.SplitHostPort(addr) - if err != nil { - host = addr - } - ip := net.ParseIP(host) - if ip == nil { - s.Addr = httpPort - s.options.Domain = immutable.Some(host) - } - } +// WithAllowedOrigins sets the allowed origins for CORS. +func WithAllowedOrigins(origins ...string) ServerOpt { + return func(opts *ServerOptions) { + opts.AllowedOrigins = origins } } -// WithCAEmail returns an option to set the email address for the CA to send problem notifications. -func WithCAEmail(email string) func(*Server) { - return func(s *Server) { - tlsOpt := s.options.TLS.Value() - tlsOpt.Email = email - s.options.TLS = immutable.Some(tlsOpt) +// WithAddress sets the bind address for the server. +func WithAddress(addr string) ServerOpt { + return func(opts *ServerOptions) { + opts.Address = addr } } -// WithRootDir returns an option to set the root directory for the node config. -func WithRootDir(rootDir string) func(*Server) { - return func(s *Server) { - s.options.RootDir = rootDir +// WithReadTimeout sets the server read timeout. +func WithReadTimeout(timeout time.Duration) ServerOpt { + return func(opts *ServerOptions) { + opts.ReadTimeout = timeout } } -// WithSelfSignedCert returns an option to set the public and private keys for TLS. -func WithSelfSignedCert(pubKey, privKey string) func(*Server) { - return func(s *Server) { - tlsOpt := s.options.TLS.Value() - tlsOpt.PublicKey = pubKey - tlsOpt.PrivateKey = privKey - s.options.TLS = immutable.Some(tlsOpt) +// WithWriteTimeout sets the server write timeout. +func WithWriteTimeout(timeout time.Duration) ServerOpt { + return func(opts *ServerOptions) { + opts.WriteTimeout = timeout } } -// WithTLS returns an option to enable TLS. -func WithTLS() func(*Server) { - return func(s *Server) { - tlsOpt := s.options.TLS.Value() - tlsOpt.Port = httpsPort - s.options.TLS = immutable.Some(tlsOpt) +// WithIdleTimeout sets the server idle timeout. +func WithIdleTimeout(timeout time.Duration) ServerOpt { + return func(opts *ServerOptions) { + opts.IdleTimeout = timeout } } -// WithTLSPort returns an option to set the port for TLS. -func WithTLSPort(port int) func(*Server) { - return func(s *Server) { - tlsOpt := s.options.TLS.Value() - tlsOpt.Port = fmt.Sprintf(":%d", port) - s.options.TLS = immutable.Some(tlsOpt) +// WithTLSCertPath sets the server TLS certificate path. +func WithTLSCertPath(path string) ServerOpt { + return func(opts *ServerOptions) { + opts.TLSCertPath = path } } -// Listen creates a new net.Listener and saves it on the receiver. -func (s *Server) Listen(ctx context.Context) error { - var err error - if s.options.TLS.HasValue() { - return s.listenWithTLS(ctx) +// WithTLSKeyPath sets the server TLS private key path. +func WithTLSKeyPath(path string) ServerOpt { + return func(opts *ServerOptions) { + opts.TLSKeyPath = path } - - lc := net.ListenConfig{} - s.listener, err = lc.Listen(ctx, "tcp", s.Addr) - if err != nil { - return errors.WithStack(err) - } - - // Save the address on the server in case the port was set to random - // and that we want to see what was assigned. - s.address = s.listener.Addr().String() - - return nil } -func (s *Server) listenWithTLS(ctx context.Context) error { - cfg := &tls.Config{ - MinVersion: tls.VersionTLS12, - // We only allow cipher suites that are marked secure - // by ssllabs - CipherSuites: []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - }, - ServerName: "DefraDB", - } - - if s.options.Domain.HasValue() && s.options.Domain.Value() != "" { - s.Addr = s.options.TLS.Value().Port - - if s.options.TLS.Value().Email == "" || s.options.TLS.Value().Email == config.DefaultAPIEmail { - return ErrNoEmail - } - - certCache := path.Join(s.options.RootDir, "autocerts") - - log.FeedbackInfo( - ctx, - "Generating auto certificate", - logging.NewKV("Domain", s.options.Domain.Value()), - logging.NewKV("Certificate cache", certCache), - ) - - m := &autocert.Manager{ - Cache: autocert.DirCache(certCache), - Prompt: autocert.AcceptTOS, - Email: s.options.TLS.Value().Email, - HostPolicy: autocert.HostWhitelist(s.options.Domain.Value()), - } - - cfg.GetCertificate = m.GetCertificate - - // We set manager on the server instance to later start - // a redirection server. - s.certManager = m - } else { - // When not using auto cert, we create a self signed certificate - // with the provided public and prive keys. - log.FeedbackInfo(ctx, "Generating self signed certificate") - - cert, err := tls.LoadX509KeyPair( - s.options.TLS.Value().PrivateKey, - s.options.TLS.Value().PublicKey, - ) - if err != nil { - return NewErrFailedToLoadKeys(err, s.options.TLS.Value().PublicKey, s.options.TLS.Value().PrivateKey) - } - - cfg.Certificates = []tls.Certificate{cert} - } - - var err error - s.listener, err = tls.Listen("tcp", s.Addr, cfg) - if err != nil { - return errors.WithStack(err) - } - - // Save the address on the server in case the port was set to random - // and that we want to see what was assigned. - s.address = s.listener.Addr().String() - - return nil +// Server struct holds the Handler for the HTTP API. +type Server struct { + address atomic.Value + options *ServerOptions + server *http.Server } -// Run calls Serve with the receiver's listener. -func (s *Server) Run(ctx context.Context) error { - if s.listener == nil { - return ErrNoListener +// NewServer instantiates a new server with the given http.Handler. +func NewServer(handler http.Handler, opts ...ServerOpt) (*Server, error) { + options := DefaultServerOptions() + for _, opt := range opts { + opt(options) + } + + // setup a mux with the default middleware stack + mux := chi.NewMux() + mux.Use( + middleware.RequestLogger(&logFormatter{}), + middleware.Recoverer, + CorsMiddleware(options.AllowedOrigins), + ) + mux.Handle("/*", handler) + + server := &http.Server{ + ReadTimeout: options.ReadTimeout, + WriteTimeout: options.WriteTimeout, + IdleTimeout: options.IdleTimeout, + TLSConfig: tlsConfig, + Handler: mux, + } + + var address atomic.Value + address.Store("") + + return &Server{ + address: address, + options: options, + server: server, + }, nil +} + +// ListenAndServe listens for and serves incoming connections. +func (s *Server) ListenAndServe() error { + listener, err := net.Listen("tcp", s.options.Address) + if err != nil { + return err } + // ignore close errors as they cannot be handled + // from the caller of this method + defer listener.Close() //nolint:errcheck - if s.certManager != nil { - // When using TLS it's important to redirect http requests to https - go func() { - srv := newHTTPRedirServer(s.certManager) - err := srv.ListenAndServe() - if err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Info(ctx, "Something went wrong with the redirection server", logging.NewKV("Error", err)) - } - }() + s.address.Store(listener.Addr().String()) + if s.options.TLSCertPath == "" && s.options.TLSKeyPath == "" { + return s.server.Serve(listener) } - return s.Serve(s.listener) + return s.server.ServeTLS(listener, s.options.TLSCertPath, s.options.TLSKeyPath) } // AssignedAddr returns the address that was assigned to the server on calls to listen. func (s *Server) AssignedAddr() string { - return s.address + return s.address.Load().(string) +} + +// Shutdown gracefully shuts down the server without interrupting any active connections. +func (s *Server) Shutdown(ctx context.Context) error { + return s.server.Shutdown(ctx) } diff --git a/http/server_test.go b/http/server_test.go index 04095b7c15..a568abe16e 100644 --- a/http/server_test.go +++ b/http/server_test.go @@ -12,111 +12,19 @@ package http import ( "context" + "crypto/tls" "net/http" "os" + "path/filepath" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/crypto/acme/autocert" ) -func TestNewServerAndRunWithoutListener(t *testing.T) { - ctx := context.Background() - s, err := NewServer(nil, WithAddress(":0")) - require.NoError(t, err) - if ok := assert.NotNil(t, s); ok { - assert.Equal(t, ErrNoListener, s.Run(ctx)) - } -} - -func TestNewServerAndRunWithListenerAndInvalidPort(t *testing.T) { - ctx := context.Background() - s, err := NewServer(nil, WithAddress(":303000")) - require.NoError(t, err) - if ok := assert.NotNil(t, s); ok { - assert.Error(t, s.Listen(ctx)) - } -} - -func TestNewServerAndRunWithListenerAndValidPort(t *testing.T) { - ctx := context.Background() - serverRunning := make(chan struct{}) - serverDone := make(chan struct{}) - s, err := NewServer(nil, WithAddress(":0")) - require.NoError(t, err) - go func() { - close(serverRunning) - err := s.Listen(ctx) - assert.NoError(t, err) - err = s.Run(ctx) - assert.ErrorIs(t, http.ErrServerClosed, err) - defer close(serverDone) - }() - - <-serverRunning - - s.Shutdown(context.Background()) - - <-serverDone -} - -func TestNewServerAndRunWithAutocertWithoutEmail(t *testing.T) { - ctx := context.Background() - dir := t.TempDir() - s, err := NewServer(nil, WithAddress("example.com"), WithRootDir(dir), WithTLSPort(0)) - require.NoError(t, err) - err = s.Listen(ctx) - assert.ErrorIs(t, err, ErrNoEmail) - - s.Shutdown(context.Background()) -} - -func TestNewServerAndRunWithAutocert(t *testing.T) { - ctx := context.Background() - serverRunning := make(chan struct{}) - serverDone := make(chan struct{}) - dir := t.TempDir() - s, err := NewServer(nil, WithAddress("example.com"), WithRootDir(dir), WithTLSPort(0), WithCAEmail("dev@defradb.net")) - require.NoError(t, err) - go func() { - close(serverRunning) - err := s.Listen(ctx) - assert.NoError(t, err) - err = s.Run(ctx) - assert.ErrorIs(t, http.ErrServerClosed, err) - defer close(serverDone) - }() - - <-serverRunning - - s.Shutdown(context.Background()) - - <-serverDone -} - -func TestNewServerAndRunWithSelfSignedCertAndNoKeyFiles(t *testing.T) { - ctx := context.Background() - serverRunning := make(chan struct{}) - serverDone := make(chan struct{}) - dir := t.TempDir() - s, err := NewServer(nil, WithAddress("localhost:0"), WithSelfSignedCert(dir+"/server.crt", dir+"/server.key")) - require.NoError(t, err) - go func() { - close(serverRunning) - err := s.Listen(ctx) - assert.Contains(t, err.Error(), "failed to load given keys") - defer close(serverDone) - }() - - <-serverRunning - - s.Shutdown(context.Background()) - - <-serverDone -} - -const pubKey = `-----BEGIN EC PARAMETERS----- +// tlsKey is the TLS private key in PEM format +const tlsKey = `-----BEGIN EC PARAMETERS----- BgUrgQQAIg== -----END EC PARAMETERS----- -----BEGIN EC PRIVATE KEY----- @@ -126,7 +34,8 @@ pS0gW/SYpAncHhRuz18RQ2ycuXlSN1S/PAryRZ5PK2xORKfwpguEDEMdVwbHorZO K44P/h3dhyNyAyf8rcRoqKXcl/K/uew= -----END EC PRIVATE KEY-----` -const privKey = `-----BEGIN CERTIFICATE----- +// tlsKey is the TLS certificate in PEM format +const tlsCert = `-----BEGIN CERTIFICATE----- MIICQDCCAcUCCQDpMnN1gQ4fGTAKBggqhkjOPQQDAjCBiDELMAkGA1UEBhMCY2Ex DzANBgNVBAgMBlF1ZWJlYzEQMA4GA1UEBwwHQ2hlbHNlYTEPMA0GA1UECgwGU291 cmNlMRAwDgYDVQQLDAdEZWZyYURCMQ8wDQYDVQQDDAZzb3VyY2UxIjAgBgkqhkiG @@ -142,121 +51,123 @@ kgIxAKaEGC+lqp0aaN+yubYLRiTDxOlNpyiHox3nZiL4bG/CCdPDvbX63QcdI2yq XPKczg== -----END CERTIFICATE-----` -func TestNewServerAndRunWithSelfSignedCertAndInvalidPort(t *testing.T) { - ctx := context.Background() - serverRunning := make(chan struct{}) - serverDone := make(chan struct{}) - dir := t.TempDir() - err := os.WriteFile(dir+"/server.key", []byte(privKey), 0644) - if err != nil { - t.Fatal(err) - } - err = os.WriteFile(dir+"/server.crt", []byte(pubKey), 0644) - if err != nil { - t.Fatal(err) - } - s, err := NewServer(nil, WithAddress(":303000"), WithSelfSignedCert(dir+"/server.crt", dir+"/server.key")) +// insecureClient is an http client that trusts all tls certificates +var insecureClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, +} + +// testHandler returns an empty body and 200 status code +var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) +}) + +func TestServerListenAndServeWithInvalidAddress(t *testing.T) { + srv, err := NewServer(testHandler, WithAddress("invalid")) + require.NoError(t, err) + + err = srv.ListenAndServe() + require.ErrorContains(t, err, "address invalid") +} + +func TestServerListenAndServeWithAddress(t *testing.T) { + srv, err := NewServer(testHandler, WithAddress("127.0.0.1:30001")) require.NoError(t, err) + go func() { - close(serverRunning) - err := s.Listen(ctx) - assert.Contains(t, err.Error(), "invalid port") - defer close(serverDone) + err := srv.ListenAndServe() + require.ErrorIs(t, http.ErrServerClosed, err) }() - <-serverRunning + // wait for server to start + <-time.After(time.Second * 1) + + res, err := http.Get("http://" + srv.AssignedAddr()) + require.NoError(t, err) - s.Shutdown(context.Background()) + defer res.Body.Close() + assert.Equal(t, 200, res.StatusCode) - <-serverDone + err = srv.Shutdown(context.Background()) + require.NoError(t, err) } -func TestNewServerAndRunWithSelfSignedCert(t *testing.T) { - ctx := context.Background() - serverRunning := make(chan struct{}) - serverDone := make(chan struct{}) - dir := t.TempDir() - err := os.WriteFile(dir+"/server.key", []byte(privKey), 0644) - if err != nil { - t.Fatal(err) - } - err = os.WriteFile(dir+"/server.crt", []byte(pubKey), 0644) - if err != nil { - t.Fatal(err) - } - s, err := NewServer(nil, WithAddress("localhost:0"), WithSelfSignedCert(dir+"/server.crt", dir+"/server.key")) +func TestServerListenAndServeWithTLS(t *testing.T) { + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "cert.pub") + keyPath := filepath.Join(tempDir, "cert.key") + + err := os.WriteFile(certPath, []byte(tlsCert), 0644) + require.NoError(t, err) + + err = os.WriteFile(keyPath, []byte(tlsKey), 0644) require.NoError(t, err) + + srv, err := NewServer(testHandler, WithAddress("127.0.0.1:8443"), WithTLSCertPath(certPath), WithTLSKeyPath(keyPath)) + require.NoError(t, err) + go func() { - close(serverRunning) - err := s.Listen(ctx) - assert.NoError(t, err) - err = s.Run(ctx) - assert.ErrorIs(t, http.ErrServerClosed, err) - defer close(serverDone) + err := srv.ListenAndServe() + require.ErrorIs(t, http.ErrServerClosed, err) }() - <-serverRunning + // wait for server to start + <-time.After(time.Second * 1) - s.Shutdown(context.Background()) + res, err := insecureClient.Get("https://" + srv.AssignedAddr()) + require.NoError(t, err) - <-serverDone -} + defer res.Body.Close() + assert.Equal(t, 200, res.StatusCode) -func TestNewServerWithoutOptions(t *testing.T) { - s, err := NewServer(nil) + err = srv.Shutdown(context.Background()) require.NoError(t, err) - assert.Equal(t, "localhost:9181", s.Addr) - assert.Equal(t, []string(nil), s.options.AllowedOrigins) } -func TestNewServerWithAddress(t *testing.T) { - s, err := NewServer(nil, WithAddress("localhost:9999")) +func TestServerListenAndServeWithAllowedOrigins(t *testing.T) { + srv, err := NewServer(testHandler, WithAllowedOrigins("localhost")) require.NoError(t, err) - assert.Equal(t, "localhost:9999", s.Addr) -} -func TestNewServerWithDomainAddress(t *testing.T) { - s, err := NewServer(nil, WithAddress("example.com")) - require.NoError(t, err) - assert.Equal(t, "example.com", s.options.Domain.Value()) - assert.NotNil(t, s.options.TLS) -} + go func() { + err := srv.ListenAndServe() + require.ErrorIs(t, http.ErrServerClosed, err) + }() -func TestNewServerWithAllowedOrigins(t *testing.T) { - s, err := NewServer(nil, WithAllowedOrigins("https://source.network", "https://app.source.network")) + // wait for server to start + <-time.After(time.Second * 1) + + req, err := http.NewRequest(http.MethodOptions, "http://"+srv.AssignedAddr(), nil) require.NoError(t, err) - assert.Equal(t, []string{"https://source.network", "https://app.source.network"}, s.options.AllowedOrigins) -} + req.Header.Add("origin", "localhost") -func TestNewServerWithCAEmail(t *testing.T) { - s, err := NewServer(nil, WithCAEmail("me@example.com")) + res, err := http.DefaultClient.Do(req) require.NoError(t, err) - assert.Equal(t, "me@example.com", s.options.TLS.Value().Email) -} -func TestNewServerWithRootDir(t *testing.T) { - dir := t.TempDir() - s, err := NewServer(nil, WithRootDir(dir)) + defer res.Body.Close() + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, "localhost", res.Header.Get("Access-Control-Allow-Origin")) + + err = srv.Shutdown(context.Background()) require.NoError(t, err) - assert.Equal(t, dir, s.options.RootDir) } -func TestNewServerWithTLSPort(t *testing.T) { - s, err := NewServer(nil, WithTLSPort(44343)) +func TestServerWithReadTimeout(t *testing.T) { + srv, err := NewServer(testHandler, WithReadTimeout(time.Second)) require.NoError(t, err) - assert.Equal(t, ":44343", s.options.TLS.Value().Port) + assert.Equal(t, time.Second, srv.server.ReadTimeout) } -func TestNewServerWithSelfSignedCert(t *testing.T) { - s, err := NewServer(nil, WithSelfSignedCert("pub.key", "priv.key")) +func TestServerWithWriteTimeout(t *testing.T) { + srv, err := NewServer(testHandler, WithWriteTimeout(time.Second)) require.NoError(t, err) - assert.Equal(t, "pub.key", s.options.TLS.Value().PublicKey) - assert.Equal(t, "priv.key", s.options.TLS.Value().PrivateKey) - assert.NotNil(t, s.options.TLS) + assert.Equal(t, time.Second, srv.server.WriteTimeout) } -func TestNewHTTPRedirServer(t *testing.T) { - m := &autocert.Manager{} - s := newHTTPRedirServer(m) - assert.Equal(t, ":80", s.Addr) +func TestServerWithIdleTimeout(t *testing.T) { + srv, err := NewServer(testHandler, WithIdleTimeout(time.Second)) + require.NoError(t, err) + assert.Equal(t, time.Second, srv.server.IdleTimeout) } diff --git a/tests/clients/cli/wrapper.go b/tests/clients/cli/wrapper.go index 49a0605598..a15b263709 100644 --- a/tests/clients/cli/wrapper.go +++ b/tests/clients/cli/wrapper.go @@ -40,7 +40,7 @@ type Wrapper struct { } func NewWrapper(node *net.Node) (*Wrapper, error) { - handler, err := http.NewHandler(node, http.ServerOptions{}) + handler, err := http.NewHandler(node) if err != nil { return nil, err } diff --git a/tests/clients/http/wrapper.go b/tests/clients/http/wrapper.go index 040ab9c1b4..4b935b8a0b 100644 --- a/tests/clients/http/wrapper.go +++ b/tests/clients/http/wrapper.go @@ -36,7 +36,7 @@ type Wrapper struct { } func NewWrapper(node *net.Node) (*Wrapper, error) { - handler, err := http.NewHandler(node, http.ServerOptions{}) + handler, err := http.NewHandler(node) if err != nil { return nil, err } diff --git a/tests/gen/cli/gendocs.go b/tests/gen/cli/gendocs.go index 226d73bc97..3bb94aef09 100644 --- a/tests/gen/cli/gendocs.go +++ b/tests/gen/cli/gendocs.go @@ -100,6 +100,16 @@ Example: The following command generates 100 User documents and 500 Device docum return nil }, } + + cmd.PersistentFlags().String( + "url", cfg.API.Address, + "URL of HTTP endpoint to listen on or connect to", + ) + err := cfg.BindFlag("api.address", cmd.PersistentFlags().Lookup("url")) + if err != nil { + panic(err) + } + cmd.Flags().StringVarP(&demandJSON, "demand", "d", "", "Documents' demand in JSON format") return cmd diff --git a/tests/gen/cli/gendocs_test.go b/tests/gen/cli/gendocs_test.go index 18b9b157c1..1bd30297a6 100644 --- a/tests/gen/cli/gendocs_test.go +++ b/tests/gen/cli/gendocs_test.go @@ -12,43 +12,39 @@ package cli import ( "bytes" + "context" "io" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/sourcenetwork/defradb/cli" - "github.com/sourcenetwork/defradb/config" "github.com/sourcenetwork/defradb/tests/gen" ) -func execAddSchemaCmd(t *testing.T, cfg *config.Config, schema string) { - rootCmd := cli.NewDefraCommand(cfg) - rootCmd.SetArgs([]string{"client", "schema", "add", schema}) - err := rootCmd.Execute() - require.NoError(t, err) -} - func TestGendocsCmd_IfNoErrors_ReturnGenerationOutput(t *testing.T) { - cfg, _, close := startTestNode(t) + defra, close := startTestNode(t) defer close() - execAddSchemaCmd(t, cfg, ` - type User { - name: String - devices: [Device] - } - type Device { - model: String - owner: User - }`) - - genDocsCmd := MakeGenDocCommand(cfg) + defra.db.AddSchema(context.Background(), ` + type User { + name: String + devices: [Device] + } + type Device { + model: String + owner: User + }`) + + genDocsCmd := MakeGenDocCommand(getTestConfig(t)) outputBuf := bytes.NewBufferString("") genDocsCmd.SetOut(outputBuf) - genDocsCmd.SetArgs([]string{"--demand", `{"User": 3, "Device": 12}`}) + genDocsCmd.SetArgs([]string{ + "--demand", `{"User": 3, "Device": 12}`, + "--url", strings.TrimPrefix(defra.server.URL, "http://"), + }) err := genDocsCmd.Execute() require.NoError(t, err) @@ -67,33 +63,38 @@ func TestGendocsCmd_IfNoErrors_ReturnGenerationOutput(t *testing.T) { } func TestGendocsCmd_IfInvalidDemandValue_ReturnError(t *testing.T) { - cfg, _, close := startTestNode(t) + defra, close := startTestNode(t) defer close() - execAddSchemaCmd(t, cfg, ` + defra.db.AddSchema(context.Background(), ` type User { name: String }`) - genDocsCmd := MakeGenDocCommand(cfg) - genDocsCmd.SetArgs([]string{"--demand", `{"User": invalid}`}) + genDocsCmd := MakeGenDocCommand(getTestConfig(t)) + genDocsCmd.SetArgs([]string{ + "--demand", `{"User": invalid}`, + "--url", strings.TrimPrefix(defra.server.URL, "http://"), + }) err := genDocsCmd.Execute() require.ErrorContains(t, err, errInvalidDemandValue) } func TestGendocsCmd_IfInvalidConfig_ReturnError(t *testing.T) { - cfg, _, close := startTestNode(t) + defra, close := startTestNode(t) defer close() - execAddSchemaCmd(t, cfg, ` + defra.db.AddSchema(context.Background(), ` type User { name: String }`) - genDocsCmd := MakeGenDocCommand(cfg) - - genDocsCmd.SetArgs([]string{"--demand", `{"Unknown": 3}`}) + genDocsCmd := MakeGenDocCommand(getTestConfig(t)) + genDocsCmd.SetArgs([]string{ + "--demand", `{"Unknown": 3}`, + "--url", strings.TrimPrefix(defra.server.URL, "http://"), + }) err := genDocsCmd.Execute() require.Error(t, err, gen.NewErrInvalidConfiguration("")) diff --git a/tests/gen/cli/util_test.go b/tests/gen/cli/util_test.go index a04761f6b5..81d713955c 100644 --- a/tests/gen/cli/util_test.go +++ b/tests/gen/cli/util_test.go @@ -12,9 +12,7 @@ package cli import ( "context" - "fmt" - "net/http" - "os" + "net/http/httptest" "testing" badger "github.com/sourcenetwork/badger/v4" @@ -33,21 +31,15 @@ var log = logging.MustNewLogger("cli") type defraInstance struct { db client.DB - server *httpapi.Server + server *httptest.Server } func (di *defraInstance) close(ctx context.Context) { di.db.Close() - if err := di.server.Close(); err != nil { - log.FeedbackInfo( - ctx, - "The server could not be closed successfully", - logging.NewKV("Error", err.Error()), - ) - } + di.server.Close() } -func start(ctx context.Context, cfg *config.Config) (*defraInstance, error) { +func start(ctx context.Context) (*defraInstance, error) { log.FeedbackInfo(ctx, "Starting DefraDB service...") log.FeedbackInfo(ctx, "Building new memory store") @@ -63,26 +55,11 @@ func start(ctx context.Context, cfg *config.Config) (*defraInstance, error) { return nil, errors.Wrap("failed to create database", err) } - server, err := httpapi.NewServer(db, httpapi.WithAddress(cfg.API.Address)) + handler, err := httpapi.NewHandler(db) if err != nil { - return nil, errors.Wrap("failed to create http server", err) - } - if err := server.Listen(ctx); err != nil { - return nil, errors.Wrap(fmt.Sprintf("failed to listen on TCP address %v", server.Addr), err) + return nil, errors.Wrap("failed to create http handler", err) } - // save the address on the config in case the port number was set to random - cfg.API.Address = server.AssignedAddr() - cfg.Persist() - - // run the server in a separate goroutine - go func(apiAddress string) { - log.FeedbackInfo(ctx, fmt.Sprintf("Providing HTTP API at %s.", apiAddress)) - if err := server.Run(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { - log.FeedbackErrorE(ctx, "Failed to run the HTTP server", err) - db.Close() - os.Exit(1) - } - }(cfg.API.AddressToURL()) + server := httptest.NewServer(handler) return &defraInstance{ db: db, @@ -101,11 +78,9 @@ func getTestConfig(t *testing.T) *config.Config { return cfg } -func startTestNode(t *testing.T) (*config.Config, *defraInstance, func()) { - cfg := getTestConfig(t) - +func startTestNode(t *testing.T) (*defraInstance, func()) { ctx := context.Background() - di, err := start(ctx, cfg) + di, err := start(ctx) require.NoError(t, err) - return cfg, di, func() { di.close(ctx) } + return di, func() { di.close(ctx) } } From af5e0c93e75d44e06b6f677a071b536d5ca6e63b Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Thu, 1 Feb 2024 11:51:55 -0800 Subject: [PATCH 2/7] remove default tls cert from config --- config/config.go | 12 ++++++------ config/configfile_yaml.gotmpl | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/config/config.go b/config/config.go index ef3aeaa8cc..23fd2d43fe 100644 --- a/config/config.go +++ b/config/config.go @@ -207,11 +207,13 @@ func (cfg *Config) paramsPreprocessing() error { if !filepath.IsAbs(cfg.v.GetString("datastore.badger.path")) { cfg.v.Set("datastore.badger.path", filepath.Join(cfg.Rootdir, cfg.v.GetString("datastore.badger.path"))) } - if !filepath.IsAbs(cfg.v.GetString("api.privkeypath")) { - cfg.v.Set("api.privkeypath", filepath.Join(cfg.Rootdir, cfg.v.GetString("api.privkeypath"))) + privKeyPath := cfg.v.GetString("api.privkeypath") + if privKeyPath != "" && !filepath.IsAbs(privKeyPath) { + cfg.v.Set("api.privkeypath", filepath.Join(cfg.Rootdir, privKeyPath)) } - if !filepath.IsAbs(cfg.v.GetString("api.pubkeypath")) { - cfg.v.Set("api.pubkeypath", filepath.Join(cfg.Rootdir, cfg.v.GetString("api.pubkeypath"))) + pubKeyPath := cfg.v.GetString("api.pubkeypath") + if pubKeyPath != "" && !filepath.IsAbs(pubKeyPath) { + cfg.v.Set("api.pubkeypath", filepath.Join(cfg.Rootdir, pubKeyPath)) } // log.logger configuration as a string @@ -303,8 +305,6 @@ func defaultAPIConfig() *APIConfig { Address: "localhost:9181", TLS: false, AllowedOrigins: []string{}, - PubKeyPath: "certs/server.key", - PrivKeyPath: "certs/server.crt", Email: DefaultAPIEmail, } } diff --git a/config/configfile_yaml.gotmpl b/config/configfile_yaml.gotmpl index 53d87c46d3..d789456a83 100644 --- a/config/configfile_yaml.gotmpl +++ b/config/configfile_yaml.gotmpl @@ -26,9 +26,9 @@ api: # The list of origins a cross-domain request can be executed from. # allowed-origins: {{ .API.AllowedOrigins }} # The path to the public key file. Ignored if domains is set. - pubkeypath: {{ .API.PubKeyPath }} + # pubkeypath: {{ .API.PubKeyPath }} # The path to the private key file. Ignored if domains is set. - privkeypath: {{ .API.PrivKeyPath }} + # privkeypath: {{ .API.PrivKeyPath }} # Email address to let the CA (Let's Encrypt) send notifications via email when there are issues (optional). # email: {{ .API.Email }} From 5f305ea7d11ad23d5abe779e9ae71265cf609e58 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Tue, 6 Feb 2024 09:48:16 -0800 Subject: [PATCH 3/7] add documentation to server address property --- http/server.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/http/server.go b/http/server.go index ddbbf7f73b..5578f0b455 100644 --- a/http/server.go +++ b/http/server.go @@ -116,6 +116,10 @@ func WithTLSKeyPath(path string) ServerOpt { // Server struct holds the Handler for the HTTP API. type Server struct { + // address is the assigned listen address for the server. + // + // The value is atomic to avoid a race condition between + // the listener starting and calling AssignedAddr. address atomic.Value options *ServerOptions server *http.Server From f2f1aac869494e641e4efafe67125e164f7d8a9c Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Tue, 6 Feb 2024 09:58:44 -0800 Subject: [PATCH 4/7] remove redundant listener close call in ListenAndServe --- http/server.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/http/server.go b/http/server.go index 5578f0b455..8c8d019f10 100644 --- a/http/server.go +++ b/http/server.go @@ -165,10 +165,6 @@ func (s *Server) ListenAndServe() error { if err != nil { return err } - // ignore close errors as they cannot be handled - // from the caller of this method - defer listener.Close() //nolint:errcheck - s.address.Store(listener.Addr().String()) if s.options.TLSCertPath == "" && s.options.TLSKeyPath == "" { return s.server.Serve(listener) From 6528764c0fdab371267733c55eda733f088a6369 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Tue, 6 Feb 2024 10:26:44 -0800 Subject: [PATCH 5/7] cleanup listen and serve logic --- http/server.go | 73 ++++++++++++++++++++++++++++----------------- http/server_test.go | 42 +++++++++++++++++++------- 2 files changed, 78 insertions(+), 37 deletions(-) diff --git a/http/server.go b/http/server.go index 8c8d019f10..dc463a653a 100644 --- a/http/server.go +++ b/http/server.go @@ -22,20 +22,15 @@ import ( "github.com/go-chi/chi/v5/middleware" ) -// tlsConfig contains the default tls config settings -var tlsConfig = &tls.Config{ - ServerName: "DefraDB", - MinVersion: tls.VersionTLS12, - // We only allow cipher suites that are marked secure - // by ssllabs - CipherSuites: []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - }, +// We only allow cipher suites that are marked secure +// by ssllabs +var tlsCipherSuites = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, } type ServerOptions struct { @@ -145,7 +140,6 @@ func NewServer(handler http.Handler, opts ...ServerOpt) (*Server, error) { ReadTimeout: options.ReadTimeout, WriteTimeout: options.WriteTimeout, IdleTimeout: options.IdleTimeout, - TLSConfig: tlsConfig, Handler: mux, } @@ -159,25 +153,50 @@ func NewServer(handler http.Handler, opts ...ServerOpt) (*Server, error) { }, nil } +// AssignedAddr returns the address that was assigned to the server on calls to listen. +func (s *Server) AssignedAddr() string { + return s.address.Load().(string) +} + +// Shutdown gracefully shuts down the server without interrupting any active connections. +func (s *Server) Shutdown(ctx context.Context) error { + return s.server.Shutdown(ctx) +} + // ListenAndServe listens for and serves incoming connections. func (s *Server) ListenAndServe() error { + if s.options.TLSCertPath == "" && s.options.TLSKeyPath == "" { + return s.listenAndServe() + } + return s.listenAndServeTLS() +} + +// listenAndServe listens for and serves http connections. +func (s *Server) listenAndServe() error { listener, err := net.Listen("tcp", s.options.Address) if err != nil { return err } s.address.Store(listener.Addr().String()) - if s.options.TLSCertPath == "" && s.options.TLSKeyPath == "" { - return s.server.Serve(listener) - } - return s.server.ServeTLS(listener, s.options.TLSCertPath, s.options.TLSKeyPath) + return s.server.Serve(listener) } -// AssignedAddr returns the address that was assigned to the server on calls to listen. -func (s *Server) AssignedAddr() string { - return s.address.Load().(string) -} - -// Shutdown gracefully shuts down the server without interrupting any active connections. -func (s *Server) Shutdown(ctx context.Context) error { - return s.server.Shutdown(ctx) +// listenAndServeTLS listens for and serves https connections. +func (s *Server) listenAndServeTLS() error { + cert, err := tls.LoadX509KeyPair(s.options.TLSCertPath, s.options.TLSKeyPath) + if err != nil { + return err + } + config := &tls.Config{ + ServerName: "DefraDB", + MinVersion: tls.VersionTLS12, + CipherSuites: tlsCipherSuites, + Certificates: []tls.Certificate{cert}, + } + listener, err := net.Listen("tcp", s.options.Address) + if err != nil { + return err + } + s.address.Store(listener.Addr().String()) + return s.server.Serve(tls.NewListener(listener, config)) } diff --git a/http/server_test.go b/http/server_test.go index a568abe16e..0c89506e54 100644 --- a/http/server_test.go +++ b/http/server_test.go @@ -73,6 +73,23 @@ func TestServerListenAndServeWithInvalidAddress(t *testing.T) { require.ErrorContains(t, err, "address invalid") } +func TestServerListenAndServeWithTLSAndInvalidAddress(t *testing.T) { + certPath, keyPath := writeTestCerts(t) + srv, err := NewServer(testHandler, WithAddress("invalid"), WithTLSCertPath(certPath), WithTLSKeyPath(keyPath)) + require.NoError(t, err) + + err = srv.ListenAndServe() + require.ErrorContains(t, err, "address invalid") +} + +func TestServerListenAndServeWithTLSAndInvalidCerts(t *testing.T) { + srv, err := NewServer(testHandler, WithAddress("invalid"), WithTLSCertPath("invalid.crt"), WithTLSKeyPath("invalid.key")) + require.NoError(t, err) + + err = srv.ListenAndServe() + require.ErrorContains(t, err, "no such file or directory") +} + func TestServerListenAndServeWithAddress(t *testing.T) { srv, err := NewServer(testHandler, WithAddress("127.0.0.1:30001")) require.NoError(t, err) @@ -96,16 +113,7 @@ func TestServerListenAndServeWithAddress(t *testing.T) { } func TestServerListenAndServeWithTLS(t *testing.T) { - tempDir := t.TempDir() - certPath := filepath.Join(tempDir, "cert.pub") - keyPath := filepath.Join(tempDir, "cert.key") - - err := os.WriteFile(certPath, []byte(tlsCert), 0644) - require.NoError(t, err) - - err = os.WriteFile(keyPath, []byte(tlsKey), 0644) - require.NoError(t, err) - + certPath, keyPath := writeTestCerts(t) srv, err := NewServer(testHandler, WithAddress("127.0.0.1:8443"), WithTLSCertPath(certPath), WithTLSKeyPath(keyPath)) require.NoError(t, err) @@ -171,3 +179,17 @@ func TestServerWithIdleTimeout(t *testing.T) { require.NoError(t, err) assert.Equal(t, time.Second, srv.server.IdleTimeout) } + +func writeTestCerts(t *testing.T) (string, string) { + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "cert.pub") + keyPath := filepath.Join(tempDir, "cert.key") + + err := os.WriteFile(certPath, []byte(tlsCert), 0644) + require.NoError(t, err) + + err = os.WriteFile(keyPath, []byte(tlsKey), 0644) + require.NoError(t, err) + + return certPath, keyPath +} From 59a42aaba024956985b3bf9142a68ba22460ebe5 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Tue, 6 Feb 2024 10:46:53 -0800 Subject: [PATCH 6/7] remove AssignedAddr --- http/server.go | 17 ----------------- http/server_test.go | 8 ++++---- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/http/server.go b/http/server.go index dc463a653a..80d46c6679 100644 --- a/http/server.go +++ b/http/server.go @@ -15,7 +15,6 @@ import ( "crypto/tls" "net" "net/http" - "sync/atomic" "time" "github.com/go-chi/chi/v5" @@ -111,11 +110,6 @@ func WithTLSKeyPath(path string) ServerOpt { // Server struct holds the Handler for the HTTP API. type Server struct { - // address is the assigned listen address for the server. - // - // The value is atomic to avoid a race condition between - // the listener starting and calling AssignedAddr. - address atomic.Value options *ServerOptions server *http.Server } @@ -143,21 +137,12 @@ func NewServer(handler http.Handler, opts ...ServerOpt) (*Server, error) { Handler: mux, } - var address atomic.Value - address.Store("") - return &Server{ - address: address, options: options, server: server, }, nil } -// AssignedAddr returns the address that was assigned to the server on calls to listen. -func (s *Server) AssignedAddr() string { - return s.address.Load().(string) -} - // Shutdown gracefully shuts down the server without interrupting any active connections. func (s *Server) Shutdown(ctx context.Context) error { return s.server.Shutdown(ctx) @@ -177,7 +162,6 @@ func (s *Server) listenAndServe() error { if err != nil { return err } - s.address.Store(listener.Addr().String()) return s.server.Serve(listener) } @@ -197,6 +181,5 @@ func (s *Server) listenAndServeTLS() error { if err != nil { return err } - s.address.Store(listener.Addr().String()) return s.server.Serve(tls.NewListener(listener, config)) } diff --git a/http/server_test.go b/http/server_test.go index 0c89506e54..4065267c26 100644 --- a/http/server_test.go +++ b/http/server_test.go @@ -102,7 +102,7 @@ func TestServerListenAndServeWithAddress(t *testing.T) { // wait for server to start <-time.After(time.Second * 1) - res, err := http.Get("http://" + srv.AssignedAddr()) + res, err := http.Get("http://127.0.0.1:30001") require.NoError(t, err) defer res.Body.Close() @@ -125,7 +125,7 @@ func TestServerListenAndServeWithTLS(t *testing.T) { // wait for server to start <-time.After(time.Second * 1) - res, err := insecureClient.Get("https://" + srv.AssignedAddr()) + res, err := insecureClient.Get("https://127.0.0.1:8443") require.NoError(t, err) defer res.Body.Close() @@ -136,7 +136,7 @@ func TestServerListenAndServeWithTLS(t *testing.T) { } func TestServerListenAndServeWithAllowedOrigins(t *testing.T) { - srv, err := NewServer(testHandler, WithAllowedOrigins("localhost")) + srv, err := NewServer(testHandler, WithAllowedOrigins("localhost"), WithAddress("127.0.0.1:30001")) require.NoError(t, err) go func() { @@ -147,7 +147,7 @@ func TestServerListenAndServeWithAllowedOrigins(t *testing.T) { // wait for server to start <-time.After(time.Second * 1) - req, err := http.NewRequest(http.MethodOptions, "http://"+srv.AssignedAddr(), nil) + req, err := http.NewRequest(http.MethodOptions, "http://127.0.0.1:30001", nil) require.NoError(t, err) req.Header.Add("origin", "localhost") From 3a6c4b4e231631fa756969ccbff38d9032363c12 Mon Sep 17 00:00:00 2001 From: Keenan Nemetz Date: Wed, 7 Feb 2024 08:58:47 -0800 Subject: [PATCH 7/7] remove autotls section from readme --- README.md | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/README.md b/README.md index 1019e34dc4..f1ac6cb88f 100644 --- a/README.md +++ b/README.md @@ -395,16 +395,6 @@ defradb start --tls --pubkeypath ~/path-to-pubkey.key --privkeypath ~/path-to-pr ``` -DefraDB also comes with automatic HTTPS for deployments on the public web. To enable HTTPS, - deploy DefraDB to a server with both port 80 and port 443 open. With your domain's DNS A record - pointed to the IP of your server, you can run the database using the following command: -```shell -sudo defradb start --tls --url=your-domain.net --email=email@example.com -``` -Note: `sudo` is needed above for the redirection server (to bind port 80). - -A valid email address is necessary for the creation of the certificate, and is important to get notifications from the Certificate Authority - in case the certificate is about to expire, etc. - ## Supporting CORS When accessing DefraDB through a frontend interface, you may be confronted with a CORS error. That is because, by default, DefraDB will not have any allowed origins set. To specify which origins should be allowed to access your DefraDB endpoint, you can specify them when starting the database: