From 725a426d2d0109bcf0fb39576e7a16261dcd8a5f Mon Sep 17 00:00:00 2001 From: rkodev <43806892+rkodev@users.noreply.github.com> Date: Fri, 12 Jul 2024 14:42:40 +0300 Subject: [PATCH] fix: allows registration of page iterator headers --- page_iterator.go | 8 ++++++- page_iterator_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/page_iterator.go b/page_iterator.go index 0e65b1c..5729edf 100644 --- a/page_iterator.go +++ b/page_iterator.go @@ -10,6 +10,8 @@ import ( "github.com/microsoft/kiota-abstractions-go/serialization" ) +const PageIteratorErrorRegistryKey = "PAGE_ITERATOR_ERROR_REGISTRY_KEY" + // PageIterator represents an iterator object that can be used to get subsequent pages of a collection. type PageIterator[T interface{}] struct { currentPage PageResult[T] @@ -18,6 +20,7 @@ type PageIterator[T interface{}] struct { constructorFunc serialization.ParsableFactory headers *abstractions.RequestHeaders reqOptions []abstractions.RequestOption + errorMappings abstractions.ErrorMappings } // PageResult represents a page object built from a graph response object @@ -57,12 +60,15 @@ func NewPageIterator[T interface{}](res interface{}, reqAdapter abstractions.Req return nil, err } + errorMapping := getErrorMapper(PageIteratorErrorRegistryKey) + return &PageIterator[T]{ currentPage: page, reqAdapter: reqAdapter, pauseIndex: 0, constructorFunc: constructorFunc, headers: abstractions.NewRequestHeaders(), + errorMappings: errorMapping, }, nil } @@ -160,7 +166,7 @@ func (pI *PageIterator[T]) fetchNextPage(context context.Context) (serialization requestInfo.Headers.AddAll(pI.headers) requestInfo.AddRequestOptions(pI.reqOptions) - graphResponse, err = pI.reqAdapter.Send(context, requestInfo, pI.constructorFunc, nil) + graphResponse, err = pI.reqAdapter.Send(context, requestInfo, pI.constructorFunc, pI.errorMappings) if err != nil { return nil, err } diff --git a/page_iterator_test.go b/page_iterator_test.go index 01c4942..73ba6bc 100644 --- a/page_iterator_test.go +++ b/page_iterator_test.go @@ -2,7 +2,9 @@ package msgraphgocore import ( "context" + "errors" "fmt" + "github.com/stretchr/testify/require" nethttp "net/http" httptest "net/http/httptest" testing "testing" @@ -67,6 +69,49 @@ func TestConstructorWithInvalidUserGraphResponse(t *testing.T) { assert.NotNil(t, err) } +func TestPageIteratorHandlesHTTPError(t *testing.T) { + errorMapping := abstractions.ErrorMappings{ + "4XX": internal.CreateSampleErrorFromDiscriminatorValue, + "5XX": internal.CreateSampleErrorFromDiscriminatorValue, + } + // register errorMapper + err := RegisterError(PageIteratorErrorRegistryKey, errorMapping) + require.NoError(t, err) + + testServer := httptest.NewServer(nethttp.HandlerFunc(func(w nethttp.ResponseWriter, req *nethttp.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(403) + fmt.Fprint(w, "{}") + })) + defer testServer.Close() + + graphResponse := buildGraphResponse() + mockPath := testServer.URL + "/next-page" + graphResponse.SetOdataNextLink(&mockPath) + + pageIterator, _ := NewPageIterator[internal.User](graphResponse, reqAdapter, ParsableCons) + headers := abstractions.NewRequestHeaders() + headers.Add("ConsistencyLevel", "eventual") + pageIterator.SetHeaders(headers) + res := make([]string, 0) + + err = pageIterator.Iterate(context.Background(), func(item internal.User) bool { + res = append(res, *item.GetDisplayName()) + return true + }) + + var sampleError *internal.SampleError + switch { + case errors.As(err, &sampleError): + assert.Equal(t, "error status code received from the API", err.Error()) + default: + assert.Fail(t, "error type is not as expected") + } + + err = DeRegisterError(PageIteratorErrorRegistryKey) + require.NoError(t, err) +} + func TestIterateStopsWhenCallbackReturnsFalse(t *testing.T) { res := make([]string, 0) graphResponse := buildGraphResponse() @@ -90,10 +135,13 @@ func TestIterateStopsWhenCallbackReturnsFalse(t *testing.T) { headers.Add("ConsistencyLevel", "eventual") pageIterator.SetHeaders(headers) - pageIterator.Iterate(context.Background(), func(item internal.User) bool { + err := pageIterator.Iterate(context.Background(), func(item internal.User) bool { res = append(res, *item.GetDisplayName()) return !(*item.GetId() == "2") }) + if err != nil { + t.Error(err) + } assert.Equal(t, len(res), 3) }