Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement optional concurrent "Range" requests (refs #86) #102

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 69 additions & 17 deletions v3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package grab
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -48,11 +50,23 @@ type Client struct {

// NewClient returns a new file download Client, using default configuration.
func NewClient() *Client {
dialer := &net.Dialer{}
return &Client{
UserAgent: "grab",
HTTPClient: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
conn, err := dialer.DialContext(ctx, network, addr)
if err == nil {
// Default net.TCPConn calls SetNoDelay(true)
// which likely could be an impact on performance
// with large file downloads, and many ACKs on networks
// with higher latency
err = conn.(*net.TCPConn).SetNoDelay(false)
}
return conn, err
},
},
},
}
Expand Down Expand Up @@ -86,6 +100,7 @@ func (c *Client) Do(req *Request) *Response {
ctx: ctx,
cancel: cancel,
bufferSize: req.BufferSize,
transfer: (*transfer)(nil),
}
if resp.bufferSize == 0 {
// default to Client.BufferSize
Expand Down Expand Up @@ -330,13 +345,19 @@ func (c *Client) headRequest(resp *Response) stateFunc {
}
resp.optionsKnown = true

if resp.Request.NoResume {
return c.getRequest
}
// If we are going to do a range request, then we need to perform
// the HEAD req to check for support.
// Otherwise, we may not need to do the HEAD request if we have
// enough information already.
if resp.Request.RangeRequestMax <= 0 {
if resp.Request.NoResume {
return c.getRequest
}

if resp.Filename != "" && resp.fi == nil {
// destination path is already known and does not exist
return c.getRequest
if resp.Filename != "" && resp.fi == nil {
// destination path is already known and does not exist
return c.getRequest
}
}

hreq := new(http.Request)
Expand All @@ -349,7 +370,8 @@ func (c *Client) headRequest(resp *Response) stateFunc {
}
resp.HTTPResponse.Body.Close()

if resp.HTTPResponse.StatusCode != http.StatusOK {
if resp.HTTPResponse.StatusCode != http.StatusOK &&
resp.HTTPResponse.StatusCode != http.StatusPartialContent {
return c.getRequest
}

Expand All @@ -365,6 +387,13 @@ func (c *Client) headRequest(resp *Response) stateFunc {
}

func (c *Client) getRequest(resp *Response) stateFunc {
if resp.isRangeRequest() {
// For a concurrent range request, we don't do a single
// GET request here. It will be handled later in the transfer,
// based on the HEAD response
return c.openWriter
}

resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest)
if resp.err != nil {
return c.closeResponse
Expand Down Expand Up @@ -410,11 +439,12 @@ func (c *Client) readResponse(resp *Response) stateFunc {
resp.Filename = filepath.Join(resp.Request.Filename, filename)
}

if !resp.Request.NoStore && resp.requestMethod() == "HEAD" {
if resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" {
resp.CanResume = true
if resp.requestMethod() == "HEAD" {
resp.acceptRanges = resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes"
if !resp.Request.NoStore {
resp.CanResume = resp.acceptRanges
return c.statFileInfo
}
return c.statFileInfo
}
return c.openWriter
}
Expand All @@ -431,13 +461,19 @@ func (c *Client) openWriter(resp *Response) stateFunc {
}
}

if resp.bufferSize < 1 {
resp.bufferSize = 32 * 1024
}

var writerAt io.WriterAt

if resp.Request.NoStore {
resp.writer = &resp.storeBuffer
} else {
// compute write flags
flag := os.O_CREATE | os.O_WRONLY
if resp.fi != nil {
if resp.DidResume {
if resp.DidResume && !resp.isRangeRequest() {
flag = os.O_APPEND | os.O_WRONLY
} else {
// truncate later in copyFile, if not cancelled
Expand All @@ -453,11 +489,12 @@ func (c *Client) openWriter(resp *Response) stateFunc {
return c.closeResponse
}
resp.writer = f
writerAt = f

// seek to start or end
whence := os.SEEK_SET
whence := io.SeekStart
if resp.bytesResumed > 0 {
whence = os.SEEK_END
whence = io.SeekEnd
}
_, resp.err = f.Seek(0, whence)
if resp.err != nil {
Expand All @@ -469,13 +506,19 @@ func (c *Client) openWriter(resp *Response) stateFunc {
if resp.bufferSize < 1 {
resp.bufferSize = 32 * 1024
}
b := make([]byte, resp.bufferSize)

if resp.isRangeRequest() && writerAt != nil {
resp.transfer = newTransferRanges(c.HTTPClient, resp, writerAt)
// next step is copyFile, but this will be called later in another goroutine
return nil
}

resp.transfer = newTransfer(
resp.Request.Context(),
resp.Request.RateLimiter,
resp.writer,
resp.HTTPResponse.Body,
b)
resp.bufferSize)

// next step is copyFile, but this will be called later in another goroutine
return nil
Expand Down Expand Up @@ -507,8 +550,17 @@ func (c *Client) copyFile(resp *Response) stateFunc {
t.Truncate(0)
}

bytesCopied, resp.err = resp.transfer.copy()
bytesCopied, resp.err = resp.transfer.Copy()
if resp.err != nil {
// If we ran parallel ranges and some of them failed, we need
// to truncate the file to the lowest successful range to avoid
// having any gaps during a subsequent resume operation.
var rangesErr transferRangesErr
if errors.As(resp.err, &rangesErr) {
if t, ok := resp.writer.(truncater); ok {
t.Truncate(rangesErr.LastOffsetEnd)
}
}
return c.closeResponse
}
closeWriter(resp)
Expand Down
165 changes: 164 additions & 1 deletion v3/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -206,8 +207,8 @@ func TestAutoResume(t *testing.T) {
segs := 8
size := 1048576
sum := grabtest.DefaultHandlerSHA256ChecksumBytes //grab/v3test.MustHexDecodeString("fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83")
filename := ".testAutoResume"

filename := ".testAutoResume"
defer os.Remove(filename)

for i := 0; i < segs; i++ {
Expand All @@ -229,6 +230,31 @@ func TestAutoResume(t *testing.T) {
})
}

filename2 := ".testAutoResumeRange"
defer os.Remove(filename2)

for i := 0; i < segs; i++ {
segsize := (i + 1) * (size / segs)
t.Run(fmt.Sprintf("RangeWith%vBytes", segsize), func(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename2, url)
req.RangeRequestMinSize = 1
req.RangeRequestMax = 5
if i == segs-1 {
req.SetChecksum(sha256.New(), sum, false)
}
resp := mustDo(req)
if i > 0 && !resp.DidResume {
t.Errorf("expected Response.DidResume to be true")
}
testComplete(t, resp)
},
grabtest.ContentLength(segsize),
grabtest.StatusCode(func(r *http.Request) int { return http.StatusPartialContent }),
)
})
}

t.Run("WithFailure", func(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
// request smaller segment
Expand Down Expand Up @@ -912,3 +938,140 @@ func TestNoStore(t *testing.T) {
})
})
}

// TestRangeRequest tests the option of using parallel range requests to download
// chunks of the remote resource
func TestRangeRequest(t *testing.T) {
size := int64(32768)
testCases := []struct {
Name string
Chunks int
StatusCode int
MinSize int64
}{
{Name: "NumChunksNeg", Chunks: -1, StatusCode: http.StatusOK},
{Name: "NumChunks0", StatusCode: http.StatusOK},
{Name: "NumChunks1", Chunks: 1, StatusCode: http.StatusPartialContent},
{Name: "NumChunks5", Chunks: 5, StatusCode: http.StatusPartialContent},

// should not run a Range request because the Content-Length is
// not large enough
{Name: "RangeRequestMinSize", Chunks: 5, MinSize: size + 1, StatusCode: http.StatusOK},
}

for _, test := range testCases {
t.Run(test.Name, func(t *testing.T) {
opts := []grabtest.HandlerOption{
grabtest.ContentLength(int(size)),
grabtest.StatusCode(func(r *http.Request) int {
if test.Chunks > 0 && size >= test.MinSize {
return http.StatusPartialContent
}
return http.StatusOK
}),
}

grabtest.WithTestServer(t, func(url string) {
name := fmt.Sprintf(".testRangeRequest-%s", test.Name)
req := mustNewRequest(name, url)
req.RangeRequestMax = test.Chunks
req.RangeRequestMinSize = test.MinSize

resp := DefaultClient.Do(req)
defer os.Remove(resp.Filename)

err := resp.Err()
if err == ErrBadLength {
t.Errorf("error: %v", err)
} else if err != nil {
panic(err)
} else if resp.Size() != size {
t.Errorf("expected %v bytes, got %v bytes", size, resp.Size())
}

if resp.HTTPResponse.StatusCode != test.StatusCode {
t.Errorf("expected status code %v, got %d", test.StatusCode, resp.HTTPResponse.StatusCode)
}

if bps := resp.BytesPerSecond(); bps <= 0 {
t.Errorf("expected BytesPerSecond > 0, got %v", bps)
}

testComplete(t, resp)
}, opts...)
})
}
}

type rangeTestClient struct {
fn func(req *http.Request) (*http.Response, error)
}

func (c *rangeTestClient) Do(req *http.Request) (*http.Response, error) {
return c.fn(req)
}

func TestRangeRequestAutoResume(t *testing.T) {
const (
NumChunks = 8
Size = 1048576
BadChunkStart = 393216
)
sum := grabtest.DefaultHandlerSHA256ChecksumBytes
expectErr := fmt.Errorf("TEST: cancelled")

client := NewClient()
var wg sync.WaitGroup
client.HTTPClient = &rangeTestClient{func(req *http.Request) (*http.Response, error) {
wg.Add(1)
// Catch a range in the middle and wait for the other
// ranges to finish. Then, fail this range.
if strings.HasPrefix(req.Header.Get("Range"), fmt.Sprintf("bytes=%v", BadChunkStart)) {
go func() {
time.Sleep(100 * time.Millisecond)
wg.Done()
}()
wg.Wait()
return nil, expectErr
}
defer wg.Done()
return DefaultClient.HTTPClient.Do(req)
}}

filename := ".testRangeRequestAutoResume"
defer os.Remove(filename)

opts := []grabtest.HandlerOption{
grabtest.ContentLength(int(Size)),
grabtest.StatusCode(func(r *http.Request) int {
return http.StatusPartialContent
}),
}

grabtest.WithTestServer(t, func(url string) {
// run a request with parallel range chunks, where a
// chunk in the middle is not written
req := mustNewRequest(filename, url)
req.RangeRequestMinSize = 1
req.RangeRequestMax = NumChunks
req.SetChecksum(sha256.New(), sum, false)
resp := client.Do(req)
if err := resp.Err(); !errors.Is(err, expectErr) {
t.Fatal(err.Error())
}
testComplete(t, resp)
if resp.BytesComplete() >= resp.Size() {
t.Fatalf("Expected BytesComplete() [%v] < Size() [%v]", resp.BytesComplete(), resp.Size())
}

st, err := os.Stat(resp.Filename)
if err != nil {
t.Fatalf(err.Error())
}
if st.Size() > BadChunkStart {
t.Fatalf("Partially written file size %v is not <= %v", st.Size(), BadChunkStart)
}
},
opts...,
)
}
2 changes: 2 additions & 0 deletions v3/go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/cavaliergopher/grab/v3

go 1.14

require golang.org/x/sync v0.3.0
2 changes: 2 additions & 0 deletions v3/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
Loading