From c0e3240b09536752bcb5cf3be47c2bfce44d3b61 Mon Sep 17 00:00:00 2001 From: JT Olds Date: Wed, 28 Dec 2016 16:42:20 -0700 Subject: [PATCH 1/2] app engine support App Engine requires use of context-specific HTTP clients. Outbound HTTP requests simply don't work unless there's a way to provide the current request context to the outgoing request. Because a *Consumer might outlast multiple requests, we need a way to thread the request context through and then optionally create an *http.Client based on it. Unfortunately, this requires adding a bunch of new arguments. My preference would be to require contexts for all callsites, but I realize that breaks backwards compatibility, so I've added new methods. --- oauth.go | 85 +++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 59 insertions(+), 26 deletions(-) diff --git a/oauth.go b/oauth.go index c34e080..1a8125c 100644 --- a/oauth.go +++ b/oauth.go @@ -56,6 +56,9 @@ import ( "strings" "sync" "time" + + "github.com/jtolds/webhelp/whcompat" + "golang.org/x/net/context" ) const ( @@ -196,6 +199,9 @@ type Consumer struct { // Defaults to http.Client{}, can be overridden (e.g. for testing) as necessary HttpClient HttpClient + // If HttpClientFunc is set, will be used instead of HttpClient. + HttpClientFunc func(ctx context.Context) (HttpClient, error) + // Some APIs (e.g. Intuit/Quickbooks) require sending additional headers along with // requests. (like "Accept" to specify the response type as XML or JSON) Note that this // will only *add* headers, not set existing ones. @@ -383,10 +389,14 @@ func NewCustomRSAConsumer(consumerKey string, privateKey *rsa.PrivateKey, // - err: // Set only if there was an error, nil otherwise. func (c *Consumer) GetRequestTokenAndUrl(callbackUrl string) (rtoken *RequestToken, loginUrl string, err error) { - return c.GetRequestTokenAndUrlWithParams(callbackUrl, c.AdditionalParams) + return c.GetRequestTokenAndUrlWithParamsCtx(context.TODO(), callbackUrl, c.AdditionalParams) } func (c *Consumer) GetRequestTokenAndUrlWithParams(callbackUrl string, additionalParams map[string]string) (rtoken *RequestToken, loginUrl string, err error) { + return c.GetRequestTokenAndUrlWithParamsCtx(context.TODO(), callbackUrl, additionalParams) +} + +func (c *Consumer) GetRequestTokenAndUrlWithParamsCtx(ctx context.Context, callbackUrl string, additionalParams map[string]string) (rtoken *RequestToken, loginUrl string, err error) { params := c.baseParams(c.consumerKey, additionalParams) if callbackUrl != "" { params.Add(CALLBACK_PARAM, callbackUrl) @@ -401,7 +411,7 @@ func (c *Consumer) GetRequestTokenAndUrlWithParams(callbackUrl string, additiona return nil, "", err } - resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.RequestTokenUrl, params) + resp, err := c.getBody(ctx, c.serviceProvider.httpMethod(), c.serviceProvider.RequestTokenUrl, params) if err != nil { return nil, "", errors.New("getBody: " + err.Error()) } @@ -440,15 +450,19 @@ func (c *Consumer) GetRequestTokenAndUrlWithParams(callbackUrl string, additiona // - err: // Set only if there was an error, nil otherwise. func (c *Consumer) AuthorizeToken(rtoken *RequestToken, verificationCode string) (atoken *AccessToken, err error) { - return c.AuthorizeTokenWithParams(rtoken, verificationCode, c.AdditionalParams) + return c.AuthorizeTokenWithParamsCtx(context.TODO(), rtoken, verificationCode, c.AdditionalParams) } func (c *Consumer) AuthorizeTokenWithParams(rtoken *RequestToken, verificationCode string, additionalParams map[string]string) (atoken *AccessToken, err error) { + return c.AuthorizeTokenWithParamsCtx(context.TODO(), rtoken, verificationCode, additionalParams) +} + +func (c *Consumer) AuthorizeTokenWithParamsCtx(ctx context.Context, rtoken *RequestToken, verificationCode string, additionalParams map[string]string) (atoken *AccessToken, err error) { params := map[string]string{ VERIFIER_PARAM: verificationCode, TOKEN_PARAM: rtoken.Token, } - return c.makeAccessTokenRequestWithParams(params, rtoken.Secret, additionalParams) + return c.makeAccessTokenRequestWithParams(ctx, params, rtoken.Secret, additionalParams) } // Use the service provider to refresh the AccessToken for a given session. @@ -472,6 +486,10 @@ func (c *Consumer) AuthorizeTokenWithParams(rtoken *RequestToken, verificationCo // Set if accessToken does not contain the SESSION_HANDLE_PARAM needed to // refresh the token, or if an error occurred when making the request. func (c *Consumer) RefreshToken(accessToken *AccessToken) (atoken *AccessToken, err error) { + return c.RefreshTokenCtx(context.TODO(), accessToken) +} + +func (c *Consumer) RefreshTokenCtx(ctx context.Context, accessToken *AccessToken) (atoken *AccessToken, err error) { params := make(map[string]string) sessionHandle, ok := accessToken.AdditionalData[SESSION_HANDLE_PARAM] if !ok { @@ -480,7 +498,7 @@ func (c *Consumer) RefreshToken(accessToken *AccessToken) (atoken *AccessToken, params[SESSION_HANDLE_PARAM] = sessionHandle params[TOKEN_PARAM] = accessToken.Token - return c.makeAccessTokenRequest(params, accessToken.Secret) + return c.makeAccessTokenRequest(ctx, params, accessToken.Secret) } // Use the service provider to obtain an AccessToken for a given session @@ -497,11 +515,11 @@ func (c *Consumer) RefreshToken(accessToken *AccessToken) (atoken *AccessToken, // // - err: // Set only if there was an error, nil otherwise. -func (c *Consumer) makeAccessTokenRequest(params map[string]string, secret string) (atoken *AccessToken, err error) { - return c.makeAccessTokenRequestWithParams(params, secret, c.AdditionalParams) +func (c *Consumer) makeAccessTokenRequest(ctx context.Context, params map[string]string, secret string) (atoken *AccessToken, err error) { + return c.makeAccessTokenRequestWithParams(ctx, params, secret, c.AdditionalParams) } -func (c *Consumer) makeAccessTokenRequestWithParams(params map[string]string, secret string, additionalParams map[string]string) (atoken *AccessToken, err error) { +func (c *Consumer) makeAccessTokenRequestWithParams(ctx context.Context, params map[string]string, secret string, additionalParams map[string]string) (atoken *AccessToken, err error) { orderedParams := c.baseParams(c.consumerKey, additionalParams) for key, value := range params { orderedParams.Add(key, value) @@ -516,7 +534,7 @@ func (c *Consumer) makeAccessTokenRequestWithParams(params map[string]string, se return nil, err } - resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.AccessTokenUrl, orderedParams) + resp, err := c.getBody(ctx, c.serviceProvider.httpMethod(), c.serviceProvider.AccessTokenUrl, orderedParams) if err != nil { return nil, err } @@ -559,7 +577,7 @@ func (c *Consumer) MakeHttpClient(token *AccessToken) (*http.Client, error) { // - err: // Set only if there was an error, nil otherwise. func (c *Consumer) Get(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { - return c.makeAuthorizedRequest("GET", url, LOC_URL, "", userParams, token) + return c.makeAuthorizedRequest(context.TODO(), "GET", url, LOC_URL, "", userParams, token) } func encodeUserParams(userParams map[string]string) string { @@ -585,40 +603,40 @@ func (c *Consumer) Post(url string, userParams map[string]string, token *AccessT // ** DEPRECATED ** // Please call "Post" on the http client returned by MakeHttpClient instead func (c *Consumer) PostWithBody(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { - return c.makeAuthorizedRequest("POST", url, LOC_BODY, body, userParams, token) + return c.makeAuthorizedRequest(context.TODO(), "POST", url, LOC_BODY, body, userParams, token) } // ** DEPRECATED ** // Please call "Do" on the http client returned by MakeHttpClient instead // (and set the "Content-Type" header explicitly in the http.Request) func (c *Consumer) PostJson(url string, body string, token *AccessToken) (resp *http.Response, err error) { - return c.makeAuthorizedRequest("POST", url, LOC_JSON, body, nil, token) + return c.makeAuthorizedRequest(context.TODO(), "POST", url, LOC_JSON, body, nil, token) } // ** DEPRECATED ** // Please call "Do" on the http client returned by MakeHttpClient instead // (and set the "Content-Type" header explicitly in the http.Request) func (c *Consumer) PostXML(url string, body string, token *AccessToken) (resp *http.Response, err error) { - return c.makeAuthorizedRequest("POST", url, LOC_XML, body, nil, token) + return c.makeAuthorizedRequest(context.TODO(), "POST", url, LOC_XML, body, nil, token) } // ** DEPRECATED ** // Please call "Do" on the http client returned by MakeHttpClient instead // (and setup the multipart data explicitly in the http.Request) func (c *Consumer) PostMultipart(url, multipartName string, multipartData io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { - return c.makeAuthorizedRequestReader("POST", url, LOC_MULTIPART, 0, multipartName, multipartData, userParams, token) + return c.makeAuthorizedRequestReader(context.TODO(), "POST", url, LOC_MULTIPART, 0, multipartName, multipartData, userParams, token) } // ** DEPRECATED ** // Please call "Delete" on the http client returned by MakeHttpClient instead func (c *Consumer) Delete(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { - return c.makeAuthorizedRequest("DELETE", url, LOC_URL, "", userParams, token) + return c.makeAuthorizedRequest(context.TODO(), "DELETE", url, LOC_URL, "", userParams, token) } // ** DEPRECATED ** // Please call "Put" on the http client returned by MakeHttpClient instead func (c *Consumer) Put(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { - return c.makeAuthorizedRequest("PUT", url, LOC_URL, body, userParams, token) + return c.makeAuthorizedRequest(context.TODO(), "PUT", url, LOC_URL, body, userParams, token) } func (c *Consumer) Debug(enabled bool) { @@ -642,19 +660,19 @@ func (p pairs) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // consumer.Post() etc), and the new API (which takes actual http.Requests) // // So, here we construct the appropriate HTTP request for the inputs. -func (c *Consumer) makeAuthorizedRequestReader(method string, urlString string, dataLocation DataLocation, contentLength int, multipartName string, body io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { +func (c *Consumer) makeAuthorizedRequestReader(ctx context.Context, method string, urlString string, dataLocation DataLocation, contentLength int, multipartName string, body io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { urlObject, err := url.Parse(urlString) if err != nil { return nil, err } - request := &http.Request{ + request := whcompat.WithContext(&http.Request{ Method: method, URL: urlObject, Header: http.Header{}, Body: body, ContentLength: int64(contentLength), - } + }, ctx) vals := url.Values{} for k, v := range userParams { @@ -926,8 +944,12 @@ func (rt *RoundTripper) RoundTrip(userRequest *http.Request) (*http.Response, er fmt.Printf("Request: %v\n", serverRequest) } - resp, err := rt.consumer.HttpClient.Do(serverRequest) + client, err := rt.consumer.httpClient(whcompat.Context(userRequest)) + if err != nil { + return nil, errors.New("httpClient: " + err.Error()) + } + resp, err := client.Do(serverRequest) if err != nil { return resp, err } @@ -935,8 +957,8 @@ func (rt *RoundTripper) RoundTrip(userRequest *http.Request) (*http.Response, er return resp, nil } -func (c *Consumer) makeAuthorizedRequest(method string, url string, dataLocation DataLocation, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { - return c.makeAuthorizedRequestReader(method, url, dataLocation, len(body), "", ioutil.NopCloser(strings.NewReader(body)), userParams, token) +func (c *Consumer) makeAuthorizedRequest(ctx context.Context, method string, url string, dataLocation DataLocation, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequestReader(ctx, method, url, dataLocation, len(body), "", ioutil.NopCloser(strings.NewReader(body)), userParams, token) } type request struct { @@ -1215,8 +1237,8 @@ func (c *Consumer) requestString(method string, url string, params *OrderedParam return result } -func (c *Consumer) getBody(method, url string, oauthParams *OrderedParams) (*string, error) { - resp, err := c.httpExecute(method, url, "", 0, nil, oauthParams) +func (c *Consumer) getBody(ctx context.Context, method, url string, oauthParams *OrderedParams) (*string, error) { + resp, err := c.httpExecute(ctx, method, url, "", 0, nil, oauthParams) if err != nil { return nil, errors.New("httpExecute: " + err.Error()) } @@ -1254,7 +1276,14 @@ func (e HTTPExecuteError) Error() string { "\tRequest Headers: " + e.RequestHeaders } -func (c *Consumer) httpExecute( +func (c *Consumer) httpClient(ctx context.Context) (HttpClient, error) { + if c.HttpClientFunc != nil { + return c.HttpClientFunc(ctx) + } + return c.HttpClient, nil +} + +func (c *Consumer) httpExecute(ctx context.Context, method string, urlStr string, contentType string, contentLength int, body io.Reader, oauthParams *OrderedParams) (*http.Response, error) { // Create base request. req, err := http.NewRequest(method, urlStr, body) @@ -1295,7 +1324,11 @@ func (c *Consumer) httpExecute( if c.debug { fmt.Printf("Request: %v\n", req) } - resp, err := c.HttpClient.Do(req) + client, err := c.httpClient(ctx) + if err != nil { + return nil, errors.New("httpClient: " + err.Error()) + } + resp, err := client.Do(req) if err != nil { return nil, errors.New("Do: " + err.Error()) } From 75c3a5bf5860ce396a1ada3aa52c66e586119109 Mon Sep 17 00:00:00 2001 From: JT Olds Date: Wed, 28 Dec 2016 17:54:30 -0700 Subject: [PATCH 2/2] webhelp moved --- oauth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oauth.go b/oauth.go index 1a8125c..8a87c39 100644 --- a/oauth.go +++ b/oauth.go @@ -57,7 +57,7 @@ import ( "sync" "time" - "github.com/jtolds/webhelp/whcompat" + "gopkg.in/webhelp.v1/whcompat" "golang.org/x/net/context" )