diff --git a/components/providers/base/ammo.go b/components/providers/base/ammo.go index a417a8d68..5293cf168 100644 --- a/components/providers/base/ammo.go +++ b/components/providers/base/ammo.go @@ -1,47 +1,110 @@ package base -import "github.com/yandex/pandora/core/aggregator/netsample" +import ( + "bytes" + "errors" + "fmt" + "io" + "net/http" + urlpkg "net/url" -type Ammo[R any] struct { - Req *R - tag string - id uint64 - isInvalid bool + "github.com/yandex/pandora/components/providers/http/util" + "github.com/yandex/pandora/core/aggregator/netsample" + "github.com/yandex/pandora/lib/netutil" +) + +func NewAmmo(method string, url string, body []byte, header http.Header, tag string) (*Ammo, error) { + if ok := netutil.ValidHTTPMethod(method); !ok { + return nil, errors.New("invalid HTTP method " + method) + } + if _, err := urlpkg.Parse(url); err != nil { + return nil, fmt.Errorf("invalid URL %s; err %w ", url, err) + } + return &Ammo{ + method: method, + body: body, + url: url, + tag: tag, + header: header, + constructor: true, + }, nil +} + +type Ammo struct { + Req *http.Request + method string + body []byte + url string + tag string + header http.Header + id uint64 + isInvalid bool + constructor bool } -func (a *Ammo[R]) Request() (*R, *netsample.Sample) { +func (a *Ammo) Request() (*http.Request, *netsample.Sample) { + if a.Req == nil { + _ = a.BuildRequest() // TODO: what if error. There isn't a logger + } sample := netsample.Acquire(a.Tag()) sample.SetID(a.ID()) return a.Req, sample } -func (a *Ammo[R]) Reset(req *R, tag string) { - a.Req = req - a.tag = tag - a.id = 0 - a.isInvalid = false -} - -func (a *Ammo[_]) SetID(id uint64) { +func (a *Ammo) SetID(id uint64) { a.id = id } -func (a *Ammo[_]) ID() uint64 { +func (a *Ammo) ID() uint64 { return a.id } -func (a *Ammo[_]) Invalidate() { +func (a *Ammo) Invalidate() { a.isInvalid = true } -func (a *Ammo[_]) IsInvalid() bool { +func (a *Ammo) IsInvalid() bool { return a.isInvalid } -func (a *Ammo[_]) IsValid() bool { +func (a *Ammo) IsValid() bool { return !a.isInvalid } -func (a *Ammo[_]) Tag() string { +func (a *Ammo) SetTag(tag string) { + a.tag = tag +} + +func (a *Ammo) Tag() string { return a.tag } + +func (a *Ammo) FromConstructor() bool { + return a.constructor +} + +// use NewAmmo() for skipping error here +func (a *Ammo) BuildRequest() error { + var buff io.Reader + if a.body != nil { + buff = bytes.NewReader(a.body) + } + req, err := http.NewRequest(a.method, a.url, buff) + if err != nil { + return fmt.Errorf("cant create request: %w", err) + } + a.Req = req + util.EnrichRequestWithHeaders(req, a.header) + return nil +} + +func (a *Ammo) Reset() { + a.Req = nil + a.method = "" + a.body = nil + a.url = "" + a.tag = "" + a.header = nil + a.id = 0 + a.isInvalid = false +} diff --git a/components/providers/base/decoder.go b/components/providers/base/decoder.go deleted file mode 100644 index 1e71484ba..000000000 --- a/components/providers/base/decoder.go +++ /dev/null @@ -1,17 +0,0 @@ -package base - -import "sync" - -type Decoder[R any] struct { - Sink chan<- *Ammo[R] - Pool *sync.Pool -} - -func NewDecoder[R any](sink chan<- *Ammo[R]) Decoder[R] { - return Decoder[R]{ - Sink: sink, - Pool: &sync.Pool{New: func() any { - return new(Ammo[R]) - }}, - } -} diff --git a/components/providers/base/provider.go b/components/providers/base/provider.go index 5f7096146..948e88090 100644 --- a/components/providers/base/provider.go +++ b/components/providers/base/provider.go @@ -8,7 +8,7 @@ import ( ) type ProviderBase struct { - core.ProviderDeps + Deps core.ProviderDeps FS afero.Fs idCounter atomic.Uint64 } diff --git a/components/providers/http/ammo.go b/components/providers/http/ammo.go index 1ef3da677..780be9ee2 100644 --- a/components/providers/http/ammo.go +++ b/components/providers/http/ammo.go @@ -16,4 +16,4 @@ type Request interface { http.Request } -var _ phttp.Ammo = (*base.Ammo[http.Request])(nil) +var _ phttp.Ammo = (*base.Ammo)(nil) diff --git a/components/providers/http/decoders/decoder.go b/components/providers/http/decoders/decoder.go index 17975b1d3..a2e20aa29 100644 --- a/components/providers/http/decoders/decoder.go +++ b/components/providers/http/decoders/decoder.go @@ -6,6 +6,7 @@ import ( "io" "net/http" + "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" "github.com/yandex/pandora/components/providers/http/util" ) @@ -24,7 +25,7 @@ var ( type Decoder interface { // Decode(context.Context, chan<- *base.Ammo[http.Request], io.ReadSeeker) error - Scan(context.Context) (*http.Request, string, error) + Scan(context.Context) (*base.Ammo, error) } type protoDecoder struct { diff --git a/components/providers/http/decoders/jsonline.go b/components/providers/http/decoders/jsonline.go index 0d972c5e4..939029888 100644 --- a/components/providers/http/decoders/jsonline.go +++ b/components/providers/http/decoders/jsonline.go @@ -7,9 +7,9 @@ import ( "net/http" "strings" + "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" "github.com/yandex/pandora/components/providers/http/decoders/jsonline" - "github.com/yandex/pandora/components/providers/http/util" "golang.org/x/xerrors" ) @@ -35,53 +35,51 @@ type jsonlineDecoder struct { line uint } -func (d *jsonlineDecoder) Scan(ctx context.Context) (*http.Request, string, error) { +func (d *jsonlineDecoder) Scan(ctx context.Context) (*base.Ammo, error) { if d.config.Limit != 0 && d.ammoNum >= d.config.Limit { - return nil, "", ErrAmmoLimit + return nil, ErrAmmoLimit } - for ; ; d.line++ { - if ctx.Err() != nil { - return nil, "", ctx.Err() + for { + if d.config.Passes != 0 && d.passNum >= d.config.Passes { + return nil, ErrPassLimit } - if !d.scanner.Scan() { - if d.scanner.Err() == nil { // assume as io.EOF; FIXME: check possible nil error with other reason - d.line = 0 - d.passNum++ - if d.config.Passes != 0 && d.passNum >= d.config.Passes { - return nil, "", ErrPassLimit - } - if d.ammoNum == 0 { - return nil, "", ErrNoAmmo - } - _, err := d.file.Seek(0, io.SeekStart) - if err != nil { - return nil, "", err - } - d.scanner = bufio.NewScanner(d.file) - if d.config.MaxAmmoSize != 0 { - var buffer []byte - d.scanner.Buffer(buffer, d.config.MaxAmmoSize) - } + for d.scanner.Scan() { + d.line++ + data := d.scanner.Bytes() + if len(strings.TrimSpace(string(data))) == 0 { continue } - return nil, "", d.scanner.Err() + d.ammoNum++ + ammo, err := jsonline.DecodeAmmo(data, d.decodedConfigHeaders) + if err != nil { + if !d.config.ContinueOnError { + return nil, xerrors.Errorf("failed to decode ammo at line: %v; data: %q, with err: %w", d.line+1, data, err) + } + // TODO: add log message about error + continue // skipping ammo + } + return ammo, err } - data := d.scanner.Bytes() - if len(strings.TrimSpace(string(data))) == 0 { - continue + + err := d.scanner.Err() + if err != nil { + return nil, err + } + if d.ammoNum == 0 { + return nil, ErrNoAmmo } - d.ammoNum++ + d.line = 0 + d.passNum++ - req, tag, err := jsonline.DecodeAmmo(data) + _, err = d.file.Seek(0, io.SeekStart) if err != nil { - if !d.config.ContinueOnError { - return nil, "", xerrors.Errorf("failed to decode ammo at line: %v; data: %q, with err: %w", d.line+1, data, err) - } - // TODO: add log message about error - continue // skipping ammo + return nil, err + } + d.scanner = bufio.NewScanner(d.file) + if d.config.MaxAmmoSize != 0 { + var buffer []byte + d.scanner.Buffer(buffer, d.config.MaxAmmoSize) } - util.EnrichRequestWithHeaders(req, d.decodedConfigHeaders) - return req, tag, err } } diff --git a/components/providers/http/decoders/jsonline/data.go b/components/providers/http/decoders/jsonline/data.go index 2fb4f5594..4e24feba1 100644 --- a/components/providers/http/decoders/jsonline/data.go +++ b/components/providers/http/decoders/jsonline/data.go @@ -1,12 +1,12 @@ -//go:generate github.com/pquerna/ffjson data_ffjson.go +//go:generate github.com/pquerna/ffjson@latest data_ffjson.go package jsonline import ( "net/http" - "strings" "github.com/pkg/errors" + "github.com/yandex/pandora/components/providers/base" ) // ffjson: noencoder @@ -24,24 +24,21 @@ type data struct { Body string `json:"body"` } -func (d *data) ToRequest() (*http.Request, error) { - uri := "http://" + d.Host + d.URI - req, err := http.NewRequest(d.Method, uri, strings.NewReader(d.Body)) - if err != nil { - return nil, errors.WithStack(err) +func DecodeAmmo(jsonDoc []byte, baseHeader http.Header) (*base.Ammo, error) { + var d = new(data) + if err := d.UnmarshalJSON(jsonDoc); err != nil { + err = errors.WithStack(err) + return nil, err } + + header := baseHeader.Clone() for k, v := range d.Headers { - req.Header.Set(k, v) + header.Set(k, v) } - return req, err -} - -func DecodeAmmo(jsonDoc []byte) (*http.Request, string, error) { - var data = new(data) - if err := data.UnmarshalJSON(jsonDoc); err != nil { - err = errors.WithStack(err) - return nil, data.Tag, err + url := "http://" + d.Host + d.URI + var body []byte + if d.Body != "" { + body = []byte(d.Body) } - req, err := data.ToRequest() - return req, data.Tag, err + return base.NewAmmo(d.Method, url, body, header, d.Tag) } diff --git a/components/providers/http/decoders/jsonline/data_test.go b/components/providers/http/decoders/jsonline/data_test.go index 60ecd564f..585cdc48e 100644 --- a/components/providers/http/decoders/jsonline/data_test.go +++ b/components/providers/http/decoders/jsonline/data_test.go @@ -1,117 +1,60 @@ package jsonline import ( - "context" - "encoding/json" "net/http" - "net/url" "testing" - "testing/iotest" "github.com/stretchr/testify/assert" - "github.com/yandex/pandora/lib/testutil" + "github.com/stretchr/testify/require" + "github.com/yandex/pandora/components/providers/base" ) -const testFile = "./ammo.jsonline" - -// testData holds jsonline.data that contains in testFile -var testData = []data{ - { - Host: "example.com", - Method: "GET", - URI: "/00", - Headers: map[string]string{"Accept": "*/*", "Accept-Encoding": "gzip, deflate", "User-Agent": "Pandora/0.0.1"}, - }, - { - Host: "ya.ru", - Method: "HEAD", - URI: "/01", - Headers: map[string]string{"Accept": "*/*", "Accept-Encoding": "gzip, brotli", "User-Agent": "YaBro/0.1"}, - Tag: "head", - }, -} - -type ToRequestWant struct { - req *http.Request - error - body []byte -} - func TestToRequest(t *testing.T) { var tests = []struct { - name string - input data - want ToRequestWant + name string + json []byte + confHeader http.Header + want *base.Ammo + wantErr bool }{ { - name: "decoded well", - input: data{ - Host: "ya.ru", - Method: "GET", - URI: "/00", - Headers: map[string]string{"A": "a", "B": "b"}, - Tag: "tag", - }, - want: ToRequestWant{ - req: &http.Request{ - Method: "GET", - URL: testutil.Must(url.Parse("http://ya.ru/00")), - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: http.Header{"A": []string{"a"}, "B": []string{"b"}}, - Body: http.NoBody, - Host: "ya.ru", - }, - error: nil, - }, + name: "GET request", + json: []byte(`{"host": "ya.ru", "method": "GET", "uri": "/00", "tag": "tag", "headers": {"A": "a", "B": "b"}}`), + confHeader: http.Header{"Default": []string{"def"}}, + want: MustNewAmmo(t, "GET", "http://ya.ru/00", nil, http.Header{"Default": []string{"def"}, "A": []string{"a"}, "B": []string{"b"}}, "tag"), + wantErr: false, + }, + { + name: "POST request", + json: []byte(`{"host": "ya.ru", "method": "POST", "uri": "/01?sleep=10", "tag": "tag", "headers": {"A": "a", "B": "b"}, "body": "body"}`), + confHeader: http.Header{"Default": []string{"def"}}, + want: MustNewAmmo(t, "POST", "http://ya.ru/01?sleep=10", []byte(`body`), http.Header{"Default": []string{"def"}, "A": []string{"a"}, "B": []string{"b"}}, "tag"), + wantErr: false, + }, + { + name: "POST request with json", + json: []byte(`{"host": "ya.ru", "method": "POST", "uri": "/01?sleep=10", "tag": "tag", "headers": {"A": "a", "B": "b"}, "body": "{\"field\":\"value\"}"}`), + confHeader: http.Header{"Default": []string{"def"}}, + want: MustNewAmmo(t, "POST", "http://ya.ru/01?sleep=10", []byte(`{"field":"value"}`), http.Header{"Default": []string{"def"}, "A": []string{"a"}, "B": []string{"b"}}, "tag"), + wantErr: false, }, } - var ans ToRequestWant for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert := assert.New(t) - ans.req, ans.error = tt.input.ToRequest() - if tt.want.body != nil { - assert.NotNil(ans.req) - assert.NoError(iotest.TestReader(ans.req.Body, tt.want.body)) - ans.req.Body = nil - tt.want.body = nil + ans, err := DecodeAmmo(tt.json, tt.confHeader) + if tt.wantErr { + require.Error(t, err) + return } - ans.req.GetBody = nil - tt.want.req = ans.req.WithContext(context.Background()) + assert.NoError(err) assert.Equal(tt.want, ans) }) } } -type DecodeAmmoWant struct { - req *http.Request - tag string - error -} - -func TestDecodeAmmo(t *testing.T) { - var tests = []struct { - name string - input []byte - want DecodeAmmoWant - }{} - var ans DecodeAmmoWant - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert := assert.New(t) - ans.req, ans.tag, ans.error = DecodeAmmo(tt.input) - assert.Equal(tt.want, ans) - }) - } -} - -func BenchmarkDecodeAmmo(b *testing.B) { - jsonDoc, err := json.Marshal(testData[0]) - assert.NoError(b, err) - b.ResetTimer() - for n := 0; n < b.N; n++ { - _, _, _ = DecodeAmmo(jsonDoc) - } +func MustNewAmmo(t *testing.T, method string, url string, body []byte, header http.Header, tag string) *base.Ammo { + ammo, err := base.NewAmmo(method, url, body, header, tag) + require.NoError(t, err) + return ammo } diff --git a/components/providers/http/decoders/jsonline_test.go b/components/providers/http/decoders/jsonline_test.go index fe7ae755c..947e0b4ed 100644 --- a/components/providers/http/decoders/jsonline_test.go +++ b/components/providers/http/decoders/jsonline_test.go @@ -3,18 +3,25 @@ package decoders import ( "context" "net/http" - "net/http/httputil" "strings" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" ) func Test_jsonlineDecoder_Scan(t *testing.T) { - input := `{"host": "4bs65mu2kdulxmir.myt.yp-c.yandex.net", "method": "GET", "uri": "/?sleep=100", "tag": "sleep1", "headers": {"User-agent": "Tank", "Connection": "close"}} -{"host": "4bs65mu2kdulxmir.myt.yp-c.yandex.net", "method": "POST", "uri": "/?sleep=200", "tag": "sleep2", "headers": {"User-agent": "Tank", "Connection": "close"}, "body": "body_data"} + var mustNewAmmo = func(t *testing.T, method string, url string, body []byte, header http.Header, tag string) *base.Ammo { + ammo, err := base.NewAmmo(method, url, body, header, tag) + require.NoError(t, err) + return ammo + } + + input := `{"host": "ya.net", "method": "GET", "uri": "/?sleep=100", "tag": "sleep1", "headers": {"User-agent": "Tank", "Connection": "close"}} +{"host": "ya.net", "method": "POST", "uri": "/?sleep=200", "tag": "sleep2", "headers": {"User-agent": "Tank", "Connection": "close"}, "body": "body_data"} ` @@ -26,41 +33,32 @@ func Test_jsonlineDecoder_Scan(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - tests := []struct { - wantTag string - wantErr bool - wantBody string - }{ - { - wantTag: "sleep1", - wantErr: false, - wantBody: "GET /?sleep=100 HTTP/1.1\r\nHost: 4bs65mu2kdulxmir.myt.yp-c.yandex.net\r\nConnection: close\r\nContent-Type: application/json\r\nUser-Agent: Tank\r\n\r\n", - }, - { - wantTag: "sleep2", - wantErr: false, - wantBody: "POST /?sleep=200 HTTP/1.1\r\nHost: 4bs65mu2kdulxmir.myt.yp-c.yandex.net\r\nConnection: close\r\nContent-Type: application/json\r\nUser-Agent: Tank\r\n\r\nbody_data", - }, + wants := []*base.Ammo{ + mustNewAmmo(t, + "GET", + "http://ya.net/?sleep=100", + nil, + http.Header{"Connection": []string{"close"}, "Content-Type": []string{"application/json"}, "User-Agent": []string{"Tank"}}, + "sleep1", + ), + mustNewAmmo(t, + "POST", + "http://ya.net/?sleep=200", + []byte("body_data"), + http.Header{"Connection": []string{"close"}, "Content-Type": []string{"application/json"}, "User-Agent": []string{"Tank"}}, + "sleep2", + ), } for j := 0; j < 2; j++ { - for i, tt := range tests { - req, tag, err := decoder.Scan(ctx) - if tt.wantErr { - assert.Error(t, err, "iteration %d-%d", j, i) - continue - } else { - assert.NoError(t, err, "iteration %d-%d", j, i) - } - assert.Equal(t, tt.wantTag, tag, "iteration %d-%d", j, i) - - req.Close = false - body, _ := httputil.DumpRequest(req, true) - assert.Equal(t, tt.wantBody, string(body), "iteration %d-%d", j, i) + for i, want := range wants { + ammo, err := decoder.Scan(ctx) + assert.NoError(t, err, "iteration %d-%d", j, i) + assert.Equal(t, want, ammo, "iteration %d-%d", j, i) } } - _, _, err := decoder.Scan(ctx) + _, err := decoder.Scan(ctx) assert.Equal(t, err, ErrAmmoLimit) - assert.Equal(t, decoder.ammoNum, uint(len(tests)*2)) + assert.Equal(t, decoder.ammoNum, uint(len(wants)*2)) assert.Equal(t, decoder.passNum, uint(1)) } diff --git a/components/providers/http/decoders/raw.go b/components/providers/http/decoders/raw.go index 0ace2dafb..fd7ab470a 100644 --- a/components/providers/http/decoders/raw.go +++ b/components/providers/http/decoders/raw.go @@ -7,6 +7,7 @@ import ( "net/http" "strings" + "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" "github.com/yandex/pandora/components/providers/http/decoders/raw" "github.com/yandex/pandora/components/providers/http/util" @@ -58,38 +59,38 @@ type rawDecoder struct { reader *bufio.Reader } -func (d *rawDecoder) Scan(ctx context.Context) (*http.Request, string, error) { +func (d *rawDecoder) Scan(ctx context.Context) (*base.Ammo, error) { var data string var buff []byte var req *http.Request var err error if d.config.Limit != 0 && d.ammoNum >= d.config.Limit { - return nil, "", ErrAmmoLimit + return nil, ErrAmmoLimit } for { if ctx.Err() != nil { - return nil, "", ctx.Err() + return nil, ctx.Err() } data, err = d.reader.ReadString('\n') if err == io.EOF { d.passNum++ if d.config.Passes != 0 && d.passNum >= d.config.Passes { - return nil, "", ErrPassLimit + return nil, ErrPassLimit } if d.ammoNum == 0 { - return nil, "", ErrNoAmmo + return nil, ErrNoAmmo } _, err := d.file.Seek(0, io.SeekStart) if err != nil { - return nil, "", err + return nil, err } d.reader.Reset(d.file) continue } if err != nil { - return nil, "", xerrors.Errorf("reading ammo failed with err: %w, at position: %v", err, filePosition(d.file)) + return nil, xerrors.Errorf("reading ammo failed with err: %w, at position: %v", err, filePosition(d.file)) } data = strings.TrimSpace(data) if len(data) == 0 { @@ -98,7 +99,7 @@ func (d *rawDecoder) Scan(ctx context.Context) (*http.Request, string, error) { d.ammoNum++ reqSize, tag, err := raw.DecodeHeader(data) if err != nil { - return nil, "", xerrors.Errorf("header decoding error: %w", err) + return nil, xerrors.Errorf("header decoding error: %w", err) } if reqSize != 0 { @@ -107,11 +108,11 @@ func (d *rawDecoder) Scan(ctx context.Context) (*http.Request, string, error) { } buff = buff[:reqSize] if n, err := io.ReadFull(d.reader, buff); err != nil { - return nil, "", xerrors.Errorf("failed to read ammo with err: %w, at position: %v; tried to read: %v; have read: %v", err, filePosition(d.file), reqSize, n) + return nil, xerrors.Errorf("failed to read ammo with err: %w, at position: %v; tried to read: %v; have read: %v", err, filePosition(d.file), reqSize, n) } req, err = raw.DecodeRequest(buff) if err != nil { - return nil, "", xerrors.Errorf("failed to decode ammo with err: %w, at position: %v; data: %q", err, filePosition(d.file), buff) + return nil, xerrors.Errorf("failed to decode ammo with err: %w, at position: %v; data: %q", err, filePosition(d.file), buff) } } else { req, _ = http.NewRequest("", "/", nil) @@ -119,6 +120,8 @@ func (d *rawDecoder) Scan(ctx context.Context) (*http.Request, string, error) { // add new Headers to request from config util.EnrichRequestWithHeaders(req, d.decodedConfigHeaders) - return req, tag, nil + ammo := &base.Ammo{Req: req} + ammo.SetTag(tag) + return ammo, nil } } diff --git a/components/providers/http/decoders/raw/decoder_bench_test.go b/components/providers/http/decoders/raw/decoder_bench_test.go index 0bb2dc09b..2f4c655dc 100644 --- a/components/providers/http/decoders/raw/decoder_bench_test.go +++ b/components/providers/http/decoders/raw/decoder_bench_test.go @@ -3,8 +3,8 @@ package raw import ( "testing" + "github.com/stretchr/testify/require" "github.com/yandex/pandora/components/providers/http/util" - "github.com/yandex/pandora/lib/testutil" ) var ( @@ -30,7 +30,8 @@ func BenchmarkRawDecoder(b *testing.B) { } func BenchmarkRawDecoderWithHeaders(b *testing.B) { - decodedHTTPConfigHeaders := testutil.Must(util.DecodeHTTPConfigHeaders(benchTestConfigHeaders)) + decodedHTTPConfigHeaders, err := util.DecodeHTTPConfigHeaders(benchTestConfigHeaders) + require.NoError(b, err) b.ResetTimer() for i := 0; i < b.N; i++ { req, _ := DecodeRequest(benchTestRequest) diff --git a/components/providers/http/decoders/raw/decoder_test.go b/components/providers/http/decoders/raw/decoder_test.go index 50d2cb3bf..3ae1eed01 100644 --- a/components/providers/http/decoders/raw/decoder_test.go +++ b/components/providers/http/decoders/raw/decoder_test.go @@ -8,7 +8,7 @@ import ( "testing/iotest" "github.com/stretchr/testify/assert" - "github.com/yandex/pandora/lib/testutil" + "github.com/stretchr/testify/require" ) type DecoderHeaderWant struct { @@ -64,7 +64,7 @@ func TestDecodeRequest(t *testing.T) { want: DecoderRequestWant{ &http.Request{ Method: "GET", - URL: testutil.Must(url.Parse("/some/path")), + URL: MustURL(t, "/some/path"), Proto: "HTTP/1.0", ProtoMajor: 1, ProtoMinor: 0, @@ -89,7 +89,7 @@ func TestDecodeRequest(t *testing.T) { want: DecoderRequestWant{ &http.Request{ Method: "POST", - URL: testutil.Must(url.Parse("/some/path")), + URL: MustURL(t, "/some/path"), Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, @@ -125,7 +125,7 @@ func TestDecodeRequest(t *testing.T) { want: DecoderRequestWant{ &http.Request{ Method: "GET", - URL: testutil.Must(url.Parse("/etc/passwd")), + URL: MustURL(t, "/etc/passwd"), Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, @@ -151,3 +151,9 @@ func TestDecodeRequest(t *testing.T) { }) } } + +func MustURL(t *testing.T, rawURL string) *url.URL { + url, err := url.Parse(rawURL) + require.NoError(t, err) + return url +} diff --git a/components/providers/http/decoders/raw_test.go b/components/providers/http/decoders/raw_test.go index c368f9f14..3dbb1f89f 100644 --- a/components/providers/http/decoders/raw_test.go +++ b/components/providers/http/decoders/raw_test.go @@ -13,9 +13,9 @@ import ( ) func Test_rawDecoder_Scan(t *testing.T) { - input := `68 good50 + input := `38 good50 GET /?sleep=50 HTTP/1.0 -Host: 4bs65mu2kdulxmir.myt.yp-c.yandex.net +Host: ya.net 74 bad @@ -58,7 +58,7 @@ User-Agent: xxx (shell 1) { wantTag: "good50", wantErr: false, - wantBody: "GET /?sleep=50 HTTP/1.0\r\nHost: 4bs65mu2kdulxmir.myt.yp-c.yandex.net\r\nContent-Type: application/json\r\n\r\n", + wantBody: "GET /?sleep=50 HTTP/1.0\r\nHost: ya.net\r\nContent-Type: application/json\r\n\r\n", }, { wantTag: "bad", @@ -78,22 +78,23 @@ User-Agent: xxx (shell 1) } for j := 0; j < 2; j++ { for i, tt := range tests { - req, tag, err := decoder.Scan(ctx) + ammo, err := decoder.Scan(ctx) if tt.wantErr { assert.Error(t, err, "iteration %d-%d", j, i) continue } else { assert.NoError(t, err, "iteration %d-%d", j, i) } - assert.Equal(t, tt.wantTag, tag, "iteration %d-%d", j, i) + assert.Equal(t, tt.wantTag, ammo.Tag(), "iteration %d-%d", j, i) + req := ammo.Req req.Close = false body, _ := httputil.DumpRequest(req, true) assert.Equal(t, tt.wantBody, string(body), "iteration %d-%d", j, i) } } - _, _, err := decoder.Scan(ctx) + _, err := decoder.Scan(ctx) assert.Equal(t, err, ErrAmmoLimit) assert.Equal(t, decoder.ammoNum, uint(len(tests)*2)) assert.Equal(t, decoder.passNum, uint(1)) diff --git a/components/providers/http/decoders/uri.go b/components/providers/http/decoders/uri.go index 8fc4278e6..3706ba2b6 100644 --- a/components/providers/http/decoders/uri.go +++ b/components/providers/http/decoders/uri.go @@ -6,8 +6,10 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" + "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" "github.com/yandex/pandora/components/providers/http/util" ) @@ -31,76 +33,72 @@ type uriDecoder struct { line uint } -func (d *uriDecoder) readLine(data string, commonHeader http.Header) (*http.Request, string, error) { +func (d *uriDecoder) readLine(data string, commonHeader http.Header) (*base.Ammo, error) { data = strings.TrimSpace(data) if len(data) == 0 { - return nil, "", nil // skip empty line + return nil, nil // skip empty line } - var req *http.Request - var tag string - var err error if data[0] == '[' { key, val, err := util.DecodeHeader(data) if err != nil { err = fmt.Errorf("decoding header error: %w", err) - return nil, "", err + return nil, err } commonHeader.Set(key, val) - } else { - var rawURL string - rawURL, tag, _ = strings.Cut(data, " ") - req, err = http.NewRequest("GET", rawURL, nil) - if err != nil { - err = fmt.Errorf("failed to decode uri: %w", err) - return nil, "", err - } - if host, ok := commonHeader["Host"]; ok { - req.Host = host[0] - } - req.Header = commonHeader.Clone() + return nil, nil + } - // add new Headers to request from config - util.EnrichRequestWithHeaders(req, d.decodedConfigHeaders) + var rawURL string + rawURL, tag, _ := strings.Cut(data, " ") + _, err := url.Parse(rawURL) + if err != nil { + return nil, err + } + header := commonHeader.Clone() + for k, vv := range d.decodedConfigHeaders { + for _, v := range vv { + header.Set(k, v) + } } - return req, tag, nil + return base.NewAmmo("GET", rawURL, nil, header, tag) } -func (d *uriDecoder) Scan(ctx context.Context) (*http.Request, string, error) { +func (d *uriDecoder) Scan(ctx context.Context) (*base.Ammo, error) { if d.config.Limit != 0 && d.ammoNum >= d.config.Limit { - return nil, "", ErrAmmoLimit + return nil, ErrAmmoLimit } for ; ; d.line++ { if ctx.Err() != nil { - return nil, "", ctx.Err() + return nil, ctx.Err() } if !d.scanner.Scan() { if d.scanner.Err() == nil { // assume as io.EOF; FIXME: check possible nil error with other reason d.line = 0 d.passNum++ if d.config.Passes != 0 && d.passNum >= d.config.Passes { - return nil, "", ErrPassLimit + return nil, ErrPassLimit } if d.ammoNum == 0 { - return nil, "", ErrNoAmmo + return nil, ErrNoAmmo } d.Header = http.Header{} _, err := d.file.Seek(0, io.SeekStart) if err != nil { - return nil, "", err + return nil, err } d.scanner = bufio.NewScanner(d.file) continue } - return nil, "", d.scanner.Err() + return nil, d.scanner.Err() } data := d.scanner.Text() - req, tag, err := d.readLine(data, d.Header) + ammo, err := d.readLine(data, d.Header) if err != nil { - return nil, "", fmt.Errorf("decode at line %d `%s` error: %w", d.line+1, data, err) + return nil, fmt.Errorf("decode at line %d `%s` error: %w", d.line+1, data, err) } - if req != nil { + if ammo != nil { d.ammoNum++ - return req, tag, nil + return ammo, nil } } } diff --git a/components/providers/http/decoders/uri_test.go b/components/providers/http/decoders/uri_test.go index 2b4dcdfb1..e32011ea4 100644 --- a/components/providers/http/decoders/uri_test.go +++ b/components/providers/http/decoders/uri_test.go @@ -3,31 +3,35 @@ package decoders import ( "context" "net/http" - "net/http/httputil" - "net/url" "strings" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" ) func Test_uriDecoder_readLine(t *testing.T) { + var mustNewAmmo = func(t *testing.T, method string, url string, body []byte, header http.Header, tag string) *base.Ammo { + ammo, err := base.NewAmmo(method, url, body, header, tag) + require.NoError(t, err) + return ammo + } + tests := []struct { name string data string - expectedReq *http.Request - expectedTag string - expectedErr bool + want *base.Ammo + wantErr bool expectedCommonHeaders http.Header }{ { - name: "Header line", - data: "[Content-Type: application/json]", - expectedReq: nil, - expectedTag: "", - expectedErr: false, + name: "Header line", + data: "[Content-Type: application/json]", + want: nil, + wantErr: false, expectedCommonHeaders: http.Header{ "Content-Type": []string{"application/json"}, "User-Agent": []string{"TestAgent"}, @@ -36,20 +40,11 @@ func Test_uriDecoder_readLine(t *testing.T) { { name: "Valid URI", data: "http://example.com/test", - expectedReq: &http.Request{ - Method: "GET", - Proto: "HTTP/1.1", - URL: &url.URL{Scheme: "http", Host: "example.com", Path: "/test"}, - Header: http.Header{ - "User-Agent": []string{"TestAgent"}, - "Authorization": []string{"Bearer xxx"}, - }, - Host: "example.com", - ProtoMajor: 1, - ProtoMinor: 1, - }, - expectedTag: "", - expectedErr: false, + want: mustNewAmmo(t, "GET", "http://example.com/test", nil, http.Header{ + "User-Agent": []string{"TestAgent"}, + "Authorization": []string{"Bearer xxx"}, + }, ""), + wantErr: false, expectedCommonHeaders: http.Header{ "User-Agent": []string{"TestAgent"}, }, @@ -57,30 +52,20 @@ func Test_uriDecoder_readLine(t *testing.T) { { name: "URI with tag", data: "http://example.com/test tag\n", - expectedReq: &http.Request{ - Method: "GET", - Proto: "HTTP/1.1", - URL: &url.URL{Scheme: "http", Host: "example.com", Path: "/test"}, - Header: http.Header{ - "User-Agent": []string{"TestAgent"}, - "Authorization": []string{"Bearer xxx"}, - }, - Host: "example.com", - ProtoMajor: 1, - ProtoMinor: 1, - }, - expectedTag: "tag", - expectedErr: false, + want: mustNewAmmo(t, "GET", "http://example.com/test", nil, http.Header{ + "User-Agent": []string{"TestAgent"}, + "Authorization": []string{"Bearer xxx"}, + }, "tag"), + wantErr: false, expectedCommonHeaders: http.Header{ "User-Agent": []string{"TestAgent"}, }, }, { - name: "Invalid data", - data: "1http://foo.com tag", - expectedReq: nil, - expectedTag: "", - expectedErr: true, + name: "Invalid data", + data: "1http://foo.com tag", + want: nil, + wantErr: true, }, } for _, test := range tests { @@ -89,19 +74,14 @@ func Test_uriDecoder_readLine(t *testing.T) { decodedConfigHeaders := http.Header{"Authorization": []string{"Bearer xxx"}} decoder := newURIDecoder(nil, config.Config{}, decodedConfigHeaders) - req, tag, err := decoder.readLine(test.data, commonHeader) - - if test.expectedReq != nil { - test.expectedReq = test.expectedReq.WithContext(context.Background()) - } + ammo, err := decoder.readLine(test.data, commonHeader) - if test.expectedErr { + if test.wantErr { assert.Error(t, err) return } - assert.Equal(t, test.expectedTag, tag) + assert.Equal(t, test.want, ammo) assert.Equal(t, test.expectedCommonHeaders, commonHeader) - assert.Equal(t, test.expectedReq, req) }) } } @@ -119,6 +99,12 @@ const uriInput = ` /0 /4 some tag` func Test_uriDecoder_Scan(t *testing.T) { + var mustNewAmmo = func(t *testing.T, method string, url string, body []byte, header http.Header, tag string) *base.Ammo { + ammo, err := base.NewAmmo(method, url, body, header, tag) + require.NoError(t, err) + return ammo + } + decoder := newURIDecoder(strings.NewReader(uriInput), config.Config{ Limit: 10, }, http.Header{"Content-Type": []string{"application/json"}}) @@ -126,57 +112,23 @@ func Test_uriDecoder_Scan(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - tests := []struct { - wantTag string - wantErr bool - wantBody string - }{ - { - wantTag: "", - wantErr: false, - wantBody: "GET /0 HTTP/1.1\r\nContent-Type: application/json\r\n\r\n", - }, - { - wantTag: "", - wantErr: false, - wantBody: "GET /1 HTTP/1.1\r\nA: b\r\nContent-Type: application/json\r\n\r\n", - }, - { - wantTag: "", - wantErr: false, - wantBody: "GET /2 HTTP/1.1\r\nHost: example.com\r\nA: b\r\nC: d\r\nContent-Type: application/json\r\n\r\n", - }, - { - wantTag: "", - wantErr: false, - wantBody: "GET /3 HTTP/1.1\r\nHost: other.net\r\nA: \r\nC: d\r\nContent-Type: application/json\r\n\r\n", - }, - { - wantTag: "some tag", - wantErr: false, - wantBody: "GET /4 HTTP/1.1\r\nHost: other.net\r\nA: \r\nC: d\r\nContent-Type: application/json\r\n\r\n", - }, + wants := []*base.Ammo{ + mustNewAmmo(t, "GET", "/0", nil, http.Header{"Content-Type": []string{"application/json"}}, ""), + mustNewAmmo(t, "GET", "/1", nil, http.Header{"A": []string{"b"}, "Content-Type": []string{"application/json"}}, ""), + mustNewAmmo(t, "GET", "/2", nil, http.Header{"Host": []string{"example.com"}, "A": []string{"b"}, "C": []string{"d"}, "Content-Type": []string{"application/json"}}, ""), + mustNewAmmo(t, "GET", "/3", nil, http.Header{"Host": []string{"other.net"}, "A": []string{""}, "C": []string{"d"}, "Content-Type": []string{"application/json"}}, ""), + mustNewAmmo(t, "GET", "/4", nil, http.Header{"Host": []string{"other.net"}, "A": []string{""}, "C": []string{"d"}, "Content-Type": []string{"application/json"}}, "some tag"), } - for j := 0; j < 2; j++ { - for i, tt := range tests { - req, tag, err := decoder.Scan(ctx) - if tt.wantErr { - assert.Error(t, err, "iteration %d-%d", j, i) - continue - } else { - assert.NoError(t, err, "iteration %d-%d", j, i) - } - assert.Equal(t, tt.wantTag, tag, "iteration %d-%d", j, i) - - req.Close = false - body, _ := httputil.DumpRequest(req, true) - assert.Equal(t, tt.wantBody, string(body), "iteration %d-%d", j, i) + for i, want := range wants { + ammo, err := decoder.Scan(ctx) + assert.NoError(t, err, "iteration %d-%d", j, i) + assert.Equal(t, want, ammo, "iteration %d-%d", j, i) } } - _, _, err := decoder.Scan(ctx) + _, err := decoder.Scan(ctx) assert.Equal(t, err, ErrAmmoLimit) - assert.Equal(t, decoder.ammoNum, uint(len(tests)*2)) + assert.Equal(t, decoder.ammoNum, uint(len(wants)*2)) assert.Equal(t, decoder.passNum, uint(1)) } diff --git a/components/providers/http/decoders/uripost.go b/components/providers/http/decoders/uripost.go index 28c4eb6dd..717b9238e 100644 --- a/components/providers/http/decoders/uripost.go +++ b/components/providers/http/decoders/uripost.go @@ -2,13 +2,14 @@ package decoders import ( "bufio" - "bytes" "context" "errors" "io" "net/http" + "net/url" "strings" + "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" "github.com/yandex/pandora/components/providers/http/decoders/uripost" "github.com/yandex/pandora/components/providers/http/util" @@ -34,26 +35,26 @@ type uripostDecoder struct { line uint } -func (d *uripostDecoder) Scan(ctx context.Context) (*http.Request, string, error) { +func (d *uripostDecoder) Scan(ctx context.Context) (*base.Ammo, error) { if d.config.Limit != 0 && d.ammoNum >= d.config.Limit { - return nil, "", ErrAmmoLimit + return nil, ErrAmmoLimit } for i := 0; i < 2; i++ { for { if ctx.Err() != nil { - return nil, "", ctx.Err() + return nil, ctx.Err() } - req, tag, err := d.readBlock(d.reader, d.header) + req, err := d.readBlock(d.reader, d.header) if err == io.EOF { break } if err != nil { - return nil, "", err + return nil, err } if req != nil { d.ammoNum++ - return req, tag, nil + return req, nil } // here only if read header } @@ -61,72 +62,64 @@ func (d *uripostDecoder) Scan(ctx context.Context) (*http.Request, string, error // seek file d.passNum++ if d.config.Passes != 0 && d.passNum >= d.config.Passes { - return nil, "", ErrPassLimit + return nil, ErrPassLimit } if d.ammoNum == 0 { - return nil, "", ErrNoAmmo + return nil, ErrNoAmmo } d.header = make(http.Header) _, err := d.file.Seek(0, io.SeekStart) if err != nil { - return nil, "", err + return nil, err } d.reader.Reset(d.file) } - return nil, "", errors.New("unexpected behavior") + return nil, errors.New("unexpected behavior") } // readBlock read one header at time and set to commonHeader or read full request -func (d *uripostDecoder) readBlock(reader *bufio.Reader, commonHeader http.Header) (*http.Request, string, error) { +func (d *uripostDecoder) readBlock(reader *bufio.Reader, commonHeader http.Header) (*base.Ammo, error) { data, err := reader.ReadString('\n') if err != nil { - return nil, "", err + return nil, err } data = strings.TrimSpace(data) if len(data) == 0 { - return nil, "", nil // skip empty lines + return nil, nil // skip empty lines } if data[0] == '[' { key, val, err := util.DecodeHeader(data) if err != nil { - return nil, "", err + return nil, err } commonHeader.Set(key, val) - return nil, "", nil + return nil, nil } bodySize, uri, tag, err := uripost.DecodeURI(data) if err != nil { - return nil, "", err + return nil, err + } + _, err = url.Parse(uri) + if err != nil { + return nil, err } - var buffReader io.Reader buff := make([]byte, bodySize) if bodySize != 0 { if n, err := io.ReadFull(reader, buff); err != nil { err = xerrors.Errorf("failed to read ammo with err: %w, at position: %v; tried to read: %v; have read: %v", err, filePosition(d.file), bodySize, n) - return nil, "", err + return nil, err } - buffReader = bytes.NewReader(buff) - } - req, err := http.NewRequest("POST", uri, buffReader) - if err != nil { - err = xerrors.Errorf("failed to decode ammo with err: %w, at position: %v; data: %q", err, filePosition(d.file), buff) - return nil, "", err } - for k, v := range commonHeader { - // http.Request.Write sends Host header based on req.URL.Host - if k == "Host" { - req.Host = v[0] - } else { - req.Header[k] = v + header := commonHeader.Clone() + for k, vv := range d.decodedConfigHeaders { + for _, v := range vv { + header.Set(k, v) } } - // add new Headers to request from config - util.EnrichRequestWithHeaders(req, d.decodedConfigHeaders) - - return req, tag, nil + return base.NewAmmo("POST", uri, buff, header, tag) } diff --git a/components/providers/http/decoders/uripost_test.go b/components/providers/http/decoders/uripost_test.go index 5494a6655..db8f43d3c 100644 --- a/components/providers/http/decoders/uripost_test.go +++ b/components/providers/http/decoders/uripost_test.go @@ -3,16 +3,22 @@ package decoders import ( "context" "net/http" - "net/http/httputil" "strings" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/yandex/pandora/components/providers/base" "github.com/yandex/pandora/components/providers/http/config" ) func Test_uripostDecoder_Scan(t *testing.T) { + var mustNewAmmo = func(t *testing.T, method string, url string, body []byte, header http.Header, tag string) *base.Ammo { + ammo, err := base.NewAmmo(method, url, body, header, tag) + require.NoError(t, err) + return ammo + } input := `5 /0 class [A:b] @@ -36,51 +42,22 @@ classclassclass ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second) defer cancel() - tests := []struct { - wantTag string - wantErr bool - wantBody string - }{ - { - wantTag: "", - wantErr: false, - wantBody: "POST /0 HTTP/1.1\r\nContent-Type: application/json\r\n\r\nclass", - }, - { - wantTag: "", - wantErr: false, - wantBody: "POST /1 HTTP/1.1\r\nA: b\r\nContent-Type: application/json\r\n\r\nclass", - }, - { - wantTag: "", - wantErr: false, - wantBody: "POST /2 HTTP/1.1\r\nHost: example.com\r\nA: b\r\nC: d\r\nContent-Type: application/json\r\n\r\nclassclass", - }, - { - wantTag: "wantTag", - wantErr: false, - wantBody: "POST /3 HTTP/1.1\r\nHost: other.net\r\nA: \r\nC: d\r\nContent-Type: application/json\r\n\r\nclassclassclass", - }, + wants := []*base.Ammo{ + mustNewAmmo(t, "POST", "/0", []byte("class"), http.Header{"Content-Type": []string{"application/json"}}, ""), + mustNewAmmo(t, "POST", "/1", []byte("class"), http.Header{"A": []string{"b"}, "Content-Type": []string{"application/json"}}, ""), + mustNewAmmo(t, "POST", "/2", []byte("classclass"), http.Header{"Host": []string{"example.com"}, "A": []string{"b"}, "C": []string{"d"}, "Content-Type": []string{"application/json"}}, ""), + mustNewAmmo(t, "POST", "/3", []byte("classclassclass"), http.Header{"Host": []string{"other.net"}, "A": []string{""}, "C": []string{"d"}, "Content-Type": []string{"application/json"}}, "wantTag"), } for j := 0; j < 2; j++ { - for i, tt := range tests { - req, tag, err := decoder.Scan(ctx) - if tt.wantErr { - assert.Error(t, err, "iteration %d-%d", j, i) - continue - } else { - assert.NoError(t, err, "iteration %d-%d", j, i) - } - assert.Equal(t, tt.wantTag, tag, "iteration %d-%d", j, i) - - req.Close = false - body, _ := httputil.DumpRequest(req, true) - assert.Equal(t, tt.wantBody, string(body), "iteration %d-%d", j, i) + for i, want := range wants { + ammo, err := decoder.Scan(ctx) + assert.NoError(t, err, "iteration %d-%d", j, i) + assert.Equal(t, want, ammo, "iteration %d-%d", j, i) } } - _, _, err := decoder.Scan(ctx) + _, err := decoder.Scan(ctx) assert.Equal(t, err, ErrAmmoLimit) - assert.Equal(t, decoder.ammoNum, uint(len(tests)*2)) + assert.Equal(t, decoder.ammoNum, uint(len(wants)*2)) assert.Equal(t, decoder.passNum, uint(1)) } diff --git a/components/providers/http/provider.go b/components/providers/http/provider.go index b910a33d7..d88c0bff6 100644 --- a/components/providers/http/provider.go +++ b/components/providers/http/provider.go @@ -7,7 +7,6 @@ package http import ( "bytes" "io" - "net/http" "strings" "sync" @@ -48,8 +47,8 @@ func NewProvider(fs afero.Fs, conf config.Config) (core.Provider, error) { Config: conf, Decoder: decoder, Close: closer.Close, - AmmoPool: sync.Pool{New: func() interface{} { return new(base.Ammo[http.Request]) }}, - Sink: make(chan *base.Ammo[http.Request]), + AmmoPool: sync.Pool{New: func() interface{} { return new(base.Ammo) }}, + Sink: make(chan *base.Ammo), }, nil } diff --git a/components/providers/http/provider/provider.go b/components/providers/http/provider/provider.go index dfd1671d4..47fd6e629 100644 --- a/components/providers/http/provider/provider.go +++ b/components/providers/http/provider/provider.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net/http" "sync" "github.com/yandex/pandora/components/providers/base" @@ -24,7 +23,7 @@ type Provider struct { Close func() error AmmoPool sync.Pool - Sink chan *base.Ammo[http.Request] + Sink chan *base.Ammo } func (p *Provider) Acquire() (core.Ammo, bool) { @@ -32,10 +31,14 @@ func (p *Provider) Acquire() (core.Ammo, bool) { if ok { ammo.SetID(p.NextID()) } + if err := ammo.BuildRequest(); err != nil { + p.Deps.Log.Error("http build request error", zap.Error(err)) + return ammo, false + } for _, mw := range p.Middlewares { err := mw.UpdateRequest(ammo.Req) if err != nil { - p.Log.Error("error on Middleware.UpdateRequest", zap.Error(err)) + p.Deps.Log.Error("error on Middleware.UpdateRequest", zap.Error(err)) return ammo, false } } @@ -43,18 +46,15 @@ func (p *Provider) Acquire() (core.Ammo, bool) { } func (p *Provider) Release(a core.Ammo) { - ammo := a.(*base.Ammo[http.Request]) - // TODO: add request release for example for future fasthttp - // ammo.Req.Body = nil - ammo.Req = nil + ammo := a.(*base.Ammo) + ammo.Reset() p.AmmoPool.Put(ammo) } func (p *Provider) Run(ctx context.Context, deps core.ProviderDeps) (err error) { - var req *http.Request - var tag string + var ammo *base.Ammo - p.ProviderDeps = deps + p.Deps = deps defer func() { // TODO: wrap in go 1.20 // err = errors.Join(err, p.Close()) @@ -84,8 +84,8 @@ func (p *Provider) Run(ctx context.Context, deps core.ProviderDeps) (err error) } return } - req, tag, err = p.Decoder.Scan(ctx) - if !confutil.IsChosenCase(tag, p.Config.ChosenCases) { + ammo, err = p.Decoder.Scan(ctx) + if !confutil.IsChosenCase(ammo.Tag(), p.Config.ChosenCases) { continue } if err != nil { @@ -95,8 +95,6 @@ func (p *Provider) Run(ctx context.Context, deps core.ProviderDeps) (err error) return } - a := p.AmmoPool.Get().(*base.Ammo[http.Request]) - a.Reset(req, tag) select { case <-ctx.Done(): err = ctx.Err() @@ -104,7 +102,7 @@ func (p *Provider) Run(ctx context.Context, deps core.ProviderDeps) (err error) err = xerrors.Errorf("error from context: %w", err) } return - case p.Sink <- a: + case p.Sink <- ammo: } } } diff --git a/lib/netutil/validator.go b/lib/netutil/validator.go new file mode 100644 index 000000000..094fce25a --- /dev/null +++ b/lib/netutil/validator.go @@ -0,0 +1,35 @@ +package netutil + +import ( + "strings" + + "golang.org/x/net/http/httpguts" +) + +func isNotToken(r rune) bool { + return !httpguts.IsTokenRune(r) +} + +// ValidHTTPMethod just copy net/http/request.go validMethod(method string) bool +func ValidHTTPMethod(method string) bool { + if method == "" { + // We document that "" means "GET" for Request.Method, and people have + // relied on that from NewRequest, so keep that working. + // We still enforce validMethod for non-empty methods. + method = "GET" + } + /* + Method = "OPTIONS" ; Section 9.2 + | "GET" ; Section 9.3 + | "HEAD" ; Section 9.4 + | "POST" ; Section 9.5 + | "PUT" ; Section 9.6 + | "DELETE" ; Section 9.7 + | "TRACE" ; Section 9.8 + | "CONNECT" ; Section 9.9 + | extension-method + extension-method = token + token = 1* + */ + return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 +} diff --git a/lib/testutil/testing.go b/lib/testutil/testing.go deleted file mode 100644 index 0abce8d65..000000000 --- a/lib/testutil/testing.go +++ /dev/null @@ -1,8 +0,0 @@ -package testutil - -func Must[T any](obj T, err error) T { - if err != nil { - panic(err) - } - return obj -}