diff --git a/v3/client.go b/v3/client.go index d960815..bd27980 100644 --- a/v3/client.go +++ b/v3/client.go @@ -3,6 +3,7 @@ package grab import ( "bytes" "context" + "errors" "fmt" "io" "net/http" @@ -81,7 +82,7 @@ func (c *Client) Do(req *Request) *Response { resp := &Response{ Request: req, Start: time.Now(), - Done: make(chan struct{}, 0), + Done: make(chan struct{}), Filename: req.Filename, ctx: ctx, cancel: cancel, @@ -208,7 +209,7 @@ func (c *Client) statFileInfo(resp *Response) stateFunc { } fi, err := os.Stat(resp.Filename) if err != nil { - if os.IsNotExist(err) { + if errors.Is(err, os.ErrNotExist) { return c.headRequest } resp.err = err @@ -345,7 +346,7 @@ func (c *Client) headRequest(resp *Response) stateFunc { resp.HTTPResponse, resp.err = c.doHTTPRequest(hreq) if resp.err != nil { - return c.closeResponse + return c.getRequest } resp.HTTPResponse.Body.Close() @@ -447,7 +448,7 @@ func (c *Client) openWriter(resp *Response) stateFunc { } // open file - f, err := os.OpenFile(resp.Filename, flag, 0666) + f, err := os.OpenFile(resp.Filename, flag, 0o666) if err != nil { resp.err = err return c.closeResponse @@ -455,9 +456,9 @@ func (c *Client) openWriter(resp *Response) stateFunc { resp.writer = f // seek to start or end - whence := os.SEEK_SET + whence := io.SeekStart if resp.bytesResumed > 0 { - whence = os.SEEK_END + whence = io.SeekEnd } _, resp.err = f.Seek(0, whence) if resp.err != nil { @@ -504,7 +505,7 @@ func (c *Client) copyFile(resp *Response) stateFunc { // the BeforeCopy didn't cancel the copy. If this was an existing // file that is not going to be resumed, truncate the contents. if t, ok := resp.writer.(truncater); ok && resp.fi != nil && !resp.DidResume { - t.Truncate(0) + _ = t.Truncate(0) } bytesCopied, resp.err = resp.transfer.copy() @@ -557,7 +558,7 @@ func (c *Client) closeResponse(resp *Response) stateFunc { resp.fi = nil closeWriter(resp) - resp.closeResponseBody() + _ = resp.closeResponseBody() resp.End = time.Now() close(resp.Done) diff --git a/v3/client_test.go b/v3/client_test.go index cbd0a81..19f39f9 100644 --- a/v3/client_test.go +++ b/v3/client_test.go @@ -1,5 +1,6 @@ package grab +//nolint:gosec import ( "bytes" "context" @@ -10,7 +11,6 @@ import ( "errors" "fmt" "hash" - "io/ioutil" "math/rand" "net/http" "os" @@ -42,7 +42,7 @@ func TestFilenameResolution(t *testing.T) { {"Failure", "", "", "", ""}, } - err := os.Mkdir(".test", 0777) + err := os.Mkdir(".test", 0o777) if err != nil { panic(err) } @@ -78,6 +78,8 @@ func TestFilenameResolution(t *testing.T) { // TestChecksums checks that checksum validation behaves as expected for valid // and corrupted downloads. +// +//nolint:gosec func TestChecksums(t *testing.T) { tests := []struct { size int @@ -205,7 +207,7 @@ func TestContentLength(t *testing.T) { func TestAutoResume(t *testing.T) { segs := 8 size := 1048576 - sum := grabtest.DefaultHandlerSHA256ChecksumBytes //grab/v3test.MustHexDecodeString("fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83") + sum := grabtest.DefaultHandlerSHA256ChecksumBytes // grab/v3test.MustHexDecodeString("fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83") filename := ".testAutoResume" defer os.Remove(filename) @@ -317,6 +319,17 @@ func TestAutoResume(t *testing.T) { grabtest.HeaderBlacklist("Content-Length"), ) }) + + t.Run("WithHeadRequestBreak", func(t *testing.T) { + grabtest.WithTestServer(t, func(url string) { + req := mustNewRequest(filename, url) + resp := DefaultClient.Do(req) + testComplete(t, resp) + }, + grabtest.WithBreakHeadRequest(), + ) + }) + // TODO: test when existing file is corrupted } @@ -382,22 +395,21 @@ func TestBatch(t *testing.T) { // listen for responses Loop: for i := 0; i < len(reqs); { - select { - case resp := <-responses: - if resp == nil { - break Loop - } - testComplete(t, resp) - if err := resp.Err(); err != nil { - t.Errorf("%s: %v", resp.Filename, err) - } + resp := <-responses + if resp == nil { + break Loop + } + testComplete(t, resp) + if err := resp.Err(); err != nil { + t.Errorf("%s: %v", resp.Filename, err) + } - // remove test file - if resp.IsComplete() { - os.Remove(resp.Filename) // ignore errors - } - i++ + // remove test file + if resp.IsComplete() { + os.Remove(resp.Filename) // ignore errors } + i++ + } } }, @@ -426,7 +438,7 @@ func TestCancelContext(t *testing.T) { time.Sleep(time.Millisecond * 500) cancel() for resp := range respch { - defer os.Remove(resp.Filename) + defer os.Remove(resp.Filename) //nolint:staticcheck // err should be context.Canceled or http.errRequestCanceled if resp.Err() == nil || !strings.Contains(resp.Err().Error(), "canceled") { @@ -516,7 +528,7 @@ func TestRemoteTime(t *testing.T) { defer os.Remove(filename) // random time between epoch and now - expect := time.Unix(rand.Int63n(time.Now().Unix()), 0) + expect := time.Unix(rand.Int63n(time.Now().Unix()), 0) //nolint:gosec grabtest.WithTestServer(t, func(url string) { resp := mustDo(mustNewRequest(filename, url)) fi, err := os.Stat(resp.Filename) @@ -625,7 +637,7 @@ func TestBeforeCopyHook(t *testing.T) { // Assert that an existing local file will not be truncated prior to the // BeforeCopy hook has a chance to cancel the request t.Run("NoTruncate", func(t *testing.T) { - tfile, err := ioutil.TempFile("", "grab_client_test.*.file") + tfile, err := os.CreateTemp("", "grab_client_test.*.file") if err != nil { t.Fatal(err) } @@ -808,7 +820,7 @@ func TestMissingContentLength(t *testing.T) { grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(".testMissingContentLength", url) req.SetChecksum( - md5.New(), + md5.New(), //nolint:gosec grabtest.DefaultHandlerMD5ChecksumBytes, false) resp := DefaultClient.Do(req) @@ -844,7 +856,7 @@ func TestNoStore(t *testing.T) { grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(filename, url) req.NoStore = true - req.SetChecksum(md5.New(), grabtest.DefaultHandlerMD5ChecksumBytes, true) + req.SetChecksum(md5.New(), grabtest.DefaultHandlerMD5ChecksumBytes, true) //nolint:gosec resp := mustDo(req) // ensure Response.Bytes is correct and can be reread @@ -902,7 +914,7 @@ func TestNoStore(t *testing.T) { req := mustNewRequest("", url) req.NoStore = true req.SetChecksum( - md5.New(), + md5.New(), //nolint:gosec grabtest.MustHexDecodeString("deadbeefcafebabe"), true) resp := DefaultClient.Do(req) diff --git a/v3/go.mod b/v3/go.mod index be7f202..70b4473 100644 --- a/v3/go.mod +++ b/v3/go.mod @@ -1,3 +1,3 @@ module github.com/cavaliergopher/grab/v3 -go 1.14 +go 1.19 diff --git a/v3/grab_test.go b/v3/grab_test.go index 6209f7c..abafe39 100644 --- a/v3/grab_test.go +++ b/v3/grab_test.go @@ -2,7 +2,6 @@ package grab import ( "fmt" - "io/ioutil" "log" "os" "testing" @@ -17,7 +16,7 @@ func TestMain(m *testing.M) { if err != nil { panic(err) } - tmpDir, err := ioutil.TempDir("", "grab-") + tmpDir, err := os.MkdirTemp("", "grab-") if err != nil { panic(err) } @@ -25,7 +24,7 @@ func TestMain(m *testing.M) { panic(err) } defer func() { - os.Chdir(cwd) + _ = os.Chdir(cwd) if err := os.RemoveAll(tmpDir); err != nil { panic(err) } diff --git a/v3/pkg/grabtest/assert.go b/v3/pkg/grabtest/assert.go index 10f4c99..6463942 100644 --- a/v3/pkg/grabtest/assert.go +++ b/v3/pkg/grabtest/assert.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "fmt" "io" - "io/ioutil" "net/http" "testing" ) @@ -15,7 +14,6 @@ func AssertHTTPResponseStatusCode(t *testing.T, resp *http.Response, expect int) t.Errorf("expected status code: %d, got: %d", expect, resp.StatusCode) return } - ok = true return true } @@ -48,7 +46,7 @@ func AssertHTTPResponseBodyLength(t *testing.T, resp *http.Response, n int64) (o panic(err) } }() - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { panic(err) } @@ -77,7 +75,7 @@ func MustHTTPDo(req *http.Request) *http.Response { func MustHTTPDoWithClose(req *http.Request) *http.Response { resp := MustHTTPDo(req) - if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil { + if _, err := io.Copy(io.Discard, resp.Body); err != nil { panic(err) } if err := resp.Body.Close(); err != nil { diff --git a/v3/pkg/grabtest/handler.go b/v3/pkg/grabtest/handler.go index efe829c..6712216 100644 --- a/v3/pkg/grabtest/handler.go +++ b/v3/pkg/grabtest/handler.go @@ -20,18 +20,20 @@ var ( type StatusCodeFunc func(req *http.Request) int type handler struct { - statusCodeFunc StatusCodeFunc - methodWhitelist []string - headerBlacklist []string - contentLength int - acceptRanges bool - attachmentFilename string - lastModified time.Time - ttfb time.Duration - rateLimiter *time.Ticker + statusCodeFunc StatusCodeFunc + methodWhitelist []string + headerBlacklist []string + contentLength int + acceptRanges bool + attachmentFilename string + lastModified time.Time + ttfb time.Duration + rateLimiter *time.Ticker + withBreakHeadRequest bool + withBreakGetRequestCh chan struct{} } -func NewHandler(options ...HandlerOption) (http.Handler, error) { +func NewHandler(options ...HandlerOption) (*handler, error) { h := &handler{ statusCodeFunc: func(req *http.Request) int { return http.StatusOK }, methodWhitelist: []string{"GET", "HEAD"}, @@ -53,13 +55,28 @@ func WithTestServer(t *testing.T, f func(url string), options ...HandlerOption) return } s := httptest.NewServer(h) + go h.closeConnections(s) defer func() { - h.(*handler).close() + h.close() s.Close() }() f(s.URL) } +func (h *handler) breakHeadRequest() { + if h.withBreakHeadRequest { + h.withBreakGetRequestCh <- struct{}{} + time.Sleep(time.Second) + } +} + +func (h *handler) closeConnections(s *httptest.Server) { + if h.withBreakHeadRequest { + <-h.withBreakGetRequestCh + s.CloseClientConnections() + } +} + func (h *handler) close() { if h.rateLimiter != nil { h.rateLimiter.Stop() @@ -72,6 +89,10 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { time.Sleep(h.ttfb) } + if r.Method == "HEAD" { + h.breakHeadRequest() + } + // validate request method allowed := false for _, m := range h.methodWhitelist { @@ -134,7 +155,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // use buffered io to reduce overhead on the reader bw := bufio.NewWriterSize(w, 4096) for i := offset; !isRequestClosed(r) && i < h.contentLength; i++ { - bw.Write([]byte{byte(i)}) + _, _ = bw.Write([]byte{byte(i)}) if h.rateLimiter != nil { bw.Flush() w.(http.Flusher).Flush() // force the server to send the data to the client diff --git a/v3/pkg/grabtest/handler_option.go b/v3/pkg/grabtest/handler_option.go index c6bd931..f822ad1 100644 --- a/v3/pkg/grabtest/handler_option.go +++ b/v3/pkg/grabtest/handler_option.go @@ -90,3 +90,11 @@ func AttachmentFilename(filename string) HandlerOption { return nil } } + +func WithBreakHeadRequest() HandlerOption { + return func(h *handler) error { + h.withBreakHeadRequest = true + h.withBreakGetRequestCh = make(chan struct{}) + return nil + } +} diff --git a/v3/pkg/grabtest/handler_test.go b/v3/pkg/grabtest/handler_test.go index a1bc8d0..af38f0b 100644 --- a/v3/pkg/grabtest/handler_test.go +++ b/v3/pkg/grabtest/handler_test.go @@ -2,7 +2,7 @@ package grabtest import ( "fmt" - "io/ioutil" + "io" "net/http" "testing" "time" @@ -84,7 +84,7 @@ func TestHandlerContentLength(t *testing.T) { AssertHTTPResponseHeader(t, resp, "Content-Length", "%d", test.ExpectHeaderLen) - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { panic(err) } diff --git a/v3/pkg/grabui/console_client.go b/v3/pkg/grabui/console_client.go index 276f813..b0b7750 100644 --- a/v3/pkg/grabui/console_client.go +++ b/v3/pkg/grabui/console_client.go @@ -155,7 +155,7 @@ func byteString(n int64) string { } func etaString(eta time.Time) string { - d := eta.Sub(time.Now()) + d := time.Until(eta) if d < time.Second { return "<1s" } diff --git a/v3/response.go b/v3/response.go index 05bbca1..5a7981d 100644 --- a/v3/response.go +++ b/v3/response.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "io" - "io/ioutil" "net/http" "os" "sync/atomic" @@ -169,7 +168,7 @@ func (c *Response) Duration() time.Duration { return c.End.Sub(c.Start) } - return time.Now().Sub(c.Start) + return time.Since(c.Start) } // ETA returns the estimated time at which the the download will complete, given @@ -204,7 +203,7 @@ func (c *Response) Open() (io.ReadCloser, error) { func (c *Response) openUnsafe() (io.ReadCloser, error) { if c.Request.NoStore { - return ioutil.NopCloser(bytes.NewReader(c.storeBuffer.Bytes())), nil + return io.NopCloser(bytes.NewReader(c.storeBuffer.Bytes())), nil } return os.Open(c.Filename) } @@ -226,7 +225,7 @@ func (c *Response) Bytes() ([]byte, error) { return nil, err } defer f.Close() - return ioutil.ReadAll(f) + return io.ReadAll(f) } func (c *Response) requestMethod() string {