Skip to content

Commit

Permalink
Merge pull request #105 from fengjingchao/master
Browse files Browse the repository at this point in the history
Regression Framework: Wait on gradient data ready in ChildDataReady
  • Loading branch information
Hongchao Deng committed Jan 13, 2015
2 parents ad6736b + a7ef868 commit 7e62eb3
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 21 deletions.
53 changes: 52 additions & 1 deletion framework/regression_framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"math/rand"
"os"
"strconv"
"sync"

"github.com/go-distributed/meritop"
)
Expand Down Expand Up @@ -168,6 +169,7 @@ type dummySlave struct {

param, gradient *dummyData
fromChildren map[uint64]*dummyData
gradientReady *countDownLatch
}

// This is useful to bring the task up to speed from scratch or if it recovers.
Expand Down Expand Up @@ -197,6 +199,7 @@ func (t *dummySlave) SetEpoch(epoch uint64) {
t.logger.Printf("slave SetEpoch, task: %d, epoch: %d\n", t.taskID, epoch)
t.param = &dummyData{}
t.gradient = &dummyData{}
t.gradientReady = newCountDownLatch(1)

t.epoch = epoch
// Make sure we have a clean slate.
Expand Down Expand Up @@ -225,11 +228,14 @@ func (t *dummySlave) ParentDataReady(parentID uint64, req string, resp []byte) {
if t.testablyFail("ParentDataReady") {
return
}
if t.gradientReady.Count() == 0 {
return
}
t.param = new(dummyData)
json.Unmarshal(resp, t.param)

// We need to carry out local compuation.
t.gradient.Value = t.param.Value * int32(t.framework.GetTaskID())
t.gradientReady.CountDown()

// If this task has children, flag meta so that children can start pull
// parameter.
Expand All @@ -255,6 +261,9 @@ func (t *dummySlave) ChildDataReady(childID uint64, req string, resp []byte) {
// But this really means that we get all the events from children, we
// should go into the next epoch now.
if len(t.fromChildren) == len(t.framework.GetTopology().GetChildren(t.epoch)) {
// If a new node restart and find out both parent and child meta ready, it will
// simultaneously request both data. We need to wait until gradient data is there.
t.gradientReady.Await()
// In real ML, we add the gradient first.
for _, g := range t.fromChildren {
t.gradient.Value += g.Value
Expand Down Expand Up @@ -333,3 +342,45 @@ func (tc SimpleTaskBuilder) GetTask(taskID uint64) meritop.Task {
config: tc.SlaveConfig,
}
}

// I am writing this count down latch because sync.WaitGroup doesn't support
// decrementing counter when it's 0.
type countDownLatch struct {
sync.Mutex
cond *sync.Cond
counter int
}

func newCountDownLatch(count int) *countDownLatch {
c := new(countDownLatch)
c.cond = sync.NewCond(c)
c.counter = count
return c
}

func (c *countDownLatch) Count() int {
c.Lock()
defer c.Unlock()
return c.counter
}

func (c *countDownLatch) CountDown() {
c.Lock()
defer c.Unlock()
if c.counter == 0 {
return
}
c.counter--
if c.counter == 0 {
c.cond.Broadcast()
}
}

func (c *countDownLatch) Await() {
c.Lock()
defer c.Unlock()
if c.counter == 0 {
return
}
c.cond.Wait()
}
49 changes: 29 additions & 20 deletions pkg/etcdutil/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,20 @@ import (

"github.com/coreos/etcd/etcdserver"
"github.com/coreos/etcd/etcdserver/etcdhttp"
"github.com/coreos/etcd/pkg/testutil"
"github.com/coreos/etcd/pkg/transport"
"github.com/coreos/etcd/pkg/types"
"github.com/coreos/etcd/rafthttp"
)

const (
tickDuration = 10 * time.Millisecond
clusterName = "etcd"
tickDuration = 10 * time.Millisecond
clusterName = "etcd"
requestTimeout = 2 * time.Second
)

var (
electionTicks = 10
)

func newLocalListener(t *testing.T) net.Listener {
Expand All @@ -44,22 +52,19 @@ func newLocalListener(t *testing.T) net.Listener {
return l
}

type member struct {
etcdserver.ServerConfig
PeerListeners, ClientListeners []net.Listener

s *etcdserver.EtcdServer
hss []*httptest.Server
}

func StartNewEtcdServer(t *testing.T, name string) *member {
m := MustNewMember(t, name)
m.Launch()
return m
}

func (m *member) URL() string {
return fmt.Sprintf("http://%s", m.ClientListeners[0].Addr().String())
type member struct {
etcdserver.ServerConfig
PeerListeners, ClientListeners []net.Listener

raftHandler *testutil.PauseableHandler
s *etcdserver.EtcdServer
hss []*httptest.Server
}

func MustNewMember(t *testing.T, name string) *member {
Expand Down Expand Up @@ -92,7 +97,8 @@ func MustNewMember(t *testing.T, name string) *member {
t.Fatal(err)
}
m.NewCluster = true
m.Transport = newTransport()
m.Transport = mustNewTransport(t)
m.ElectionTimeoutTicks = electionTicks
return m
}

Expand All @@ -107,10 +113,12 @@ func (m *member) Launch() error {
m.s.SyncTicker = time.Tick(500 * time.Millisecond)
m.s.Start()

m.raftHandler = &testutil.PauseableHandler{Next: etcdhttp.NewPeerHandler(m.s.Cluster, m.s.RaftHandler())}

for _, ln := range m.PeerListeners {
hs := &httptest.Server{
Listener: ln,
Config: &http.Server{Handler: etcdhttp.NewPeerHandler(m.s)},
Config: &http.Server{Handler: m.raftHandler},
}
hs.Start()
m.hss = append(m.hss, hs)
Expand All @@ -126,6 +134,8 @@ func (m *member) Launch() error {
return nil
}

func (m *member) URL() string { return m.ClientURLs[0].String() }

// Terminate stops the member and removes the data dir.
func (m *member) Terminate(t *testing.T) {
m.s.Stop()
Expand All @@ -137,11 +147,10 @@ func (m *member) Terminate(t *testing.T) {
t.Fatal(err)
}
}

func newTransport() *http.Transport {
tr := &http.Transport{}
// TODO: need the support of graceful stop in Sender to remove this
tr.DisableKeepAlives = true
tr.Dial = (&net.Dialer{Timeout: 100 * time.Millisecond}).Dial
func mustNewTransport(t *testing.T) *http.Transport {
tr, err := transport.NewTimeoutTransport(transport.TLSInfo{}, rafthttp.ConnReadTimeout, rafthttp.ConnWriteTimeout)
if err != nil {
t.Fatal(err)
}
return tr
}

0 comments on commit 7e62eb3

Please sign in to comment.