diff --git a/go/pkg/client/bytestream.go b/go/pkg/client/bytestream.go index 0a1bffd97..ac85b5db1 100644 --- a/go/pkg/client/bytestream.go +++ b/go/pkg/client/bytestream.go @@ -42,6 +42,20 @@ func (c *Client) WriteBytesAtRemoteOffset(ctx context.Context, name string, data return writtenBytes, nil } +// withCtx makes the niladic function f behaves like one that accepts a ctx. +func withCtx(ctx context.Context, f func() error) error { + errChan := make(chan error, 1) + go func() { + errChan <- f() + }() + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errChan: + return err + } +} + // writeChunked uploads chunked data with a given resource name to the CAS. func (c *Client) writeChunked(ctx context.Context, name string, ch *chunker.Chunker, doNotFinalize bool, initialOffset int64) (int64, error) { var totalBytes int64 @@ -54,6 +68,8 @@ func (c *Client) writeChunked(ctx context.Context, name string, ch *chunker.Chun // TODO(olaola): implement resumable uploads. initialOffset passed in allows to // start writing data at an arbitrary offset, but retries still restart from initialOffset. + ctx, cancel := context.WithCancel(ctx) + defer cancel() stream, err := c.Write(ctx) if err != nil { return err @@ -70,7 +86,11 @@ func (c *Client) writeChunked(ctx context.Context, name string, ch *chunker.Chun if !ch.HasNext() && !doNotFinalize { req.FinishWrite = true } - err = c.CallWithTimeout(ctx, "Write", func(_ context.Context) error { return stream.Send(req) }) + err = c.CallWithTimeout(ctx, "Write", func(ctx context.Context) error { + return withCtx(ctx, func() error { + return stream.Send(req) + }) + }) if err == io.EOF { break } @@ -79,7 +99,12 @@ func (c *Client) writeChunked(ctx context.Context, name string, ch *chunker.Chun } totalBytes += int64(len(req.Data)) } - if _, err := stream.CloseAndRecv(); err != nil { + if err := c.CallWithTimeout(ctx, "Write", func(ctx context.Context) error { + return withCtx(ctx, func() error { + _, err := stream.CloseAndRecv() + return err + }) + }); err != nil { return err } return nil @@ -132,6 +157,8 @@ func (c *Client) readToFile(ctx context.Context, name string, fpath string) (int // stream. The limit must be non-negative, although offset+limit may exceed the length of the // stream. func (c *Client) readStreamed(ctx context.Context, name string, offset, limit int64, w io.Writer) (int64, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() stream, err := c.Read(ctx, &bspb.ReadRequest{ ResourceName: name, ReadOffset: offset, @@ -144,10 +171,12 @@ func (c *Client) readStreamed(ctx context.Context, name string, offset, limit in var n int64 for { var resp *bspb.ReadResponse - err := c.CallWithTimeout(ctx, "Read", func(_ context.Context) error { - r, err := stream.Recv() - resp = r - return err + err := c.CallWithTimeout(ctx, "Read", func(ctx context.Context) error { + return withCtx(ctx, func() error { + r, err := stream.Recv() + resp = r + return err + }) }) if err == io.EOF { break diff --git a/go/pkg/client/bytestream_test.go b/go/pkg/client/bytestream_test.go index de4d0d622..0d71d9474 100644 --- a/go/pkg/client/bytestream_test.go +++ b/go/pkg/client/bytestream_test.go @@ -2,9 +2,11 @@ package client import ( "context" + "errors" "fmt" "net" "testing" + "time" "google.golang.org/grpc" @@ -21,6 +23,40 @@ type logStream struct { finalized bool } +func TestReadTimeout(t *testing.T) { + s := newServer(t) + defer s.shutDown() + + s.client.Retrier = nil + s.client.rpcTimeouts["Read"] = 100 * time.Millisecond + s.fake.read = func(req *bspb.ReadRequest, stream bsgrpc.ByteStream_ReadServer) error { + time.Sleep(1 * time.Second) + return stream.Send(&bspb.ReadResponse{}) + } + + _, err := s.client.ReadBytes(context.Background(), "test") + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected error %v, but got %v", context.DeadlineExceeded, err) + } +} + +func TestWriteTimeout(t *testing.T) { + s := newServer(t) + defer s.shutDown() + + s.client.Retrier = nil + s.client.rpcTimeouts["Write"] = 100 * time.Millisecond + s.fake.write = func(stream bsgrpc.ByteStream_WriteServer) error { + time.Sleep(1 * time.Second) + return fmt.Errorf("write should have timed out") + } + + err := s.client.WriteBytes(context.Background(), "test", []byte("hello")) + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected error %v, but got %v", context.DeadlineExceeded, err) + } +} + func TestWriteBytesAtRemoteOffsetSuccess_LogStream(t *testing.T) { tests := []struct { description string @@ -117,6 +153,9 @@ func TestWriteBytesAtRemoteOffsetSuccess_LogStream(t *testing.T) { } func TestWriteBytesAtRemoteOffsetErrors_LogStream(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow test because short is set") + } tests := []struct { description string ls *logStream @@ -180,6 +219,8 @@ func TestWriteBytesAtRemoteOffsetErrors_LogStream(t *testing.T) { type ByteStream struct { logStreams map[string]*logStream + read func(req *bspb.ReadRequest, stream bsgrpc.ByteStream_ReadServer) error + write func(stream bsgrpc.ByteStream_WriteServer) error } type Server struct { @@ -202,7 +243,7 @@ func newServer(t *testing.T) *Server { bsgrpc.RegisterByteStreamServer(s.server, s.fake) go s.server.Serve(s.listener) - s.client, err = NewClient(s.ctx, instance, DialParams{ + s.client, err = NewClient(s.ctx, "test", DialParams{ Service: s.listener.Addr().String(), NoSecurity: true, }, StartupCapabilities(false), ChunkMaxSize(2)) @@ -223,11 +264,18 @@ func (b *ByteStream) QueryWriteStatus(context.Context, *bspb.QueryWriteStatusReq } func (b *ByteStream) Read(req *bspb.ReadRequest, stream bsgrpc.ByteStream_ReadServer) error { - return nil + if b.read != nil { + return b.read(req, stream) + } + return stream.Send(&bspb.ReadResponse{Data: logStreamData}) } // Write implements the write operation for LogStream Write API. func (b *ByteStream) Write(stream bsgrpc.ByteStream_WriteServer) error { + if b.write != nil { + return b.write(stream) + } + defer stream.SendAndClose(&bspb.WriteResponse{}) req, err := stream.Recv() if err != nil { diff --git a/go/pkg/client/cas_test.go b/go/pkg/client/cas_test.go index 16e2c6a34..8f659a16a 100644 --- a/go/pkg/client/cas_test.go +++ b/go/pkg/client/cas_test.go @@ -435,6 +435,9 @@ func TestMissingBlobs(t *testing.T) { } func TestUploadConcurrent(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow test because short is set") + } t.Parallel() blobs := make([][]byte, 50) for i := range blobs { diff --git a/go/pkg/client/retries_test.go b/go/pkg/client/retries_test.go index 015b6416a..8e4e2fba1 100644 --- a/go/pkg/client/retries_test.go +++ b/go/pkg/client/retries_test.go @@ -306,6 +306,9 @@ func TestWriteRetries(t *testing.T) { } func TestRetryWriteBytesAtRemoteOffset(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow test because short is set") + } tests := []struct { description string initialOffset int64 @@ -441,6 +444,9 @@ func TestBatchWriteBlobsRpcRetriesExhausted(t *testing.T) { } func TestGetTreeRetries(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow test because short is set") + } t.Parallel() f := setup(t) defer f.shutDown()