Skip to content

Commit

Permalink
Add minimal, multi stream support to minimal-download client (#38)
Browse files Browse the repository at this point in the history
* Add minimal multistream option
* Add all conn logic to download

This change consolidates websocket logic into the download() method so
that connection start and shutdown can happen concurrently across
multiple streams. As such, we checkpoint firstStart firstClose and
lastStart and lastClose times as well as byte counts at significant
events. With these variables, we can calculate various avg rates or a
peak rates.
  • Loading branch information
stephen-soltesz authored Jan 12, 2024
1 parent fa019b5 commit b0d0b68
Showing 1 changed file with 125 additions and 44 deletions.
169 changes: 125 additions & 44 deletions cmd/minimal-download/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"net/url"
"path"
"runtime"
"sync"
"sync/atomic"
"time"

"github.com/google/uuid"
Expand All @@ -27,14 +29,16 @@ const (
)

var (
flagCC = flag.String("cc", "bbr", "Congestion control algorithm to use")
flagDuration = flag.Duration("duration", 5*time.Second, "Length of the last stream")
flagByteLimit = flag.Int("bytes", 0, "Byte limit to request to the server")
flagNoVerify = flag.Bool("no-verify", false, "Skip TLS certificate verification")
flagServerURL = flag.String("server.url", "", "URL to directly target")
flagMID = flag.String("mid", uuid.NewString(), "Measurement ID to use")
flagScheme = flag.String("scheme", "wss", "Websocket scheme (wss or ws)")
flagLocateURL = flag.String("locate.url", locateURL, "The base url for the Locate API")
flagCC = flag.String("cc", "bbr", "Congestion control algorithm to use")
flagDuration = flag.Duration("duration", 5*time.Second, "Length of the last stream")
flagMaxDuration = flag.Duration("max-duration", 15*time.Second, "Maximum length of all connections")
flagByteLimit = flag.Int("bytes", 0, "Byte limit to request to the server")
flagNoVerify = flag.Bool("no-verify", false, "Skip TLS certificate verification")
flagServerURL = flag.String("server.url", "", "URL to directly target")
flagMID = flag.String("server.mid", uuid.NewString(), "Measurement ID to use")
flagScheme = flag.String("locate.scheme", "wss", "Websocket scheme (wss or ws)")
flagLocateURL = flag.String("locate.url", locateURL, "The base url for the Locate API")
flagStreams = flag.Int("streams", 1, "The number of concurrent streams to create")
)

// WireMeasurement is a wrapper for Measurement structs that contains
Expand Down Expand Up @@ -110,9 +114,9 @@ func init() {
}

// connect to the given msak server URL, returning a *websocket.Conn.
func connect(ctx context.Context, s *url.URL) (*websocket.Conn, error) {
func prepareHeaders(ctx context.Context, s *url.URL) (string, http.Header) {
q := s.Query()
q.Set("streams", fmt.Sprintf("%d", 1))
q.Set("streams", fmt.Sprintf("%d", *flagStreams))
q.Set("cc", *flagCC)
q.Set("bytes", fmt.Sprintf("%d", *flagByteLimit))
q.Set("duration", fmt.Sprintf("%d", (*flagDuration).Milliseconds()))
Expand All @@ -126,8 +130,7 @@ func connect(ctx context.Context, s *url.URL) (*websocket.Conn, error) {
headers := http.Header{}
headers.Add("Sec-WebSocket-Protocol", "net.measurementlab.throughput.v1")
headers.Add("User-Agent", clientName+"/"+clientVersion)
conn, _, err := localDialer.DialContext(ctx, s.String(), headers)
return conn, err
return s.String(), headers
}

// formatMessage reports a WireMeasurement in a human readable format.
Expand Down Expand Up @@ -204,38 +207,65 @@ func getDownloadServer(ctx context.Context) (*url.URL, error) {
return nil, errors.New("no server")
}

// getConn connects to a download server, returning the *websocket.Conn.
func getConn(ctx context.Context) (*websocket.Conn, error) {
srv, err := getDownloadServer(ctx)
if err != nil {
return nil, err
}
// Connect to server.
return connect(ctx, srv)
type sharedResults struct {
bytesTotal atomic.Int64 // total bytes seen over the life of all connections.
bytesAtLastStart atomic.Int64 // total bytes seen when the last connection starts.
bytesAtFirstStop atomic.Int64 // total bytes seen when the first connection stops/closes.
minRTT atomic.Int64 // minimum of all MinRTT values from all connections.
mu sync.Mutex
started atomic.Bool // set true after first connection opens.
firstStartTime time.Time
lastStartTime time.Time
stopped atomic.Bool // set true after first connection closes (may be different than start conn).
firstStopTime time.Time
lastStopTime time.Time
}

func main() {
flag.Parse()

ctx, cancel := context.WithTimeout(context.Background(), *flagDuration*2)
defer cancel()

conn, err := getConn(ctx)
func (s *sharedResults) download(ctx context.Context, u string, headers http.Header, wg *sync.WaitGroup, streamCount int, stream int) {
// Connect to server.
conn, _, err := localDialer.DialContext(ctx, u, headers)
if err != nil {
log.Fatal(err)
log.Println("skipping one stream; fialed to connect:", err)
return
}
defer func(conn *websocket.Conn) {
// Close on return.
conn.Close()
// On return, record first and last stop times.
s.mu.Lock() // protect stopTime.
now := time.Now()
if !s.stopped.Load() {
// Stop after first connect close.
s.stopped.Store(true)
s.firstStopTime = now
s.bytesAtFirstStop.Store(s.bytesTotal.Load())
}
// This will update for every closed stream, but the last stream to close will be the correct "lastStopTime".
s.lastStopTime = now
s.mu.Unlock()
wg.Done()
}(conn)

// Record first and last start times.
s.mu.Lock()
now := time.Now()
if !s.started.Load() {
s.started.Store(true)
// record start time as first open connection.
s.firstStartTime = now
}
defer conn.Close()
// This will update for every stream, but the last stream to update will be the correct "lastStartTime".
s.lastStartTime = now
s.bytesAtLastStart.Store(s.bytesTotal.Load())
s.mu.Unlock()

// Max runtime.
deadline := time.Now().Add(*flagDuration * 2)
// Set absolute deadline for connections.
deadline := time.Now().Add(*flagMaxDuration)
conn.SetWriteDeadline(deadline)
conn.SetReadDeadline(deadline)

// receive from text & binary messages from conn until the context expires or conn closes.
var applicationBytesReceived int64
var minRTT int64
start := time.Now()
outer:
// Receive text & binary messages from conn until the context expires or conn closes.
for {
select {
case <-ctx.Done():
Expand All @@ -256,28 +286,79 @@ outer:
log.Println("error", err)
return
}
applicationBytesReceived += size
s.bytesTotal.Add(size)
case websocket.TextMessage:
data, err := io.ReadAll(reader)
if err != nil {
log.Println("error", err)
return
}
applicationBytesReceived += int64(len(data))
s.bytesTotal.Add(int64(len(data)))

var m WireMeasurement
if err := json.Unmarshal(data, &m); err != nil {
log.Println("error", err)
return
}
formatMessage("Download server", 1, m)
minRTT = m.TCPInfo["MinRTT"]
if m.TCPInfo["MinRTT"] < s.minRTT.Load() || s.minRTT.Load() == 0 {
// NOTE: this will be the minimum of MinRTT across all streams.
s.minRTT.Store(m.TCPInfo["MinRTT"])
}

switch {
case streamCount == 1:
// Use server metrics for single stream tests.
formatMessage("Download server", 1, m)
case streamCount > 1 && stream == 0:
// Only do this for one stream.
elapsed := time.Since(s.firstStartTime)
log.Printf("Download client #1 - Avg %0.2f Mbps, MinRTT %5.2fms, elapsed %0.4fs, application r/w: %d/%d\n",
8*float64(s.bytesTotal.Load())/1e6/elapsed.Seconds(), // as mbps.
float64(s.minRTT.Load())/1000.0, // as ms.
elapsed.Seconds(), 0, s.bytesTotal.Load())
}
}
}
}
since := time.Since(start)
log.Printf("Download client #1 - Avg %0.2f Mbps, MinRTT %5.2fms, elapsed %0.4fs, application r/w: %d/%d\n",
8*float64(applicationBytesReceived)/1e6/since.Seconds(), // as mbps.
float64(minRTT)/1000.0, // as ms.
since.Seconds(), 0, applicationBytesReceived)
}

func main() {
flag.Parse()

ctx, cancel := context.WithTimeout(context.Background(), *flagMaxDuration)
defer cancel()

srv, err := getDownloadServer(ctx)
if err != nil {
log.Fatal(err)
}
// Get common URL and headers.
u, headers := prepareHeaders(ctx, srv)
log.Printf("Connecting: %s://%s%s?...", srv.Scheme, srv.Host, srv.Path)

s := &sharedResults{}
wg := &sync.WaitGroup{}
for i := 0; i < *flagStreams; i++ {
wg.Add(1)
go s.download(ctx, u, headers, wg, *flagStreams, i)
}
wg.Wait()

log.Println("------")
elapsedAvg := s.firstStopTime.Sub(s.firstStartTime)
bytesAvg := s.bytesAtFirstStop.Load() // like msak-client, bytes during first-start to first-stop.
log.Printf("Download client #1 - Avg %0.2f Mbps, MinRTT %5.2fms, elapsed %0.4fs, application r/w: %d/%d\n",
8*float64(bytesAvg)/1e6/elapsedAvg.Seconds(), // as mbps.
float64(s.minRTT.Load())/1000.0, // as ms.
elapsedAvg.Seconds(), 0, bytesAvg)

// TODO: we assume connections all overlap during peak periods.
elapsedPeak := s.firstStopTime.Sub(s.lastStartTime)
bytesPeak := s.bytesAtFirstStop.Load() - s.bytesAtLastStart.Load() // bytes during of peak period.
if *flagStreams > 1 && bytesPeak > 0 && elapsedPeak > 0 {
log.Printf("Download client #1 - Peak %0.2f Mbps, MinRTT %5.2fms, elapsed %0.4fs, application r/w: %d/%d\n",
8*float64(bytesPeak)/1e6/elapsedPeak.Seconds(), // as mbps.
float64(s.minRTT.Load())/1000.0, // as ms.
elapsedPeak.Seconds(), 0, bytesPeak)
}
}

0 comments on commit b0d0b68

Please sign in to comment.