Skip to content

Commit

Permalink
Move with-context HTTP client short-hand functions to http contrib
Browse files Browse the repository at this point in the history
  • Loading branch information
RomainMuller committed Jan 13, 2025
1 parent f0b349b commit 5b36f94
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 3 deletions.
43 changes: 43 additions & 0 deletions contrib/net/http/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Unless explicitly stated otherwise all files in this repository are licensed
// under the Apache License Version 2.0.
// This product includes software developed at Datadog (https://www.datadoghq.com/).
// Copyright 2016 Datadog, Inc.

package http

import (
"context"
"io"
"net/http"
"net/url"
"strings"
)

func Get(ctx context.Context, url string) (resp *http.Response, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
return http.DefaultClient.Do(req)
}

func Head(ctx context.Context, url string) (resp *http.Response, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err != nil {
return nil, err
}
return http.DefaultClient.Do(req)
}

func Post(ctx context.Context, url string, contentType string, body io.Reader) (resp *http.Response, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", contentType)
return http.DefaultClient.Do(req)
}

func PostForm(ctx context.Context, url string, data url.Values) (resp *http.Response, err error) {
return Post(ctx, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
}
190 changes: 190 additions & 0 deletions contrib/net/http/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// Unless explicitly stated otherwise all files in this repository are licensed
// under the Apache License Version 2.0.
// This product includes software developed at Datadog (https://www.datadoghq.com/).
// Copyright 2016 Datadog, Inc.

package http

import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const (
validURL = "http://example.com"
invalidURL = "http:/\x00/invalid."
)

func TestGet(t *testing.T) {
ctx := context.Background()
t.Run("valid URL", func(t *testing.T) {
withMockDefaultClient(
func(req *http.Request) (*http.Response, error) {
assert.Equal(t, ctx, req.Context())
assert.Equal(t, "GET", req.Method)
assert.Equal(t, validURL, req.URL.String())
return &http.Response{StatusCode: 200}, nil
},
func() {
res, err := Get(ctx, validURL)
require.NoError(t, err)
require.Equal(t, 200, res.StatusCode)
},
)
})

t.Run("invalid URL", func(t *testing.T) {
withMockDefaultClient(
func(*http.Request) (*http.Response, error) {
assert.Fail(t, "unexpected call to RoundTrip")
return nil, errors.New("unreachable")
},
func() {
ctx := context.Background()
res, err := Get(ctx, invalidURL)
require.Error(t, err)
require.Nil(t, res)
},
)
})
}

func TestHead(t *testing.T) {
t.Run("valid URL", func(t *testing.T) {
ctx := context.Background()
withMockDefaultClient(
func(req *http.Request) (*http.Response, error) {
assert.Equal(t, ctx, req.Context())
assert.Equal(t, "HEAD", req.Method)
assert.Equal(t, validURL, req.URL.String())
return &http.Response{StatusCode: 200}, nil
},
func() {
res, err := Head(ctx, validURL)
require.NoError(t, err)
require.Equal(t, 200, res.StatusCode)
},
)
})

t.Run("invalid URL", func(t *testing.T) {
withMockDefaultClient(
func(*http.Request) (*http.Response, error) {
assert.Fail(t, "unexpected call to RoundTrip")
return nil, errors.New("unreachable")
},
func() {
ctx := context.Background()
res, err := Head(ctx, invalidURL)
require.Error(t, err)
require.Nil(t, res)
},
)
})
}

func TestPost(t *testing.T) {
const contentType = "text/plain"
body := []byte("hello")

t.Run("valid URL", func(t *testing.T) {
ctx := context.Background()
withMockDefaultClient(
func(req *http.Request) (*http.Response, error) {
assert.Equal(t, ctx, req.Context())
assert.Equal(t, "POST", req.Method)
assert.Equal(t, validURL, req.URL.String())
assert.Equal(t, contentType, req.Header.Get("content-type"))
data, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Equal(t, body, data)
return &http.Response{StatusCode: 200}, nil
},
func() {
res, err := Post(ctx, validURL, contentType, bytes.NewReader(body))
require.NoError(t, err)
require.Equal(t, 200, res.StatusCode)
},
)
})

t.Run("invalid URL", func(t *testing.T) {
withMockDefaultClient(
func(*http.Request) (*http.Response, error) {
assert.Fail(t, "unexpected call to RoundTrip")
return nil, errors.New("unreachable")
},
func() {
ctx := context.Background()
res, err := Post(ctx, invalidURL, contentType, bytes.NewReader(body))
require.Error(t, err)
require.Nil(t, res)
},
)
})
}

func TestPostForm(t *testing.T) {
values := url.Values{
"key": {"value1", "value2"},
"foo": {"bar"},
}

t.Run("valid URL", func(t *testing.T) {
ctx := context.Background()
withMockDefaultClient(
func(req *http.Request) (*http.Response, error) {
assert.Equal(t, ctx, req.Context())
assert.Equal(t, "POST", req.Method)
assert.Equal(t, validURL, req.URL.String())
assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("content-type"))
data, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Equal(t, []byte(values.Encode()), data)
return &http.Response{StatusCode: 200}, nil
},
func() {
res, err := PostForm(ctx, validURL, values)
require.NoError(t, err)
require.Equal(t, 200, res.StatusCode)
},
)
})

t.Run("invalid URL", func(t *testing.T) {
withMockDefaultClient(
func(*http.Request) (*http.Response, error) {
assert.Fail(t, "unexpected call to RoundTrip")
return nil, errors.New("unreachable")
},
func() {
ctx := context.Background()
res, err := PostForm(ctx, invalidURL, values)
require.Error(t, err)
require.Nil(t, res)
},
)
})
}

func withMockDefaultClient(roundTrip func(*http.Request) (*http.Response, error), cb func()) {
backup := http.DefaultClient
defer func() { http.DefaultClient = backup }()

http.DefaultClient = &http.Client{Transport: testTransport(roundTrip)}
cb()
}

type testTransport func(*http.Request) (*http.Response, error)

func (t testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return t(req)
}
6 changes: 3 additions & 3 deletions contrib/net/http/orchestrion.client.yml
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,18 @@ aspects:
# Wire the context that is found to the handlers...
- wrap-expression:
imports:
instrument: github.com/DataDog/orchestrion/instrument/net/http
httptrace: gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http
template: |-
{{- $ctx := .Function.ArgumentOfType "context.Context" -}}
{{- $req := .Function.ArgumentOfType "*net/http.Request" }}
{{- if $ctx -}}
instrument.{{ .AST.Fun.Name }}(
httptrace.{{ .AST.Fun.Name }}(
{{ $ctx }},
{{ range .AST.Args }}{{ . }},
{{ end }}
)
{{- else if $req -}}
instrument.{{ .AST.Fun.Name }}(
httptrace.{{ .AST.Fun.Name }}(
{{ $req }}.Context(),
{{ range .AST.Args }}{{ . }},
{{ end }}
Expand Down

0 comments on commit 5b36f94

Please sign in to comment.