Skip to content

Commit

Permalink
Fixing nits in bandit code
Browse files Browse the repository at this point in the history
  • Loading branch information
myleshorton committed Dec 17, 2024
1 parent 40b474f commit dd9ce9f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 41 deletions.
62 changes: 31 additions & 31 deletions dialer/bandit.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import (
bandit "github.com/alextanhongpin/go-bandit"
)

// BanditDialer is responsible for continually choosing the optimized dialer.
type BanditDialer struct {
// banditDialer is responsible for continually choosing the optimized dialer.
type banditDialer struct {
dialers []ProxyDialer
bandit bandit.Bandit
opts *Options
Expand Down Expand Up @@ -46,7 +46,7 @@ func NewBandit(opts *Options) (Dialer, error) {

var b bandit.Bandit
var err error
dialer := &BanditDialer{
dialer := &banditDialer{
dialers: dialers,
opts: opts,
banditRewardsMutex: &sync.Mutex{},
Expand Down Expand Up @@ -89,7 +89,7 @@ func NewBandit(opts *Options) (Dialer, error) {
return dialer, nil
}

func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
func (bd *banditDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
deadline, _ := ctx.Deadline()
log.Debugf("bandit::DialContext::time remaining: %v", time.Until(deadline))
// We can not create a multi-armed bandit with no arms.
Expand Down Expand Up @@ -138,8 +138,8 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) (
time.AfterFunc(secondsForSample*time.Second, func() {
speed := normalizeReceiveSpeed(dataRecv.Load())
//log.Debugf("Dialer %v received %v bytes in %v seconds, normalized speed: %v", d.Name(), dt.dataRecv, secondsForSample, speed)
if err = bd.bandit.Update(chosenArm, speed); err != nil {
log.Errorf("unable to update bandit: %v", err)
if errUpdatingBanditReward := bd.bandit.Update(chosenArm, speed); errUpdatingBanditReward != nil {
log.Errorf("unable to update bandit: %v", errUpdatingBanditReward)
}
})

Expand All @@ -156,9 +156,9 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) (
}
}

err = bd.updateBanditRewards(metrics)
if err != nil {
log.Errorf("unable to save bandit weights: %v", err)
errUpdatingBanditReward := bd.updateBanditRewards(metrics)
if errUpdatingBanditReward != nil {
log.Errorf("unable to save bandit weights: %v", errUpdatingBanditReward)
}
})

Expand All @@ -178,14 +178,14 @@ const (
// loadLastBanditRewards is a function that returns the last bandit rewards
// for each dialer. If this is set, the bandit will be initialized with the
// last metrics.
func (o *BanditDialer) loadLastBanditRewards() (map[string]banditMetrics, error) {
o.banditRewardsMutex.Lock()
defer o.banditRewardsMutex.Unlock()
if o.opts.BanditDir == "" {
func (bd *banditDialer) loadLastBanditRewards() (map[string]banditMetrics, error) {
bd.banditRewardsMutex.Lock()
defer bd.banditRewardsMutex.Unlock()
if bd.opts.BanditDir == "" {
return nil, log.Error("bandit directory is not set")
}

file := filepath.Join(o.opts.BanditDir, "rewards.csv")
file := filepath.Join(bd.opts.BanditDir, "rewards.csv")
data, err := os.Open(file)
if err != nil {
return nil, err
Expand Down Expand Up @@ -234,17 +234,17 @@ func (o *BanditDialer) loadLastBanditRewards() (map[string]banditMetrics, error)
return metrics, nil
}

func (o *BanditDialer) updateBanditRewards(newRewards map[string]banditMetrics) error {
if err := os.MkdirAll(o.opts.BanditDir, 0755); err != nil {
func (bd *banditDialer) updateBanditRewards(newRewards map[string]banditMetrics) error {
if err := os.MkdirAll(bd.opts.BanditDir, 0755); err != nil {
return log.Errorf("unable to create bandit directory: %v", err)
}

previousRewards, err := o.loadLastBanditRewards()
previousRewards, err := bd.loadLastBanditRewards()
if err != nil && !os.IsNotExist(err) {
return log.Errorf("couldn't load previous bandit rewards: %w", err)
}
o.banditRewardsMutex.Lock()
defer o.banditRewardsMutex.Unlock()
bd.banditRewardsMutex.Lock()
defer bd.banditRewardsMutex.Unlock()

// if there's previous rewards, we must overwrite current values
if previousRewards != nil {
Expand All @@ -255,11 +255,11 @@ func (o *BanditDialer) updateBanditRewards(newRewards map[string]banditMetrics)
previousRewards = newRewards
}

if o.opts.BanditDir == "" {
if bd.opts.BanditDir == "" {
return log.Error("bandit directory is not set")
}

file := filepath.Join(o.opts.BanditDir, "rewards.csv")
file := filepath.Join(bd.opts.BanditDir, "rewards.csv")

headers := []string{"dialer", "reward", "count", "updated at"}
f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
Expand All @@ -284,26 +284,26 @@ func (o *BanditDialer) updateBanditRewards(newRewards map[string]banditMetrics)
return nil
}

func (o *BanditDialer) chooseDialerForDomain(network, addr string) (ProxyDialer, int) {
func (bd *banditDialer) chooseDialerForDomain(network, addr string) (ProxyDialer, int) {
// Loop through the number of dialers we have and select the one that is best
// for the given domain.
chosenArm := o.bandit.SelectArm(rand.Float64())
chosenArm := bd.bandit.SelectArm(rand.Float64())
var d ProxyDialer
notAllFailing := hasNotFailing(o.dialers)
for i := 0; i < (len(o.dialers) * 2); i++ {
d = o.dialers[chosenArm]
notAllFailing := hasNotFailing(bd.dialers)
for i := 0; i < (len(bd.dialers) * 2); i++ {
d = bd.dialers[chosenArm]
readyChan := d.Ready()
if readyChan != nil {
select {
case err := <-readyChan:
if err != nil {
log.Errorf("dialer %q failed to initialize with error %w, chossing different arm", d.Name(), err)
chosenArm = differentArm(chosenArm, len(o.dialers))
chosenArm = differentArm(chosenArm, len(bd.dialers))
continue
}
default:
log.Debugf("dialer %q is not ready, chossing different arm", d.Name())
chosenArm = differentArm(chosenArm, len(o.dialers))
chosenArm = differentArm(chosenArm, len(bd.dialers))
continue
}
}
Expand All @@ -313,7 +313,7 @@ func (o *BanditDialer) chooseDialerForDomain(network, addr string) (ProxyDialer,
//
// If the chosen dialer does not support the address, we should also
// choose a different dialer.
chosenArm = differentArm(chosenArm, len(o.dialers))
chosenArm = differentArm(chosenArm, len(bd.dialers))
continue
}
break
Expand Down Expand Up @@ -348,9 +348,9 @@ func normalizeReceiveSpeed(dataRecv uint64) float64 {
return (float64(dataRecv) / secondsForSample) / topExpectedBps
}

func (o *BanditDialer) Close() {
func (bd *banditDialer) Close() {
log.Debug("Closing all dialers")
for _, d := range o.dialers {
for _, d := range bd.dialers {
d.Stop()
}
}
Expand Down
14 changes: 7 additions & 7 deletions dialer/bandit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestBanditDialer_chooseDialerForDomain(t *testing.T) {
}
o, err := NewBandit(opts)
require.NoError(t, err)
got, got1 := o.(*BanditDialer).chooseDialerForDomain(tt.args.network, tt.args.addr)
got, got1 := o.(*banditDialer).chooseDialerForDomain(tt.args.network, tt.args.addr)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("BanditDialer.chooseDialerForDomain() got = %v, want %v", got, tt.want)
}
Expand Down Expand Up @@ -115,7 +115,7 @@ func TestNewBandit(t *testing.T) {
assert: func(t *testing.T, got Dialer, err error, _ string) {
assert.NotNil(t, got)
assert.NoError(t, err)
assert.IsType(t, &BanditDialer{}, got)
assert.IsType(t, &banditDialer{}, got)
},
},
{
Expand All @@ -135,9 +135,9 @@ func TestNewBandit(t *testing.T) {
assert: func(t *testing.T, got Dialer, err error, dir string) {
assert.NotNil(t, got)
assert.NoError(t, err)
assert.IsType(t, &BanditDialer{}, got)
rewards := got.(*BanditDialer).bandit.GetRewards()
counts := got.(*BanditDialer).bandit.GetCounts()
assert.IsType(t, &banditDialer{}, got)
rewards := got.(*banditDialer).bandit.GetRewards()
counts := got.(*banditDialer).bandit.GetCounts()
// checking if the rewards are loaded correctly
assert.Equal(t, oldDialerMetric.Reward, rewards[0])
assert.Equal(t, oldDialerMetric.Count, counts[0])
Expand Down Expand Up @@ -373,7 +373,7 @@ func TestUpdateBanditRewards(t *testing.T) {
require.NoError(t, err)
defer os.RemoveAll(tempDir)

banditDialer := &BanditDialer{
banditDialer := &banditDialer{
opts: &Options{
BanditDir: tempDir,
},
Expand Down Expand Up @@ -421,7 +421,7 @@ func TestLoadLastBanditRewards(t *testing.T) {
err = os.WriteFile(filepath.Join(tempDir, "rewards.csv"), []byte(tt.given), 0644)
require.NoError(t, err)

banditDialer := &BanditDialer{
banditDialer := &banditDialer{
opts: &Options{
BanditDir: tempDir,
},
Expand Down
6 changes: 3 additions & 3 deletions dialer/fastconnect.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ func (fcd *fastConnectDialer) connectAll(dialers []ProxyDialer) {
func (fcd *fastConnectDialer) parallelDial(dialers []ProxyDialer) {
log.Debug("Connecting to all dialers")
var wg sync.WaitGroup
for index, d := range dialers {
for _, d := range dialers {
wg.Add(1)
go func(pd ProxyDialer, index int) {
go func(pd ProxyDialer) {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
Expand All @@ -159,7 +159,7 @@ func (fcd *fastConnectDialer) parallelDial(dialers []ProxyDialer) {

log.Debugf("Dialer %v succeeded in %v", pd.Name(), time.Since(start))
fcd.onConnected(pd, time.Since(start))
}(d, index)
}(d)
}
wg.Wait()
}
Expand Down

0 comments on commit dd9ce9f

Please sign in to comment.