From 310da48fc3e81bf89feab59681a58cb059817639 Mon Sep 17 00:00:00 2001 From: luopengift <870148195@qq.com> Date: Mon, 25 Mar 2024 20:57:45 +0800 Subject: [PATCH] update --- .gitignore | 3 ++- go.sum | 0 options.go | 9 +++++++++ requests_test.go | 27 +++++++++++++++++++++++++++ response.go | 24 +++++++++++++++--------- session.go | 8 ++------ stat.go | 25 +++++++++++++++---------- transport.go | 23 +++++++++++++++-------- transport_test.go | 13 ++++++++----- uid.go | 4 ++-- util.go | 18 +++++++++--------- 11 files changed, 104 insertions(+), 50 deletions(-) create mode 100644 go.sum diff --git a/.gitignore b/.gitignore index 4e81372..cad7f12 100644 --- a/.gitignore +++ b/.gitignore @@ -20,4 +20,5 @@ # Go workspace file go.work -vendor/ \ No newline at end of file +vendor/ +.idea/ \ No newline at end of file diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/options.go b/options.go index b7bac60..26b3432 100644 --- a/options.go +++ b/options.go @@ -260,6 +260,7 @@ func Proxy(addr string) Option { } } +// Setup use middleware func Setup(httpFn ...func(HttpRoundTripFunc) HttpRoundTripFunc) Option { return func(o *Options) { for _, fn := range httpFn { @@ -268,6 +269,14 @@ func Setup(httpFn ...func(HttpRoundTripFunc) HttpRoundTripFunc) Option { } } +// RoundTripFunc set default `*http.Transport` by customer define. +func RoundTripFunc(fn HttpRoundTripFunc) Option { + return func(o *Options) { + o.Transport = fn + } +} + +// Logf print log func Logf(f func(ctx context.Context, stat *Stat)) Option { return func(o *Options) { o.RoundTripFunc = append(o.RoundTripFunc, fprintf(f)) diff --git a/requests_test.go b/requests_test.go index 597c77d..ebfd83c 100644 --- a/requests_test.go +++ b/requests_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" "sync/atomic" "testing" "time" @@ -186,3 +187,29 @@ func Test_ForEach(t *testing.T) { t.Logf("%v, %v", resp.Stat(), err) } + +func TestResponse_Download(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + text := "abc\ndef\nghij\n\n123" + fmt.Fprint(w, text) + })) + defer s.Close() + + u := "https://go.dev/dl/go1.22.1.darwin-amd64.tar.gz" // a35015fca6f631f3501a36b3bccba9c5 + sess := New(URL(u)) + f, err := os.OpenFile("tmp/xx.tar.gz", os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0644) + defer f.Close() + sum := 0 + resp, err := sess.DoRequest(context.Background(), + TraceLv(3), + Stream(func(i int64, row []byte) error { + cnt, err := f.Write(row) + sum += cnt + return err + })) + if err != nil { + t.Logf("resp=%d, err=%s", resp.Content, err) + return + } + t.Logf("resp=%d, err=%s", resp.Content, err) +} diff --git a/response.go b/response.go index 4ad789f..6d10063 100644 --- a/response.go +++ b/response.go @@ -58,20 +58,26 @@ func (resp *Response) Stat() *Stat { return StatLoad(resp) } -// streamRead +// streamRead xx func streamRead(reader io.Reader, f func(int64, []byte) error) (int64, error) { i, cnt, r := int64(0), int64(0), bufio.NewReaderSize(reader, 1024*1024) for { - raw, err := r.ReadBytes(10) // ascii('\n') = 10 - if err != nil { - if err != io.EOF { - return cnt, err - } - return cnt, nil + raw, err1 := r.ReadBytes(10) // ascii('\n') = 10 + if err1 != nil && err1 != io.EOF { + return cnt, err1 } + // 保证最后一行能被处理,并且可以正常返回 i, cnt = i+1, cnt+int64(len(raw)) - if err = f(i, raw); err != nil { - return cnt, err + if err2 := f(i, raw); err1 == io.EOF || err2 != nil { + return cnt, err2 } } } + +// CopyResponseBody xx +func CopyResponseBody(resp *http.Response) (b *bytes.Buffer, err error) { + if b, resp.Body, err = copyBody(resp.Body); err != nil { + return nil, err + } + return b, err +} diff --git a/session.go b/session.go index cc648e5..de89923 100644 --- a/session.go +++ b/session.go @@ -123,16 +123,12 @@ func (s *Session) DoRequest(ctx context.Context, opts ...Option) (*Response, err defer resp.Response.Body.Close() - var err error if options.Stream != nil { - _, err = streamRead(resp.Response.Body, options.Stream) + _, resp.Err = streamRead(resp.Response.Body, options.Stream) resp.Content = bytes.NewBufferString("[consumed]") } else { - _, err = resp.Content.ReadFrom(resp.Response.Body) + _, resp.Err = resp.Content.ReadFrom(resp.Response.Body) resp.Response.Body = io.NopCloser(bytes.NewReader(resp.Content.Bytes())) } - if err != nil { - resp.Err = fmt.Errorf("err1=%w, err2=%w", resp.Err, err) - } return resp, resp.Err } diff --git a/stat.go b/stat.go index 301f079..c4e94a3 100644 --- a/stat.go +++ b/stat.go @@ -1,9 +1,8 @@ package requests import ( - "bytes" "encoding/json" - "io" + "fmt" ) const RequestId = "Request-Id" @@ -44,11 +43,10 @@ func StatLoad(resp *Response) *Stat { } if resp.Response != nil { var err error - if resp.Content == nil || resp.Response.ContentLength != 0 { - resp.Content = &bytes.Buffer{} - if resp.Response.ContentLength, err = resp.Content.ReadFrom(resp.Response.Body); err == nil { - defer resp.Response.Body.Close() - resp.Response.Body = io.NopCloser(bytes.NewReader(resp.Content.Bytes())) + if resp.Content == nil || resp.Content.Len() == 0 { + if resp.Content, err = CopyResponseBody(resp.Response); err != nil { + stat.Err += fmt.Sprintf("read response: %s", err) + return stat } } stat.Response.Body = make(map[string]any) @@ -68,10 +66,17 @@ func StatLoad(resp *Response) *Stat { stat.Request.Method = resp.Request.Method stat.Request.URL = resp.Request.URL.String() if resp.Request.GetBody != nil { - body, _ := resp.Request.GetBody() + body, err := resp.Request.GetBody() + if err != nil { + stat.Err += fmt.Sprintf("read request1: %s", err) + return stat + } - var buf bytes.Buffer - _, _ = buf.ReadFrom(body) + buf, _, err := copyBody(body) + if err != nil { + stat.Err += fmt.Sprintf("read request2: %s", err) + return stat + } m := make(map[string]any) diff --git a/transport.go b/transport.go index bb4c47b..2f4328e 100644 --- a/transport.go +++ b/transport.go @@ -58,16 +58,23 @@ func verbose(v int, mLimit ...int) func(fn HttpRoundTripFunc) HttpRoundTripFunc return nil, err } - respLog, err := httputil.DumpResponse(resp, v > 3) - if err != nil { - return nil, err + if v >= 3 { + // 答应响应头和响应体长度 + Log("< %s %s", resp.Proto, resp.Status) + for k, vs := range resp.Header { + for _, v := range vs { + Log("< %s: %s", k, v) + } + } } - if v > 3 { - Log(show("< ", respLog, maxLimit)) - } else { - Log("* resp.body is skipped") + if v >= 4 { + buf, err := CopyResponseBody(resp) + if err != nil { + Log("! response error: %w", err) + return nil, err + } + Log(show("*", buf.Bytes(), maxLimit)) } - Log("* ") return resp, nil } } diff --git a/transport_test.go b/transport_test.go index f415dc6..df8b5c5 100644 --- a/transport_test.go +++ b/transport_test.go @@ -22,18 +22,20 @@ func Test_Middleware(t *testing.T) { t.Logf("session.ResponseEach end") return nil }), - requests.Setup(func(tripFunc requests.HttpRoundTripFunc) requests.HttpRoundTripFunc { + requests.Setup(func(fn requests.HttpRoundTripFunc) requests.HttpRoundTripFunc { return func(req *http.Request) (*http.Response, error) { t.Logf("session.Setup start") defer t.Logf("session.Setup defer end") + resp, err := fn(req) t.Logf("session.Setup end") - return tripFunc(req) + return resp, err } }), ) resp, err := sess.DoRequest( - context.Background(), requests.URL(ss.URL), requests.Body(`{"Hello":"World"}`), requests.Logf(requests.LogS), requests.TraceLv(3), + context.Background(), requests.URL(ss.URL), requests.Body(`{"Hello":"World"}`), + //requests.Logf(requests.LogS), requests.TraceLv(4), requests.RequestEach(func(ctx context.Context, r *http.Request) error { t.Logf("doRequest.RequestEach start") defer t.Logf("doRequest.RequestEach defer end") @@ -46,12 +48,13 @@ func Test_Middleware(t *testing.T) { t.Logf("doRequest.ResponseEach end") return nil }), - requests.Setup(func(tripFunc requests.HttpRoundTripFunc) requests.HttpRoundTripFunc { + requests.Setup(func(fn requests.HttpRoundTripFunc) requests.HttpRoundTripFunc { return func(req *http.Request) (*http.Response, error) { t.Logf("doRequest.Setup start") defer t.Logf("doRequest.Setup defer end") + resp, err := fn(req) t.Logf("doRequest.Setup end") - return tripFunc(req) + return resp, err } }), ) diff --git a/uid.go b/uid.go index dfa6d26..fd3bd56 100644 --- a/uid.go +++ b/uid.go @@ -14,6 +14,6 @@ func GenId(id ...string) string { if len(id) != 0 && id[0] != "" { return id[0] } - s := uint64(time.Now().UnixMicro()*1000 + source.Int63n(1000)) // % 4738381338321616895 - return strings.ToUpper(strconv.FormatUint(s, 36)) + i := time.Now().UnixMicro()*1000 + source.Int63n(1000) // % 4738381338321616895 + return strings.ToUpper(strconv.FormatUint(uint64(i), 36)) } diff --git a/util.go b/util.go index cc60be2..63ea0a3 100644 --- a/util.go +++ b/util.go @@ -23,24 +23,24 @@ func show(prompt string, b []byte, mLimit int) string { return str } -// drainBody reads all of b to memory and then returns two equivalent +// copyBody reads all of b to memory and then returns two equivalent // ReadClosers yielding the same bytes. // // It returns an error if the initial slurp of all bytes fails. It does not attempt // to make the returned ReadClosers have identical error-matching behavior. -func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { +func copyBody(b io.ReadCloser) (*bytes.Buffer, io.ReadCloser, error) { + var buf bytes.Buffer if b == nil || b == http.NoBody { // No copying needed. Preserve the magic sentinel meaning of NoBody. - return http.NoBody, http.NoBody, nil + return &buf, http.NoBody, nil } - var buf bytes.Buffer - if _, err = buf.ReadFrom(b); err != nil { - return nil, b, err + if _, err := buf.ReadFrom(b); err != nil { + return &buf, b, err } - if err = b.Close(); err != nil { - return nil, b, err + if err := b.Close(); err != nil { + return &buf, b, err } - return io.NopCloser(&buf), io.NopCloser(bytes.NewReader(buf.Bytes())), nil + return &buf, io.NopCloser(bytes.NewReader(buf.Bytes())), nil } // LogS supply default handle Stat, print to stdout.