diff --git a/storage/grpc_client.go b/storage/grpc_client.go index fa9fcc2b8bac..35f0ceaf5809 100644 --- a/storage/grpc_client.go +++ b/storage/grpc_client.go @@ -1194,17 +1194,18 @@ func (c *grpcStorageClient) NewMultiRangeDownloader(ctx context.Context, params case <-rr.ctx.Done(): rr.mu.Lock() rr.done = true + if rr.stream != nil { + rr.stream.CloseSend() + } rr.mu.Unlock() return case <-rr.managerRetry: + // We are not closing stream here as it is already closed and we are retring it. return case <-rr.closeManager: rr.mu.Lock() - if len(rr.mp) != 0 { - for key := range rr.mp { - rr.mp[key].callback(rr.mp[key].offset, rr.mp[key].limit, fmt.Errorf("stream closed early")) - delete(rr.mp, key) - } + if rr.stream != nil { + rr.stream.CloseSend() } rr.mu.Unlock() return @@ -1253,11 +1254,27 @@ func (c *grpcStorageClient) NewMultiRangeDownloader(ctx context.Context, params for { select { case <-rr.ctx.Done(): + rr.mu.Lock() rr.done = true + if len(rr.mp) != 0 { + drainInboundReadStream(rr.stream) + } + for key := range rr.mp { + rr.mp[key].callback(rr.mp[key].offset, rr.mp[key].limit, rr.ctx.Err()) + delete(rr.mp, key) + } + rr.activeTask = 0 + rr.mu.Unlock() return case <-rr.receiverRetry: + // We are not draining from stream here as it is already closed and we are retring it. return case <-rr.closeReceiver: + rr.mu.Lock() + if len(rr.mp) != 0 { + drainInboundReadStream(rr.stream) + } + rr.mu.Unlock() return default: // This function reads the data sent for a particular range request and has a callback @@ -1470,18 +1487,32 @@ func (mr *gRPCBidiReader) wait() { // Close will notify stream manager goroutine that the reader has been closed, if it's still running. func (mr *gRPCBidiReader) close() error { - if mr.cancel != nil { - mr.cancel() - } + mr.closeManager <- true + mr.closeReceiver <- true mr.mu.Lock() + for key := range mr.mp { + mr.mp[key].callback(mr.mp[key].offset, mr.mp[key].limit, fmt.Errorf("stream closed early")) + delete(mr.mp, key) + } mr.done = true mr.activeTask = 0 + if mr.cancel != nil { + mr.cancel() + } mr.mu.Unlock() - mr.closeReceiver <- true - mr.closeManager <- true return nil } +// drainInboundReadStream calls stream.Recv() repeatedly until an error is returned. +// drainInboundReadStream always returns a non-nil error. io.EOF indicates all +// messages were successfully read. +func drainInboundReadStream(stream storagepb.Storage_BidiReadObjectClient) (err error) { + for err == nil { + _, err = stream.Recv() + } + return err +} + func (mrr *gRPCBidiReader) getHandle() []byte { return mrr.readHandle } @@ -2640,11 +2671,11 @@ func bucketContext(ctx context.Context, bucket string) context.Context { return gax.InsertMetadataIntoOutgoingContext(ctx, hds...) } -// drainInboundStream calls stream.Recv() repeatedly until an error is returned. +// drainInboundWriteStream calls stream.Recv() repeatedly until an error is returned. // It returns the last Resource received on the stream, or nil if no Resource -// was returned. drainInboundStream always returns a non-nil error. io.EOF +// was returned. drainInboundWriteStream always returns a non-nil error. io.EOF // indicates all messages were successfully read. -func drainInboundStream(stream storagepb.Storage_BidiWriteObjectClient) (object *storagepb.Object, err error) { +func drainInboundWriteStream(stream storagepb.Storage_BidiWriteObjectClient) (object *storagepb.Object, err error) { for err == nil { var resp *storagepb.BidiWriteObjectResponse resp, err = stream.Recv() @@ -2724,7 +2755,7 @@ func (s *gRPCOneshotBidiWriteBufferSender) sendBuffer(ctx context.Context, buf [ sendErr := s.stream.Send(req) if sendErr != nil { - obj, err = drainInboundStream(s.stream) + obj, err = drainInboundWriteStream(s.stream) s.stream = nil if sendErr != io.EOF { err = sendErr @@ -2737,7 +2768,7 @@ func (s *gRPCOneshotBidiWriteBufferSender) sendBuffer(ctx context.Context, buf [ s.stream.CloseSend() // Oneshot uploads only read from the response stream on completion or // failure - obj, err = drainInboundStream(s.stream) + obj, err = drainInboundWriteStream(s.stream) s.stream = nil if err == io.EOF { err = nil @@ -2849,7 +2880,7 @@ func (s *gRPCResumableBidiWriteBufferSender) sendBuffer(ctx context.Context, buf sendErr := s.stream.Send(req) if sendErr != nil { - obj, err = drainInboundStream(s.stream) + obj, err = drainInboundWriteStream(s.stream) s.stream = nil if err == io.EOF { // This is unexpected - we got an error on Send(), but not on Recv(). @@ -2861,7 +2892,7 @@ func (s *gRPCResumableBidiWriteBufferSender) sendBuffer(ctx context.Context, buf if finishWrite { s.stream.CloseSend() - obj, err = drainInboundStream(s.stream) + obj, err = drainInboundWriteStream(s.stream) s.stream = nil if err == io.EOF { err = nil diff --git a/storage/integration_test.go b/storage/integration_test.go index f2aa5ba69d06..9f5f291b8038 100644 --- a/storage/integration_test.go +++ b/storage/integration_test.go @@ -548,6 +548,123 @@ func TestIntegration_MRDWithNonRetriableError(t *testing.T) { }) } +// Test that context cancellation correctly stops a multi range download before completion. +func TestIntegration_MultiRangeDownloaderContextCancel(t *testing.T) { + multiTransportTest(skipHTTP("gRPC implementation specific test"), t, func(t *testing.T, ctx context.Context, bucket, _ string, client *Client) { + ctx, close := context.WithDeadline(ctx, time.Now().Add(time.Second*30)) + defer close() + content := make([]byte, 5<<20) + rand.New(rand.NewSource(0)).Read(content) + objName := "mrdnonretry" + // Upload test data. + obj := client.Bucket(bucket).Object(objName) + if err := writeObject(ctx, obj, "text/plain", content); err != nil { + t.Fatal(err) + } + defer func() { + if err := obj.Delete(ctx); err != nil { + log.Printf("failed to delete test object: %v", err) + } + }() + // Create a multi-range-reader and then cancel the context before completing the reads. + readerCtx, cancel := context.WithCancel(ctx) + reader, err := obj.NewMultiRangeDownloader(readerCtx) + if err != nil { + t.Fatalf("NewMultiRangeDownloader: %v", err) + } + res := make([]multiRangeDownloaderOutput, 3) + callback := func(x, y int64, err error) { + res[0].offset = x + res[0].limit = y + res[0].err = err + } + callback1 := func(x, y int64, err error) { + res[1].offset = x + res[1].limit = y + res[1].err = err + } + callback2 := func(x, y int64, err error) { + res[2].offset = x + res[2].limit = y + res[2].err = err + } + // Add one range on the reader, and then cancel the context. + reader.Add(&res[0].buf, 0, int64(len(content)), callback) + // As context is cancelled remaining ranges would result in context cancelled error or stream is closed errors. + cancel() + reader.Add(&res[1].buf, -10, 0, callback1) + reader.Add(&res[2].buf, 0, 10, callback2) + reader.Wait() + // we can get stream is closed, can't add range error in case process is over before we add the range. + expErr := fmt.Errorf("stream is closed, can't add range") + for i, k := range res { + // if we get nil error for any callback other than first, that should be an error. + if i == 0 && k.err == nil && !bytes.Equal(content, k.buf.Bytes()) { + t.Errorf("Error in read range offset %v, limit %v, got: %v; want: %v", + k.offset, k.limit, len(k.buf.Bytes()), len(content)) + } + if k.err == nil && k.err.Error() != expErr.Error() && !errors.Is(err, context.Canceled) && !(status.Code(err) == codes.Canceled) { + t.Fatalf("read range %v to %v: got error %v, want nil, context.Canceled or stream is closed error", k.offset, k.limit, k.err) + } + } + if err = reader.Close(); err != nil { + t.Fatalf("Error while closing reader %v", err) + } + }) +} + +func TestIntegration_MultiRangeDownloaderSuddenClose(t *testing.T) { + multiTransportTest(skipHTTP("gRPC implementation specific test"), t, func(t *testing.T, ctx context.Context, bucket string, _ string, client *Client) { + content := make([]byte, 5<<20) + rand.New(rand.NewSource(0)).Read(content) + objName := "MultiRangeDownloader" + + // Upload test data. + obj := client.Bucket(bucket).Object(objName) + if err := writeObject(ctx, obj, "text/plain", content); err != nil { + t.Fatal(err) + } + defer func() { + if err := obj.Delete(ctx); err != nil { + log.Printf("failed to delete test object: %v", err) + } + }() + reader, err := obj.NewMultiRangeDownloader(ctx) + if err != nil { + t.Fatalf("NewMultiRangeDownloader: %v", err) + } + res := make([]multiRangeDownloaderOutput, 3) + callback := func(x, y int64, err error) { + res[0].offset = x + res[0].limit = y + res[0].err = err + } + callback1 := func(x, y int64, err error) { + res[1].offset = x + res[1].limit = y + res[1].err = err + } + callback2 := func(x, y int64, err error) { + res[2].offset = x + res[2].limit = y + res[2].err = err + } + // Add three ranges on the reader, and then do a sudden close. + reader.Add(&res[0].buf, 0, int64(len(content)), callback) + reader.Close() + reader.Add(&res[1].buf, -10, 0, callback1) + reader.Add(&res[2].buf, 0, 10, callback2) + // we can get stream is closed, can't add range error in case process is over before we add the range. + expErr := fmt.Errorf("stream is closed, can't add range") + expErr2 := fmt.Errorf("stream closed early") + for _, k := range res { + if k.err.Error() != expErr.Error() && k.err.Error() != expErr2.Error() { + t.Fatalf("read range %v to %v: got error %v, want stream closed error", k.offset, k.limit, k.err) + } + } + }) +} + // Test in a GCE environment expected to be located in one of: // - us-west1-a, us-west1-b, us-west-c //