diff --git a/examples/tls/main.go b/examples/tls/main.go index 627de54..dbaff3a 100644 --- a/examples/tls/main.go +++ b/examples/tls/main.go @@ -23,8 +23,8 @@ func run() error { return err } - certs := []tls.Certificate{cert} - server, err := wire.NewServer(handler, wire.Certificates(certs), wire.Logger(logger), wire.MessageBufferSize(100)) + config := &tls.Config{Certificates: []tls.Certificate{cert}} + server, err := wire.NewServer(handler, wire.TLSConfig(config), wire.Logger(logger), wire.MessageBufferSize(100)) if err != nil { return err } diff --git a/handshake.go b/handshake.go index 721bfb7..cc29717 100644 --- a/handshake.go +++ b/handshake.go @@ -146,7 +146,7 @@ func (srv *Server) potentialConnUpgrade(conn net.Conn, reader *buffer.Reader, ve srv.logger.Debug("attempting to upgrade the client to a TLS connection") - if len(srv.Certificates) == 0 { + if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 { srv.logger.Debug("no TLS certificates available continuing with a insecure connection") return srv.sslUnsupported(conn, reader, version) } @@ -156,15 +156,9 @@ func (srv *Server) potentialConnUpgrade(conn net.Conn, reader *buffer.Reader, ve return conn, reader, version, err } - tlsConfig := tls.Config{ - Certificates: srv.Certificates, - ClientAuth: srv.ClientAuth, - ClientCAs: srv.ClientCAs, - } - // NOTE: initialize the TLS connection and construct a new buffered // reader for the constructed TLS connection. - conn = tls.Server(conn, &tlsConfig) + conn = tls.Server(conn, srv.TLSConfig) reader = buffer.NewReader(srv.logger, conn, srv.BufferedMsgSize) version, err = srv.readVersion(reader) diff --git a/options.go b/options.go index 982b19f..15a9545 100644 --- a/options.go +++ b/options.go @@ -3,7 +3,6 @@ package wire import ( "context" "crypto/tls" - "crypto/x509" "log/slog" "regexp" "strconv" @@ -146,29 +145,11 @@ func MessageBufferSize(size int) OptionFn { } } -// Certificates sets the given TLS certificates to be used to initialize a +// TLSConfig sets the given TLS config to be used to initialize a // secure connection between the front-end (client) and back-end (server). -func Certificates(certs []tls.Certificate) OptionFn { +func TLSConfig(config *tls.Config) OptionFn { return func(srv *Server) error { - srv.Certificates = certs - return nil - } -} - -// ClientCAs sets the given Client CAs to be used, by the server, to verify a -// secure connection between the front-end (client) and back-end (server). -func ClientCAs(cas *x509.CertPool) OptionFn { - return func(srv *Server) error { - srv.ClientCAs = cas - return nil - } -} - -// ClientAuth sets the given Client Auth to be used, by the server, to verify a -// secure connection between the front-end (client) and back-end (server). -func ClientAuth(authType tls.ClientAuthType) OptionFn { - return func(srv *Server) error { - srv.ClientAuth = authType + srv.TLSConfig = config return nil } } diff --git a/wire.go b/wire.go index b5d77ab..f84a513 100644 --- a/wire.go +++ b/wire.go @@ -3,7 +3,6 @@ package wire import ( "context" "crypto/tls" - "crypto/x509" "errors" "fmt" "log/slog" @@ -60,9 +59,7 @@ type Server struct { Auth AuthStrategy BufferedMsgSize int Parameters Parameters - Certificates []tls.Certificate - ClientCAs *x509.CertPool - ClientAuth tls.ClientAuthType + TLSConfig *tls.Config parse ParseFn Session SessionHandler Statements StatementCache