Skip to content

Commit

Permalink
Close socket if Bolt handshake fails
Browse files Browse the repository at this point in the history
Co-authored-by: Rouven Bauer <[email protected]>
  • Loading branch information
fbiville and robsdedude authored Mar 17, 2023
1 parent 62bae2d commit 3f5f3f4
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 19 deletions.
58 changes: 39 additions & 19 deletions neo4j/internal/connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,17 @@ import (
)

type Connector struct {
SkipEncryption bool
SkipVerify bool
RootCAs *x509.CertPool
DialTimeout time.Duration
SocketKeepAlive bool
Auth map[string]interface{}
Log log.Logger
UserAgent string
RoutingContext map[string]string
Network string
SkipEncryption bool
SkipVerify bool
RootCAs *x509.CertPool
DialTimeout time.Duration
SocketKeepAlive bool
Auth map[string]interface{}
Log log.Logger
UserAgent string
RoutingContext map[string]string
Network string
SupplyConnection func(address string) (net.Conn, error)
}

type ConnectError struct {
Expand All @@ -63,19 +64,24 @@ func (e *TlsError) Error() string {
}

func (c Connector) Connect(address string, boltLogger log.BoltLogger) (db.Connection, error) {
dialer := net.Dialer{Timeout: c.DialTimeout}
if !c.SocketKeepAlive {
dialer.KeepAlive = -1 * time.Second // Turns keep-alive off
if c.SupplyConnection == nil {
c.SupplyConnection = c.createConnection
}

conn, err := dialer.Dial(c.Network, address)
conn, err := c.SupplyConnection(address)
if err != nil {
return nil, &ConnectError{inner: err}
}

// TLS not requested, perform Bolt handshake
// TLS not requested
if c.SkipEncryption {
return bolt.Connect(address, conn, c.Auth, c.UserAgent, c.RoutingContext, c.Log, boltLogger)
connection, err := bolt.Connect(address, conn, c.Auth, c.UserAgent, c.RoutingContext, c.Log, boltLogger)
if err != nil {
if connErr := conn.Close(); connErr != nil {
c.Log.Warnf(log.Driver, "", "Could not close underlying socket after Bolt handshake error")
}
return nil, err
}
return connection, err
}

// TLS requested, continue with handshake
Expand All @@ -100,6 +106,20 @@ func (c Connector) Connect(address string, boltLogger log.BoltLogger) (db.Connec
conn.Close()
return nil, &TlsError{inner: err}
}
// Perform Bolt handshake
return bolt.Connect(address, tlsconn, c.Auth, c.UserAgent, c.RoutingContext, c.Log, boltLogger)
connection, err := bolt.Connect(address, tlsconn, c.Auth, c.UserAgent, c.RoutingContext, c.Log, boltLogger)
if err != nil {
if connErr := conn.Close(); connErr != nil {
c.Log.Warnf(log.Driver, "", "Could not close underlying socket after Bolt handshake error")
}
return nil, err
}
return connection, nil
}

func (c Connector) createConnection(address string) (net.Conn, error) {
dialer := net.Dialer{Timeout: c.DialTimeout}
if !c.SocketKeepAlive {
dialer.KeepAlive = -1 * time.Second // Turns keep-alive off
}
return dialer.Dial(c.Network, address)
}
159 changes: 159 additions & 0 deletions neo4j/internal/connector/connector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [https://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package connector_test

import (
"github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/connector"
. "github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/testutil"
"io"
"net"
"testing"
"time"
)

func TestConnect(outer *testing.T) {
outer.Parallel()

outer.Run("closes connection if Bolt handshake does not reach agreement", func(t *testing.T) {
clientConnection, server := setUp(t)
go func() {
server.acceptVersion(1, 0)
}()
connectionDelegate := &ConnDelegate{Delegate: clientConnection}
connector := &connector.Connector{SupplyConnection: supplyThis(connectionDelegate), SkipEncryption: true}

connection, err := connector.Connect("irrelevant", nil)

AssertNil(t, connection)
AssertErrorMessageContains(t, err, "unsupported version 1.0")
AssertTrue(t, connectionDelegate.Closed)
})

outer.Run("closes connection if Bolt handshake errors", func(t *testing.T) {
clientConnection, server := setUp(t)
go func() {
server.failAcceptingVersion()
}()
connectionDelegate := &ConnDelegate{Delegate: clientConnection}
connector := &connector.Connector{SupplyConnection: supplyThis(connectionDelegate), SkipEncryption: true}

connection, err := connector.Connect("irrelevant", nil)

AssertNil(t, connection)
AssertError(t, err)
AssertTrue(t, connectionDelegate.Closed)
})
}

func setUp(t *testing.T) (net.Conn, *boltHandshakeServer) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Unable to listen: %s", err)
}
t.Cleanup(func() {
_ = listener.Close()
})

address := listener.Addr()
clientConnection, err := net.Dial(address.Network(), address.String())
if err != nil {
t.Fatalf("Dial error: %s", err)
}
t.Cleanup(func() {
_ = clientConnection.Close()
})
serverConnection, err := listener.Accept()
if err != nil {
t.Fatalf("Accept error: %s", err)
}
t.Cleanup(func() {
_ = serverConnection.Close()
})
handshakeServer := &boltHandshakeServer{t, serverConnection}
return clientConnection, handshakeServer
}

func supplyThis(connection net.Conn) func(address string) (net.Conn, error) {
return func(address string) (net.Conn, error) {
return connection, nil
}
}

type boltHandshakeServer struct {
t *testing.T
conn net.Conn
}

func (server *boltHandshakeServer) waitForHandshake() []byte {
handshake := make([]byte, 4*5)
if _, err := io.ReadFull(server.conn, handshake); err != nil {
server.t.Fatalf("Unable to read client versions: %s", err)
}
return handshake
}

func (server *boltHandshakeServer) acceptVersion(major, minor byte) {
server.waitForHandshake()
if _, err := server.conn.Write([]byte{0x00, 0x00, minor, major}); err != nil {
panic(err)
}
}

func (server *boltHandshakeServer) failAcceptingVersion() {
_ = server.conn.Close()
}

type ConnDelegate struct {
Closed bool
Delegate net.Conn
}

func (cd *ConnDelegate) Read(b []byte) (n int, err error) {
return cd.Delegate.Read(b)
}

func (cd *ConnDelegate) Write(b []byte) (n int, err error) {
return cd.Delegate.Write(b)
}

func (cd *ConnDelegate) Close() error {
cd.Closed = true
return cd.Delegate.Close()
}

func (cd *ConnDelegate) LocalAddr() net.Addr {
return cd.Delegate.LocalAddr()
}

func (cd *ConnDelegate) RemoteAddr() net.Addr {
return cd.Delegate.RemoteAddr()
}

func (cd *ConnDelegate) SetDeadline(t time.Time) error {
return cd.Delegate.SetDeadline(t)
}

func (cd *ConnDelegate) SetReadDeadline(t time.Time) error {
return cd.Delegate.SetReadDeadline(t)
}

func (cd *ConnDelegate) SetWriteDeadline(t time.Time) error {
return cd.Delegate.SetWriteDeadline(t)
}

0 comments on commit 3f5f3f4

Please sign in to comment.