Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix head break connection #100

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions v3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package grab
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -81,7 +82,7 @@ func (c *Client) Do(req *Request) *Response {
resp := &Response{
Request: req,
Start: time.Now(),
Done: make(chan struct{}, 0),
Done: make(chan struct{}),
Filename: req.Filename,
ctx: ctx,
cancel: cancel,
Expand Down Expand Up @@ -208,7 +209,7 @@ func (c *Client) statFileInfo(resp *Response) stateFunc {
}
fi, err := os.Stat(resp.Filename)
if err != nil {
if os.IsNotExist(err) {
if errors.Is(err, os.ErrNotExist) {
return c.headRequest
}
resp.err = err
Expand Down Expand Up @@ -345,7 +346,7 @@ func (c *Client) headRequest(resp *Response) stateFunc {

resp.HTTPResponse, resp.err = c.doHTTPRequest(hreq)
if resp.err != nil {
return c.closeResponse
return c.getRequest
}
resp.HTTPResponse.Body.Close()

Expand Down Expand Up @@ -447,17 +448,17 @@ func (c *Client) openWriter(resp *Response) stateFunc {
}

// open file
f, err := os.OpenFile(resp.Filename, flag, 0666)
f, err := os.OpenFile(resp.Filename, flag, 0o666)
if err != nil {
resp.err = err
return c.closeResponse
}
resp.writer = f

// seek to start or end
whence := os.SEEK_SET
whence := io.SeekStart
if resp.bytesResumed > 0 {
whence = os.SEEK_END
whence = io.SeekEnd
}
_, resp.err = f.Seek(0, whence)
if resp.err != nil {
Expand Down Expand Up @@ -504,7 +505,7 @@ func (c *Client) copyFile(resp *Response) stateFunc {
// the BeforeCopy didn't cancel the copy. If this was an existing
// file that is not going to be resumed, truncate the contents.
if t, ok := resp.writer.(truncater); ok && resp.fi != nil && !resp.DidResume {
t.Truncate(0)
_ = t.Truncate(0)
}

bytesCopied, resp.err = resp.transfer.copy()
Expand Down Expand Up @@ -557,7 +558,7 @@ func (c *Client) closeResponse(resp *Response) stateFunc {

resp.fi = nil
closeWriter(resp)
resp.closeResponseBody()
_ = resp.closeResponseBody()

resp.End = time.Now()
close(resp.Done)
Expand Down
58 changes: 35 additions & 23 deletions v3/client_test.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package grab

//nolint:gosec
import (
"bytes"
"context"
Expand All @@ -10,7 +11,6 @@ import (
"errors"
"fmt"
"hash"
"io/ioutil"
"math/rand"
"net/http"
"os"
Expand Down Expand Up @@ -42,7 +42,7 @@ func TestFilenameResolution(t *testing.T) {
{"Failure", "", "", "", ""},
}

err := os.Mkdir(".test", 0777)
err := os.Mkdir(".test", 0o777)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -78,6 +78,8 @@ func TestFilenameResolution(t *testing.T) {

// TestChecksums checks that checksum validation behaves as expected for valid
// and corrupted downloads.
//
//nolint:gosec
func TestChecksums(t *testing.T) {
tests := []struct {
size int
Expand Down Expand Up @@ -205,7 +207,7 @@ func TestContentLength(t *testing.T) {
func TestAutoResume(t *testing.T) {
segs := 8
size := 1048576
sum := grabtest.DefaultHandlerSHA256ChecksumBytes //grab/v3test.MustHexDecodeString("fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83")
sum := grabtest.DefaultHandlerSHA256ChecksumBytes // grab/v3test.MustHexDecodeString("fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83")
filename := ".testAutoResume"

defer os.Remove(filename)
Expand Down Expand Up @@ -317,6 +319,17 @@ func TestAutoResume(t *testing.T) {
grabtest.HeaderBlacklist("Content-Length"),
)
})

t.Run("WithHeadRequestBreak", func(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename, url)
resp := DefaultClient.Do(req)
testComplete(t, resp)
},
grabtest.WithBreakHeadRequest(),
)
})

// TODO: test when existing file is corrupted
}

Expand Down Expand Up @@ -382,22 +395,21 @@ func TestBatch(t *testing.T) {
// listen for responses
Loop:
for i := 0; i < len(reqs); {
select {
case resp := <-responses:
if resp == nil {
break Loop
}
testComplete(t, resp)
if err := resp.Err(); err != nil {
t.Errorf("%s: %v", resp.Filename, err)
}
resp := <-responses
if resp == nil {
break Loop
}
testComplete(t, resp)
if err := resp.Err(); err != nil {
t.Errorf("%s: %v", resp.Filename, err)
}

// remove test file
if resp.IsComplete() {
os.Remove(resp.Filename) // ignore errors
}
i++
// remove test file
if resp.IsComplete() {
os.Remove(resp.Filename) // ignore errors
}
i++

}
}
},
Expand Down Expand Up @@ -426,7 +438,7 @@ func TestCancelContext(t *testing.T) {
time.Sleep(time.Millisecond * 500)
cancel()
for resp := range respch {
defer os.Remove(resp.Filename)
defer os.Remove(resp.Filename) //nolint:staticcheck

// err should be context.Canceled or http.errRequestCanceled
if resp.Err() == nil || !strings.Contains(resp.Err().Error(), "canceled") {
Expand Down Expand Up @@ -516,7 +528,7 @@ func TestRemoteTime(t *testing.T) {
defer os.Remove(filename)

// random time between epoch and now
expect := time.Unix(rand.Int63n(time.Now().Unix()), 0)
expect := time.Unix(rand.Int63n(time.Now().Unix()), 0) //nolint:gosec
grabtest.WithTestServer(t, func(url string) {
resp := mustDo(mustNewRequest(filename, url))
fi, err := os.Stat(resp.Filename)
Expand Down Expand Up @@ -625,7 +637,7 @@ func TestBeforeCopyHook(t *testing.T) {
// Assert that an existing local file will not be truncated prior to the
// BeforeCopy hook has a chance to cancel the request
t.Run("NoTruncate", func(t *testing.T) {
tfile, err := ioutil.TempFile("", "grab_client_test.*.file")
tfile, err := os.CreateTemp("", "grab_client_test.*.file")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -808,7 +820,7 @@ func TestMissingContentLength(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(".testMissingContentLength", url)
req.SetChecksum(
md5.New(),
md5.New(), //nolint:gosec
grabtest.DefaultHandlerMD5ChecksumBytes,
false)
resp := DefaultClient.Do(req)
Expand Down Expand Up @@ -844,7 +856,7 @@ func TestNoStore(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename, url)
req.NoStore = true
req.SetChecksum(md5.New(), grabtest.DefaultHandlerMD5ChecksumBytes, true)
req.SetChecksum(md5.New(), grabtest.DefaultHandlerMD5ChecksumBytes, true) //nolint:gosec
resp := mustDo(req)

// ensure Response.Bytes is correct and can be reread
Expand Down Expand Up @@ -902,7 +914,7 @@ func TestNoStore(t *testing.T) {
req := mustNewRequest("", url)
req.NoStore = true
req.SetChecksum(
md5.New(),
md5.New(), //nolint:gosec
grabtest.MustHexDecodeString("deadbeefcafebabe"),
true)
resp := DefaultClient.Do(req)
Expand Down
2 changes: 1 addition & 1 deletion v3/go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/cavaliergopher/grab/v3

go 1.14
go 1.19
5 changes: 2 additions & 3 deletions v3/grab_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package grab

import (
"fmt"
"io/ioutil"
"log"
"os"
"testing"
Expand All @@ -17,15 +16,15 @@ func TestMain(m *testing.M) {
if err != nil {
panic(err)
}
tmpDir, err := ioutil.TempDir("", "grab-")
tmpDir, err := os.MkdirTemp("", "grab-")
if err != nil {
panic(err)
}
if err := os.Chdir(tmpDir); err != nil {
panic(err)
}
defer func() {
os.Chdir(cwd)
_ = os.Chdir(cwd)
if err := os.RemoveAll(tmpDir); err != nil {
panic(err)
}
Expand Down
6 changes: 2 additions & 4 deletions v3/pkg/grabtest/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"crypto/sha256"
"fmt"
"io"
"io/ioutil"
"net/http"
"testing"
)
Expand All @@ -15,7 +14,6 @@ func AssertHTTPResponseStatusCode(t *testing.T, resp *http.Response, expect int)
t.Errorf("expected status code: %d, got: %d", expect, resp.StatusCode)
return
}
ok = true
return true
}

Expand Down Expand Up @@ -48,7 +46,7 @@ func AssertHTTPResponseBodyLength(t *testing.T, resp *http.Response, n int64) (o
panic(err)
}
}()
b, err := ioutil.ReadAll(resp.Body)
b, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -77,7 +75,7 @@ func MustHTTPDo(req *http.Request) *http.Response {

func MustHTTPDoWithClose(req *http.Request) *http.Response {
resp := MustHTTPDo(req)
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
if _, err := io.Copy(io.Discard, resp.Body); err != nil {
panic(err)
}
if err := resp.Body.Close(); err != nil {
Expand Down
45 changes: 33 additions & 12 deletions v3/pkg/grabtest/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,20 @@ var (
type StatusCodeFunc func(req *http.Request) int

type handler struct {
statusCodeFunc StatusCodeFunc
methodWhitelist []string
headerBlacklist []string
contentLength int
acceptRanges bool
attachmentFilename string
lastModified time.Time
ttfb time.Duration
rateLimiter *time.Ticker
statusCodeFunc StatusCodeFunc
methodWhitelist []string
headerBlacklist []string
contentLength int
acceptRanges bool
attachmentFilename string
lastModified time.Time
ttfb time.Duration
rateLimiter *time.Ticker
withBreakHeadRequest bool
withBreakGetRequestCh chan struct{}
}

func NewHandler(options ...HandlerOption) (http.Handler, error) {
func NewHandler(options ...HandlerOption) (*handler, error) {
h := &handler{
statusCodeFunc: func(req *http.Request) int { return http.StatusOK },
methodWhitelist: []string{"GET", "HEAD"},
Expand All @@ -53,13 +55,28 @@ func WithTestServer(t *testing.T, f func(url string), options ...HandlerOption)
return
}
s := httptest.NewServer(h)
go h.closeConnections(s)
defer func() {
h.(*handler).close()
h.close()
s.Close()
}()
f(s.URL)
}

func (h *handler) breakHeadRequest() {
if h.withBreakHeadRequest {
h.withBreakGetRequestCh <- struct{}{}
time.Sleep(time.Second)
}
}

func (h *handler) closeConnections(s *httptest.Server) {
if h.withBreakHeadRequest {
<-h.withBreakGetRequestCh
s.CloseClientConnections()
}
}

func (h *handler) close() {
if h.rateLimiter != nil {
h.rateLimiter.Stop()
Expand All @@ -72,6 +89,10 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
time.Sleep(h.ttfb)
}

if r.Method == "HEAD" {
h.breakHeadRequest()
}

// validate request method
allowed := false
for _, m := range h.methodWhitelist {
Expand Down Expand Up @@ -134,7 +155,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// use buffered io to reduce overhead on the reader
bw := bufio.NewWriterSize(w, 4096)
for i := offset; !isRequestClosed(r) && i < h.contentLength; i++ {
bw.Write([]byte{byte(i)})
_, _ = bw.Write([]byte{byte(i)})
if h.rateLimiter != nil {
bw.Flush()
w.(http.Flusher).Flush() // force the server to send the data to the client
Expand Down
8 changes: 8 additions & 0 deletions v3/pkg/grabtest/handler_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,11 @@ func AttachmentFilename(filename string) HandlerOption {
return nil
}
}

func WithBreakHeadRequest() HandlerOption {
return func(h *handler) error {
h.withBreakHeadRequest = true
h.withBreakGetRequestCh = make(chan struct{})
return nil
}
}
4 changes: 2 additions & 2 deletions v3/pkg/grabtest/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package grabtest

import (
"fmt"
"io/ioutil"
"io"
"net/http"
"testing"
"time"
Expand Down Expand Up @@ -84,7 +84,7 @@ func TestHandlerContentLength(t *testing.T) {

AssertHTTPResponseHeader(t, resp, "Content-Length", "%d", test.ExpectHeaderLen)

b, err := ioutil.ReadAll(resp.Body)
b, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
Expand Down
Loading