diff --git a/network/http/gin_test.go b/network/http/gin_test.go index 76ab7efe..65b44739 100644 --- a/network/http/gin_test.go +++ b/network/http/gin_test.go @@ -6,8 +6,10 @@ package http import ( + "github.com/stretchr/testify/assert" "io" "net/http" + "strings" "testing" "time" @@ -31,11 +33,23 @@ func BenchmarkAllMiddlewares(b *testing.B) { }, }, { - name: "cors", + name: "cors-0", ms: []gin.HandlerFunc{ CORSMiddleware([]string{}), }, }, + { + name: "cors-1", + ms: []gin.HandlerFunc{ + CORSMiddleware([]string{"http://foobar.com"}), + }, + }, + { + name: "cors-2", + ms: []gin.HandlerFunc{ + CORSMiddleware([]string{"www.baidu.com"}), + }, + }, { name: "trace-id", ms: []gin.HandlerFunc{ @@ -79,14 +93,34 @@ func BenchmarkAllMiddlewares(b *testing.B) { time.Sleep(time.Second) for i := 0; i < b.N; i++ { - resp, err := http.Get("http://localhost:1234/v1/get") - if err != nil { - b.Logf("get error: %s, ignored", err) - } - - if resp.Body != nil { - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + if !strings.Contains(bc.name, "cors") { + resp, err := http.Get("http://localhost:1234/v1/get") + if err != nil { + b.Logf("get error: %s, ignored", err) + } + + if resp.Body != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + } else { + req, err := http.NewRequest("GET", "http://localhost:1234/v1/get", nil) + if err != nil { + b.Error(err) + } + origin := "http://foobar.com" + req.Header.Set("Origin", origin) + c := &http.Client{} + resp, err := c.Do(req) + if err != nil { + b.Error(err) + } + defer resp.Body.Close() + got := resp.Header.Get("Access-Control-Allow-Origin") + if bc.name == "cors-2" { + origin = "" + } + assert.Equal(b, origin, got, "expect %s, got '%s'", origin, got) } } srv.Close()