diff --git a/framework/regression_framework.go b/framework/regression_framework.go index 48a7342..2eb0ba6 100644 --- a/framework/regression_framework.go +++ b/framework/regression_framework.go @@ -6,6 +6,7 @@ import ( "math/rand" "os" "strconv" + "sync" "github.com/go-distributed/meritop" ) @@ -168,6 +169,7 @@ type dummySlave struct { param, gradient *dummyData fromChildren map[uint64]*dummyData + gradientReady *countDownLatch } // This is useful to bring the task up to speed from scratch or if it recovers. @@ -197,6 +199,7 @@ func (t *dummySlave) SetEpoch(epoch uint64) { t.logger.Printf("slave SetEpoch, task: %d, epoch: %d\n", t.taskID, epoch) t.param = &dummyData{} t.gradient = &dummyData{} + t.gradientReady = newCountDownLatch(1) t.epoch = epoch // Make sure we have a clean slate. @@ -225,11 +228,14 @@ func (t *dummySlave) ParentDataReady(parentID uint64, req string, resp []byte) { if t.testablyFail("ParentDataReady") { return } + if t.gradientReady.Count() == 0 { + return + } t.param = new(dummyData) json.Unmarshal(resp, t.param) - // We need to carry out local compuation. t.gradient.Value = t.param.Value * int32(t.framework.GetTaskID()) + t.gradientReady.CountDown() // If this task has children, flag meta so that children can start pull // parameter. @@ -255,6 +261,9 @@ func (t *dummySlave) ChildDataReady(childID uint64, req string, resp []byte) { // But this really means that we get all the events from children, we // should go into the next epoch now. if len(t.fromChildren) == len(t.framework.GetTopology().GetChildren(t.epoch)) { + // If a new node restart and find out both parent and child meta ready, it will + // simultaneously request both data. We need to wait until gradient data is there. + t.gradientReady.Await() // In real ML, we add the gradient first. for _, g := range t.fromChildren { t.gradient.Value += g.Value @@ -333,3 +342,45 @@ func (tc SimpleTaskBuilder) GetTask(taskID uint64) meritop.Task { config: tc.SlaveConfig, } } + +// I am writing this count down latch because sync.WaitGroup doesn't support +// decrementing counter when it's 0. +type countDownLatch struct { + sync.Mutex + cond *sync.Cond + counter int +} + +func newCountDownLatch(count int) *countDownLatch { + c := new(countDownLatch) + c.cond = sync.NewCond(c) + c.counter = count + return c +} + +func (c *countDownLatch) Count() int { + c.Lock() + defer c.Unlock() + return c.counter +} + +func (c *countDownLatch) CountDown() { + c.Lock() + defer c.Unlock() + if c.counter == 0 { + return + } + c.counter-- + if c.counter == 0 { + c.cond.Broadcast() + } +} + +func (c *countDownLatch) Await() { + c.Lock() + defer c.Unlock() + if c.counter == 0 { + return + } + c.cond.Wait() +} diff --git a/pkg/etcdutil/cluster.go b/pkg/etcdutil/cluster.go index 9176132..9daba93 100644 --- a/pkg/etcdutil/cluster.go +++ b/pkg/etcdutil/cluster.go @@ -28,12 +28,20 @@ import ( "github.com/coreos/etcd/etcdserver" "github.com/coreos/etcd/etcdserver/etcdhttp" + "github.com/coreos/etcd/pkg/testutil" + "github.com/coreos/etcd/pkg/transport" "github.com/coreos/etcd/pkg/types" + "github.com/coreos/etcd/rafthttp" ) const ( - tickDuration = 10 * time.Millisecond - clusterName = "etcd" + tickDuration = 10 * time.Millisecond + clusterName = "etcd" + requestTimeout = 2 * time.Second +) + +var ( + electionTicks = 10 ) func newLocalListener(t *testing.T) net.Listener { @@ -44,22 +52,19 @@ func newLocalListener(t *testing.T) net.Listener { return l } -type member struct { - etcdserver.ServerConfig - PeerListeners, ClientListeners []net.Listener - - s *etcdserver.EtcdServer - hss []*httptest.Server -} - func StartNewEtcdServer(t *testing.T, name string) *member { m := MustNewMember(t, name) m.Launch() return m } -func (m *member) URL() string { - return fmt.Sprintf("http://%s", m.ClientListeners[0].Addr().String()) +type member struct { + etcdserver.ServerConfig + PeerListeners, ClientListeners []net.Listener + + raftHandler *testutil.PauseableHandler + s *etcdserver.EtcdServer + hss []*httptest.Server } func MustNewMember(t *testing.T, name string) *member { @@ -92,7 +97,8 @@ func MustNewMember(t *testing.T, name string) *member { t.Fatal(err) } m.NewCluster = true - m.Transport = newTransport() + m.Transport = mustNewTransport(t) + m.ElectionTimeoutTicks = electionTicks return m } @@ -107,10 +113,12 @@ func (m *member) Launch() error { m.s.SyncTicker = time.Tick(500 * time.Millisecond) m.s.Start() + m.raftHandler = &testutil.PauseableHandler{Next: etcdhttp.NewPeerHandler(m.s.Cluster, m.s.RaftHandler())} + for _, ln := range m.PeerListeners { hs := &httptest.Server{ Listener: ln, - Config: &http.Server{Handler: etcdhttp.NewPeerHandler(m.s)}, + Config: &http.Server{Handler: m.raftHandler}, } hs.Start() m.hss = append(m.hss, hs) @@ -126,6 +134,8 @@ func (m *member) Launch() error { return nil } +func (m *member) URL() string { return m.ClientURLs[0].String() } + // Terminate stops the member and removes the data dir. func (m *member) Terminate(t *testing.T) { m.s.Stop() @@ -137,11 +147,10 @@ func (m *member) Terminate(t *testing.T) { t.Fatal(err) } } - -func newTransport() *http.Transport { - tr := &http.Transport{} - // TODO: need the support of graceful stop in Sender to remove this - tr.DisableKeepAlives = true - tr.Dial = (&net.Dialer{Timeout: 100 * time.Millisecond}).Dial +func mustNewTransport(t *testing.T) *http.Transport { + tr, err := transport.NewTimeoutTransport(transport.TLSInfo{}, rafthttp.ConnReadTimeout, rafthttp.ConnWriteTimeout) + if err != nil { + t.Fatal(err) + } return tr }