diff --git a/components/providers/base/ammo.go b/components/providers/base/ammo.go index 5293cf168..a417a8d68 100644 --- a/components/providers/base/ammo.go +++ b/components/providers/base/ammo.go @@ -1,110 +1,47 @@ package base -import ( - "bytes" - "errors" - "fmt" - "io" - "net/http" - urlpkg "net/url" +import "github.com/yandex/pandora/core/aggregator/netsample" - "github.com/yandex/pandora/components/providers/http/util" - "github.com/yandex/pandora/core/aggregator/netsample" - "github.com/yandex/pandora/lib/netutil" -) - -func NewAmmo(method string, url string, body []byte, header http.Header, tag string) (*Ammo, error) { - if ok := netutil.ValidHTTPMethod(method); !ok { - return nil, errors.New("invalid HTTP method " + method) - } - if _, err := urlpkg.Parse(url); err != nil { - return nil, fmt.Errorf("invalid URL %s; err %w ", url, err) - } - return &Ammo{ - method: method, - body: body, - url: url, - tag: tag, - header: header, - constructor: true, - }, nil -} - -type Ammo struct { - Req *http.Request - method string - body []byte - url string - tag string - header http.Header - id uint64 - isInvalid bool - constructor bool +type Ammo[R any] struct { + Req *R + tag string + id uint64 + isInvalid bool } -func (a *Ammo) Request() (*http.Request, *netsample.Sample) { - if a.Req == nil { - _ = a.BuildRequest() // TODO: what if error. There isn't a logger - } +func (a *Ammo[R]) Request() (*R, *netsample.Sample) { sample := netsample.Acquire(a.Tag()) sample.SetID(a.ID()) return a.Req, sample } -func (a *Ammo) SetID(id uint64) { +func (a *Ammo[R]) Reset(req *R, tag string) { + a.Req = req + a.tag = tag + a.id = 0 + a.isInvalid = false +} + +func (a *Ammo[_]) SetID(id uint64) { a.id = id } -func (a *Ammo) ID() uint64 { +func (a *Ammo[_]) ID() uint64 { return a.id } -func (a *Ammo) Invalidate() { +func (a *Ammo[_]) Invalidate() { a.isInvalid = true } -func (a *Ammo) IsInvalid() bool { +func (a *Ammo[_]) IsInvalid() bool { return a.isInvalid } -func (a *Ammo) IsValid() bool { +func (a *Ammo[_]) IsValid() bool { return !a.isInvalid } -func (a *Ammo) SetTag(tag string) { - a.tag = tag -} - -func (a *Ammo) Tag() string { +func (a *Ammo[_]) Tag() string { return a.tag } - -func (a *Ammo) FromConstructor() bool { - return a.constructor -} - -// use NewAmmo() for skipping error here -func (a *Ammo) BuildRequest() error { - var buff io.Reader - if a.body != nil { - buff = bytes.NewReader(a.body) - } - req, err := http.NewRequest(a.method, a.url, buff) - if err != nil { - return fmt.Errorf("cant create request: %w", err) - } - a.Req = req - util.EnrichRequestWithHeaders(req, a.header) - return nil -} - -func (a *Ammo) Reset() { - a.Req = nil - a.method = "" - a.body = nil - a.url = "" - a.tag = "" - a.header = nil - a.id = 0 - a.isInvalid = false -} diff --git a/components/providers/base/decoder.go b/components/providers/base/decoder.go new file mode 100644 index 000000000..1e71484ba --- /dev/null +++ b/components/providers/base/decoder.go @@ -0,0 +1,17 @@ +package base + +import "sync" + +type Decoder[R any] struct { + Sink chan<- *Ammo[R] + Pool *sync.Pool +} + +func NewDecoder[R any](sink chan<- *Ammo[R]) Decoder[R] { + return Decoder[R]{ + Sink: sink, + Pool: &sync.Pool{New: func() any { + return new(Ammo[R]) + }}, + } +} diff --git a/components/providers/base/provider.go b/components/providers/base/provider.go index 948e88090..5f7096146 100644 --- a/components/providers/base/provider.go +++ b/components/providers/base/provider.go @@ -8,7 +8,7 @@ import ( ) type ProviderBase struct { - Deps core.ProviderDeps + core.ProviderDeps FS afero.Fs idCounter atomic.Uint64 } diff --git a/components/providers/http/ammo.go b/components/providers/http/ammo.go index 780be9ee2..1ef3da677 100644 --- a/components/providers/http/ammo.go +++ b/components/providers/http/ammo.go @@ -16,4 +16,4 @@ type Request interface { http.Request } -var _ phttp.Ammo = (*base.Ammo)(nil) +var _ phttp.Ammo = (*base.Ammo[http.Request])(nil) diff --git a/components/providers/http/decoders/decoder.go b/components/providers/http/decoders/decoder.go index a2e20aa29..17975b1d3 100644 --- a/components/providers/http/decoders/decoder.go +++ b/components/providers/http/decoders/decoder.go @@ -6,7 +6,6 @@ import ( "io" "net/http" - "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" "github.com/yandex/pandora/components/providers/http/util" ) @@ -25,7 +24,7 @@ var ( type Decoder interface { // Decode(context.Context, chan<- *base.Ammo[http.Request], io.ReadSeeker) error - Scan(context.Context) (*base.Ammo, error) + Scan(context.Context) (*http.Request, string, error) } type protoDecoder struct { diff --git a/components/providers/http/decoders/jsonline.go b/components/providers/http/decoders/jsonline.go index 939029888..0d972c5e4 100644 --- a/components/providers/http/decoders/jsonline.go +++ b/components/providers/http/decoders/jsonline.go @@ -7,9 +7,9 @@ import ( "net/http" "strings" - "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" "github.com/yandex/pandora/components/providers/http/decoders/jsonline" + "github.com/yandex/pandora/components/providers/http/util" "golang.org/x/xerrors" ) @@ -35,51 +35,53 @@ type jsonlineDecoder struct { line uint } -func (d *jsonlineDecoder) Scan(ctx context.Context) (*base.Ammo, error) { +func (d *jsonlineDecoder) Scan(ctx context.Context) (*http.Request, string, error) { if d.config.Limit != 0 && d.ammoNum >= d.config.Limit { - return nil, ErrAmmoLimit + return nil, "", ErrAmmoLimit } - for { - if d.config.Passes != 0 && d.passNum >= d.config.Passes { - return nil, ErrPassLimit + for ; ; d.line++ { + if ctx.Err() != nil { + return nil, "", ctx.Err() } - for d.scanner.Scan() { - d.line++ - data := d.scanner.Bytes() - if len(strings.TrimSpace(string(data))) == 0 { - continue - } - d.ammoNum++ - ammo, err := jsonline.DecodeAmmo(data, d.decodedConfigHeaders) - if err != nil { - if !d.config.ContinueOnError { - return nil, xerrors.Errorf("failed to decode ammo at line: %v; data: %q, with err: %w", d.line+1, data, err) + if !d.scanner.Scan() { + if d.scanner.Err() == nil { // assume as io.EOF; FIXME: check possible nil error with other reason + d.line = 0 + d.passNum++ + if d.config.Passes != 0 && d.passNum >= d.config.Passes { + return nil, "", ErrPassLimit + } + if d.ammoNum == 0 { + return nil, "", ErrNoAmmo } - // TODO: add log message about error - continue // skipping ammo + _, err := d.file.Seek(0, io.SeekStart) + if err != nil { + return nil, "", err + } + d.scanner = bufio.NewScanner(d.file) + if d.config.MaxAmmoSize != 0 { + var buffer []byte + d.scanner.Buffer(buffer, d.config.MaxAmmoSize) + } + continue } - return ammo, err + return nil, "", d.scanner.Err() } - - err := d.scanner.Err() - if err != nil { - return nil, err - } - if d.ammoNum == 0 { - return nil, ErrNoAmmo + data := d.scanner.Bytes() + if len(strings.TrimSpace(string(data))) == 0 { + continue } - d.line = 0 - d.passNum++ + d.ammoNum++ - _, err = d.file.Seek(0, io.SeekStart) + req, tag, err := jsonline.DecodeAmmo(data) if err != nil { - return nil, err - } - d.scanner = bufio.NewScanner(d.file) - if d.config.MaxAmmoSize != 0 { - var buffer []byte - d.scanner.Buffer(buffer, d.config.MaxAmmoSize) + if !d.config.ContinueOnError { + return nil, "", xerrors.Errorf("failed to decode ammo at line: %v; data: %q, with err: %w", d.line+1, data, err) + } + // TODO: add log message about error + continue // skipping ammo } + util.EnrichRequestWithHeaders(req, d.decodedConfigHeaders) + return req, tag, err } } diff --git a/components/providers/http/decoders/jsonline/data.go b/components/providers/http/decoders/jsonline/data.go index 03df32d4e..2fb4f5594 100644 --- a/components/providers/http/decoders/jsonline/data.go +++ b/components/providers/http/decoders/jsonline/data.go @@ -1,12 +1,12 @@ -//go:generate github.com/pquerna/ffjson@latest data_ffjson.go +//go:generate github.com/pquerna/ffjson data_ffjson.go package jsonline import ( "net/http" + "strings" "github.com/pkg/errors" - "github.com/yandex/pandora/components/providers/base" ) // ffjson: noencoder @@ -24,20 +24,24 @@ type data struct { Body string `json:"body"` } -func DecodeAmmo(jsonDoc []byte, headers http.Header) (*base.Ammo, error) { - var d = new(data) - if err := d.UnmarshalJSON(jsonDoc); err != nil { - err = errors.WithStack(err) - return nil, err +func (d *data) ToRequest() (*http.Request, error) { + uri := "http://" + d.Host + d.URI + req, err := http.NewRequest(d.Method, uri, strings.NewReader(d.Body)) + if err != nil { + return nil, errors.WithStack(err) } - for k, v := range d.Headers { - headers.Set(k, v) + req.Header.Set(k, v) } - url := "http://" + d.Host + d.URI - var body []byte - if d.Body != "" { - body = []byte(d.Body) + return req, err +} + +func DecodeAmmo(jsonDoc []byte) (*http.Request, string, error) { + var data = new(data) + if err := data.UnmarshalJSON(jsonDoc); err != nil { + err = errors.WithStack(err) + return nil, data.Tag, err } - return base.NewAmmo(d.Method, url, body, headers, d.Tag) + req, err := data.ToRequest() + return req, data.Tag, err } diff --git a/components/providers/http/decoders/jsonline/data_test.go b/components/providers/http/decoders/jsonline/data_test.go index 585cdc48e..60ecd564f 100644 --- a/components/providers/http/decoders/jsonline/data_test.go +++ b/components/providers/http/decoders/jsonline/data_test.go @@ -1,60 +1,117 @@ package jsonline import ( + "context" + "encoding/json" "net/http" + "net/url" "testing" + "testing/iotest" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yandex/pandora/components/providers/base" + "github.com/yandex/pandora/lib/testutil" ) +const testFile = "./ammo.jsonline" + +// testData holds jsonline.data that contains in testFile +var testData = []data{ + { + Host: "example.com", + Method: "GET", + URI: "/00", + Headers: map[string]string{"Accept": "*/*", "Accept-Encoding": "gzip, deflate", "User-Agent": "Pandora/0.0.1"}, + }, + { + Host: "ya.ru", + Method: "HEAD", + URI: "/01", + Headers: map[string]string{"Accept": "*/*", "Accept-Encoding": "gzip, brotli", "User-Agent": "YaBro/0.1"}, + Tag: "head", + }, +} + +type ToRequestWant struct { + req *http.Request + error + body []byte +} + func TestToRequest(t *testing.T) { var tests = []struct { - name string - json []byte - confHeader http.Header - want *base.Ammo - wantErr bool + name string + input data + want ToRequestWant }{ { - name: "GET request", - json: []byte(`{"host": "ya.ru", "method": "GET", "uri": "/00", "tag": "tag", "headers": {"A": "a", "B": "b"}}`), - confHeader: http.Header{"Default": []string{"def"}}, - want: MustNewAmmo(t, "GET", "http://ya.ru/00", nil, http.Header{"Default": []string{"def"}, "A": []string{"a"}, "B": []string{"b"}}, "tag"), - wantErr: false, - }, - { - name: "POST request", - json: []byte(`{"host": "ya.ru", "method": "POST", "uri": "/01?sleep=10", "tag": "tag", "headers": {"A": "a", "B": "b"}, "body": "body"}`), - confHeader: http.Header{"Default": []string{"def"}}, - want: MustNewAmmo(t, "POST", "http://ya.ru/01?sleep=10", []byte(`body`), http.Header{"Default": []string{"def"}, "A": []string{"a"}, "B": []string{"b"}}, "tag"), - wantErr: false, - }, - { - name: "POST request with json", - json: []byte(`{"host": "ya.ru", "method": "POST", "uri": "/01?sleep=10", "tag": "tag", "headers": {"A": "a", "B": "b"}, "body": "{\"field\":\"value\"}"}`), - confHeader: http.Header{"Default": []string{"def"}}, - want: MustNewAmmo(t, "POST", "http://ya.ru/01?sleep=10", []byte(`{"field":"value"}`), http.Header{"Default": []string{"def"}, "A": []string{"a"}, "B": []string{"b"}}, "tag"), - wantErr: false, + name: "decoded well", + input: data{ + Host: "ya.ru", + Method: "GET", + URI: "/00", + Headers: map[string]string{"A": "a", "B": "b"}, + Tag: "tag", + }, + want: ToRequestWant{ + req: &http.Request{ + Method: "GET", + URL: testutil.Must(url.Parse("http://ya.ru/00")), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{"A": []string{"a"}, "B": []string{"b"}}, + Body: http.NoBody, + Host: "ya.ru", + }, + error: nil, + }, }, } + var ans ToRequestWant for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert := assert.New(t) - ans, err := DecodeAmmo(tt.json, tt.confHeader) - if tt.wantErr { - require.Error(t, err) - return + ans.req, ans.error = tt.input.ToRequest() + if tt.want.body != nil { + assert.NotNil(ans.req) + assert.NoError(iotest.TestReader(ans.req.Body, tt.want.body)) + ans.req.Body = nil + tt.want.body = nil } - assert.NoError(err) + ans.req.GetBody = nil + tt.want.req = ans.req.WithContext(context.Background()) assert.Equal(tt.want, ans) }) } } -func MustNewAmmo(t *testing.T, method string, url string, body []byte, header http.Header, tag string) *base.Ammo { - ammo, err := base.NewAmmo(method, url, body, header, tag) - require.NoError(t, err) - return ammo +type DecodeAmmoWant struct { + req *http.Request + tag string + error +} + +func TestDecodeAmmo(t *testing.T) { + var tests = []struct { + name string + input []byte + want DecodeAmmoWant + }{} + var ans DecodeAmmoWant + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + ans.req, ans.tag, ans.error = DecodeAmmo(tt.input) + assert.Equal(tt.want, ans) + }) + } +} + +func BenchmarkDecodeAmmo(b *testing.B) { + jsonDoc, err := json.Marshal(testData[0]) + assert.NoError(b, err) + b.ResetTimer() + for n := 0; n < b.N; n++ { + _, _, _ = DecodeAmmo(jsonDoc) + } } diff --git a/components/providers/http/decoders/jsonline_test.go b/components/providers/http/decoders/jsonline_test.go index 947e0b4ed..fe7ae755c 100644 --- a/components/providers/http/decoders/jsonline_test.go +++ b/components/providers/http/decoders/jsonline_test.go @@ -3,25 +3,18 @@ package decoders import ( "context" "net/http" + "net/http/httputil" "strings" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" ) func Test_jsonlineDecoder_Scan(t *testing.T) { - var mustNewAmmo = func(t *testing.T, method string, url string, body []byte, header http.Header, tag string) *base.Ammo { - ammo, err := base.NewAmmo(method, url, body, header, tag) - require.NoError(t, err) - return ammo - } - - input := `{"host": "ya.net", "method": "GET", "uri": "/?sleep=100", "tag": "sleep1", "headers": {"User-agent": "Tank", "Connection": "close"}} -{"host": "ya.net", "method": "POST", "uri": "/?sleep=200", "tag": "sleep2", "headers": {"User-agent": "Tank", "Connection": "close"}, "body": "body_data"} + input := `{"host": "4bs65mu2kdulxmir.myt.yp-c.yandex.net", "method": "GET", "uri": "/?sleep=100", "tag": "sleep1", "headers": {"User-agent": "Tank", "Connection": "close"}} +{"host": "4bs65mu2kdulxmir.myt.yp-c.yandex.net", "method": "POST", "uri": "/?sleep=200", "tag": "sleep2", "headers": {"User-agent": "Tank", "Connection": "close"}, "body": "body_data"} ` @@ -33,32 +26,41 @@ func Test_jsonlineDecoder_Scan(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - wants := []*base.Ammo{ - mustNewAmmo(t, - "GET", - "http://ya.net/?sleep=100", - nil, - http.Header{"Connection": []string{"close"}, "Content-Type": []string{"application/json"}, "User-Agent": []string{"Tank"}}, - "sleep1", - ), - mustNewAmmo(t, - "POST", - "http://ya.net/?sleep=200", - []byte("body_data"), - http.Header{"Connection": []string{"close"}, "Content-Type": []string{"application/json"}, "User-Agent": []string{"Tank"}}, - "sleep2", - ), + tests := []struct { + wantTag string + wantErr bool + wantBody string + }{ + { + wantTag: "sleep1", + wantErr: false, + wantBody: "GET /?sleep=100 HTTP/1.1\r\nHost: 4bs65mu2kdulxmir.myt.yp-c.yandex.net\r\nConnection: close\r\nContent-Type: application/json\r\nUser-Agent: Tank\r\n\r\n", + }, + { + wantTag: "sleep2", + wantErr: false, + wantBody: "POST /?sleep=200 HTTP/1.1\r\nHost: 4bs65mu2kdulxmir.myt.yp-c.yandex.net\r\nConnection: close\r\nContent-Type: application/json\r\nUser-Agent: Tank\r\n\r\nbody_data", + }, } for j := 0; j < 2; j++ { - for i, want := range wants { - ammo, err := decoder.Scan(ctx) - assert.NoError(t, err, "iteration %d-%d", j, i) - assert.Equal(t, want, ammo, "iteration %d-%d", j, i) + for i, tt := range tests { + req, tag, err := decoder.Scan(ctx) + if tt.wantErr { + assert.Error(t, err, "iteration %d-%d", j, i) + continue + } else { + assert.NoError(t, err, "iteration %d-%d", j, i) + } + assert.Equal(t, tt.wantTag, tag, "iteration %d-%d", j, i) + + req.Close = false + body, _ := httputil.DumpRequest(req, true) + assert.Equal(t, tt.wantBody, string(body), "iteration %d-%d", j, i) } } - _, err := decoder.Scan(ctx) + _, _, err := decoder.Scan(ctx) assert.Equal(t, err, ErrAmmoLimit) - assert.Equal(t, decoder.ammoNum, uint(len(wants)*2)) + assert.Equal(t, decoder.ammoNum, uint(len(tests)*2)) assert.Equal(t, decoder.passNum, uint(1)) } diff --git a/components/providers/http/decoders/raw.go b/components/providers/http/decoders/raw.go index fd7ab470a..0ace2dafb 100644 --- a/components/providers/http/decoders/raw.go +++ b/components/providers/http/decoders/raw.go @@ -7,7 +7,6 @@ import ( "net/http" "strings" - "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" "github.com/yandex/pandora/components/providers/http/decoders/raw" "github.com/yandex/pandora/components/providers/http/util" @@ -59,38 +58,38 @@ type rawDecoder struct { reader *bufio.Reader } -func (d *rawDecoder) Scan(ctx context.Context) (*base.Ammo, error) { +func (d *rawDecoder) Scan(ctx context.Context) (*http.Request, string, error) { var data string var buff []byte var req *http.Request var err error if d.config.Limit != 0 && d.ammoNum >= d.config.Limit { - return nil, ErrAmmoLimit + return nil, "", ErrAmmoLimit } for { if ctx.Err() != nil { - return nil, ctx.Err() + return nil, "", ctx.Err() } data, err = d.reader.ReadString('\n') if err == io.EOF { d.passNum++ if d.config.Passes != 0 && d.passNum >= d.config.Passes { - return nil, ErrPassLimit + return nil, "", ErrPassLimit } if d.ammoNum == 0 { - return nil, ErrNoAmmo + return nil, "", ErrNoAmmo } _, err := d.file.Seek(0, io.SeekStart) if err != nil { - return nil, err + return nil, "", err } d.reader.Reset(d.file) continue } if err != nil { - return nil, xerrors.Errorf("reading ammo failed with err: %w, at position: %v", err, filePosition(d.file)) + return nil, "", xerrors.Errorf("reading ammo failed with err: %w, at position: %v", err, filePosition(d.file)) } data = strings.TrimSpace(data) if len(data) == 0 { @@ -99,7 +98,7 @@ func (d *rawDecoder) Scan(ctx context.Context) (*base.Ammo, error) { d.ammoNum++ reqSize, tag, err := raw.DecodeHeader(data) if err != nil { - return nil, xerrors.Errorf("header decoding error: %w", err) + return nil, "", xerrors.Errorf("header decoding error: %w", err) } if reqSize != 0 { @@ -108,11 +107,11 @@ func (d *rawDecoder) Scan(ctx context.Context) (*base.Ammo, error) { } buff = buff[:reqSize] if n, err := io.ReadFull(d.reader, buff); err != nil { - return nil, xerrors.Errorf("failed to read ammo with err: %w, at position: %v; tried to read: %v; have read: %v", err, filePosition(d.file), reqSize, n) + return nil, "", xerrors.Errorf("failed to read ammo with err: %w, at position: %v; tried to read: %v; have read: %v", err, filePosition(d.file), reqSize, n) } req, err = raw.DecodeRequest(buff) if err != nil { - return nil, xerrors.Errorf("failed to decode ammo with err: %w, at position: %v; data: %q", err, filePosition(d.file), buff) + return nil, "", xerrors.Errorf("failed to decode ammo with err: %w, at position: %v; data: %q", err, filePosition(d.file), buff) } } else { req, _ = http.NewRequest("", "/", nil) @@ -120,8 +119,6 @@ func (d *rawDecoder) Scan(ctx context.Context) (*base.Ammo, error) { // add new Headers to request from config util.EnrichRequestWithHeaders(req, d.decodedConfigHeaders) - ammo := &base.Ammo{Req: req} - ammo.SetTag(tag) - return ammo, nil + return req, tag, nil } } diff --git a/components/providers/http/decoders/raw/decoder_bench_test.go b/components/providers/http/decoders/raw/decoder_bench_test.go index 2f4c655dc..0bb2dc09b 100644 --- a/components/providers/http/decoders/raw/decoder_bench_test.go +++ b/components/providers/http/decoders/raw/decoder_bench_test.go @@ -3,8 +3,8 @@ package raw import ( "testing" - "github.com/stretchr/testify/require" "github.com/yandex/pandora/components/providers/http/util" + "github.com/yandex/pandora/lib/testutil" ) var ( @@ -30,8 +30,7 @@ func BenchmarkRawDecoder(b *testing.B) { } func BenchmarkRawDecoderWithHeaders(b *testing.B) { - decodedHTTPConfigHeaders, err := util.DecodeHTTPConfigHeaders(benchTestConfigHeaders) - require.NoError(b, err) + decodedHTTPConfigHeaders := testutil.Must(util.DecodeHTTPConfigHeaders(benchTestConfigHeaders)) b.ResetTimer() for i := 0; i < b.N; i++ { req, _ := DecodeRequest(benchTestRequest) diff --git a/components/providers/http/decoders/raw/decoder_test.go b/components/providers/http/decoders/raw/decoder_test.go index 3ae1eed01..50d2cb3bf 100644 --- a/components/providers/http/decoders/raw/decoder_test.go +++ b/components/providers/http/decoders/raw/decoder_test.go @@ -8,7 +8,7 @@ import ( "testing/iotest" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/yandex/pandora/lib/testutil" ) type DecoderHeaderWant struct { @@ -64,7 +64,7 @@ func TestDecodeRequest(t *testing.T) { want: DecoderRequestWant{ &http.Request{ Method: "GET", - URL: MustURL(t, "/some/path"), + URL: testutil.Must(url.Parse("/some/path")), Proto: "HTTP/1.0", ProtoMajor: 1, ProtoMinor: 0, @@ -89,7 +89,7 @@ func TestDecodeRequest(t *testing.T) { want: DecoderRequestWant{ &http.Request{ Method: "POST", - URL: MustURL(t, "/some/path"), + URL: testutil.Must(url.Parse("/some/path")), Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, @@ -125,7 +125,7 @@ func TestDecodeRequest(t *testing.T) { want: DecoderRequestWant{ &http.Request{ Method: "GET", - URL: MustURL(t, "/etc/passwd"), + URL: testutil.Must(url.Parse("/etc/passwd")), Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, @@ -151,9 +151,3 @@ func TestDecodeRequest(t *testing.T) { }) } } - -func MustURL(t *testing.T, rawURL string) *url.URL { - url, err := url.Parse(rawURL) - require.NoError(t, err) - return url -} diff --git a/components/providers/http/decoders/raw_test.go b/components/providers/http/decoders/raw_test.go index 3dbb1f89f..c368f9f14 100644 --- a/components/providers/http/decoders/raw_test.go +++ b/components/providers/http/decoders/raw_test.go @@ -13,9 +13,9 @@ import ( ) func Test_rawDecoder_Scan(t *testing.T) { - input := `38 good50 + input := `68 good50 GET /?sleep=50 HTTP/1.0 -Host: ya.net +Host: 4bs65mu2kdulxmir.myt.yp-c.yandex.net 74 bad @@ -58,7 +58,7 @@ User-Agent: xxx (shell 1) { wantTag: "good50", wantErr: false, - wantBody: "GET /?sleep=50 HTTP/1.0\r\nHost: ya.net\r\nContent-Type: application/json\r\n\r\n", + wantBody: "GET /?sleep=50 HTTP/1.0\r\nHost: 4bs65mu2kdulxmir.myt.yp-c.yandex.net\r\nContent-Type: application/json\r\n\r\n", }, { wantTag: "bad", @@ -78,23 +78,22 @@ User-Agent: xxx (shell 1) } for j := 0; j < 2; j++ { for i, tt := range tests { - ammo, err := decoder.Scan(ctx) + req, tag, err := decoder.Scan(ctx) if tt.wantErr { assert.Error(t, err, "iteration %d-%d", j, i) continue } else { assert.NoError(t, err, "iteration %d-%d", j, i) } - assert.Equal(t, tt.wantTag, ammo.Tag(), "iteration %d-%d", j, i) + assert.Equal(t, tt.wantTag, tag, "iteration %d-%d", j, i) - req := ammo.Req req.Close = false body, _ := httputil.DumpRequest(req, true) assert.Equal(t, tt.wantBody, string(body), "iteration %d-%d", j, i) } } - _, err := decoder.Scan(ctx) + _, _, err := decoder.Scan(ctx) assert.Equal(t, err, ErrAmmoLimit) assert.Equal(t, decoder.ammoNum, uint(len(tests)*2)) assert.Equal(t, decoder.passNum, uint(1)) diff --git a/components/providers/http/decoders/uri.go b/components/providers/http/decoders/uri.go index 3706ba2b6..8fc4278e6 100644 --- a/components/providers/http/decoders/uri.go +++ b/components/providers/http/decoders/uri.go @@ -6,10 +6,8 @@ import ( "fmt" "io" "net/http" - "net/url" "strings" - "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" "github.com/yandex/pandora/components/providers/http/util" ) @@ -33,72 +31,76 @@ type uriDecoder struct { line uint } -func (d *uriDecoder) readLine(data string, commonHeader http.Header) (*base.Ammo, error) { +func (d *uriDecoder) readLine(data string, commonHeader http.Header) (*http.Request, string, error) { data = strings.TrimSpace(data) if len(data) == 0 { - return nil, nil // skip empty line + return nil, "", nil // skip empty line } + var req *http.Request + var tag string + var err error if data[0] == '[' { key, val, err := util.DecodeHeader(data) if err != nil { err = fmt.Errorf("decoding header error: %w", err) - return nil, err + return nil, "", err } commonHeader.Set(key, val) - return nil, nil - } - - var rawURL string - rawURL, tag, _ := strings.Cut(data, " ") - _, err := url.Parse(rawURL) - if err != nil { - return nil, err - } - header := commonHeader.Clone() - for k, vv := range d.decodedConfigHeaders { - for _, v := range vv { - header.Set(k, v) + } else { + var rawURL string + rawURL, tag, _ = strings.Cut(data, " ") + req, err = http.NewRequest("GET", rawURL, nil) + if err != nil { + err = fmt.Errorf("failed to decode uri: %w", err) + return nil, "", err + } + if host, ok := commonHeader["Host"]; ok { + req.Host = host[0] } + req.Header = commonHeader.Clone() + + // add new Headers to request from config + util.EnrichRequestWithHeaders(req, d.decodedConfigHeaders) } - return base.NewAmmo("GET", rawURL, nil, header, tag) + return req, tag, nil } -func (d *uriDecoder) Scan(ctx context.Context) (*base.Ammo, error) { +func (d *uriDecoder) Scan(ctx context.Context) (*http.Request, string, error) { if d.config.Limit != 0 && d.ammoNum >= d.config.Limit { - return nil, ErrAmmoLimit + return nil, "", ErrAmmoLimit } for ; ; d.line++ { if ctx.Err() != nil { - return nil, ctx.Err() + return nil, "", ctx.Err() } if !d.scanner.Scan() { if d.scanner.Err() == nil { // assume as io.EOF; FIXME: check possible nil error with other reason d.line = 0 d.passNum++ if d.config.Passes != 0 && d.passNum >= d.config.Passes { - return nil, ErrPassLimit + return nil, "", ErrPassLimit } if d.ammoNum == 0 { - return nil, ErrNoAmmo + return nil, "", ErrNoAmmo } d.Header = http.Header{} _, err := d.file.Seek(0, io.SeekStart) if err != nil { - return nil, err + return nil, "", err } d.scanner = bufio.NewScanner(d.file) continue } - return nil, d.scanner.Err() + return nil, "", d.scanner.Err() } data := d.scanner.Text() - ammo, err := d.readLine(data, d.Header) + req, tag, err := d.readLine(data, d.Header) if err != nil { - return nil, fmt.Errorf("decode at line %d `%s` error: %w", d.line+1, data, err) + return nil, "", fmt.Errorf("decode at line %d `%s` error: %w", d.line+1, data, err) } - if ammo != nil { + if req != nil { d.ammoNum++ - return ammo, nil + return req, tag, nil } } } diff --git a/components/providers/http/decoders/uri_test.go b/components/providers/http/decoders/uri_test.go index e32011ea4..2b4dcdfb1 100644 --- a/components/providers/http/decoders/uri_test.go +++ b/components/providers/http/decoders/uri_test.go @@ -3,35 +3,31 @@ package decoders import ( "context" "net/http" + "net/http/httputil" + "net/url" "strings" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" ) func Test_uriDecoder_readLine(t *testing.T) { - var mustNewAmmo = func(t *testing.T, method string, url string, body []byte, header http.Header, tag string) *base.Ammo { - ammo, err := base.NewAmmo(method, url, body, header, tag) - require.NoError(t, err) - return ammo - } - tests := []struct { name string data string - want *base.Ammo - wantErr bool + expectedReq *http.Request + expectedTag string + expectedErr bool expectedCommonHeaders http.Header }{ { - name: "Header line", - data: "[Content-Type: application/json]", - want: nil, - wantErr: false, + name: "Header line", + data: "[Content-Type: application/json]", + expectedReq: nil, + expectedTag: "", + expectedErr: false, expectedCommonHeaders: http.Header{ "Content-Type": []string{"application/json"}, "User-Agent": []string{"TestAgent"}, @@ -40,11 +36,20 @@ func Test_uriDecoder_readLine(t *testing.T) { { name: "Valid URI", data: "http://example.com/test", - want: mustNewAmmo(t, "GET", "http://example.com/test", nil, http.Header{ - "User-Agent": []string{"TestAgent"}, - "Authorization": []string{"Bearer xxx"}, - }, ""), - wantErr: false, + expectedReq: &http.Request{ + Method: "GET", + Proto: "HTTP/1.1", + URL: &url.URL{Scheme: "http", Host: "example.com", Path: "/test"}, + Header: http.Header{ + "User-Agent": []string{"TestAgent"}, + "Authorization": []string{"Bearer xxx"}, + }, + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + }, + expectedTag: "", + expectedErr: false, expectedCommonHeaders: http.Header{ "User-Agent": []string{"TestAgent"}, }, @@ -52,20 +57,30 @@ func Test_uriDecoder_readLine(t *testing.T) { { name: "URI with tag", data: "http://example.com/test tag\n", - want: mustNewAmmo(t, "GET", "http://example.com/test", nil, http.Header{ - "User-Agent": []string{"TestAgent"}, - "Authorization": []string{"Bearer xxx"}, - }, "tag"), - wantErr: false, + expectedReq: &http.Request{ + Method: "GET", + Proto: "HTTP/1.1", + URL: &url.URL{Scheme: "http", Host: "example.com", Path: "/test"}, + Header: http.Header{ + "User-Agent": []string{"TestAgent"}, + "Authorization": []string{"Bearer xxx"}, + }, + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + }, + expectedTag: "tag", + expectedErr: false, expectedCommonHeaders: http.Header{ "User-Agent": []string{"TestAgent"}, }, }, { - name: "Invalid data", - data: "1http://foo.com tag", - want: nil, - wantErr: true, + name: "Invalid data", + data: "1http://foo.com tag", + expectedReq: nil, + expectedTag: "", + expectedErr: true, }, } for _, test := range tests { @@ -74,14 +89,19 @@ func Test_uriDecoder_readLine(t *testing.T) { decodedConfigHeaders := http.Header{"Authorization": []string{"Bearer xxx"}} decoder := newURIDecoder(nil, config.Config{}, decodedConfigHeaders) - ammo, err := decoder.readLine(test.data, commonHeader) + req, tag, err := decoder.readLine(test.data, commonHeader) + + if test.expectedReq != nil { + test.expectedReq = test.expectedReq.WithContext(context.Background()) + } - if test.wantErr { + if test.expectedErr { assert.Error(t, err) return } - assert.Equal(t, test.want, ammo) + assert.Equal(t, test.expectedTag, tag) assert.Equal(t, test.expectedCommonHeaders, commonHeader) + assert.Equal(t, test.expectedReq, req) }) } } @@ -99,12 +119,6 @@ const uriInput = ` /0 /4 some tag` func Test_uriDecoder_Scan(t *testing.T) { - var mustNewAmmo = func(t *testing.T, method string, url string, body []byte, header http.Header, tag string) *base.Ammo { - ammo, err := base.NewAmmo(method, url, body, header, tag) - require.NoError(t, err) - return ammo - } - decoder := newURIDecoder(strings.NewReader(uriInput), config.Config{ Limit: 10, }, http.Header{"Content-Type": []string{"application/json"}}) @@ -112,23 +126,57 @@ func Test_uriDecoder_Scan(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - wants := []*base.Ammo{ - mustNewAmmo(t, "GET", "/0", nil, http.Header{"Content-Type": []string{"application/json"}}, ""), - mustNewAmmo(t, "GET", "/1", nil, http.Header{"A": []string{"b"}, "Content-Type": []string{"application/json"}}, ""), - mustNewAmmo(t, "GET", "/2", nil, http.Header{"Host": []string{"example.com"}, "A": []string{"b"}, "C": []string{"d"}, "Content-Type": []string{"application/json"}}, ""), - mustNewAmmo(t, "GET", "/3", nil, http.Header{"Host": []string{"other.net"}, "A": []string{""}, "C": []string{"d"}, "Content-Type": []string{"application/json"}}, ""), - mustNewAmmo(t, "GET", "/4", nil, http.Header{"Host": []string{"other.net"}, "A": []string{""}, "C": []string{"d"}, "Content-Type": []string{"application/json"}}, "some tag"), + tests := []struct { + wantTag string + wantErr bool + wantBody string + }{ + { + wantTag: "", + wantErr: false, + wantBody: "GET /0 HTTP/1.1\r\nContent-Type: application/json\r\n\r\n", + }, + { + wantTag: "", + wantErr: false, + wantBody: "GET /1 HTTP/1.1\r\nA: b\r\nContent-Type: application/json\r\n\r\n", + }, + { + wantTag: "", + wantErr: false, + wantBody: "GET /2 HTTP/1.1\r\nHost: example.com\r\nA: b\r\nC: d\r\nContent-Type: application/json\r\n\r\n", + }, + { + wantTag: "", + wantErr: false, + wantBody: "GET /3 HTTP/1.1\r\nHost: other.net\r\nA: \r\nC: d\r\nContent-Type: application/json\r\n\r\n", + }, + { + wantTag: "some tag", + wantErr: false, + wantBody: "GET /4 HTTP/1.1\r\nHost: other.net\r\nA: \r\nC: d\r\nContent-Type: application/json\r\n\r\n", + }, } + for j := 0; j < 2; j++ { - for i, want := range wants { - ammo, err := decoder.Scan(ctx) - assert.NoError(t, err, "iteration %d-%d", j, i) - assert.Equal(t, want, ammo, "iteration %d-%d", j, i) + for i, tt := range tests { + req, tag, err := decoder.Scan(ctx) + if tt.wantErr { + assert.Error(t, err, "iteration %d-%d", j, i) + continue + } else { + assert.NoError(t, err, "iteration %d-%d", j, i) + } + assert.Equal(t, tt.wantTag, tag, "iteration %d-%d", j, i) + + req.Close = false + body, _ := httputil.DumpRequest(req, true) + assert.Equal(t, tt.wantBody, string(body), "iteration %d-%d", j, i) } } - _, err := decoder.Scan(ctx) + _, _, err := decoder.Scan(ctx) assert.Equal(t, err, ErrAmmoLimit) - assert.Equal(t, decoder.ammoNum, uint(len(wants)*2)) + assert.Equal(t, decoder.ammoNum, uint(len(tests)*2)) assert.Equal(t, decoder.passNum, uint(1)) } diff --git a/components/providers/http/decoders/uripost.go b/components/providers/http/decoders/uripost.go index 717b9238e..28c4eb6dd 100644 --- a/components/providers/http/decoders/uripost.go +++ b/components/providers/http/decoders/uripost.go @@ -2,14 +2,13 @@ package decoders import ( "bufio" + "bytes" "context" "errors" "io" "net/http" - "net/url" "strings" - "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" "github.com/yandex/pandora/components/providers/http/decoders/uripost" "github.com/yandex/pandora/components/providers/http/util" @@ -35,26 +34,26 @@ type uripostDecoder struct { line uint } -func (d *uripostDecoder) Scan(ctx context.Context) (*base.Ammo, error) { +func (d *uripostDecoder) Scan(ctx context.Context) (*http.Request, string, error) { if d.config.Limit != 0 && d.ammoNum >= d.config.Limit { - return nil, ErrAmmoLimit + return nil, "", ErrAmmoLimit } for i := 0; i < 2; i++ { for { if ctx.Err() != nil { - return nil, ctx.Err() + return nil, "", ctx.Err() } - req, err := d.readBlock(d.reader, d.header) + req, tag, err := d.readBlock(d.reader, d.header) if err == io.EOF { break } if err != nil { - return nil, err + return nil, "", err } if req != nil { d.ammoNum++ - return req, nil + return req, tag, nil } // here only if read header } @@ -62,64 +61,72 @@ func (d *uripostDecoder) Scan(ctx context.Context) (*base.Ammo, error) { // seek file d.passNum++ if d.config.Passes != 0 && d.passNum >= d.config.Passes { - return nil, ErrPassLimit + return nil, "", ErrPassLimit } if d.ammoNum == 0 { - return nil, ErrNoAmmo + return nil, "", ErrNoAmmo } d.header = make(http.Header) _, err := d.file.Seek(0, io.SeekStart) if err != nil { - return nil, err + return nil, "", err } d.reader.Reset(d.file) } - return nil, errors.New("unexpected behavior") + return nil, "", errors.New("unexpected behavior") } // readBlock read one header at time and set to commonHeader or read full request -func (d *uripostDecoder) readBlock(reader *bufio.Reader, commonHeader http.Header) (*base.Ammo, error) { +func (d *uripostDecoder) readBlock(reader *bufio.Reader, commonHeader http.Header) (*http.Request, string, error) { data, err := reader.ReadString('\n') if err != nil { - return nil, err + return nil, "", err } data = strings.TrimSpace(data) if len(data) == 0 { - return nil, nil // skip empty lines + return nil, "", nil // skip empty lines } if data[0] == '[' { key, val, err := util.DecodeHeader(data) if err != nil { - return nil, err + return nil, "", err } commonHeader.Set(key, val) - return nil, nil + return nil, "", nil } bodySize, uri, tag, err := uripost.DecodeURI(data) if err != nil { - return nil, err - } - _, err = url.Parse(uri) - if err != nil { - return nil, err + return nil, "", err } + var buffReader io.Reader buff := make([]byte, bodySize) if bodySize != 0 { if n, err := io.ReadFull(reader, buff); err != nil { err = xerrors.Errorf("failed to read ammo with err: %w, at position: %v; tried to read: %v; have read: %v", err, filePosition(d.file), bodySize, n) - return nil, err + return nil, "", err } + buffReader = bytes.NewReader(buff) + } + req, err := http.NewRequest("POST", uri, buffReader) + if err != nil { + err = xerrors.Errorf("failed to decode ammo with err: %w, at position: %v; data: %q", err, filePosition(d.file), buff) + return nil, "", err } - header := commonHeader.Clone() - for k, vv := range d.decodedConfigHeaders { - for _, v := range vv { - header.Set(k, v) + for k, v := range commonHeader { + // http.Request.Write sends Host header based on req.URL.Host + if k == "Host" { + req.Host = v[0] + } else { + req.Header[k] = v } } - return base.NewAmmo("POST", uri, buff, header, tag) + // add new Headers to request from config + util.EnrichRequestWithHeaders(req, d.decodedConfigHeaders) + + return req, tag, nil } diff --git a/components/providers/http/decoders/uripost_test.go b/components/providers/http/decoders/uripost_test.go index db8f43d3c..5494a6655 100644 --- a/components/providers/http/decoders/uripost_test.go +++ b/components/providers/http/decoders/uripost_test.go @@ -3,22 +3,16 @@ package decoders import ( "context" "net/http" + "net/http/httputil" "strings" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" ) func Test_uripostDecoder_Scan(t *testing.T) { - var mustNewAmmo = func(t *testing.T, method string, url string, body []byte, header http.Header, tag string) *base.Ammo { - ammo, err := base.NewAmmo(method, url, body, header, tag) - require.NoError(t, err) - return ammo - } input := `5 /0 class [A:b] @@ -42,22 +36,51 @@ classclassclass ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second) defer cancel() - wants := []*base.Ammo{ - mustNewAmmo(t, "POST", "/0", []byte("class"), http.Header{"Content-Type": []string{"application/json"}}, ""), - mustNewAmmo(t, "POST", "/1", []byte("class"), http.Header{"A": []string{"b"}, "Content-Type": []string{"application/json"}}, ""), - mustNewAmmo(t, "POST", "/2", []byte("classclass"), http.Header{"Host": []string{"example.com"}, "A": []string{"b"}, "C": []string{"d"}, "Content-Type": []string{"application/json"}}, ""), - mustNewAmmo(t, "POST", "/3", []byte("classclassclass"), http.Header{"Host": []string{"other.net"}, "A": []string{""}, "C": []string{"d"}, "Content-Type": []string{"application/json"}}, "wantTag"), + tests := []struct { + wantTag string + wantErr bool + wantBody string + }{ + { + wantTag: "", + wantErr: false, + wantBody: "POST /0 HTTP/1.1\r\nContent-Type: application/json\r\n\r\nclass", + }, + { + wantTag: "", + wantErr: false, + wantBody: "POST /1 HTTP/1.1\r\nA: b\r\nContent-Type: application/json\r\n\r\nclass", + }, + { + wantTag: "", + wantErr: false, + wantBody: "POST /2 HTTP/1.1\r\nHost: example.com\r\nA: b\r\nC: d\r\nContent-Type: application/json\r\n\r\nclassclass", + }, + { + wantTag: "wantTag", + wantErr: false, + wantBody: "POST /3 HTTP/1.1\r\nHost: other.net\r\nA: \r\nC: d\r\nContent-Type: application/json\r\n\r\nclassclassclass", + }, } for j := 0; j < 2; j++ { - for i, want := range wants { - ammo, err := decoder.Scan(ctx) - assert.NoError(t, err, "iteration %d-%d", j, i) - assert.Equal(t, want, ammo, "iteration %d-%d", j, i) + for i, tt := range tests { + req, tag, err := decoder.Scan(ctx) + if tt.wantErr { + assert.Error(t, err, "iteration %d-%d", j, i) + continue + } else { + assert.NoError(t, err, "iteration %d-%d", j, i) + } + assert.Equal(t, tt.wantTag, tag, "iteration %d-%d", j, i) + + req.Close = false + body, _ := httputil.DumpRequest(req, true) + assert.Equal(t, tt.wantBody, string(body), "iteration %d-%d", j, i) } } - _, err := decoder.Scan(ctx) + _, _, err := decoder.Scan(ctx) assert.Equal(t, err, ErrAmmoLimit) - assert.Equal(t, decoder.ammoNum, uint(len(wants)*2)) + assert.Equal(t, decoder.ammoNum, uint(len(tests)*2)) assert.Equal(t, decoder.passNum, uint(1)) } diff --git a/components/providers/http/provider.go b/components/providers/http/provider.go index d88c0bff6..b910a33d7 100644 --- a/components/providers/http/provider.go +++ b/components/providers/http/provider.go @@ -7,6 +7,7 @@ package http import ( "bytes" "io" + "net/http" "strings" "sync" @@ -47,8 +48,8 @@ func NewProvider(fs afero.Fs, conf config.Config) (core.Provider, error) { Config: conf, Decoder: decoder, Close: closer.Close, - AmmoPool: sync.Pool{New: func() interface{} { return new(base.Ammo) }}, - Sink: make(chan *base.Ammo), + AmmoPool: sync.Pool{New: func() interface{} { return new(base.Ammo[http.Request]) }}, + Sink: make(chan *base.Ammo[http.Request]), }, nil } diff --git a/components/providers/http/provider/provider.go b/components/providers/http/provider/provider.go index 47fd6e629..dfd1671d4 100644 --- a/components/providers/http/provider/provider.go +++ b/components/providers/http/provider/provider.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net/http" "sync" "github.com/yandex/pandora/components/providers/base" @@ -23,7 +24,7 @@ type Provider struct { Close func() error AmmoPool sync.Pool - Sink chan *base.Ammo + Sink chan *base.Ammo[http.Request] } func (p *Provider) Acquire() (core.Ammo, bool) { @@ -31,14 +32,10 @@ func (p *Provider) Acquire() (core.Ammo, bool) { if ok { ammo.SetID(p.NextID()) } - if err := ammo.BuildRequest(); err != nil { - p.Deps.Log.Error("http build request error", zap.Error(err)) - return ammo, false - } for _, mw := range p.Middlewares { err := mw.UpdateRequest(ammo.Req) if err != nil { - p.Deps.Log.Error("error on Middleware.UpdateRequest", zap.Error(err)) + p.Log.Error("error on Middleware.UpdateRequest", zap.Error(err)) return ammo, false } } @@ -46,15 +43,18 @@ func (p *Provider) Acquire() (core.Ammo, bool) { } func (p *Provider) Release(a core.Ammo) { - ammo := a.(*base.Ammo) - ammo.Reset() + ammo := a.(*base.Ammo[http.Request]) + // TODO: add request release for example for future fasthttp + // ammo.Req.Body = nil + ammo.Req = nil p.AmmoPool.Put(ammo) } func (p *Provider) Run(ctx context.Context, deps core.ProviderDeps) (err error) { - var ammo *base.Ammo + var req *http.Request + var tag string - p.Deps = deps + p.ProviderDeps = deps defer func() { // TODO: wrap in go 1.20 // err = errors.Join(err, p.Close()) @@ -84,8 +84,8 @@ func (p *Provider) Run(ctx context.Context, deps core.ProviderDeps) (err error) } return } - ammo, err = p.Decoder.Scan(ctx) - if !confutil.IsChosenCase(ammo.Tag(), p.Config.ChosenCases) { + req, tag, err = p.Decoder.Scan(ctx) + if !confutil.IsChosenCase(tag, p.Config.ChosenCases) { continue } if err != nil { @@ -95,6 +95,8 @@ func (p *Provider) Run(ctx context.Context, deps core.ProviderDeps) (err error) return } + a := p.AmmoPool.Get().(*base.Ammo[http.Request]) + a.Reset(req, tag) select { case <-ctx.Done(): err = ctx.Err() @@ -102,7 +104,7 @@ func (p *Provider) Run(ctx context.Context, deps core.ProviderDeps) (err error) err = xerrors.Errorf("error from context: %w", err) } return - case p.Sink <- ammo: + case p.Sink <- a: } } } diff --git a/lib/netutil/validator.go b/lib/netutil/validator.go deleted file mode 100644 index 094fce25a..000000000 --- a/lib/netutil/validator.go +++ /dev/null @@ -1,35 +0,0 @@ -package netutil - -import ( - "strings" - - "golang.org/x/net/http/httpguts" -) - -func isNotToken(r rune) bool { - return !httpguts.IsTokenRune(r) -} - -// ValidHTTPMethod just copy net/http/request.go validMethod(method string) bool -func ValidHTTPMethod(method string) bool { - if method == "" { - // We document that "" means "GET" for Request.Method, and people have - // relied on that from NewRequest, so keep that working. - // We still enforce validMethod for non-empty methods. - method = "GET" - } - /* - Method = "OPTIONS" ; Section 9.2 - | "GET" ; Section 9.3 - | "HEAD" ; Section 9.4 - | "POST" ; Section 9.5 - | "PUT" ; Section 9.6 - | "DELETE" ; Section 9.7 - | "TRACE" ; Section 9.8 - | "CONNECT" ; Section 9.9 - | extension-method - extension-method = token - token = 1* - */ - return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 -} diff --git a/lib/testutil/testing.go b/lib/testutil/testing.go new file mode 100644 index 000000000..0abce8d65 --- /dev/null +++ b/lib/testutil/testing.go @@ -0,0 +1,8 @@ +package testutil + +func Must[T any](obj T, err error) T { + if err != nil { + panic(err) + } + return obj +}