Skip to content

Commit

Permalink
fix: allows registration of page iterator headers
Browse files Browse the repository at this point in the history
  • Loading branch information
rkodev committed Jul 12, 2024
1 parent 6d96cd9 commit 725a426
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
8 changes: 7 additions & 1 deletion page_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"github.com/microsoft/kiota-abstractions-go/serialization"
)

const PageIteratorErrorRegistryKey = "PAGE_ITERATOR_ERROR_REGISTRY_KEY"

// PageIterator represents an iterator object that can be used to get subsequent pages of a collection.
type PageIterator[T interface{}] struct {
currentPage PageResult[T]
Expand All @@ -18,6 +20,7 @@ type PageIterator[T interface{}] struct {
constructorFunc serialization.ParsableFactory
headers *abstractions.RequestHeaders
reqOptions []abstractions.RequestOption
errorMappings abstractions.ErrorMappings
}

// PageResult represents a page object built from a graph response object
Expand Down Expand Up @@ -57,12 +60,15 @@ func NewPageIterator[T interface{}](res interface{}, reqAdapter abstractions.Req
return nil, err
}

errorMapping := getErrorMapper(PageIteratorErrorRegistryKey)

return &PageIterator[T]{
currentPage: page,
reqAdapter: reqAdapter,
pauseIndex: 0,
constructorFunc: constructorFunc,
headers: abstractions.NewRequestHeaders(),
errorMappings: errorMapping,
}, nil
}

Expand Down Expand Up @@ -160,7 +166,7 @@ func (pI *PageIterator[T]) fetchNextPage(context context.Context) (serialization
requestInfo.Headers.AddAll(pI.headers)
requestInfo.AddRequestOptions(pI.reqOptions)

graphResponse, err = pI.reqAdapter.Send(context, requestInfo, pI.constructorFunc, nil)
graphResponse, err = pI.reqAdapter.Send(context, requestInfo, pI.constructorFunc, pI.errorMappings)
if err != nil {
return nil, err
}
Expand Down
50 changes: 49 additions & 1 deletion page_iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package msgraphgocore

import (
"context"
"errors"
"fmt"
"github.com/stretchr/testify/require"
nethttp "net/http"
httptest "net/http/httptest"
testing "testing"
Expand Down Expand Up @@ -67,6 +69,49 @@ func TestConstructorWithInvalidUserGraphResponse(t *testing.T) {
assert.NotNil(t, err)
}

func TestPageIteratorHandlesHTTPError(t *testing.T) {
errorMapping := abstractions.ErrorMappings{
"4XX": internal.CreateSampleErrorFromDiscriminatorValue,
"5XX": internal.CreateSampleErrorFromDiscriminatorValue,
}
// register errorMapper
err := RegisterError(PageIteratorErrorRegistryKey, errorMapping)
require.NoError(t, err)

testServer := httptest.NewServer(nethttp.HandlerFunc(func(w nethttp.ResponseWriter, req *nethttp.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(403)
fmt.Fprint(w, "{}")
}))
defer testServer.Close()

graphResponse := buildGraphResponse()
mockPath := testServer.URL + "/next-page"
graphResponse.SetOdataNextLink(&mockPath)

pageIterator, _ := NewPageIterator[internal.User](graphResponse, reqAdapter, ParsableCons)
headers := abstractions.NewRequestHeaders()
headers.Add("ConsistencyLevel", "eventual")
pageIterator.SetHeaders(headers)
res := make([]string, 0)

err = pageIterator.Iterate(context.Background(), func(item internal.User) bool {
res = append(res, *item.GetDisplayName())
return true
})

var sampleError *internal.SampleError
switch {
case errors.As(err, &sampleError):
assert.Equal(t, "error status code received from the API", err.Error())
default:
assert.Fail(t, "error type is not as expected")
}

err = DeRegisterError(PageIteratorErrorRegistryKey)
require.NoError(t, err)
}

func TestIterateStopsWhenCallbackReturnsFalse(t *testing.T) {
res := make([]string, 0)
graphResponse := buildGraphResponse()
Expand All @@ -90,10 +135,13 @@ func TestIterateStopsWhenCallbackReturnsFalse(t *testing.T) {
headers.Add("ConsistencyLevel", "eventual")
pageIterator.SetHeaders(headers)

pageIterator.Iterate(context.Background(), func(item internal.User) bool {
err := pageIterator.Iterate(context.Background(), func(item internal.User) bool {
res = append(res, *item.GetDisplayName())
return !(*item.GetId() == "2")
})
if err != nil {
t.Error(err)
}

assert.Equal(t, len(res), 3)
}
Expand Down

0 comments on commit 725a426

Please sign in to comment.