diff --git a/contrib/net/http/context.go b/contrib/net/http/context.go new file mode 100644 index 0000000000..9603dc3edf --- /dev/null +++ b/contrib/net/http/context.go @@ -0,0 +1,43 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package http + +import ( + "context" + "io" + "net/http" + "net/url" + "strings" +) + +func Get(ctx context.Context, url string) (resp *http.Response, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + return http.DefaultClient.Do(req) +} + +func Head(ctx context.Context, url string) (resp *http.Response, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) + if err != nil { + return nil, err + } + return http.DefaultClient.Do(req) +} + +func Post(ctx context.Context, url string, contentType string, body io.Reader) (resp *http.Response, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", contentType) + return http.DefaultClient.Do(req) +} + +func PostForm(ctx context.Context, url string, data url.Values) (resp *http.Response, err error) { + return Post(ctx, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} diff --git a/contrib/net/http/context_test.go b/contrib/net/http/context_test.go new file mode 100644 index 0000000000..1c1b4a5152 --- /dev/null +++ b/contrib/net/http/context_test.go @@ -0,0 +1,190 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package http + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + validURL = "http://example.com" + invalidURL = "http:/\x00/invalid." +) + +func TestGet(t *testing.T) { + ctx := context.Background() + t.Run("valid URL", func(t *testing.T) { + withMockDefaultClient( + func(req *http.Request) (*http.Response, error) { + assert.Equal(t, ctx, req.Context()) + assert.Equal(t, "GET", req.Method) + assert.Equal(t, validURL, req.URL.String()) + return &http.Response{StatusCode: 200}, nil + }, + func() { + res, err := Get(ctx, validURL) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + }, + ) + }) + + t.Run("invalid URL", func(t *testing.T) { + withMockDefaultClient( + func(*http.Request) (*http.Response, error) { + assert.Fail(t, "unexpected call to RoundTrip") + return nil, errors.New("unreachable") + }, + func() { + ctx := context.Background() + res, err := Get(ctx, invalidURL) + require.Error(t, err) + require.Nil(t, res) + }, + ) + }) +} + +func TestHead(t *testing.T) { + t.Run("valid URL", func(t *testing.T) { + ctx := context.Background() + withMockDefaultClient( + func(req *http.Request) (*http.Response, error) { + assert.Equal(t, ctx, req.Context()) + assert.Equal(t, "HEAD", req.Method) + assert.Equal(t, validURL, req.URL.String()) + return &http.Response{StatusCode: 200}, nil + }, + func() { + res, err := Head(ctx, validURL) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + }, + ) + }) + + t.Run("invalid URL", func(t *testing.T) { + withMockDefaultClient( + func(*http.Request) (*http.Response, error) { + assert.Fail(t, "unexpected call to RoundTrip") + return nil, errors.New("unreachable") + }, + func() { + ctx := context.Background() + res, err := Head(ctx, invalidURL) + require.Error(t, err) + require.Nil(t, res) + }, + ) + }) +} + +func TestPost(t *testing.T) { + const contentType = "text/plain" + body := []byte("hello") + + t.Run("valid URL", func(t *testing.T) { + ctx := context.Background() + withMockDefaultClient( + func(req *http.Request) (*http.Response, error) { + assert.Equal(t, ctx, req.Context()) + assert.Equal(t, "POST", req.Method) + assert.Equal(t, validURL, req.URL.String()) + assert.Equal(t, contentType, req.Header.Get("content-type")) + data, err := io.ReadAll(req.Body) + require.NoError(t, err) + assert.Equal(t, body, data) + return &http.Response{StatusCode: 200}, nil + }, + func() { + res, err := Post(ctx, validURL, contentType, bytes.NewReader(body)) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + }, + ) + }) + + t.Run("invalid URL", func(t *testing.T) { + withMockDefaultClient( + func(*http.Request) (*http.Response, error) { + assert.Fail(t, "unexpected call to RoundTrip") + return nil, errors.New("unreachable") + }, + func() { + ctx := context.Background() + res, err := Post(ctx, invalidURL, contentType, bytes.NewReader(body)) + require.Error(t, err) + require.Nil(t, res) + }, + ) + }) +} + +func TestPostForm(t *testing.T) { + values := url.Values{ + "key": {"value1", "value2"}, + "foo": {"bar"}, + } + + t.Run("valid URL", func(t *testing.T) { + ctx := context.Background() + withMockDefaultClient( + func(req *http.Request) (*http.Response, error) { + assert.Equal(t, ctx, req.Context()) + assert.Equal(t, "POST", req.Method) + assert.Equal(t, validURL, req.URL.String()) + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("content-type")) + data, err := io.ReadAll(req.Body) + require.NoError(t, err) + assert.Equal(t, []byte(values.Encode()), data) + return &http.Response{StatusCode: 200}, nil + }, + func() { + res, err := PostForm(ctx, validURL, values) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + }, + ) + }) + + t.Run("invalid URL", func(t *testing.T) { + withMockDefaultClient( + func(*http.Request) (*http.Response, error) { + assert.Fail(t, "unexpected call to RoundTrip") + return nil, errors.New("unreachable") + }, + func() { + ctx := context.Background() + res, err := PostForm(ctx, invalidURL, values) + require.Error(t, err) + require.Nil(t, res) + }, + ) + }) +} + +func withMockDefaultClient(roundTrip func(*http.Request) (*http.Response, error), cb func()) { + backup := http.DefaultClient + defer func() { http.DefaultClient = backup }() + + http.DefaultClient = &http.Client{Transport: testTransport(roundTrip)} + cb() +} + +type testTransport func(*http.Request) (*http.Response, error) + +func (t testTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return t(req) +} diff --git a/contrib/net/http/orchestrion.client.yml b/contrib/net/http/orchestrion.client.yml index 169a4fa9f1..13c5daefc3 100644 --- a/contrib/net/http/orchestrion.client.yml +++ b/contrib/net/http/orchestrion.client.yml @@ -192,18 +192,18 @@ aspects: # Wire the context that is found to the handlers... - wrap-expression: imports: - instrument: github.com/DataDog/orchestrion/instrument/net/http + httptrace: gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http template: |- {{- $ctx := .Function.ArgumentOfType "context.Context" -}} {{- $req := .Function.ArgumentOfType "*net/http.Request" }} {{- if $ctx -}} - instrument.{{ .AST.Fun.Name }}( + httptrace.{{ .AST.Fun.Name }}( {{ $ctx }}, {{ range .AST.Args }}{{ . }}, {{ end }} ) {{- else if $req -}} - instrument.{{ .AST.Fun.Name }}( + httptrace.{{ .AST.Fun.Name }}( {{ $req }}.Context(), {{ range .AST.Args }}{{ . }}, {{ end }}