From 29ef5e4b5d0634b1bd249d5c68b35b2148a0ba8f Mon Sep 17 00:00:00 2001 From: Abhishek Pandey Date: Fri, 2 Jun 2023 17:48:21 -0700 Subject: [PATCH] Refactor downloadItem (#3534) This refactor is being done to allow calling `downloadFile` with an URL obtained from URL cache. --- #### Does this PR need a docs update or release note? - [ ] :white_check_mark: Yes, it's included - [ ] :clock1: Yes, but in a later PR - [x] :no_entry: No #### Type of change - [ ] :sunflower: Feature - [ ] :bug: Bugfix - [ ] :world_map: Documentation - [ ] :robot: Supportability/Tests - [ ] :computer: CI/Deployment - [x] :broom: Tech Debt/Cleanup #### Issue(s) * # #### Test Plan - [ ] :muscle: Manual - [x] :zap: Unit test - [ ] :green_heart: E2E --- src/internal/connector/onedrive/item.go | 52 ++++-- src/internal/connector/onedrive/item_test.go | 174 +++++++++++++++++++ 2 files changed, 208 insertions(+), 18 deletions(-) diff --git a/src/internal/connector/onedrive/item.go b/src/internal/connector/onedrive/item.go index 9247f377e9..6c954fd409 100644 --- a/src/internal/connector/onedrive/item.go +++ b/src/internal/connector/onedrive/item.go @@ -27,9 +27,14 @@ func downloadItem( ag api.Getter, item models.DriveItemable, ) (io.ReadCloser, error) { + if item == nil { + return nil, clues.New("nil item") + } + var ( rc io.ReadCloser isFile = item.GetFile() != nil + err error ) if isFile { @@ -45,31 +50,42 @@ func downloadItem( } } - if len(url) == 0 { - return nil, clues.New("extracting file url") - } - - resp, err := ag.Get(ctx, url, nil) + rc, err = downloadFile(ctx, ag, url) if err != nil { - return nil, clues.Wrap(err, "getting item") + return nil, clues.Stack(err) } + } - if graph.IsMalwareResp(ctx, resp) { - return nil, clues.New("malware detected").Label(graph.LabelsMalware) - } + return rc, nil +} - if (resp.StatusCode / 100) != 2 { - // upstream error checks can compare the status with - // clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode)) - return nil, clues. - Wrap(clues.New(resp.Status), "non-2xx http response"). - Label(graph.LabelStatus(resp.StatusCode)) - } +func downloadFile( + ctx context.Context, + ag api.Getter, + url string, +) (io.ReadCloser, error) { + if len(url) == 0 { + return nil, clues.New("empty file url") + } - rc = resp.Body + resp, err := ag.Get(ctx, url, nil) + if err != nil { + return nil, clues.Wrap(err, "getting file") } - return rc, nil + if graph.IsMalwareResp(ctx, resp) { + return nil, clues.New("malware detected").Label(graph.LabelsMalware) + } + + if (resp.StatusCode / 100) != 2 { + // upstream error checks can compare the status with + // clues.HasLabel(err, graph.LabelStatus(http.KnownStatusCode)) + return nil, clues. + Wrap(clues.New(resp.Status), "non-2xx http response"). + Label(graph.LabelStatus(resp.StatusCode)) + } + + return resp.Body, nil } func downloadItemMeta( diff --git a/src/internal/connector/onedrive/item_test.go b/src/internal/connector/onedrive/item_test.go index fd556e7dd0..8c1af9ca70 100644 --- a/src/internal/connector/onedrive/item_test.go +++ b/src/internal/connector/onedrive/item_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "net/http" "testing" "github.com/alcionai/clues" @@ -256,3 +257,176 @@ func (suite *ItemIntegrationSuite) TestDriveGetFolder() { }) } } + +// Unit tests + +type mockGetter struct { + GetFunc func(ctx context.Context, url string) (*http.Response, error) +} + +func (m mockGetter) Get( + ctx context.Context, + url string, + headers map[string]string, +) (*http.Response, error) { + return m.GetFunc(ctx, url) +} + +type ItemUnitTestSuite struct { + tester.Suite +} + +func TestItemUnitTestSuite(t *testing.T) { + suite.Run(t, &ItemUnitTestSuite{Suite: tester.NewUnitSuite(t)}) +} + +func (suite *ItemUnitTestSuite) TestDownloadItem() { + testRc := io.NopCloser(bytes.NewReader([]byte("test"))) + url := "https://example.com" + + table := []struct { + name string + itemFunc func() models.DriveItemable + GetFunc func(ctx context.Context, url string) (*http.Response, error) + errorExpected require.ErrorAssertionFunc + rcExpected require.ValueAssertionFunc + label string + }{ + { + name: "nil item", + itemFunc: func() models.DriveItemable { + return nil + }, + GetFunc: func(ctx context.Context, url string) (*http.Response, error) { + return nil, nil + }, + errorExpected: require.Error, + rcExpected: require.Nil, + }, + { + name: "success", + itemFunc: func() models.DriveItemable { + di := newItem("test", false) + di.SetAdditionalData(map[string]interface{}{ + "@microsoft.graph.downloadUrl": url, + }) + + return di + }, + GetFunc: func(ctx context.Context, url string) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: testRc, + }, nil + }, + errorExpected: require.NoError, + rcExpected: require.NotNil, + }, + { + name: "success, content url set instead of download url", + itemFunc: func() models.DriveItemable { + di := newItem("test", false) + di.SetAdditionalData(map[string]interface{}{ + "@content.downloadUrl": url, + }) + + return di + }, + GetFunc: func(ctx context.Context, url string) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: testRc, + }, nil + }, + errorExpected: require.NoError, + rcExpected: require.NotNil, + }, + { + name: "api getter returns error", + itemFunc: func() models.DriveItemable { + di := newItem("test", false) + di.SetAdditionalData(map[string]interface{}{ + "@microsoft.graph.downloadUrl": url, + }) + + return di + }, + GetFunc: func(ctx context.Context, url string) (*http.Response, error) { + return nil, clues.New("test error") + }, + errorExpected: require.Error, + rcExpected: require.Nil, + }, + { + name: "download url is empty", + itemFunc: func() models.DriveItemable { + di := newItem("test", false) + return di + }, + GetFunc: func(ctx context.Context, url string) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: testRc, + }, nil + }, + errorExpected: require.Error, + rcExpected: require.Nil, + }, + { + name: "malware", + itemFunc: func() models.DriveItemable { + di := newItem("test", false) + di.SetAdditionalData(map[string]interface{}{ + "@microsoft.graph.downloadUrl": url, + }) + + return di + }, + GetFunc: func(ctx context.Context, url string) (*http.Response, error) { + return &http.Response{ + Header: http.Header{ + "X-Virus-Infected": []string{"true"}, + }, + StatusCode: http.StatusOK, + Body: testRc, + }, nil + }, + errorExpected: require.Error, + rcExpected: require.Nil, + }, + { + name: "non-2xx http response", + itemFunc: func() models.DriveItemable { + di := newItem("test", false) + di.SetAdditionalData(map[string]interface{}{ + "@microsoft.graph.downloadUrl": url, + }) + + return di + }, + GetFunc: func(ctx context.Context, url string) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: nil, + }, nil + }, + errorExpected: require.Error, + rcExpected: require.Nil, + }, + } + + for _, test := range table { + suite.Run(test.name, func() { + t := suite.T() + ctx, flush := tester.NewContext(t) + defer flush() + + mg := mockGetter{ + GetFunc: test.GetFunc, + } + rc, err := downloadItem(ctx, mg, test.itemFunc()) + test.errorExpected(t, err, clues.ToCore(err)) + test.rcExpected(t, rc) + }) + } +}