From 298bc02b72b75580a4b1e758f5503eb35fd39245 Mon Sep 17 00:00:00 2001 From: Gareth Date: Thu, 13 Jun 2024 09:08:24 -0700 Subject: [PATCH] fix: cancel request context when timeout exceeded (#244) * feat: requests timeout respecting CLOUD_RUN_TIMEOUT_SECONDS * add test coverage * fix windows test --- funcframework/events.go | 4 ++ funcframework/framework.go | 26 ++++++- funcframework/framework_test.go | 121 ++++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 1 deletion(-) diff --git a/funcframework/events.go b/funcframework/events.go index 148423be..c095d940 100644 --- a/funcframework/events.go +++ b/funcframework/events.go @@ -192,6 +192,10 @@ func convertBackgroundToCloudEvent(ceHandler http.Handler) http.Handler { return } } + r, cancel := setContextTimeoutIfRequested(r) + if cancel != nil { + defer cancel() + } ceHandler.ServeHTTP(w, r) }) } diff --git a/funcframework/framework.go b/funcframework/framework.go index e9bdcd3b..e95f9273 100644 --- a/funcframework/framework.go +++ b/funcframework/framework.go @@ -25,7 +25,9 @@ import ( "os" "reflect" "runtime/debug" + "strconv" "strings" + "time" "github.com/GoogleCloudPlatform/functions-framework-go/internal/registry" cloudevents "github.com/cloudevents/sdk-go/v2" @@ -196,6 +198,10 @@ func wrapHTTPFunction(fn func(http.ResponseWriter, *http.Request)) (http.Handler defer fmt.Println() defer fmt.Fprintln(os.Stderr) } + r, cancel := setContextTimeoutIfRequested(r) + if cancel != nil { + defer cancel() + } defer recoverPanic(w, "user function execution", false) fn(w, r) }), nil @@ -212,7 +218,10 @@ func wrapEventFunction(fn interface{}) (http.Handler, error) { defer fmt.Println() defer fmt.Fprintln(os.Stderr) } - + r, cancel := setContextTimeoutIfRequested(r) + if cancel != nil { + defer cancel() + } if shouldConvertCloudEventToBackgroundRequest(r) { if err := convertCloudEventToBackgroundRequest(r); err != nil { writeHTTPErrorResponse(w, http.StatusBadRequest, crashStatus, fmt.Sprintf("error converting CloudEvent to Background Event: %v", err)) @@ -388,3 +397,18 @@ func writeHTTPErrorResponse(w http.ResponseWriter, statusCode int, status, msg s w.WriteHeader(statusCode) fmt.Fprint(w, msg) } + +// setContextTimeoutIfRequested replaces the request's context with a cancellation if requested +func setContextTimeoutIfRequested(r *http.Request) (*http.Request, func()) { + timeoutStr := os.Getenv("CLOUD_RUN_TIMEOUT_SECONDS") + if timeoutStr == "" { + return r, nil + } + timeoutSecs, err := strconv.Atoi(timeoutStr) + if err != nil { + fmt.Fprintf(os.Stderr, "Could not parse CLOUD_RUN_TIMEOUT_SECONDS as an integer value in seconds: %v\n", err) + return r, nil + } + ctx, cancel := context.WithTimeout(r.Context(), time.Duration(timeoutSecs)*time.Second) + return r.WithContext(ctx), cancel +} diff --git a/funcframework/framework_test.go b/funcframework/framework_test.go index 5be77ec9..41d66143 100644 --- a/funcframework/framework_test.go +++ b/funcframework/framework_test.go @@ -25,10 +25,12 @@ import ( "os" "strings" "testing" + "time" "github.com/GoogleCloudPlatform/functions-framework-go/functions" "github.com/GoogleCloudPlatform/functions-framework-go/internal/registry" cloudevents "github.com/cloudevents/sdk-go/v2" + "github.com/cloudevents/sdk-go/v2/event" "github.com/google/go-cmp/cmp" ) @@ -995,6 +997,125 @@ func TestServeMultipleFunctions(t *testing.T) { } } +func TestHTTPRequestTimeout(t *testing.T) { + timeoutEnvVar := "CLOUD_RUN_TIMEOUT_SECONDS" + prev := os.Getenv(timeoutEnvVar) + defer os.Setenv(timeoutEnvVar, prev) + + cloudeventsJSON := []byte(`{ + "specversion" : "1.0", + "type" : "com.github.pull.create", + "source" : "https://github.com/cloudevents/spec/pull", + "subject" : "123", + "id" : "A234-1234-1234", + "time" : "2018-04-05T17:31:00Z", + "comexampleextension1" : "value", + "datacontenttype" : "application/xml", + "data" : "" + }`) + + tcs := []struct { + name string + wantDeadline bool + waitForExpiration bool + timeout string + }{ + { + name: "deadline not requested", + wantDeadline: false, + waitForExpiration: false, + timeout: "", + }, + { + name: "NaN deadline", + wantDeadline: false, + waitForExpiration: false, + timeout: "aaa", + }, + { + name: "very long deadline", + wantDeadline: true, + waitForExpiration: false, + timeout: "3600", + }, + { + name: "short deadline should terminate", + wantDeadline: true, + waitForExpiration: true, + timeout: "1", + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + defer cleanup() + os.Setenv(timeoutEnvVar, tc.timeout) + + var httpReqCtx context.Context + functions.HTTP("http", func(w http.ResponseWriter, r *http.Request) { + if tc.waitForExpiration { + <-r.Context().Done() + } + httpReqCtx = r.Context() + }) + var ceReqCtx context.Context + functions.CloudEvent("cloudevent", func(ctx context.Context, event event.Event) error { + if tc.waitForExpiration { + <-ctx.Done() + } + ceReqCtx = ctx + return nil + }) + server, err := initServer() + if err != nil { + t.Fatalf("initServer(): %v", err) + } + srv := httptest.NewServer(server) + defer srv.Close() + + t.Run("http", func(t *testing.T) { + _, err = http.Get(srv.URL + "/http") + if err != nil { + t.Fatalf("expected success") + } + if httpReqCtx == nil { + t.Fatalf("expected non-nil request context") + } + deadline, ok := httpReqCtx.Deadline() + if ok != tc.wantDeadline { + t.Errorf("expected deadline %v but got %v", tc.wantDeadline, ok) + } + if expired := deadline.Before(time.Now()); ok && expired != tc.waitForExpiration { + t.Errorf("expected expired %v but got %v", tc.waitForExpiration, expired) + } + }) + + t.Run("cloudevent", func(t *testing.T) { + req, err := http.NewRequest("POST", srv.URL+"/cloudevent", bytes.NewBuffer(cloudeventsJSON)) + if err != nil { + t.Fatalf("failed to create request") + } + req.Header.Add("Content-Type", "application/cloudevents+json") + client := &http.Client{} + _, err = client.Do(req) + if err != nil { + t.Fatalf("request failed") + } + if ceReqCtx == nil { + t.Fatalf("expected non-nil request context") + } + deadline, ok := ceReqCtx.Deadline() + if ok != tc.wantDeadline { + t.Errorf("expected deadline %v but got %v", tc.wantDeadline, ok) + } + if expired := deadline.Before(time.Now()); ok && expired != tc.waitForExpiration { + t.Errorf("expected expired %v but got %v", tc.waitForExpiration, expired) + } + }) + }) + } +} + func cleanup() { os.Unsetenv("FUNCTION_TARGET") registry.Default().Reset()