Skip to content

Commit

Permalink
Backport new 5.0 implementation of connection timeout hints
Browse files Browse the repository at this point in the history
  • Loading branch information
fbiville committed Apr 11, 2022
1 parent e232f4b commit c8d2795
Show file tree
Hide file tree
Showing 5 changed files with 319 additions and 43 deletions.
6 changes: 5 additions & 1 deletion neo4j/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
package neo4j

import (
"context"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -114,7 +116,9 @@ func wrapError(err error) error {
if err == nil {
return nil
}
if err == io.EOF {
if err == io.EOF ||
errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) {
return &ConnectivityError{inner: err}
}
switch e := err.(type) {
Expand Down
64 changes: 49 additions & 15 deletions neo4j/internal/bolt/dechunker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,48 @@
package bolt

import (
"context"
"encoding/binary"
rio "github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/racingio"
"github.com/neo4j/neo4j-go-driver/v4/neo4j/log"
"io"
"net"
"time"
)

// dechunkMessage takes a buffer to be reused and returns the reusable buffer
// (might have been reallocated to handle growth), the message buffer and
// error.
// If a non-default connection read timeout configuration hint is passed, the dechunker resets the connection read
// deadline as well after successfully reading a chunk (NOOP messages included)
func dechunkMessage(conn net.Conn, msgBuf []byte, readTimeout time.Duration,
logger log.Logger, logName, logId string) ([]byte, []byte, error) {
// Reads will race against the provided context ctx
// If the server provides the connection read timeout hint readTimeout, a new context will be created from that timeout
// and the user-provided context ctx before every read
func dechunkMessage(
conn net.Conn,
msgBuf []byte,
readTimeout time.Duration,
logger log.Logger,
logName string,
logId string) ([]byte, []byte, error) {

sizeBuf := []byte{0x00, 0x00}
off := 0

reader := rio.NewRacingReader(conn)

for {
_, err := io.ReadFull(conn, sizeBuf)
updatedCtx, cancelFunc := newContext(readTimeout, logger, logName, logId)
_, err := reader.ReadFull(updatedCtx, sizeBuf)
if err != nil {
return msgBuf, nil, err
}
if cancelFunc != nil { // reading has been completed, time to release the context
cancelFunc()
}
chunkSize := int(binary.BigEndian.Uint16(sizeBuf))
if chunkSize == 0 {
if off > 0 {
return msgBuf, msgBuf[:off], nil
}
// Got a nop chunk
resetConnectionReadDeadline(conn, readTimeout, logger,
logName, logId)
continue
}

Expand All @@ -60,20 +72,42 @@ func dechunkMessage(conn net.Conn, msgBuf []byte, readTimeout time.Duration,
msgBuf = newMsgBuf
}
// Read the chunk into buffer
_, err = io.ReadFull(conn, msgBuf[off:(off+chunkSize)])
updatedCtx, cancelFunc = newContext(readTimeout, logger, logName, logId)
_, err = reader.ReadFull(updatedCtx, msgBuf[off:(off+chunkSize)])
if err != nil {
return msgBuf, nil, err
}
if cancelFunc != nil { // reading has been completed, time to release the context
cancelFunc()
}
off += chunkSize
resetConnectionReadDeadline(conn, readTimeout, logger, logName, logId)
}
}

func resetConnectionReadDeadline(conn net.Conn, readTimeout time.Duration, logger log.Logger, logName, logId string) {
if readTimeout < 0 {
return
// newContext computes a new context and cancel function if a readTimeout is set
func newContext(
readTimeout time.Duration,
logger log.Logger,
logName string,
logId string) (context.Context, context.CancelFunc) {

ctx := context.Background()
if readTimeout >= 0 {
newCtx, cancelFunc := context.WithTimeout(ctx, readTimeout)
logger.Debugf(logName, logId,
"read timeout of %s applied, chunk read deadline is now: %s",
readTimeout.String(),
deadlineOf(newCtx),
)
return newCtx, cancelFunc
}
if err := conn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil {
logger.Error(logName, logId, err)
return ctx, nil
}

func deadlineOf(ctx context.Context) string {
deadline, hasDeadline := ctx.Deadline()
if !hasDeadline {
return "N/A (no deadline set)"
}
return deadline.String()
}
40 changes: 13 additions & 27 deletions neo4j/internal/bolt/dechunker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package bolt
import (
"bytes"
"encoding/binary"
"github.com/neo4j/neo4j-go-driver/v4/neo4j/log"
"net"
"reflect"
"testing"
Expand Down Expand Up @@ -108,17 +109,10 @@ func TestDechunker(t *testing.T) {

func TestDechunkerWithTimeout(ot *testing.T) {
timeout := time.Millisecond * 600
serv, cli := net.Pipe()
defer func() {
AssertNoError(ot, serv.Close())
AssertNoError(ot, cli.Close())
}()
AssertNoError(ot, serv.SetReadDeadline(time.Now().Add(timeout)))
logger := &noopLogger{}
logName := "dechunker"
logId := "dechunker-test"

ot.Run("Resets connection deadline upon successful reads", func(t *testing.T) {
serv, cli := net.Pipe()
defer closePipe(ot, serv, cli)
go func() {
time.Sleep(timeout / 2)
AssertWriteSucceeds(t, cli, []byte{0x00, 0x00})
Expand All @@ -128,32 +122,24 @@ func TestDechunkerWithTimeout(ot *testing.T) {
AssertWriteSucceeds(t, cli, []byte{0x00, 0x00})
}()
buffer := make([]byte, 2)
_, _, err := dechunkMessage(serv, buffer, timeout, logger, logName,
logId)
_, _, err := dechunkMessage(serv, buffer, timeout, log.Void{}, "", "")
AssertNoError(t, err)
AssertTrue(t, reflect.DeepEqual(buffer, []byte{0xCA, 0xFE}))
})

ot.Run("Fails when connection deadline is reached", func(t *testing.T) {
_, _, err := dechunkMessage(serv, nil, timeout, logger, logName,
logId)
AssertError(t, err)
AssertStringContain(t, err.Error(), "read pipe")
})

}

type noopLogger struct {
}
serv, cli := net.Pipe()
defer closePipe(ot, serv, cli)

func (*noopLogger) Error(string, string, error) {
}
_, _, err := dechunkMessage(serv, nil, timeout, log.Void{}, "", "")

func (*noopLogger) Warnf(string, string, string, ...interface{}) {
}
AssertError(t, err)
AssertStringContain(t, err.Error(), "context deadline exceeded")
})

func (*noopLogger) Infof(string, string, string, ...interface{}) {
}

func (*noopLogger) Debugf(string, string, string, ...interface{}) {
func closePipe(t *testing.T, srv, cli net.Conn) {
AssertNoError(t, srv.Close())
AssertNoError(t, cli.Close())
}
99 changes: 99 additions & 0 deletions neo4j/internal/racingio/reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package racingio

import (
"context"
"errors"
"fmt"
"io"
)

type RacingReader interface {
Read(ctx context.Context, bytes []byte) (int, error)
ReadFull(ctx context.Context, bytes []byte) (int, error)
}

func NewRacingReader(reader io.Reader) RacingReader {
return &racingReader{reader: reader}
}

type racingReader struct {
reader io.Reader
}

func (rr *racingReader) Read(ctx context.Context, bytes []byte) (int, error) {
return rr.race(ctx, bytes, read)
}

func (rr *racingReader) ReadFull(ctx context.Context, bytes []byte) (int, error) {
return rr.race(ctx, bytes, readFull)
}

func (rr *racingReader) race(ctx context.Context, bytes []byte, readFn func(io.Reader, []byte) (int, error)) (int, error) {
if err := ctx.Err(); err != nil {
return 0, wrapRaceError(err)
}
resultChan := make(chan *ioResult, 1)
defer close(resultChan)
go func() {
n, err := readFn(rr.reader, bytes)
defer func() {
// When the read operation completes, the outer function may have returned already.
// In that situation, the channel will have been closed and the result emission will crash.
// Let's just swallow the panic that may happen and ignore it
_ = recover()
}()
resultChan <- &ioResult{
n: n,
err: err,
}
}()
select {
case <-ctx.Done():
return 0, wrapRaceError(ctx.Err())
case result := <-resultChan:
return result.n, wrapRaceError(result.err)
}
}

func wrapRaceError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
// temporary adjustment for 4.x
return fmt.Errorf("i/o timeout: %w", err)
}
return err
}

type ioResult struct {
n int
err error
}

func read(reader io.Reader, bytes []byte) (int, error) {
return reader.Read(bytes)
}

func readFull(reader io.Reader, bytes []byte) (int, error) {
return io.ReadFull(reader, bytes)
}
Loading

0 comments on commit c8d2795

Please sign in to comment.