Skip to content

Commit

Permalink
updating algorithm to calculate speed connection for MAB (#1466)
Browse files Browse the repository at this point in the history
* fix: recording time elapsed during connection reads and use it to calculate the connection speed

* fix: update normalize speed equation and unit tests to receive the new parameter

* fix: adding time durations on bandit dialer so we can make it testable

* feat: testing bandit dialer flow

* fix: deleting test dir after test execution

* fix: checking if rewards.csv was written

* fix: usinc the copy response as the number of bytes written at b

* fix: renaming type casting

* fix: using InEpsilon for float comparison
  • Loading branch information
WendelHime authored Dec 19, 2024
1 parent 27f3d46 commit 29111f0
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 26 deletions.
49 changes: 30 additions & 19 deletions dialer/bandit.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ import (

// banditDialer is responsible for continually choosing the optimized dialer.
type banditDialer struct {
dialers []ProxyDialer
bandit bandit.Bandit
opts *Options
banditRewardsMutex *sync.Mutex
dialers []ProxyDialer
bandit bandit.Bandit
opts *Options
banditRewardsMutex *sync.Mutex
secondsUntilRewardSample time.Duration
secondsUntilSaveBanditRewards time.Duration
}

type banditMetrics struct {
Expand All @@ -47,9 +49,11 @@ func NewBandit(opts *Options) (Dialer, error) {
var b bandit.Bandit
var err error
dialer := &banditDialer{
dialers: dialers,
opts: opts,
banditRewardsMutex: &sync.Mutex{},
dialers: dialers,
opts: opts,
banditRewardsMutex: &sync.Mutex{},
secondsUntilRewardSample: secondsForSample * time.Second,
secondsUntilSaveBanditRewards: saveBanditRewardsAfter,
}

dialerWeights, err := dialer.loadLastBanditRewards()
Expand Down Expand Up @@ -134,16 +138,17 @@ func (bd *banditDialer) DialContext(ctx context.Context, network, addr string) (

// Tell the dialer to update the bandit with it's throughput after 5 seconds.
var dataRecv atomic.Uint64
dt := newDataTrackingConn(conn, &dataRecv)
time.AfterFunc(secondsForSample*time.Second, func() {
speed := normalizeReceiveSpeed(dataRecv.Load())
//log.Debugf("Dialer %v received %v bytes in %v seconds, normalized speed: %v", d.Name(), dt.dataRecv, secondsForSample, speed)
var elapsedTimeReading atomic.Int64
dt := newDataTrackingConn(conn, &dataRecv, &elapsedTimeReading)
time.AfterFunc(bd.secondsUntilRewardSample, func() {
speed := normalizeReceiveSpeed(dataRecv.Load(), elapsedTimeReading.Load())
// log.Debugf("Dialer %v received %v bytes in %v seconds, normalized speed: %v", d.Name(), dt.dataRecv, secondsForSample, speed)
if errUpdatingBanditReward := bd.bandit.Update(chosenArm, speed); errUpdatingBanditReward != nil {
log.Errorf("unable to update bandit: %v", errUpdatingBanditReward)
log.Errorf("unable to update bandit: %v", err)
}
})

time.AfterFunc(30*time.Second, func() {
time.AfterFunc(bd.secondsUntilSaveBanditRewards, func() {
log.Debugf("saving bandit rewards")
metrics := make(map[string]banditMetrics)
rewards := bd.bandit.GetRewards()
Expand Down Expand Up @@ -339,13 +344,15 @@ func differentArm(existingArm, numDialers int) int {

const secondsForSample = 6

const saveBanditRewardsAfter = 30 * time.Second

// A reasonable upper bound for the top expected bytes to receive per second.
// Anything over this will be normalized to over 1.
const topExpectedBps = 125000

func normalizeReceiveSpeed(dataRecv uint64) float64 {
func normalizeReceiveSpeed(dataRecv uint64, elapsedTimeReading int64) float64 {
// Record the bytes in relation to the top expected speed.
return (float64(dataRecv) / secondsForSample) / topExpectedBps
return (float64(dataRecv) / (float64(elapsedTimeReading) / 1000)) / topExpectedBps
}

func (bd *banditDialer) Close() {
Expand All @@ -355,20 +362,24 @@ func (bd *banditDialer) Close() {
}
}

func newDataTrackingConn(conn net.Conn, dataRecv *atomic.Uint64) *dataTrackingConn {
func newDataTrackingConn(conn net.Conn, dataRecv *atomic.Uint64, elapsedTimeReading *atomic.Int64) *dataTrackingConn {
return &dataTrackingConn{
Conn: conn,
dataRecv: dataRecv,
Conn: conn,
dataRecv: dataRecv,
elapsedTimeReading: elapsedTimeReading,
}
}

type dataTrackingConn struct {
net.Conn
dataRecv *atomic.Uint64
dataRecv *atomic.Uint64
elapsedTimeReading *atomic.Int64 // elapsedTimeReading store in milliseconds the time the connection took to read data
}

func (c *dataTrackingConn) Read(b []byte) (int, error) {
startedReading := time.Now()
n, err := c.Conn.Read(b)
c.dataRecv.Add(uint64(n))
c.elapsedTimeReading.Add(time.Since(startedReading).Milliseconds())
return n, err
}
82 changes: 75 additions & 7 deletions dialer/bandit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

func TestBanditDialer_chooseDialerForDomain(t *testing.T) {
Expand Down Expand Up @@ -230,7 +231,8 @@ func TestBanditDialer_DialContext(t *testing.T) {

func Test_normalizeReceiveSpeed(t *testing.T) {
type args struct {
dataRecv uint64
dataRecv uint64
elapsedTimeReading int64
}
tests := []struct {
name string
Expand All @@ -240,7 +242,8 @@ func Test_normalizeReceiveSpeed(t *testing.T) {
{
name: "should return 0 if no data received",
args: args{
dataRecv: 0,
dataRecv: 0,
elapsedTimeReading: secondsForSample * 1000,
},
want: func(got float64) bool {
return got == 0
Expand All @@ -249,7 +252,8 @@ func Test_normalizeReceiveSpeed(t *testing.T) {
{
name: "should return 1 if pretty fast",
args: args{
dataRecv: topExpectedBps * secondsForSample,
dataRecv: topExpectedBps * secondsForSample,
elapsedTimeReading: secondsForSample * 1000,
},
want: func(got float64) bool {
return got == 1
Expand All @@ -258,17 +262,18 @@ func Test_normalizeReceiveSpeed(t *testing.T) {
{
name: "should return 1 if super fast",
args: args{
dataRecv: topExpectedBps * 50,
dataRecv: topExpectedBps * 50,
elapsedTimeReading: secondsForSample * 1000,
},
want: func(got float64) bool {
return got > 1
},
},

{
name: "should return <1 if sorta fast",
args: args{
dataRecv: 2000,
dataRecv: 2000,
elapsedTimeReading: secondsForSample * 1000,
},
want: func(got float64) bool {
return got > 0 && got < 1
Expand All @@ -277,7 +282,7 @@ func Test_normalizeReceiveSpeed(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := normalizeReceiveSpeed(tt.args.dataRecv); !tt.want(got) {
if got := normalizeReceiveSpeed(tt.args.dataRecv, tt.args.elapsedTimeReading); !assert.True(t, tt.want(got)) {
t.Errorf("unexpected normalizeReceiveSpeed() = %v", got)
}
})
Expand Down Expand Up @@ -453,6 +458,7 @@ type tcpConnDialer struct {
client net.Conn
server net.Conn
name string
dial func() (net.Conn, bool, error)
}

func (*tcpConnDialer) Ready() <-chan error {
Expand Down Expand Up @@ -508,6 +514,10 @@ func (t *tcpConnDialer) DialContext(ctx context.Context, network string, addr st
if t.shouldFail {
return nil, true, io.EOF
}

if t.dial != nil {
return t.dial()
}
return &net.TCPConn{}, false, nil
}

Expand Down Expand Up @@ -600,3 +610,61 @@ func (*tcpConnDialer) Trusted() bool {
// WriteStats implements Dialer.
func (*tcpConnDialer) WriteStats(w io.Writer) {
}

//go:generate mockgen -package=dialer -destination=mocks_test.go net Conn

func TestBanditDialerIntegration(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

baseDialer := newTcpConnDialer()
message := "hello"
connSleepTime := 200 * time.Millisecond

baseDialer.(*tcpConnDialer).dial = func() (net.Conn, bool, error) {
conn := NewMockConn(ctrl)
conn.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
time.Sleep(connSleepTime)
return copy(b, []byte(message)), io.EOF
}).AnyTimes()
return conn, false, nil
}

banditDir, err := os.MkdirTemp("", "bandit_dial_test")
require.NoError(t, err)
defer os.RemoveAll(banditDir)

opts := &Options{
Dialers: []ProxyDialer{baseDialer},
BanditDir: banditDir,
}
bandit, err := NewBandit(opts)
require.NoError(t, err)
banditDialer := bandit.(*banditDialer)
banditDialer.secondsUntilRewardSample = 1 * time.Second
banditDialer.secondsUntilSaveBanditRewards = 1200 * time.Millisecond

ctx := context.Background()
banditConn, err := banditDialer.DialContext(ctx, "tcp", "localhost:8080")
require.NoError(t, err)

got, err := io.ReadAll(banditConn)
assert.NoError(t, err)
assert.Equal(t, message, string(got[:len(message)]))

// waiting so reward is sampled and bandit rewards are stored
time.Sleep(1400 * time.Millisecond)

rewards := banditDialer.bandit.GetRewards()
counts := banditDialer.bandit.GetCounts()

// there's only one dialer
assert.Len(t, counts, 1)
assert.Len(t, rewards, 1)
// since there's only one dialer and one Dial call, we're expecting one count
assert.Equal(t, 1, counts[0])
assert.InEpsilon(t, normalizeReceiveSpeed(uint64(len(got)), connSleepTime.Milliseconds()), rewards[0], 0.2)

// check if rewards.csv was written
assert.FileExists(t, filepath.Join(banditDir, "rewards.csv"))
}
Loading

0 comments on commit 29111f0

Please sign in to comment.