diff --git a/go/pkg/client/cas_download.go b/go/pkg/client/cas_download.go index 9fe5e1fb0..1cff3cf1b 100644 --- a/go/pkg/client/cas_download.go +++ b/go/pkg/client/cas_download.go @@ -256,12 +256,16 @@ func (c *Client) BatchDownloadBlobsWithStats(ctx context.Context, dgs []digest.D if err != nil { errDg = r.Digest errMsg = err.Error() + numErrs++ + allRetriable = false continue } r.Data = b default: errDg = r.Digest errMsg = fmt.Sprintf("blob returned with unsupported compressor %s", r.Compressor) + numErrs++ + allRetriable = false continue } bi := CompressedBlobInfo{ diff --git a/go/pkg/client/cas_test.go b/go/pkg/client/cas_test.go index e2a075dcd..64136b748 100644 --- a/go/pkg/client/cas_test.go +++ b/go/pkg/client/cas_test.go @@ -1930,3 +1930,123 @@ func TestBatchDownloadBlobsCompressed(t *testing.T) { t.Errorf("client.BatchDownloadBlobs(ctx, digests) had diff (want -> got):\n%s", diff) } } + +type readResponseModifier func(idx int, resp *repb.BatchReadBlobsResponse_Response) + +type invalidReadServer struct { + repb.ContentAddressableStorageServer + readResponseModifier +} + +func (s *invalidReadServer) setModifier(modifier readResponseModifier) { + s.readResponseModifier = modifier +} + +func (s *invalidReadServer) BatchReadBlobs(ctx context.Context, req *repb.BatchReadBlobsRequest) (*repb.BatchReadBlobsResponse, error) { + resp, err := s.ContentAddressableStorageServer.BatchReadBlobs(ctx, req) + + if s.readResponseModifier == nil { + return resp, err + } + + for idx, r := range resp.GetResponses() { + s.readResponseModifier(idx, r) + } + + return resp, err +} + +func TestBatchDownloadBlobsBrokenCompression(t *testing.T) { + t.Parallel() + ctx := context.Background() + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Cannot listen: %v", err) + } + fakeCAS := fakes.NewCAS() + defer listener.Close() + server := grpc.NewServer() + + s := invalidReadServer{ + ContentAddressableStorageServer: fakeCAS, + } + regrpc.RegisterContentAddressableStorageServer(server, &s) + go server.Serve(listener) + defer server.Stop() + c, err := client.NewClient(ctx, instance, client.DialParams{ + Service: listener.Addr().String(), + NoSecurity: true, + }, client.StartupCapabilities(false)) + if err != nil { + t.Fatalf("Error connecting to server: %v", err) + } + defer c.Close() + + fooDigest := fakeCAS.Put([]byte("foo")) + barDigest := fakeCAS.Put([]byte("bar")) + fakeDigest := fakeCAS.Put([]byte("fake")) + digests := []digest.Digest{fooDigest, barDigest, fakeDigest} + client.UseBatchCompression(true).Apply(c) + + type testCase struct { + name string + expected map[digest.Digest]client.CompressedBlobInfo + modifier readResponseModifier + errorContains string + } + tests := []testCase{ + { + name: "invalid compressor", + modifier: func(idx int, resp *repb.BatchReadBlobsResponse_Response) { + resp.Compressor = repb.Compressor_DEFLATE + }, + errorContains: "with unsupported compressor DEFLATE", + }, + { + name: "invalid data", + modifier: func(idx int, resp *repb.BatchReadBlobsResponse_Response) { + resp.Data = []byte{1, 3, 5, 9} + }, + errorContains: "magic number mismatch", + }, + { + name: "mixed errors", + modifier: func(idx int, resp *repb.BatchReadBlobsResponse_Response) { + if idx == 0 { + resp.Data[0] = 1 // Corrupt compressed data + } else if idx == 1 { + resp.Compressor = repb.Compressor_DEFLATE + } + }, + expected: map[digest.Digest]client.CompressedBlobInfo{ + fakeDigest: { + CompressedSize: 17, + Data: []byte("fake"), + }, + }, + errorContains: "with unsupported compressor DEFLATE", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + wantBlobs := test.expected + if wantBlobs == nil { + wantBlobs = map[digest.Digest]client.CompressedBlobInfo{} + } + s.setModifier(test.modifier) + defer s.setModifier(nil) + gotBlobs, err := c.BatchDownloadBlobsWithStats(ctx, digests) + if err == nil { + t.Error("client.BatchDownloadBlobs(ctx, digests) should return download error") + } + errMsg := err.Error() + if !strings.Contains(errMsg, test.errorContains) { + t.Errorf("client.BatchDownloadBlobs(ctx, digests) should report %s: %s", test.errorContains, errMsg) + } + if diff := cmp.Diff(wantBlobs, gotBlobs); diff != "" { + t.Errorf("client.BatchDownloadBlobs(ctx, digests) had diff (want -> got):\n%s", diff) + } + }) + } +}