diff --git a/README.md b/README.md index 1019e34dc4..f1ac6cb88f 100644 --- a/README.md +++ b/README.md @@ -395,16 +395,6 @@ defradb start --tls --pubkeypath ~/path-to-pubkey.key --privkeypath ~/path-to-pr ``` -DefraDB also comes with automatic HTTPS for deployments on the public web. To enable HTTPS, - deploy DefraDB to a server with both port 80 and port 443 open. With your domain's DNS A record - pointed to the IP of your server, you can run the database using the following command: -```shell -sudo defradb start --tls --url=your-domain.net --email=email@example.com -``` -Note: `sudo` is needed above for the redirection server (to bind port 80). - -A valid email address is necessary for the creation of the certificate, and is important to get notifications from the Certificate Authority - in case the certificate is about to expire, etc. - ## Supporting CORS When accessing DefraDB through a frontend interface, you may be confronted with a CORS error. That is because, by default, DefraDB will not have any allowed origins set. To specify which origins should be allowed to access your DefraDB endpoint, you can specify them when starting the database: diff --git a/cli/start.go b/cli/start.go index 55b168205f..0c344ea94c 100644 --- a/cli/start.go +++ b/cli/start.go @@ -175,7 +175,7 @@ func (di *defraInstance) close(ctx context.Context) { } else { di.db.Close() } - if err := di.server.Close(); err != nil { + if err := di.server.Shutdown(ctx); err != nil { log.FeedbackInfo( ctx, "The server could not be closed successfully", @@ -259,40 +259,31 @@ func start(ctx context.Context, cfg *config.Config) (*defraInstance, error) { } } - sOpt := []func(*httpapi.Server){ + serverOpts := []httpapi.ServerOpt{ httpapi.WithAddress(cfg.API.Address), - httpapi.WithRootDir(cfg.Rootdir), httpapi.WithAllowedOrigins(cfg.API.AllowedOrigins...), + httpapi.WithTLSCertPath(cfg.API.PubKeyPath), + httpapi.WithTLSKeyPath(cfg.API.PrivKeyPath), } - if cfg.API.TLS { - sOpt = append( - sOpt, - httpapi.WithTLS(), - httpapi.WithSelfSignedCert(cfg.API.PubKeyPath, cfg.API.PrivKeyPath), - httpapi.WithCAEmail(cfg.API.Email), - ) - } - - var server *httpapi.Server + var handler *httpapi.Handler if node != nil { - server, err = httpapi.NewServer(node, sOpt...) + handler, err = httpapi.NewHandler(node) } else { - server, err = httpapi.NewServer(db, sOpt...) + handler, err = httpapi.NewHandler(db) } if err != nil { - return nil, errors.Wrap("failed to create http server", err) + return nil, errors.Wrap("failed to create http handler", err) } - if err := server.Listen(ctx); err != nil { - return nil, errors.Wrap(fmt.Sprintf("failed to listen on TCP address %v", server.Addr), err) + server, err := httpapi.NewServer(handler, serverOpts...) + if err != nil { + return nil, errors.Wrap("failed to create http server", err) } - // save the address on the config in case the port number was set to random - cfg.API.Address = server.AssignedAddr() // run the server in a separate goroutine go func() { log.FeedbackInfo(ctx, fmt.Sprintf("Providing HTTP API at %s.", cfg.API.AddressToURL())) - if err := server.Run(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { log.FeedbackErrorE(ctx, "Failed to run the HTTP server", err) if node != nil { node.Close() diff --git a/config/config.go b/config/config.go index ef3aeaa8cc..23fd2d43fe 100644 --- a/config/config.go +++ b/config/config.go @@ -207,11 +207,13 @@ func (cfg *Config) paramsPreprocessing() error { if !filepath.IsAbs(cfg.v.GetString("datastore.badger.path")) { cfg.v.Set("datastore.badger.path", filepath.Join(cfg.Rootdir, cfg.v.GetString("datastore.badger.path"))) } - if !filepath.IsAbs(cfg.v.GetString("api.privkeypath")) { - cfg.v.Set("api.privkeypath", filepath.Join(cfg.Rootdir, cfg.v.GetString("api.privkeypath"))) + privKeyPath := cfg.v.GetString("api.privkeypath") + if privKeyPath != "" && !filepath.IsAbs(privKeyPath) { + cfg.v.Set("api.privkeypath", filepath.Join(cfg.Rootdir, privKeyPath)) } - if !filepath.IsAbs(cfg.v.GetString("api.pubkeypath")) { - cfg.v.Set("api.pubkeypath", filepath.Join(cfg.Rootdir, cfg.v.GetString("api.pubkeypath"))) + pubKeyPath := cfg.v.GetString("api.pubkeypath") + if pubKeyPath != "" && !filepath.IsAbs(pubKeyPath) { + cfg.v.Set("api.pubkeypath", filepath.Join(cfg.Rootdir, pubKeyPath)) } // log.logger configuration as a string @@ -303,8 +305,6 @@ func defaultAPIConfig() *APIConfig { Address: "localhost:9181", TLS: false, AllowedOrigins: []string{}, - PubKeyPath: "certs/server.key", - PrivKeyPath: "certs/server.crt", Email: DefaultAPIEmail, } } diff --git a/config/configfile_yaml.gotmpl b/config/configfile_yaml.gotmpl index 53d87c46d3..d789456a83 100644 --- a/config/configfile_yaml.gotmpl +++ b/config/configfile_yaml.gotmpl @@ -26,9 +26,9 @@ api: # The list of origins a cross-domain request can be executed from. # allowed-origins: {{ .API.AllowedOrigins }} # The path to the public key file. Ignored if domains is set. - pubkeypath: {{ .API.PubKeyPath }} + # pubkeypath: {{ .API.PubKeyPath }} # The path to the private key file. Ignored if domains is set. - privkeypath: {{ .API.PrivKeyPath }} + # privkeypath: {{ .API.PrivKeyPath }} # Email address to let the CA (Let's Encrypt) send notifications via email when there are issues (optional). # email: {{ .API.Email }} diff --git a/http/handler.go b/http/handler.go index 328ea8fab9..b06ef06cb6 100644 --- a/http/handler.go +++ b/http/handler.go @@ -20,7 +20,6 @@ import ( "github.com/sourcenetwork/defradb/datastore" "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" ) // Version is the identifier for the current API version. @@ -69,7 +68,7 @@ type Handler struct { txs *sync.Map } -func NewHandler(db client.DB, opts ServerOptions) (*Handler, error) { +func NewHandler(db client.DB) (*Handler, error) { router, err := NewApiRouter() if err != nil { return nil, err @@ -77,14 +76,9 @@ func NewHandler(db client.DB, opts ServerOptions) (*Handler, error) { txs := &sync.Map{} mux := chi.NewMux() - mux.Use( - middleware.RequestLogger(&logFormatter{}), - middleware.Recoverer, - CorsMiddleware(opts), - ) mux.Route("/api/"+Version, func(r chi.Router) { r.Use( - ApiMiddleware(db, txs, opts), + ApiMiddleware(db, txs), TransactionMiddleware, StoreMiddleware, ) diff --git a/http/handler_ccip_test.go b/http/handler_ccip_test.go index c0df7e6a26..2a2cc4f077 100644 --- a/http/handler_ccip_test.go +++ b/http/handler_ccip_test.go @@ -49,7 +49,7 @@ func TestCCIPGet_WithValidData(t *testing.T) { req := httptest.NewRequest(http.MethodGet, url, nil) rec := httptest.NewRecorder() - handler, err := NewHandler(cdb, ServerOptions{}) + handler, err := NewHandler(cdb) require.NoError(t, err) handler.ServeHTTP(rec, req) @@ -88,7 +88,7 @@ func TestCCIPGet_WithSubscription(t *testing.T) { req := httptest.NewRequest(http.MethodGet, url, nil) rec := httptest.NewRecorder() - handler, err := NewHandler(cdb, ServerOptions{}) + handler, err := NewHandler(cdb) require.NoError(t, err) handler.ServeHTTP(rec, req) @@ -106,7 +106,7 @@ func TestCCIPGet_WithInvalidData(t *testing.T) { req := httptest.NewRequest(http.MethodGet, url, nil) rec := httptest.NewRecorder() - handler, err := NewHandler(cdb, ServerOptions{}) + handler, err := NewHandler(cdb) require.NoError(t, err) handler.ServeHTTP(rec, req) @@ -135,7 +135,7 @@ func TestCCIPPost_WithValidData(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "http://localhost:9181/api/v0/ccip", bytes.NewBuffer(body)) rec := httptest.NewRecorder() - handler, err := NewHandler(cdb, ServerOptions{}) + handler, err := NewHandler(cdb) require.NoError(t, err) handler.ServeHTTP(rec, req) @@ -167,7 +167,7 @@ func TestCCIPPost_WithInvalidGraphQLRequest(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "http://localhost:9181/api/v0/ccip", bytes.NewBuffer(body)) rec := httptest.NewRecorder() - handler, err := NewHandler(cdb, ServerOptions{}) + handler, err := NewHandler(cdb) require.NoError(t, err) handler.ServeHTTP(rec, req) @@ -181,7 +181,7 @@ func TestCCIPPost_WithInvalidBody(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "http://localhost:9181/api/v0/ccip", nil) rec := httptest.NewRecorder() - handler, err := NewHandler(cdb, ServerOptions{}) + handler, err := NewHandler(cdb) require.NoError(t, err) handler.ServeHTTP(rec, req) diff --git a/http/middleware.go b/http/middleware.go index d33cbfb5ff..f18ba8bf60 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -56,13 +56,13 @@ var ( ) // CorsMiddleware handles cross origin request -func CorsMiddleware(opts ServerOptions) func(http.Handler) http.Handler { +func CorsMiddleware(allowedOrigins []string) func(http.Handler) http.Handler { return cors.Handler(cors.Options{ AllowOriginFunc: func(r *http.Request, origin string) bool { - if slices.Contains(opts.AllowedOrigins, "*") { + if slices.Contains(allowedOrigins, "*") { return true } - return slices.Contains(opts.AllowedOrigins, strings.ToLower(origin)) + return slices.Contains(allowedOrigins, strings.ToLower(origin)) }, AllowedMethods: []string{"GET", "HEAD", "POST", "PATCH", "DELETE"}, AllowedHeaders: []string{"Content-Type"}, @@ -71,13 +71,9 @@ func CorsMiddleware(opts ServerOptions) func(http.Handler) http.Handler { } // ApiMiddleware sets the required context values for all API requests. -func ApiMiddleware(db client.DB, txs *sync.Map, opts ServerOptions) func(http.Handler) http.Handler { +func ApiMiddleware(db client.DB, txs *sync.Map) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if opts.TLS.HasValue() { - rw.Header().Add("Strict-Transport-Security", "max-age=63072000; includeSubDomains") - } - ctx := req.Context() ctx = context.WithValue(ctx, dbContextKey, db) ctx = context.WithValue(ctx, txsContextKey, txs) diff --git a/http/server.go b/http/server.go index 768542c68d..80d46c6679 100644 --- a/http/server.go +++ b/http/server.go @@ -13,304 +13,173 @@ package http import ( "context" "crypto/tls" - "fmt" "net" "net/http" - "path" - "strings" + "time" - "github.com/sourcenetwork/immutable" - "golang.org/x/crypto/acme/autocert" - - "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/config" - "github.com/sourcenetwork/defradb/errors" - "github.com/sourcenetwork/defradb/logging" -) - -const ( - // These constants are best effort durations that fit our current API - // and possibly prevent from running out of file descriptors. - // readTimeout = 5 * time.Second - // writeTimeout = 10 * time.Second - // idleTimeout = 120 * time.Second - - // Temparily disabling timeouts until [this proposal](https://github.com/golang/go/issues/54136) is merged. - // https://github.com/sourcenetwork/defradb/issues/927 - readTimeout = 0 - writeTimeout = 0 - idleTimeout = 0 + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" ) -const ( - httpPort = ":80" - httpsPort = ":443" -) - -// Server struct holds the Handler for the HTTP API. -type Server struct { - options ServerOptions - listener net.Listener - certManager *autocert.Manager - // address that is assigned to the server on listen - address string - - http.Server +// We only allow cipher suites that are marked secure +// by ssllabs +var tlsCipherSuites = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, } type ServerOptions struct { + // Address is the bind address the server listens on. + Address string // AllowedOrigins is the list of allowed origins for CORS. AllowedOrigins []string - // TLS enables https when the value is present. - TLS immutable.Option[TLSOptions] - // RootDirectory is the directory for the node config. - RootDir string - // Domain is the domain for the API (optional). - Domain immutable.Option[string] -} - -type TLSOptions struct { - // PublicKey is the public key for TLS. Ignored if domain is set. - PublicKey string - // PrivateKey is the private key for TLS. Ignored if domain is set. - PrivateKey string - // Email is the address for the CA to send problem notifications (optional) - Email string - // Port is the tls port - Port string + // TLSCertPath is the path to the TLS certificate. + TLSCertPath string + // TLSKeyPath is the path to the TLS key. + TLSKeyPath string + // ReadTimeout is the read timeout for connections. + ReadTimeout time.Duration + // WriteTimeout is the write timeout for connections. + WriteTimeout time.Duration + // IdleTimeout is the idle timeout for connections. + IdleTimeout time.Duration } -// NewServer instantiates a new server with the given http.Handler. -func NewServer(db client.DB, options ...func(*Server)) (*Server, error) { - srv := &Server{ - Server: http.Server{ - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, - IdleTimeout: idleTimeout, - }, - } - - for _, opt := range append(options, DefaultOpts()) { - opt(srv) - } - - handler, err := NewHandler(db, srv.options) - if err != nil { - return nil, err +// DefaultOpts returns the default options for the server. +func DefaultServerOptions() *ServerOptions { + return &ServerOptions{ + Address: "127.0.0.1:9181", } - srv.Handler = handler - return srv, nil } -func newHTTPRedirServer(m *autocert.Manager) *Server { - srv := &Server{ - Server: http.Server{ - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, - IdleTimeout: idleTimeout, - }, - } - - srv.Addr = httpPort - srv.Handler = m.HTTPHandler(nil) - - return srv -} +// ServerOpt is a function that configures server options. +type ServerOpt func(*ServerOptions) -// DefaultOpts returns the default options for the server. -func DefaultOpts() func(*Server) { - return func(s *Server) { - if s.Addr == "" { - s.Addr = "localhost:9181" - } +// WithAllowedOrigins sets the allowed origins for CORS. +func WithAllowedOrigins(origins ...string) ServerOpt { + return func(opts *ServerOptions) { + opts.AllowedOrigins = origins } } -// WithAllowedOrigins returns an option to set the allowed origins for CORS. -func WithAllowedOrigins(origins ...string) func(*Server) { - return func(s *Server) { - s.options.AllowedOrigins = append(s.options.AllowedOrigins, origins...) +// WithAddress sets the bind address for the server. +func WithAddress(addr string) ServerOpt { + return func(opts *ServerOptions) { + opts.Address = addr } } -// WithAddress returns an option to set the address for the server. -func WithAddress(addr string) func(*Server) { - return func(s *Server) { - s.Addr = addr - - // If the address is not localhost, we check to see if it's a valid IP address. - // If it's not a valid IP, we assume that it's a domain name to be used with TLS. - if !strings.HasPrefix(addr, "localhost:") && !strings.HasPrefix(addr, ":") { - host, _, err := net.SplitHostPort(addr) - if err != nil { - host = addr - } - ip := net.ParseIP(host) - if ip == nil { - s.Addr = httpPort - s.options.Domain = immutable.Some(host) - } - } +// WithReadTimeout sets the server read timeout. +func WithReadTimeout(timeout time.Duration) ServerOpt { + return func(opts *ServerOptions) { + opts.ReadTimeout = timeout } } -// WithCAEmail returns an option to set the email address for the CA to send problem notifications. -func WithCAEmail(email string) func(*Server) { - return func(s *Server) { - tlsOpt := s.options.TLS.Value() - tlsOpt.Email = email - s.options.TLS = immutable.Some(tlsOpt) +// WithWriteTimeout sets the server write timeout. +func WithWriteTimeout(timeout time.Duration) ServerOpt { + return func(opts *ServerOptions) { + opts.WriteTimeout = timeout } } -// WithRootDir returns an option to set the root directory for the node config. -func WithRootDir(rootDir string) func(*Server) { - return func(s *Server) { - s.options.RootDir = rootDir +// WithIdleTimeout sets the server idle timeout. +func WithIdleTimeout(timeout time.Duration) ServerOpt { + return func(opts *ServerOptions) { + opts.IdleTimeout = timeout } } -// WithSelfSignedCert returns an option to set the public and private keys for TLS. -func WithSelfSignedCert(pubKey, privKey string) func(*Server) { - return func(s *Server) { - tlsOpt := s.options.TLS.Value() - tlsOpt.PublicKey = pubKey - tlsOpt.PrivateKey = privKey - s.options.TLS = immutable.Some(tlsOpt) +// WithTLSCertPath sets the server TLS certificate path. +func WithTLSCertPath(path string) ServerOpt { + return func(opts *ServerOptions) { + opts.TLSCertPath = path } } -// WithTLS returns an option to enable TLS. -func WithTLS() func(*Server) { - return func(s *Server) { - tlsOpt := s.options.TLS.Value() - tlsOpt.Port = httpsPort - s.options.TLS = immutable.Some(tlsOpt) +// WithTLSKeyPath sets the server TLS private key path. +func WithTLSKeyPath(path string) ServerOpt { + return func(opts *ServerOptions) { + opts.TLSKeyPath = path } } -// WithTLSPort returns an option to set the port for TLS. -func WithTLSPort(port int) func(*Server) { - return func(s *Server) { - tlsOpt := s.options.TLS.Value() - tlsOpt.Port = fmt.Sprintf(":%d", port) - s.options.TLS = immutable.Some(tlsOpt) - } +// Server struct holds the Handler for the HTTP API. +type Server struct { + options *ServerOptions + server *http.Server } -// Listen creates a new net.Listener and saves it on the receiver. -func (s *Server) Listen(ctx context.Context) error { - var err error - if s.options.TLS.HasValue() { - return s.listenWithTLS(ctx) - } - - lc := net.ListenConfig{} - s.listener, err = lc.Listen(ctx, "tcp", s.Addr) - if err != nil { - return errors.WithStack(err) +// NewServer instantiates a new server with the given http.Handler. +func NewServer(handler http.Handler, opts ...ServerOpt) (*Server, error) { + options := DefaultServerOptions() + for _, opt := range opts { + opt(options) } - // Save the address on the server in case the port was set to random - // and that we want to see what was assigned. - s.address = s.listener.Addr().String() + // setup a mux with the default middleware stack + mux := chi.NewMux() + mux.Use( + middleware.RequestLogger(&logFormatter{}), + middleware.Recoverer, + CorsMiddleware(options.AllowedOrigins), + ) + mux.Handle("/*", handler) - return nil -} - -func (s *Server) listenWithTLS(ctx context.Context) error { - cfg := &tls.Config{ - MinVersion: tls.VersionTLS12, - // We only allow cipher suites that are marked secure - // by ssllabs - CipherSuites: []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - }, - ServerName: "DefraDB", + server := &http.Server{ + ReadTimeout: options.ReadTimeout, + WriteTimeout: options.WriteTimeout, + IdleTimeout: options.IdleTimeout, + Handler: mux, } - if s.options.Domain.HasValue() && s.options.Domain.Value() != "" { - s.Addr = s.options.TLS.Value().Port - - if s.options.TLS.Value().Email == "" || s.options.TLS.Value().Email == config.DefaultAPIEmail { - return ErrNoEmail - } - - certCache := path.Join(s.options.RootDir, "autocerts") - - log.FeedbackInfo( - ctx, - "Generating auto certificate", - logging.NewKV("Domain", s.options.Domain.Value()), - logging.NewKV("Certificate cache", certCache), - ) - - m := &autocert.Manager{ - Cache: autocert.DirCache(certCache), - Prompt: autocert.AcceptTOS, - Email: s.options.TLS.Value().Email, - HostPolicy: autocert.HostWhitelist(s.options.Domain.Value()), - } - - cfg.GetCertificate = m.GetCertificate - - // We set manager on the server instance to later start - // a redirection server. - s.certManager = m - } else { - // When not using auto cert, we create a self signed certificate - // with the provided public and prive keys. - log.FeedbackInfo(ctx, "Generating self signed certificate") + return &Server{ + options: options, + server: server, + }, nil +} - cert, err := tls.LoadX509KeyPair( - s.options.TLS.Value().PrivateKey, - s.options.TLS.Value().PublicKey, - ) - if err != nil { - return NewErrFailedToLoadKeys(err, s.options.TLS.Value().PublicKey, s.options.TLS.Value().PrivateKey) - } +// Shutdown gracefully shuts down the server without interrupting any active connections. +func (s *Server) Shutdown(ctx context.Context) error { + return s.server.Shutdown(ctx) +} - cfg.Certificates = []tls.Certificate{cert} +// ListenAndServe listens for and serves incoming connections. +func (s *Server) ListenAndServe() error { + if s.options.TLSCertPath == "" && s.options.TLSKeyPath == "" { + return s.listenAndServe() } + return s.listenAndServeTLS() +} - var err error - s.listener, err = tls.Listen("tcp", s.Addr, cfg) +// listenAndServe listens for and serves http connections. +func (s *Server) listenAndServe() error { + listener, err := net.Listen("tcp", s.options.Address) if err != nil { - return errors.WithStack(err) + return err } - - // Save the address on the server in case the port was set to random - // and that we want to see what was assigned. - s.address = s.listener.Addr().String() - - return nil + return s.server.Serve(listener) } -// Run calls Serve with the receiver's listener. -func (s *Server) Run(ctx context.Context) error { - if s.listener == nil { - return ErrNoListener +// listenAndServeTLS listens for and serves https connections. +func (s *Server) listenAndServeTLS() error { + cert, err := tls.LoadX509KeyPair(s.options.TLSCertPath, s.options.TLSKeyPath) + if err != nil { + return err } - - if s.certManager != nil { - // When using TLS it's important to redirect http requests to https - go func() { - srv := newHTTPRedirServer(s.certManager) - err := srv.ListenAndServe() - if err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Info(ctx, "Something went wrong with the redirection server", logging.NewKV("Error", err)) - } - }() + config := &tls.Config{ + ServerName: "DefraDB", + MinVersion: tls.VersionTLS12, + CipherSuites: tlsCipherSuites, + Certificates: []tls.Certificate{cert}, } - return s.Serve(s.listener) -} - -// AssignedAddr returns the address that was assigned to the server on calls to listen. -func (s *Server) AssignedAddr() string { - return s.address + listener, err := net.Listen("tcp", s.options.Address) + if err != nil { + return err + } + return s.server.Serve(tls.NewListener(listener, config)) } diff --git a/http/server_test.go b/http/server_test.go index 04095b7c15..4065267c26 100644 --- a/http/server_test.go +++ b/http/server_test.go @@ -12,111 +12,19 @@ package http import ( "context" + "crypto/tls" "net/http" "os" + "path/filepath" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/crypto/acme/autocert" ) -func TestNewServerAndRunWithoutListener(t *testing.T) { - ctx := context.Background() - s, err := NewServer(nil, WithAddress(":0")) - require.NoError(t, err) - if ok := assert.NotNil(t, s); ok { - assert.Equal(t, ErrNoListener, s.Run(ctx)) - } -} - -func TestNewServerAndRunWithListenerAndInvalidPort(t *testing.T) { - ctx := context.Background() - s, err := NewServer(nil, WithAddress(":303000")) - require.NoError(t, err) - if ok := assert.NotNil(t, s); ok { - assert.Error(t, s.Listen(ctx)) - } -} - -func TestNewServerAndRunWithListenerAndValidPort(t *testing.T) { - ctx := context.Background() - serverRunning := make(chan struct{}) - serverDone := make(chan struct{}) - s, err := NewServer(nil, WithAddress(":0")) - require.NoError(t, err) - go func() { - close(serverRunning) - err := s.Listen(ctx) - assert.NoError(t, err) - err = s.Run(ctx) - assert.ErrorIs(t, http.ErrServerClosed, err) - defer close(serverDone) - }() - - <-serverRunning - - s.Shutdown(context.Background()) - - <-serverDone -} - -func TestNewServerAndRunWithAutocertWithoutEmail(t *testing.T) { - ctx := context.Background() - dir := t.TempDir() - s, err := NewServer(nil, WithAddress("example.com"), WithRootDir(dir), WithTLSPort(0)) - require.NoError(t, err) - err = s.Listen(ctx) - assert.ErrorIs(t, err, ErrNoEmail) - - s.Shutdown(context.Background()) -} - -func TestNewServerAndRunWithAutocert(t *testing.T) { - ctx := context.Background() - serverRunning := make(chan struct{}) - serverDone := make(chan struct{}) - dir := t.TempDir() - s, err := NewServer(nil, WithAddress("example.com"), WithRootDir(dir), WithTLSPort(0), WithCAEmail("dev@defradb.net")) - require.NoError(t, err) - go func() { - close(serverRunning) - err := s.Listen(ctx) - assert.NoError(t, err) - err = s.Run(ctx) - assert.ErrorIs(t, http.ErrServerClosed, err) - defer close(serverDone) - }() - - <-serverRunning - - s.Shutdown(context.Background()) - - <-serverDone -} - -func TestNewServerAndRunWithSelfSignedCertAndNoKeyFiles(t *testing.T) { - ctx := context.Background() - serverRunning := make(chan struct{}) - serverDone := make(chan struct{}) - dir := t.TempDir() - s, err := NewServer(nil, WithAddress("localhost:0"), WithSelfSignedCert(dir+"/server.crt", dir+"/server.key")) - require.NoError(t, err) - go func() { - close(serverRunning) - err := s.Listen(ctx) - assert.Contains(t, err.Error(), "failed to load given keys") - defer close(serverDone) - }() - - <-serverRunning - - s.Shutdown(context.Background()) - - <-serverDone -} - -const pubKey = `-----BEGIN EC PARAMETERS----- +// tlsKey is the TLS private key in PEM format +const tlsKey = `-----BEGIN EC PARAMETERS----- BgUrgQQAIg== -----END EC PARAMETERS----- -----BEGIN EC PRIVATE KEY----- @@ -126,7 +34,8 @@ pS0gW/SYpAncHhRuz18RQ2ycuXlSN1S/PAryRZ5PK2xORKfwpguEDEMdVwbHorZO K44P/h3dhyNyAyf8rcRoqKXcl/K/uew= -----END EC PRIVATE KEY-----` -const privKey = `-----BEGIN CERTIFICATE----- +// tlsKey is the TLS certificate in PEM format +const tlsCert = `-----BEGIN CERTIFICATE----- MIICQDCCAcUCCQDpMnN1gQ4fGTAKBggqhkjOPQQDAjCBiDELMAkGA1UEBhMCY2Ex DzANBgNVBAgMBlF1ZWJlYzEQMA4GA1UEBwwHQ2hlbHNlYTEPMA0GA1UECgwGU291 cmNlMRAwDgYDVQQLDAdEZWZyYURCMQ8wDQYDVQQDDAZzb3VyY2UxIjAgBgkqhkiG @@ -142,121 +51,145 @@ kgIxAKaEGC+lqp0aaN+yubYLRiTDxOlNpyiHox3nZiL4bG/CCdPDvbX63QcdI2yq XPKczg== -----END CERTIFICATE-----` -func TestNewServerAndRunWithSelfSignedCertAndInvalidPort(t *testing.T) { - ctx := context.Background() - serverRunning := make(chan struct{}) - serverDone := make(chan struct{}) - dir := t.TempDir() - err := os.WriteFile(dir+"/server.key", []byte(privKey), 0644) - if err != nil { - t.Fatal(err) - } - err = os.WriteFile(dir+"/server.crt", []byte(pubKey), 0644) - if err != nil { - t.Fatal(err) - } - s, err := NewServer(nil, WithAddress(":303000"), WithSelfSignedCert(dir+"/server.crt", dir+"/server.key")) +// insecureClient is an http client that trusts all tls certificates +var insecureClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, +} + +// testHandler returns an empty body and 200 status code +var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) +}) + +func TestServerListenAndServeWithInvalidAddress(t *testing.T) { + srv, err := NewServer(testHandler, WithAddress("invalid")) require.NoError(t, err) + + err = srv.ListenAndServe() + require.ErrorContains(t, err, "address invalid") +} + +func TestServerListenAndServeWithTLSAndInvalidAddress(t *testing.T) { + certPath, keyPath := writeTestCerts(t) + srv, err := NewServer(testHandler, WithAddress("invalid"), WithTLSCertPath(certPath), WithTLSKeyPath(keyPath)) + require.NoError(t, err) + + err = srv.ListenAndServe() + require.ErrorContains(t, err, "address invalid") +} + +func TestServerListenAndServeWithTLSAndInvalidCerts(t *testing.T) { + srv, err := NewServer(testHandler, WithAddress("invalid"), WithTLSCertPath("invalid.crt"), WithTLSKeyPath("invalid.key")) + require.NoError(t, err) + + err = srv.ListenAndServe() + require.ErrorContains(t, err, "no such file or directory") +} + +func TestServerListenAndServeWithAddress(t *testing.T) { + srv, err := NewServer(testHandler, WithAddress("127.0.0.1:30001")) + require.NoError(t, err) + go func() { - close(serverRunning) - err := s.Listen(ctx) - assert.Contains(t, err.Error(), "invalid port") - defer close(serverDone) + err := srv.ListenAndServe() + require.ErrorIs(t, http.ErrServerClosed, err) }() - <-serverRunning + // wait for server to start + <-time.After(time.Second * 1) + + res, err := http.Get("http://127.0.0.1:30001") + require.NoError(t, err) - s.Shutdown(context.Background()) + defer res.Body.Close() + assert.Equal(t, 200, res.StatusCode) - <-serverDone + err = srv.Shutdown(context.Background()) + require.NoError(t, err) } -func TestNewServerAndRunWithSelfSignedCert(t *testing.T) { - ctx := context.Background() - serverRunning := make(chan struct{}) - serverDone := make(chan struct{}) - dir := t.TempDir() - err := os.WriteFile(dir+"/server.key", []byte(privKey), 0644) - if err != nil { - t.Fatal(err) - } - err = os.WriteFile(dir+"/server.crt", []byte(pubKey), 0644) - if err != nil { - t.Fatal(err) - } - s, err := NewServer(nil, WithAddress("localhost:0"), WithSelfSignedCert(dir+"/server.crt", dir+"/server.key")) +func TestServerListenAndServeWithTLS(t *testing.T) { + certPath, keyPath := writeTestCerts(t) + srv, err := NewServer(testHandler, WithAddress("127.0.0.1:8443"), WithTLSCertPath(certPath), WithTLSKeyPath(keyPath)) require.NoError(t, err) + go func() { - close(serverRunning) - err := s.Listen(ctx) - assert.NoError(t, err) - err = s.Run(ctx) - assert.ErrorIs(t, http.ErrServerClosed, err) - defer close(serverDone) + err := srv.ListenAndServe() + require.ErrorIs(t, http.ErrServerClosed, err) }() - <-serverRunning + // wait for server to start + <-time.After(time.Second * 1) - s.Shutdown(context.Background()) + res, err := insecureClient.Get("https://127.0.0.1:8443") + require.NoError(t, err) - <-serverDone -} + defer res.Body.Close() + assert.Equal(t, 200, res.StatusCode) -func TestNewServerWithoutOptions(t *testing.T) { - s, err := NewServer(nil) + err = srv.Shutdown(context.Background()) require.NoError(t, err) - assert.Equal(t, "localhost:9181", s.Addr) - assert.Equal(t, []string(nil), s.options.AllowedOrigins) } -func TestNewServerWithAddress(t *testing.T) { - s, err := NewServer(nil, WithAddress("localhost:9999")) +func TestServerListenAndServeWithAllowedOrigins(t *testing.T) { + srv, err := NewServer(testHandler, WithAllowedOrigins("localhost"), WithAddress("127.0.0.1:30001")) require.NoError(t, err) - assert.Equal(t, "localhost:9999", s.Addr) -} -func TestNewServerWithDomainAddress(t *testing.T) { - s, err := NewServer(nil, WithAddress("example.com")) + go func() { + err := srv.ListenAndServe() + require.ErrorIs(t, http.ErrServerClosed, err) + }() + + // wait for server to start + <-time.After(time.Second * 1) + + req, err := http.NewRequest(http.MethodOptions, "http://127.0.0.1:30001", nil) require.NoError(t, err) - assert.Equal(t, "example.com", s.options.Domain.Value()) - assert.NotNil(t, s.options.TLS) -} + req.Header.Add("origin", "localhost") -func TestNewServerWithAllowedOrigins(t *testing.T) { - s, err := NewServer(nil, WithAllowedOrigins("https://source.network", "https://app.source.network")) + res, err := http.DefaultClient.Do(req) require.NoError(t, err) - assert.Equal(t, []string{"https://source.network", "https://app.source.network"}, s.options.AllowedOrigins) -} -func TestNewServerWithCAEmail(t *testing.T) { - s, err := NewServer(nil, WithCAEmail("me@example.com")) + defer res.Body.Close() + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, "localhost", res.Header.Get("Access-Control-Allow-Origin")) + + err = srv.Shutdown(context.Background()) require.NoError(t, err) - assert.Equal(t, "me@example.com", s.options.TLS.Value().Email) } -func TestNewServerWithRootDir(t *testing.T) { - dir := t.TempDir() - s, err := NewServer(nil, WithRootDir(dir)) +func TestServerWithReadTimeout(t *testing.T) { + srv, err := NewServer(testHandler, WithReadTimeout(time.Second)) require.NoError(t, err) - assert.Equal(t, dir, s.options.RootDir) + assert.Equal(t, time.Second, srv.server.ReadTimeout) } -func TestNewServerWithTLSPort(t *testing.T) { - s, err := NewServer(nil, WithTLSPort(44343)) +func TestServerWithWriteTimeout(t *testing.T) { + srv, err := NewServer(testHandler, WithWriteTimeout(time.Second)) require.NoError(t, err) - assert.Equal(t, ":44343", s.options.TLS.Value().Port) + assert.Equal(t, time.Second, srv.server.WriteTimeout) } -func TestNewServerWithSelfSignedCert(t *testing.T) { - s, err := NewServer(nil, WithSelfSignedCert("pub.key", "priv.key")) +func TestServerWithIdleTimeout(t *testing.T) { + srv, err := NewServer(testHandler, WithIdleTimeout(time.Second)) require.NoError(t, err) - assert.Equal(t, "pub.key", s.options.TLS.Value().PublicKey) - assert.Equal(t, "priv.key", s.options.TLS.Value().PrivateKey) - assert.NotNil(t, s.options.TLS) + assert.Equal(t, time.Second, srv.server.IdleTimeout) } -func TestNewHTTPRedirServer(t *testing.T) { - m := &autocert.Manager{} - s := newHTTPRedirServer(m) - assert.Equal(t, ":80", s.Addr) +func writeTestCerts(t *testing.T) (string, string) { + tempDir := t.TempDir() + certPath := filepath.Join(tempDir, "cert.pub") + keyPath := filepath.Join(tempDir, "cert.key") + + err := os.WriteFile(certPath, []byte(tlsCert), 0644) + require.NoError(t, err) + + err = os.WriteFile(keyPath, []byte(tlsKey), 0644) + require.NoError(t, err) + + return certPath, keyPath } diff --git a/tests/clients/cli/wrapper.go b/tests/clients/cli/wrapper.go index 49a0605598..a15b263709 100644 --- a/tests/clients/cli/wrapper.go +++ b/tests/clients/cli/wrapper.go @@ -40,7 +40,7 @@ type Wrapper struct { } func NewWrapper(node *net.Node) (*Wrapper, error) { - handler, err := http.NewHandler(node, http.ServerOptions{}) + handler, err := http.NewHandler(node) if err != nil { return nil, err } diff --git a/tests/clients/http/wrapper.go b/tests/clients/http/wrapper.go index 040ab9c1b4..4b935b8a0b 100644 --- a/tests/clients/http/wrapper.go +++ b/tests/clients/http/wrapper.go @@ -36,7 +36,7 @@ type Wrapper struct { } func NewWrapper(node *net.Node) (*Wrapper, error) { - handler, err := http.NewHandler(node, http.ServerOptions{}) + handler, err := http.NewHandler(node) if err != nil { return nil, err } diff --git a/tests/gen/cli/gendocs.go b/tests/gen/cli/gendocs.go index 226d73bc97..3bb94aef09 100644 --- a/tests/gen/cli/gendocs.go +++ b/tests/gen/cli/gendocs.go @@ -100,6 +100,16 @@ Example: The following command generates 100 User documents and 500 Device docum return nil }, } + + cmd.PersistentFlags().String( + "url", cfg.API.Address, + "URL of HTTP endpoint to listen on or connect to", + ) + err := cfg.BindFlag("api.address", cmd.PersistentFlags().Lookup("url")) + if err != nil { + panic(err) + } + cmd.Flags().StringVarP(&demandJSON, "demand", "d", "", "Documents' demand in JSON format") return cmd diff --git a/tests/gen/cli/gendocs_test.go b/tests/gen/cli/gendocs_test.go index 18b9b157c1..1bd30297a6 100644 --- a/tests/gen/cli/gendocs_test.go +++ b/tests/gen/cli/gendocs_test.go @@ -12,43 +12,39 @@ package cli import ( "bytes" + "context" "io" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/sourcenetwork/defradb/cli" - "github.com/sourcenetwork/defradb/config" "github.com/sourcenetwork/defradb/tests/gen" ) -func execAddSchemaCmd(t *testing.T, cfg *config.Config, schema string) { - rootCmd := cli.NewDefraCommand(cfg) - rootCmd.SetArgs([]string{"client", "schema", "add", schema}) - err := rootCmd.Execute() - require.NoError(t, err) -} - func TestGendocsCmd_IfNoErrors_ReturnGenerationOutput(t *testing.T) { - cfg, _, close := startTestNode(t) + defra, close := startTestNode(t) defer close() - execAddSchemaCmd(t, cfg, ` - type User { - name: String - devices: [Device] - } - type Device { - model: String - owner: User - }`) - - genDocsCmd := MakeGenDocCommand(cfg) + defra.db.AddSchema(context.Background(), ` + type User { + name: String + devices: [Device] + } + type Device { + model: String + owner: User + }`) + + genDocsCmd := MakeGenDocCommand(getTestConfig(t)) outputBuf := bytes.NewBufferString("") genDocsCmd.SetOut(outputBuf) - genDocsCmd.SetArgs([]string{"--demand", `{"User": 3, "Device": 12}`}) + genDocsCmd.SetArgs([]string{ + "--demand", `{"User": 3, "Device": 12}`, + "--url", strings.TrimPrefix(defra.server.URL, "http://"), + }) err := genDocsCmd.Execute() require.NoError(t, err) @@ -67,33 +63,38 @@ func TestGendocsCmd_IfNoErrors_ReturnGenerationOutput(t *testing.T) { } func TestGendocsCmd_IfInvalidDemandValue_ReturnError(t *testing.T) { - cfg, _, close := startTestNode(t) + defra, close := startTestNode(t) defer close() - execAddSchemaCmd(t, cfg, ` + defra.db.AddSchema(context.Background(), ` type User { name: String }`) - genDocsCmd := MakeGenDocCommand(cfg) - genDocsCmd.SetArgs([]string{"--demand", `{"User": invalid}`}) + genDocsCmd := MakeGenDocCommand(getTestConfig(t)) + genDocsCmd.SetArgs([]string{ + "--demand", `{"User": invalid}`, + "--url", strings.TrimPrefix(defra.server.URL, "http://"), + }) err := genDocsCmd.Execute() require.ErrorContains(t, err, errInvalidDemandValue) } func TestGendocsCmd_IfInvalidConfig_ReturnError(t *testing.T) { - cfg, _, close := startTestNode(t) + defra, close := startTestNode(t) defer close() - execAddSchemaCmd(t, cfg, ` + defra.db.AddSchema(context.Background(), ` type User { name: String }`) - genDocsCmd := MakeGenDocCommand(cfg) - - genDocsCmd.SetArgs([]string{"--demand", `{"Unknown": 3}`}) + genDocsCmd := MakeGenDocCommand(getTestConfig(t)) + genDocsCmd.SetArgs([]string{ + "--demand", `{"Unknown": 3}`, + "--url", strings.TrimPrefix(defra.server.URL, "http://"), + }) err := genDocsCmd.Execute() require.Error(t, err, gen.NewErrInvalidConfiguration("")) diff --git a/tests/gen/cli/util_test.go b/tests/gen/cli/util_test.go index a04761f6b5..81d713955c 100644 --- a/tests/gen/cli/util_test.go +++ b/tests/gen/cli/util_test.go @@ -12,9 +12,7 @@ package cli import ( "context" - "fmt" - "net/http" - "os" + "net/http/httptest" "testing" badger "github.com/sourcenetwork/badger/v4" @@ -33,21 +31,15 @@ var log = logging.MustNewLogger("cli") type defraInstance struct { db client.DB - server *httpapi.Server + server *httptest.Server } func (di *defraInstance) close(ctx context.Context) { di.db.Close() - if err := di.server.Close(); err != nil { - log.FeedbackInfo( - ctx, - "The server could not be closed successfully", - logging.NewKV("Error", err.Error()), - ) - } + di.server.Close() } -func start(ctx context.Context, cfg *config.Config) (*defraInstance, error) { +func start(ctx context.Context) (*defraInstance, error) { log.FeedbackInfo(ctx, "Starting DefraDB service...") log.FeedbackInfo(ctx, "Building new memory store") @@ -63,26 +55,11 @@ func start(ctx context.Context, cfg *config.Config) (*defraInstance, error) { return nil, errors.Wrap("failed to create database", err) } - server, err := httpapi.NewServer(db, httpapi.WithAddress(cfg.API.Address)) + handler, err := httpapi.NewHandler(db) if err != nil { - return nil, errors.Wrap("failed to create http server", err) - } - if err := server.Listen(ctx); err != nil { - return nil, errors.Wrap(fmt.Sprintf("failed to listen on TCP address %v", server.Addr), err) + return nil, errors.Wrap("failed to create http handler", err) } - // save the address on the config in case the port number was set to random - cfg.API.Address = server.AssignedAddr() - cfg.Persist() - - // run the server in a separate goroutine - go func(apiAddress string) { - log.FeedbackInfo(ctx, fmt.Sprintf("Providing HTTP API at %s.", apiAddress)) - if err := server.Run(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { - log.FeedbackErrorE(ctx, "Failed to run the HTTP server", err) - db.Close() - os.Exit(1) - } - }(cfg.API.AddressToURL()) + server := httptest.NewServer(handler) return &defraInstance{ db: db, @@ -101,11 +78,9 @@ func getTestConfig(t *testing.T) *config.Config { return cfg } -func startTestNode(t *testing.T) (*config.Config, *defraInstance, func()) { - cfg := getTestConfig(t) - +func startTestNode(t *testing.T) (*defraInstance, func()) { ctx := context.Background() - di, err := start(ctx, cfg) + di, err := start(ctx) require.NoError(t, err) - return cfg, di, func() { di.close(ctx) } + return di, func() { di.close(ctx) } }