From b1f764d0052895b819254833750f186a16bc5b0b Mon Sep 17 00:00:00 2001 From: luopengift <870148195@qq.com> Date: Fri, 26 Apr 2024 02:49:33 +0800 Subject: [PATCH] update --- api.go | 2 +- api_test.go | 7 ------- options.go | 31 +++++++++++++++---------------- request.go | 10 +--------- requests_test.go | 24 ++++++++++++++---------- response.go | 22 ++-------------------- response_writer.go | 1 + server_test.go | 14 +++++--------- session_test.go | 12 +++++++++--- stat.go | 10 +++++----- transport.go | 2 +- util.go | 18 ------------------ 12 files changed, 54 insertions(+), 99 deletions(-) diff --git a/api.go b/api.go index a57ea9a..aa9182d 100644 --- a/api.go +++ b/api.go @@ -37,7 +37,7 @@ func Head(url string) (resp *http.Response, err error) { // PostForm send post request, content-type = application/x-www-form-urlencoded func PostForm(url string, data url.Values) (*http.Response, error) { - return s.Do(context.Background(), MethodPost, URL(url), Header("Content-Type", "application/x-www-form-urlencoded"), + return s.Do(context.TODO(), MethodPost, URL(url), Header("Content-Type", "application/x-www-form-urlencoded"), Body(strings.NewReader(data.Encode())), ) } diff --git a/api_test.go b/api_test.go index c98b6fa..c63e22e 100644 --- a/api_test.go +++ b/api_test.go @@ -1,7 +1,6 @@ package requests_test import ( - "github.com/golang-io/requests" "io" "net/http" "net/http/httptest" @@ -18,9 +17,3 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } - -func TestGet(t *testing.T) { - resp, err := requests.Get(ss.URL) - stat := requests.StatLoad(&requests.Response{Response: resp, Err: err}) - t.Logf("%s", stat) -} diff --git a/options.go b/options.go index 4103294..b8097bd 100644 --- a/options.go +++ b/options.go @@ -17,24 +17,21 @@ type Options struct { Method string URL string Path []string - Params map[string]any + RawQuery url.Values body any Header http.Header Cookies []http.Cookie Timeout time.Duration MaxConns int - TraceLv int - mLimit int Verify bool Stream func(int64, []byte) error Transport http.RoundTripper HttpRoundTripper []func(http.RoundTripper) http.RoundTripper + HttpHandler []func(http.Handler) http.Handler - // HttpHandler is only used by server mode - HttpHandler []func(http.Handler) http.Handler - certFile string - keyFile string + certFile string + keyFile string // client session used LocalAddr net.Addr @@ -48,12 +45,11 @@ type Option func(*Options) func newOptions(opts []Option, extends ...Option) Options { opt := Options{ Method: "GET", - Params: make(map[string]any), + RawQuery: make(url.Values), Header: make(http.Header), Timeout: 30 * time.Second, MaxConns: 100, Proxy: http.ProxyFromEnvironment, - mLimit: 1024, } for _, o := range opts { o(&opt) @@ -70,7 +66,8 @@ var ( MethodPost = Method("POST") ) -func CertAndKey(cert, key string) Option { +// CertKey is cert and key file. +func CertKey(cert, key string) Option { return func(o *Options) { o.certFile, o.keyFile = cert, key } @@ -109,18 +106,20 @@ func Path(path string) Option { } // Params add query args -func Params(query map[string]any) Option { +func Params(query map[string]string) Option { return func(o *Options) { for k, v := range query { - o.Params[k] = v + o.RawQuery.Add(k, v) } } } // Param params -func Param(k string, v any) Option { +func Param(k string, v ...string) Option { return func(o *Options) { - o.Params[k] = v + for _, x := range v { + o.RawQuery.Add(k, x) + } } } @@ -192,8 +191,8 @@ func Cookies(cookies ...http.Cookie) Option { } // BasicAuth base auth -func BasicAuth(user, pass string) Option { - return Header("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(user+":"+pass))) +func BasicAuth(username, password string) Option { + return Header("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(username+":"+password))) } diff --git a/request.go b/request.go index 35237d0..7ae76f3 100644 --- a/request.go +++ b/request.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "fmt" "io" "net/http" "net/url" @@ -53,18 +52,11 @@ func NewRequestWithContext(ctx context.Context, options Options) (*http.Request, r.URL.Path += p } - for k, v := range options.Params { - if r.URL.RawQuery != "" { - r.URL.RawQuery += "&" - } - r.URL.RawQuery += k + "=" + url.QueryEscape(fmt.Sprintf("%v", v)) - } + r.URL.RawQuery = options.RawQuery.Encode() r.Header = options.Header - for _, cookie := range options.Cookies { r.AddCookie(&cookie) } - return r, nil } diff --git a/requests_test.go b/requests_test.go index eee59cb..09cd738 100644 --- a/requests_test.go +++ b/requests_test.go @@ -60,17 +60,19 @@ func Test_PostBody(t *testing.T) { resp, err := sess.DoRequest(context.Background(), Method("POST"), URL("http://httpbin.org/post"), - Params(map[string]any{ + Params(map[string]string{ "a": "b/c", - "c": 3, - "d": []int{1, 2, 3}, + "c": "3", + "d": "ddd", }), + Param("e", "ea", "es"), + Body(`{"body":"QWER"}`), Header("hello", "world"), //TraceLv(9), - //Logf(func(ctx context.Context, stat Stat) { - // fmt.Println(stat) - //}), + Logf(func(ctx context.Context, stat *Stat) { + t.Logf("%v", stat.String()) + }), ) if err != nil { t.Logf("%v", err) @@ -89,11 +91,13 @@ func Test_FormPost(t *testing.T) { Method("POST"), URL("http://httpbin.org/post"), Form(url.Values{"name": {"12.com"}}), - Params(map[string]any{ + Params(map[string]string{ "a": "b/c", - "c": 3, - "d": []int{1, 2, 3}, + "c": "cc", + "d": "dddd", }), + Param("e", "ea", "es"), + //TraceLv(9), ) if err != nil { @@ -111,7 +115,7 @@ func Test_Race(t *testing.T) { sess := New(URL("http://httpbin.org/post")) //, Auth("user", "123456")) for i := 0; i < 10; i++ { go func() { - _, _ = sess.DoRequest(ctx, MethodPost, Body(`{"a":"b"}`), Params(map[string]any{"1": "2/2"})) // nolint: errcheck + _, _ = sess.DoRequest(ctx, MethodPost, Body(`{"a":"b"}`), Params(map[string]string{"1": "2/2"})) // nolint: errcheck }() } time.Sleep(3 * time.Second) diff --git a/response.go b/response.go index 085d07c..5b5c571 100644 --- a/response.go +++ b/response.go @@ -5,14 +5,13 @@ import ( "bytes" "io" "net/http" - "os" "time" ) // Response wrap std response type Response struct { - *http.Response *http.Request + *http.Response StartAt time.Time Cost time.Duration Content *bytes.Buffer @@ -36,26 +35,9 @@ func (resp *Response) Error() string { return resp.Err.Error() } -// Text parse to string -func (resp *Response) Text() string { - return resp.Content.String() -} - -// Download parse response to a file -func (resp *Response) Download(name string) (int64, error) { - f, err := os.OpenFile(name, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) - if err != nil { - return 0, err - } - defer func(f *os.File) { - _ = f.Close() - }(f) - return io.Copy(f, resp.Content) -} - // Stat stat func (resp *Response) Stat() *Stat { - return StatLoad(resp) + return responseLoad(resp) } // streamRead xx diff --git a/response_writer.go b/response_writer.go index cc53849..f0feccd 100644 --- a/response_writer.go +++ b/response_writer.go @@ -9,6 +9,7 @@ import ( // ResponseWriter wrap `http.ResponseWriter` interface. type ResponseWriter struct { http.ResponseWriter + wroteHeader bool StatusCode int ContentLength int64 diff --git a/server_test.go b/server_test.go index 2b524c8..47d9b0e 100644 --- a/server_test.go +++ b/server_test.go @@ -10,13 +10,9 @@ import ( "time" ) -var STEP2 = func(next http.Handler) http.Handler { - fmt.Println("STEP2 init") - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Println("STEP2 start") - next.ServeHTTP(w, r) - fmt.Println("STEP2 end") - }) +// LogS supply default handle Stat, print to stdout. +func LogS(_ context.Context, stat *requests.Stat) { + _, _ = fmt.Printf("%s\n", stat) } func Test_Use(t *testing.T) { @@ -63,8 +59,8 @@ func Test_Use(t *testing.T) { }() time.Sleep(1 * time.Second) sess := requests.New(requests.URL("http://127.0.0.1:9099")) - _, _ = sess.DoRequest(context.Background(), requests.Path("/echo"), requests.Body("12345"), requests.Logf(requests.LogS), requests.Method("OPTIONS")) - _, _ = sess.DoRequest(context.Background(), requests.Path("/echo"), requests.Body("12345"), requests.Logf(requests.LogS), requests.Method("GET")) + _, _ = sess.DoRequest(context.Background(), requests.Path("/echo"), requests.Body("12345"), requests.Logf(LogS), requests.Method("OPTIONS")) + _, _ = sess.DoRequest(context.Background(), requests.Path("/echo"), requests.Body("12345"), requests.Logf(LogS), requests.Method("GET")) cancel() time.Sleep(3 * time.Second) //sess.DoRequest(context.Background(), Path("/ping"), Logf(LogS)) diff --git a/session_test.go b/session_test.go index 4fbc5d5..27255e8 100644 --- a/session_test.go +++ b/session_test.go @@ -1,8 +1,9 @@ -package requests +package requests_test import ( "context" "fmt" + "github.com/golang-io/requests" "io" "net" "net/http" @@ -36,7 +37,12 @@ func TestSession_Do(t *testing.T) { s.Serve(l) }() - sess := New(URL(sock)) - sess.DoRequest(context.Background(), URL("http://path?k=v"), Body("12345"), MethodPost, Logf(LogS)) + sess := requests.New(requests.URL(sock)) + sess.DoRequest(context.Background(), + requests.URL("http://path?k=v"), + requests.Body("12345"), requests.MethodPost, + requests.Logf(func(ctx context.Context, stat *requests.Stat) { + _, _ = fmt.Printf("%s\n", stat) + })) } diff --git a/stat.go b/stat.go index c5233f1..9cae969 100644 --- a/stat.go +++ b/stat.go @@ -15,9 +15,9 @@ type Stat struct { Cost int64 `json:"Cost"` Request struct { - // Remote is remote addr in server side, + // RemoteAddr is remote addr in server side, // For client requests, it is unused. - Remote string `json:"Remote"` + RemoteAddr string `json:"RemoteAddr"` // URL is Request.URL // For client requests, is request addr. contains schema://ip:port/path/xx @@ -46,8 +46,8 @@ func (stat *Stat) String() string { return string(b) } -// StatLoad stat. -func StatLoad(resp *Response) *Stat { +// statLoad stat. +func responseLoad(resp *Response) *Stat { stat := &Stat{ StartAt: resp.StartAt.Format(dateTime), Cost: resp.Cost.Milliseconds(), @@ -55,7 +55,7 @@ func StatLoad(resp *Response) *Stat { if resp.Response != nil { var err error if resp.Content == nil || resp.Content.Len() == 0 { - if resp.Content, err = ParseBody(resp.Response.Body); err != nil { + if resp.Content, resp.Response.Body, err = CopyBody(resp.Response.Body); err != nil { stat.Err += fmt.Sprintf("read response: %s", err) return stat } diff --git a/transport.go b/transport.go index 2ef6297..4fd3c9b 100644 --- a/transport.go +++ b/transport.go @@ -29,7 +29,7 @@ func fprintf(f func(ctx context.Context, stat *Stat)) func(http.RoundTripper) ht return RoundTripperFunc(func(r *http.Request) (*http.Response, error) { resp.Request = r defer func() { - f(r.Context(), StatLoad(resp)) + f(r.Context(), resp.Stat()) }() resp.Response, resp.Err = next.RoundTrip(r) return resp.Response, resp.Err diff --git a/util.go b/util.go index bf3bd0c..703a307 100644 --- a/util.go +++ b/util.go @@ -2,8 +2,6 @@ package requests import ( "bytes" - "context" - "fmt" "io" "net/http" ) @@ -43,19 +41,3 @@ func CopyBody(b io.ReadCloser) (*bytes.Buffer, io.ReadCloser, error) { } return &buf, io.NopCloser(bytes.NewReader(buf.Bytes())), nil } - -// Log print -func Log(format string, v ...any) { - _, _ = fmt.Printf(format+"\n", v...) -} - -// LogS supply default handle Stat, print to stdout. -func LogS(_ context.Context, stat *Stat) { - Log("%s\n", stat) -} - -// StreamS supply default handle Stream, print raw msg in stream to stdout. -func StreamS(i int64, raw []byte) error { - _, err := fmt.Printf("i=%d, raw=%s\n", i, raw) - return err -}