diff --git a/pkg/rpc/client.go b/pkg/rpc/client.go new file mode 100644 index 0000000..d112550 --- /dev/null +++ b/pkg/rpc/client.go @@ -0,0 +1,424 @@ +package rpc + +import ( + "context" + "maps" + "net" + "os" + "slices" + "sync" + "time" + + "github.com/pancsta/rpc2" + + am "github.com/pancsta/asyncmachine-go/pkg/machine" + "github.com/pancsta/asyncmachine-go/pkg/rpc/rpcnames" + ss "github.com/pancsta/asyncmachine-go/pkg/rpc/states/client" +) + +type Client struct { + *ExceptionHandler + + Mach *am.Machine + // Worker is a remote Worker instance + Worker *Worker + Payloads map[string]*ArgsPayload + CallCount uint64 + LogEnabled bool + CallTimeout time.Duration + + clockMx sync.Mutex + workerAddr string + rpc *rpc2.Client + stateNames am.S + stateStruct am.Struct + conn net.Conn +} + +// interfaces +var _ clientRpcMethods = &Client{} +var _ clientServerMethods = &Client{} + +func NewClient( + ctx context.Context, workerAddr string, id string, stateStruct am.Struct, + stateNames am.S, +) (*Client, error) { + if id == "" { + id = "rpc" + } + + c := &Client{ + ExceptionHandler: &ExceptionHandler{}, + Payloads: map[string]*ArgsPayload{}, + LogEnabled: os.Getenv("AM_RPC_LOG_CLIENT") != "", + CallTimeout: 3 * time.Second, + + workerAddr: workerAddr, + stateNames: slices.Clone(stateNames), + stateStruct: maps.Clone(stateStruct), + } + + if os.Getenv("AM_DEBUG") != "" { + c.CallTimeout = 100 * time.Second + } + + // state machine + mach, err := am.NewCommon(ctx, "c-"+id, ss.States, ss.Names, + c, nil, nil) + if err != nil { + return nil, err + } + c.Mach = mach + + return c, nil +} + +// ///// ///// ///// + +// ///// HANDLERS + +// ///// ///// ///// + +func (c *Client) StartEnd(e *am.Event) { + // gather state from before the transition + before := e.Transition().TimeBefore + mach := e.Machine + wasConn := before.Is1(mach.Index(ss.Connecting)) || + before.Is1(mach.Index(ss.Connected)) + + // graceful disconnect + if wasConn { + c.Mach.Add1(ss.Disconnecting, nil) + } +} + +func (c *Client) ConnectingState(e *am.Event) { + ctx := c.Mach.NewStateCtx(ss.Connecting) + + // async + go func() { + if ctx.Err() != nil { + return // expired + } + + // net dial + timeout := 1 * time.Second + if os.Getenv("AM_DEBUG") != "" || os.Getenv("AM_TEST") != "" { + timeout = 100 * time.Second + } + // TODO TLS + d := net.Dialer{ + Timeout: timeout, + } + conn, err := d.DialContext(ctx, "tcp", c.workerAddr) + if err != nil { + errNetwork(c.Mach, err) + return + } + c.conn = conn + + // rpc + c.bindRpcHandlers(conn) + go c.rpc.Run() + + c.Mach.Add1(ss.Connected, nil) + }() +} + +func (c *Client) DisconnectingEnter(e *am.Event) bool { + return c.rpc != nil && c.conn != nil +} + +func (c *Client) DisconnectingState(e *am.Event) { + ctx := c.Mach.NewStateCtx(ss.Disconnecting) + + go func() { + if ctx.Err() != nil { + return // expired + } + + c.notify(ctx, rpcnames.Bye.Encode(), &Empty{}) + time.Sleep(1 * time.Second) + + c.Mach.Add1(ss.Disconnected, nil) + }() +} + +func (c *Client) ConnectedState(e *am.Event) { + ctx := c.Mach.NewStateCtx(ss.Connected) + + go func() { + select { + + case <-ctx.Done(): + return // expired + + case <-c.rpc.DisconnectNotify(): + c.Mach.Add1(ss.Disconnecting, nil) + } + }() +} + +func (c *Client) DisconnectedEnter(e *am.Event) bool { + // graceful disconnect + willDisconn := e.Machine.IsQueued(am.MutationAdd, am.S{ss.Disconnecting}, + false, false, 0) + + return willDisconn <= -1 +} + +func (c *Client) DisconnectedState(e *am.Event) { + // ignore the error when disconnecting + if c.rpc != nil { + _ = c.rpc.Close() + } + if c.conn != nil { + _ = c.conn.Close() + } +} + +func (c *Client) HandshakingState(e *am.Event) { + ctx := c.Mach.NewStateCtx(ss.Connected) + + go func() { + // call rpc + var resp = &RespHandshake{} + if !c.call(ctx, rpcnames.Handshake.Encode(), Empty{}, resp) { + return + } + + // validate + if len(resp.StateNames) == 0 { + errResponseStr(c.Mach, "states missing") + return + } + if resp.ID == "" { + errResponseStr(c.Mach, "ID missing") + return + } + + // compare states + diff := am.DiffStates(c.stateNames, resp.StateNames) + if len(diff) > 0 || len(resp.StateNames) != len(c.stateNames) { + errResponseStr(c.Mach, "States differ on client/server") + return + } + + // confirm the handshake + if !c.notify(ctx, rpcnames.HandshakeAck.Encode(), true) { + return + } + + // finalize + c.Mach.Add1(ss.HandshakeDone, am.A{ + "ID": resp.ID, + "am.Time": resp.Time, + }) + }() +} + +func (c *Client) HandshakeDoneState(e *am.Event) { + // TODO validate on Enter + // TODO enum names + + id := e.Args["ID"].(string) + clock := e.Args["am.Time"].(am.Time) + + // handshake crates the worker + // TODO extract to NewWorker + c.Worker = &Worker{ + c: c, + ID: id, + Ctx: c.Mach.Ctx, + states: c.stateStruct, + stateNames: c.stateNames, + clockTime: clock, + indexWhen: am.IndexWhen{}, + indexStateCtx: am.IndexStateCtx{}, + indexWhenTime: am.IndexWhenTime{}, + whenDisposed: make(chan struct{}), + } +} + +// ///// ///// ///// + +// ///// METHODS + +// ///// ///// ///// + +// Start connects the client to the server and initializes the worker. +// Results in the Ready state. +func (c *Client) Start() am.Result { + return c.Mach.Add1(ss.Start, nil) +} + +// Stop disconnects the client from the server and disposes the worker. +// waitTillExit: if passed, waits for the client to disconnect using the +// context. +func (c *Client) Stop(waitTillExit context.Context, dispose bool) am.Result { + res := c.Mach.Remove1(ss.Start, nil) + // wait for the client to disconnect + if res != am.Canceled && waitTillExit != nil { + <-c.Mach.When1(ss.Disconnected, waitTillExit) + } + if dispose { + c.log("disposing") + c.Mach.Dispose() + c.Worker.Dispose() + } + + return res +} + +// Get requests predefined data from the server's getter function. +func (c *Client) Get(ctx context.Context, name string) (*RespGet, error) { + // call rpc + resp := RespGet{} + if !c.call(ctx, rpcnames.Get.Encode(), name, &resp) { + return nil, c.Mach.Err() + } + + return &resp, nil +} + +// GetKind returns a kind of RPC component (server / client). +func (c *Client) GetKind() Kind { + return KindClient +} + +func (c *Client) log(msg string, args ...any) { + if !c.LogEnabled { + return + } + c.Mach.Log(msg, args...) +} + +func (c *Client) bindRpcHandlers(conn net.Conn) { + c.rpc = rpc2.NewClient(conn) + c.rpc.Handle(rpcnames.ClientSetClock.Encode(), c.RemoteSetClock) + c.rpc.Handle(rpcnames.ClientSendPayload.Encode(), c.RemoteSendPayload) + // TODO check if test suite passes without it + c.rpc.SetBlocking(true) +} + +func (c *Client) updateClock(msg ClockMsg, t am.Time) { + if c.Mach.Not1(ss.Ready) { + return + } + + c.clockMx.Lock() + var clock am.Time + if msg != nil { + // diff clock update + clock = ClockFromMsg(c.Worker.clockTime, msg) + } else { + // full clock update + clock = t + } + + if clock == nil { + c.clockMx.Unlock() + return + } + + var sum uint64 + for _, v := range clock { + sum += v + } + + if msg != nil { + c.log("updateClock from msg %dt: %v", sum, msg) + } else { + c.log("updateClock full %d: %v", sum, t) + } + + timeBefore := c.Worker.clockTime + activeBefore := c.Worker.activeStatesUnlocked() + c.Worker.clockTime = clock + c.clockMx.Unlock() + + // process clock-based indexes + c.Worker.processWhenBindings(activeBefore) + c.Worker.processWhenTimeBindings(timeBefore) + c.Worker.processStateCtxBindings(activeBefore) +} + +func (c *Client) call( + ctx context.Context, method string, args, resp any, +) bool { + defer c.Mach.PanicToErr(nil) + + callCtx, cancel := context.WithTimeout(ctx, c.CallTimeout) + defer cancel() + + c.CallCount++ + err := c.rpc.CallWithContext(callCtx, method, args, resp) + if ctx.Err() != nil { + return false // expired + } + if callCtx.Err() != nil { + errAuto(c.Mach, rpcnames.Decode(method).String(), ErrNetworkTimeout) + return false + } + + if err != nil { + errAuto(c.Mach, rpcnames.Decode(method).String(), err) + return false + } + + return true +} + +func (c *Client) notify( + ctx context.Context, method string, args any, +) bool { + defer c.Mach.PanicToErr(nil) + + // TODO timeout + + c.CallCount++ + err := c.rpc.Notify(method, args) + if ctx.Err() != nil { + return false + } + if err != nil { + errAuto(c.Mach, method, err) + return false + } + + return true +} + +// ///// ///// ///// + +// ///// REMOTE METHODS + +// ///// ///// ///// + +// RemoteSetClock updates the client's clock. Only called by the server. +func (c *Client) RemoteSetClock( + _ *rpc2.Client, clock ClockMsg, _ *Empty, +) error { + // validate + if clock == nil { + errParams(c.Mach, nil) + return nil + } + + // execute + c.updateClock(clock, nil) + + return nil +} + +// RemoteSendPayload receives a payload from the server. Only called by the +// server. +func (c *Client) RemoteSendPayload( + _ *rpc2.Client, file *ArgsPayload, _ *Empty, +) error { + // TODO test + c.log("RemoteSendPayload %s", file.Name) + c.Payloads[file.Name] = file + + return nil +} diff --git a/pkg/rpc/rpcnames/rpcnames.go b/pkg/rpc/rpcnames/rpcnames.go new file mode 100644 index 0000000..89d9376 --- /dev/null +++ b/pkg/rpc/rpcnames/rpcnames.go @@ -0,0 +1,64 @@ +package rpcnames + +type Name int + +// TODO separate and reserve IDs forever + +const ( + // Server + + Add Name = iota + 1 + AddNS Name = iota + 1 + Remove + Set + Handshake + HandshakeAck + Log + Sync + Get + Bye + + // Client + + ClientSetClock + ClientSendPayload +) + +func (n Name) Encode() string { + return string(rune(n)) +} + +func (n Name) String() string { + switch n { + case Add: + return "Add" + case AddNS: + return "AddNS" + case Remove: + return "Remove" + case Set: + return "Set" + case Handshake: + return "Handshake" + case HandshakeAck: + return "HandshakeAck" + case Log: + return "Log" + case Sync: + return "Sync" + case Get: + return "Get" + case ClientSetClock: + return "ClientSetClock" + case ClientSendPayload: + return "ClientSendPayload" + case Bye: + return "Close" + } + + return "!UNKNOWN!" +} + +func Decode(s string) Name { + return Name(s[0]) +} diff --git a/pkg/rpc/server.go b/pkg/rpc/server.go new file mode 100644 index 0000000..35e616f --- /dev/null +++ b/pkg/rpc/server.go @@ -0,0 +1,558 @@ +package rpc + +import ( + "context" + "encoding/gob" + "fmt" + "net" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/pancsta/rpc2" + + amh "github.com/pancsta/asyncmachine-go/pkg/helpers" + am "github.com/pancsta/asyncmachine-go/pkg/machine" + "github.com/pancsta/asyncmachine-go/pkg/rpc/rpcnames" + ss "github.com/pancsta/asyncmachine-go/pkg/rpc/states/server" +) + +type GetterFunc func(string) any + +type Server struct { + *ExceptionHandler + + Addr string + Mach *am.Machine + // ClockInterval is the interval for clock updates, effectively throttling + // the number of updates sent to the client within the interval window. + // 0 means pushes are disabled. Setting to a very small value will make + // pushes instant. + ClockInterval time.Duration + // Listener can be set manually before starting the server. + Listener net.Listener + LogEnabled bool + CallCount uint64 + + // w is the worker machine + w *am.Machine + rpcServer *rpc2.Server + rpcClient *rpc2.Client + // lastClockHTime is the last (human) time a clock update was sent to the + // client. + lastClockHTime time.Time + lastClock am.Time + ticker *time.Ticker + clockMx sync.Mutex + // mutMx is a lock preventing mutation methods from racing each other. + mutMx sync.Mutex + lastClockSum uint64 + skipClockPush atomic.Bool + lastClockMsg ClockMsg + getter GetterFunc +} + +// interfaces +var _ serverRpcMethods = &Server{} +var _ clientServerMethods = &Server{} + +func NewServer( + ctx context.Context, addr string, id string, worker *am.Machine, + getter GetterFunc, +) (*Server, error) { + if !worker.StatesVerified { + return nil, fmt.Errorf("states not verified") + } + if id == "" { + id = "rpc" + } + + gob.Register(am.Relation(0)) + + s := &Server{ + ExceptionHandler: &ExceptionHandler{}, + Addr: addr, + ClockInterval: 250 * time.Millisecond, + LogEnabled: os.Getenv("AM_RPC_LOG_SERVER") != "", + w: worker, + getter: getter, + } + + // state machine + mach, err := am.NewCommon(ctx, "s-"+id, ss.States, ss.Names, + s, nil, nil) + if err != nil { + return nil, err + } + s.Mach = mach + + // bind to worker via Tracer API + s.traceWorker() + + return s, nil +} + +// ///// ///// ///// + +// ///// HANDLERS + +// ///// ///// ///// + +func (s *Server) RpcStartingState(e *am.Event) { + ctx := s.Mach.NewStateCtx(ss.RpcStarting) + ctxStart := s.Mach.NewStateCtx(ss.Start) + s.log("Connecting to %s", s.Addr) + + go func() { + if ctx.Err() != nil { + return // expired + } + + if s.Listener == nil { + // use Start as the context for the listener + cfg := net.ListenConfig{} + lis, err := cfg.Listen(ctxStart, "tcp", s.Addr) + if err != nil { + // add err to mach + errNetwork(s.Mach, err) + // add outcome to mach + s.Mach.Remove1(ss.RpcStarting, nil) + + return + } + + s.Listener = lis + } else { + // update Addr if an external listener was provided + s.Addr = s.Listener.Addr().String() + } + + s.bindRpcHandlers() + go s.rpcServer.Accept(s.Listener) + if ctx.Err() != nil { + return // expired + } + + // bind to client events + s.rpcServer.OnDisconnect(func(client *rpc2.Client) { + if ctx.Err() != nil { + return // expired + } + s.Mach.Add1(ss.ClientDisconn, am.A{"client": client}) + }) + // TODO flaky conn event + s.rpcServer.OnConnect(func(client *rpc2.Client) { + s.Mach.Add1(ss.ClientConn, am.A{"client": client}) + }) + + s.Mach.Add1(ss.RpcReady, nil) + }() +} + +// RpcReadyState starts a ticker to compensate for clock push denounces. +func (s *Server) RpcReadyState(e *am.Event) { + // no ticker for instant clocks + if s.ClockInterval == 0 { + return + } + + ctx := s.Mach.NewStateCtx(ss.RpcReady) + if s.ticker == nil { + s.ticker = time.NewTicker(s.ClockInterval) + } + + // avoid dispose + t := s.ticker + + go func() { + if ctx.Err() != nil { + return // expired + } + + // push clock updates, debounced by genClockUpdate + for { + select { + case <-ctx.Done(): + s.ticker = nil + return + + case <-t.C: + s.pushClockUpdate() + } + } + }() +} + +func (s *Server) RpcReadyEnd(e *am.Event) { + // TODO gracefully disconn from the client + _ = s.Listener.Close() + s.rpcServer = nil +} + +func (s *Server) HandshakeDoneEnd(e *am.Event) { + _ = s.rpcClient.Close() + s.rpcClient = nil +} + +// ///// ///// ///// + +// ///// METHODS + +// ///// ///// ///// + +// Start starts the server, optionally creates a listener if not provided and +// results in the Ready state. +func (s *Server) Start() am.Result { + return s.Mach.Add1(ss.Start, nil) +} + +// Stop stops the server, and optionally disposes resources. +func (s *Server) Stop(dispose bool) am.Result { + res := s.Mach.Remove1(ss.Start, nil) + if dispose { + s.log("disposing") + s.Mach.Dispose() + s.Listener = nil + s.rpcServer = nil + } + + return res +} + +// SendPayload sends a payload to the client. +func (s *Server) SendPayload(ctx context.Context, file *ArgsPayload) error { + // TODO bind to an async state + // TODO test SendPayload + defer s.Mach.PanicToErr(nil) + + return s.rpcClient.CallWithContext(ctx, rpcnames.ClientSendPayload.Encode(), + file, nil) +} + +// GetKind returns a kind of RPC component (server / client). +func (s *Server) GetKind() Kind { + return KindServer +} + +// NewMirrorMach returns a new machine instance of the same kind as the +// the remote worker. It does suppose handlers, but only final ones, not +// negotiation, as all the write ops go through the remote worker. It does +// however support relation based negotiation, to locally reject mutations. +// TODO add Opts.NoNegHandlers to am +// TODO add full clock stram to make sure all final handlers are triggered +// func (s *Server) NewMirrorMach() *am.Machine { +// return KindServer +// } + +func (s *Server) log(msg string, args ...any) { + if !s.LogEnabled { + return + } + s.Mach.Log(msg, args...) +} + +func (s *Server) traceWorker() bool { + // reg a new tracer via an eval window (not mid-tx) + ok := s.w.Eval("traceWorker", func() { + s.w.Tracers = append(s.w.Tracers, &WorkerTracer{s: s}) + }, s.Mach.Ctx) + + // TODO handle dispose and close the connection + + return ok +} + +func (s *Server) bindRpcHandlers() { + // new RPC instance, release prev resources + s.rpcServer = rpc2.NewServer() + + s.rpcServer.Handle(rpcnames.Handshake.Encode(), s.RemoteHandshake) + s.rpcServer.Handle(rpcnames.HandshakeAck.Encode(), s.RemoteHandshakeAck) + s.rpcServer.Handle(rpcnames.Add.Encode(), s.RemoteAdd) + s.rpcServer.Handle(rpcnames.AddNS.Encode(), s.RemoteAddNS) + s.rpcServer.Handle(rpcnames.Remove.Encode(), s.RemoteRemove) + s.rpcServer.Handle(rpcnames.Set.Encode(), s.RemoteSet) + s.rpcServer.Handle(rpcnames.Sync.Encode(), s.RemoteSync) + s.rpcServer.Handle(rpcnames.Get.Encode(), s.RemoteGet) + s.rpcServer.Handle(rpcnames.Bye.Encode(), s.RemoteBye) + + // TODO RemoteLog, RemoteWhenArgs, RemoteGetMany + + // s.rpcServer.Handle("RemoteLog", s.RemoteLog) + // s.rpcServer.Handle("RemoteWhenArgs", s.RemoteWhenArgs) +} + +func (s *Server) pushClockUpdate() { + if s.skipClockPush.Load() || s.Mach.Not1(ss.ClientConn) || + s.Mach.Not1(ss.HandshakeDone) { + s.log("force-skip clock push") + return + } + + // disabled + if s.ClockInterval == 0 { + return + } + + clock := s.genClockUpdate(false) + // debounce + if clock == nil { + return + } + + // notif without a response + defer s.Mach.PanicToErr(nil) + s.log("pushClockUpdate") + s.CallCount++ + err := s.rpcClient.Notify(rpcnames.ClientSetClock.Encode(), clock) + if err != nil { + errAuto(s.Mach, "pushClockUpdate", err) + } +} + +func (s *Server) genClockUpdate(skipTimeCheck bool) ClockMsg { + // TODO cache based on time sum (track the history) + s.clockMx.Lock() + defer s.clockMx.Unlock() + + // exit if too often + if !skipTimeCheck && (time.Since(s.lastClockHTime) < s.ClockInterval) { + s.log("genClockUpdate: too soon") + return nil + } + hTime := time.Now() + mTime := s.w.Time(nil) + + // exit if no change since the last sync + var sum uint64 + for _, v := range mTime { + sum += v + } + if sum == s.lastClockSum && s.ClockInterval != 0 { + // s.log("genClockUpdate: same sum") + return nil + } + + // proceed - valid clock update + s.lastClockMsg = NewClockMsg(s.lastClock, mTime) + s.lastClock = mTime + s.lastClockHTime = hTime + s.lastClockSum = sum + + return s.lastClockMsg +} + +// ///// ///// ///// + +// ///// REMOTE METHODS + +// ///// ///// ///// + +func (s *Server) RemoteHandshake( + client *rpc2.Client, _ *Empty, resp *RespHandshake, +) error { + // TODO GetStruct and Time inside Eval + // TODO check if client here is the same as RespHandshakeAck + + mTime := s.w.Time(nil) + *resp = RespHandshake{ + ID: s.w.ID, + StateNames: s.w.StateNames(), + Time: mTime, + } + + s.Mach.Add1(ss.Handshaking, nil) + var sum uint64 + for _, v := range mTime { + sum += v + } + s.lastClock = mTime + s.lastClockSum = sum + s.lastClockHTime = time.Now() + + // TODO timeout for RemoteHandshakeAck + + return nil +} + +func (s *Server) RemoteHandshakeAck( + client *rpc2.Client, done *bool, _ *Empty, +) error { + if done == nil || !*done { + s.Mach.Remove1(ss.Handshaking, nil) + errResponseStr(s.Mach, "handshake failed") + return nil + } + + // accept the client + s.rpcClient = client + // TODO pass as param + s.Mach.Add1(ss.HandshakeDone, nil) + + return nil +} + +func (s *Server) RemoteAdd( + _ *rpc2.Client, args *ArgsMut, resp *RespResult, +) error { + s.mutMx.Lock() + defer s.mutMx.Unlock() + + // validate + if args.States == nil { + return fmt.Errorf("%w", ErrInvalidParams) + } + + // execute + s.skipClockPush.Store(true) + val := s.w.Add(amh.IndexesToStates(s.w, args.States), args.Args) + + // return + *resp = RespResult{ + Result: val, + Clock: s.genClockUpdate(true), + } + s.skipClockPush.Store(false) + return nil +} + +func (s *Server) RemoteAddNS( + _ *rpc2.Client, args *ArgsMut, _ *Empty, +) error { + s.mutMx.Lock() + defer s.mutMx.Unlock() + + // validate + if args.States == nil { + return fmt.Errorf("%w", ErrInvalidParams) + } + + // execute + s.skipClockPush.Store(true) + _ = s.w.Add(amh.IndexesToStates(s.w, args.States), args.Args) + s.skipClockPush.Store(false) + + return nil +} + +func (s *Server) RemoteRemove( + _ *rpc2.Client, args *ArgsMut, resp *RespResult, +) error { + s.mutMx.Lock() + defer s.mutMx.Unlock() + + // validate + if args.States == nil { + return fmt.Errorf("%w", ErrInvalidParams) + } + + // execute + s.skipClockPush.Store(true) + val := s.w.Remove(amh.IndexesToStates(s.w, args.States), args.Args) + s.skipClockPush.Store(false) + + // return + *resp = RespResult{ + Result: val, + Clock: s.genClockUpdate(true), + } + return nil +} + +func (s *Server) RemoteSet( + _ *rpc2.Client, args *ArgsMut, resp *RespResult, +) error { + s.mutMx.Lock() + defer s.mutMx.Unlock() + + // validate + if args.States == nil { + return fmt.Errorf("%w", ErrInvalidParams) + } + + // execute + s.skipClockPush.Store(true) + val := s.w.Set(amh.IndexesToStates(s.w, args.States), args.Args) + s.skipClockPush.Store(false) + + // return + *resp = RespResult{ + Result: val, + Clock: s.genClockUpdate(true), + } + return nil +} + +func (s *Server) RemoteSync( + _ *rpc2.Client, sum uint64, resp *RespSync, +) error { + s.log("RemoteSync") + + if s.w.TimeSum(nil) > sum { + *resp = RespSync{ + Time: s.w.Time(nil), + } + } else { + *resp = RespSync{} + } + + s.log("RemoteSync: %v", resp.Time) + + return nil +} + +func (s *Server) RemoteGet( + _ *rpc2.Client, name string, resp *RespGet, +) error { + s.log("RemoteGet: %s", rpcnames.Decode(name)) + + if s.getter == nil { + // TODO error + *resp = RespGet{"no_getter"} + } else { + *resp = RespGet{s.getter(name)} + } + return nil +} + +func (s *Server) RemoteBye( + _ *rpc2.Client, _ *Empty, _ *Empty, +) error { + s.log("RemoteBye") + + // TODO ClientBye to keep it in sync + s.Mach.Add1(ss.ClientDisconn, nil) + go func() { + time.Sleep(100 * time.Millisecond) + s.Mach.Remove1(ss.HandshakeDone, nil) + }() + + return nil +} + +// ///// ///// ///// + +// ///// MISC + +// ///// ///// ///// + +type WorkerTracer struct { + *am.NoOpTracer + + s *Server +} + +func (t *WorkerTracer) TransitionEnd(_ *am.Transition) { + go func() { + t.s.mutMx.Lock() + defer t.s.mutMx.Unlock() + + t.s.pushClockUpdate() + }() +} + +// TODO implement as an optimization +// func (t *WorkerTracer) QueueEnd(_ *am.Transition) { +// t.s.pushClockUpdate() +// } diff --git a/pkg/rpc/shared.go b/pkg/rpc/shared.go new file mode 100644 index 0000000..8a5e992 --- /dev/null +++ b/pkg/rpc/shared.go @@ -0,0 +1,307 @@ +package rpc + +import ( + "errors" + "fmt" + "io" + "net" + "slices" + "strings" + "sync" + "sync/atomic" + + ssCli "github.com/pancsta/asyncmachine-go/pkg/rpc/states/client" + "github.com/pancsta/rpc2" + + am "github.com/pancsta/asyncmachine-go/pkg/machine" + ss "github.com/pancsta/asyncmachine-go/pkg/rpc/states" +) + +// ///// ///// ///// + +// ///// TYPES + +// ///// ///// ///// + +// ArgsMut is args for mutation methods. +type ArgsMut struct { + States []int + Args am.A +} + +type ArgsGet struct { + Name string +} + +type ArgsLog struct { + Msg string + Args []any +} + +type ArgsPayload struct { + Name string + Data []byte +} + +type RespHandshake = am.Serialized + +type RespResult struct { + Clock ClockMsg + Result am.Result +} + +type RespSync struct { + Time am.Time +} + +type RespGet struct { + Value any +} + +type Empty struct{} + +type ClockMsg [][2]int + +// clientServerMethods is a shared interface for RPC client/server. +type clientServerMethods interface { + GetKind() Kind +} + +type Kind string + +const ( + KindClient Kind = "client" + KindServer Kind = "server" +) + +// // DEBUG for perf testing TODO tag +// type ClockMsg am.Time + +// ///// ///// ///// + +// ///// RPC APIS + +// ///// ///// ///// + +// serverRpcMethods is an RPC server for controlling RemoteMachine. +// TODO verify parity with RemoteMachine via reflection +type serverRpcMethods interface { + // rpc + + RemoteHandshake(client *rpc2.Client, args *Empty, resp *RespHandshake) error + + // mutations + + RemoteAdd(client *rpc2.Client, args *ArgsMut, resp *RespResult) error + RemoteRemove(client *rpc2.Client, args *ArgsMut, resp *RespResult) error + RemoteSet(client *rpc2.Client, args *ArgsMut, reply *RespResult) error +} + +// clientRpcMethods is the RPC server exposed by the RPC client for bi-di comm. +type clientRpcMethods interface { + RemoteSetClock(worker *rpc2.Client, args ClockMsg, resp *Empty) error + RemoteSendPayload(worker *rpc2.Client, file *ArgsPayload, resp *Empty) error +} + +// ///// ///// ///// + +// ///// ERRORS + +// ///// ///// ///// + +// sentinel errors + +var ( + // ErrClient group + + ErrInvalidParams = errors.New("invalid params") + ErrInvalidResp = errors.New("invalid response") + ErrRpc = errors.New("rpc") + + // ErrNetwork group + + ErrNetwork = errors.New("network error") + ErrNetworkTimeout = errors.New("network timeout") +) + +// wrapping error setters + +func errResponse(mach *am.Machine, err error) { + mach.AddErr(fmt.Errorf("%w: %w", ErrInvalidResp, err), nil) +} + +func errResponseStr(mach *am.Machine, msg string) { + mach.AddErr(fmt.Errorf("%w: %s", ErrInvalidResp, msg), nil) +} + +func errParams(mach *am.Machine, err error) { + mach.AddErr(fmt.Errorf("%w: %w", ErrInvalidParams, err), nil) +} + +func errNetwork(mach *am.Machine, err error) { + mach.AddErr(fmt.Errorf("%w: %w", ErrNetwork, err), nil) +} + +// errAuto detects sentinels from error msgs and wraps. +func errAuto(mach *am.Machine, msg string, err error) { + + // detect group from text + var errGroup error + if strings.HasPrefix(err.Error(), "gob: ") { + errGroup = ErrInvalidResp + } else if strings.Contains(err.Error(), "rpc2: can't find method") { + errGroup = ErrRpc + } else if strings.Contains(err.Error(), "connection is shut down") || + strings.Contains(err.Error(), "unexpected EOF") { + errGroup = ErrNetwork + } + + // wrap in a group + if errGroup != nil { + mach.AddErr(fmt.Errorf("%w: %s: %w", errGroup, msg, err), nil) + return + } + + // Exception state fallback + if msg == "" { + mach.AddErr(err, nil) + } else { + mach.AddErr(fmt.Errorf("%s: %w", msg, err), nil) + } +} + +// ExceptionHandler is a shared exception handler for RPC server and +// client. +type ExceptionHandler struct { + *am.ExceptionHandler +} + +func (h *ExceptionHandler) ExceptionEnter(e *am.Event) bool { + err := e.Args["err"].(error) + + mach := e.Machine + isClient := mach.Has(am.S{ssCli.Disconnecting, ssCli.Disconnected}) + if errors.Is(err, ErrNetwork) && isClient && + mach.Any1(ssCli.Disconnecting, ssCli.Disconnected) { + + // skip network errors on client disconnect + return false + } + + return true +} + +func (h *ExceptionHandler) ExceptionState(e *am.Event) { + // call super + h.ExceptionHandler.ExceptionState(e) + mach := e.Machine + err := e.Args["err"].(error) + + // handle sentinel errors to states + // TODO handle rpc2.ErrShutdown + if errors.Is(err, am.ErrHandlerTimeout) { + // TODO activate ErrSlowHandlers + } else if errors.Is(err, ErrNetwork) || errors.Is(err, ErrNetworkTimeout) { + mach.Add1(ss.ErrNetwork, nil) + } else if errors.Is(err, ErrInvalidParams) { + mach.Add1(ss.ErrRpc, nil) + } else if errors.Is(err, ErrInvalidResp) { + mach.Add1(ss.ErrRpc, nil) + } else if errors.Is(err, ErrRpc) { + mach.Add1(ss.ErrRpc, nil) + } +} + +// ///// ///// ///// + +// ///// CLOCK + +// ///// ///// ///// + +// // DEBUG for perf testing +// func NewClockMsg(before, after am.Time) ClockMsg { +// return ClockMsg(after) +// } +// +// // DEBUG for perf testing +// func ClockFromMsg(before am.Time, msg ClockMsg) am.Time { +// return am.Time(msg) +// } + +func NewClockMsg(before, after am.Time) ClockMsg { + var val [][2]int + + for k := range after { + if before == nil { + // TODO test this path + val = append(val, [2]int{k, int(after[k])}) + } else if before[k] != after[k] { + val = append(val, [2]int{k, int(after[k] - before[k])}) + } + } + + return val +} + +func ClockFromMsg(before am.Time, msg ClockMsg) am.Time { + after := slices.Clone(before) + + for _, v := range msg { + key := v[0] + val := v[1] + after[key] += uint64(val) + } + + return after +} + +func TrafficMeter( + listener net.Listener, fwdTo string, counter chan<- int64, + end <-chan struct{}, +) { + defer listener.Close() + // fmt.Println("Listening on " + listenOn) + + // call the destination + destination, err := net.Dial("tcp", fwdTo) + if err != nil { + fmt.Println("Error connecting to destination:", err.Error()) + return + } + defer destination.Close() + + // wait for the connection + conn, err := listener.Accept() + if err != nil { + fmt.Println("Error accepting connection:", err.Error()) + return + } + defer conn.Close() + + // forward data bidirectionally + wg := sync.WaitGroup{} + wg.Add(2) + bytes := atomic.Int64{} + go func() { + c, _ := io.Copy(destination, conn) + bytes.Add(c) + wg.Done() + }() + go func() { + c, _ := io.Copy(conn, destination) + bytes.Add(c) + wg.Done() + }() + + // wait for the test and forwarding to finish + <-end + // fmt.Printf("Closing counter...\n") + _ = listener.Close() + _ = destination.Close() + _ = conn.Close() + wg.Wait() + + c := bytes.Load() + // fmt.Printf("Forwarded %d bytes\n", c) + counter <- c +} diff --git a/pkg/rpc/states/client/ss_client.go b/pkg/rpc/states/client/ss_client.go new file mode 100644 index 0000000..966c93f --- /dev/null +++ b/pkg/rpc/states/client/ss_client.go @@ -0,0 +1,80 @@ +package states + +import ( + am "github.com/pancsta/asyncmachine-go/pkg/machine" + ss "github.com/pancsta/asyncmachine-go/pkg/rpc/states" +) + +// S is a type alias for a list of state names. +type S = am.S + +// States map defines relations and properties of states. +// Base on shared rpc states. +var States = am.StructMerge(ss.States, am.Struct{ + ss.Start: {Add: S{Connecting}}, + ss.Ready: { + Auto: true, + Require: S{HandshakeDone}, + }, + + // Connection + Connecting: { + Require: S{Start}, + Remove: GroupConnected, + }, + Connected: { + Require: S{Start}, + Remove: GroupConnected, + Add: S{Handshaking}, + }, + Disconnecting: { + Remove: GroupConnected, + }, + Disconnected: { + Auto: true, + Remove: GroupConnected, + }, + + // Add a dependency on Connected to HandshakeDone. + HandshakeDone: am.StateAdd(ss.States[ss.HandshakeDone], am.State{ + Require: S{Connected}, + }), +}) + +// Groups of mutually exclusive states. + +var ( + GroupConnected = S{Connecting, Connected, Disconnecting, Disconnected} +) + +// #region boilerplate defs + +// Names of all the states (pkg enum). + +const ( + // Shared + + ErrOnClient = ss.ErrOnClient + + Start = ss.Start + Ready = ss.Ready + HandshakeDone = ss.HandshakeDone + Handshaking = ss.Handshaking + + // Client + + Connecting = "Connecting" + Connected = "Connected" + Disconnecting = "Disconnecting" + Disconnected = "Disconnected" +) + +// Names is an ordered list of all the state names. +var Names = am.SAdd(ss.Names, S{ + Connecting, + Connected, + Disconnecting, + Disconnected, +}) + +// #endregion diff --git a/pkg/rpc/states/server/ss_server.go b/pkg/rpc/states/server/ss_server.go new file mode 100644 index 0000000..7fa61c6 --- /dev/null +++ b/pkg/rpc/states/server/ss_server.go @@ -0,0 +1,101 @@ +package states + +import ( + am "github.com/pancsta/asyncmachine-go/pkg/machine" + ss "github.com/pancsta/asyncmachine-go/pkg/rpc/states" +) + +// S is a type alias for a list of state names. +type S = am.S + +// SMerge is a func alias for merging lists of states. +var SMerge = am.SAdd + +// StructMerge is a func alias for extending an existing state structure. +var StructMerge = am.StructMerge + +// StateAdd is a func alias for adding to an existing state definition. +var StateAdd = am.StateAdd + +// States map defines relations and properties of states. +// Base on shared rpc states. +var States = StructMerge(ss.States, am.Struct{ + // Errors + + // Add a removal of ClientConn to ErrNetwork. + ss.ErrNetwork: StateAdd(ss.States[ss.ErrNetwork], am.State{ + Remove: S{ClientConn}, + }), + + // Server + + ss.Start: {Add: S{RpcStarting}}, + ss.Ready: { + Auto: true, + Require: S{HandshakeDone, RpcReady}, + }, + + RpcStarting: { + Require: S{ss.Start}, + Remove: GroupRPC, + }, + RpcReady: { + Require: S{ss.Start}, + Remove: GroupRPC, + }, + + ClientConn: { + Require: S{RpcReady}, + Remove: GroupClientConn, + }, + ClientDisconn: { + Auto: true, + Require: S{RpcReady}, + Remove: GroupClientConn, + }, + // TODO ClientBye for graceful shutdowns +}) + +// Groups of mutually exclusive states. + +var ( + GroupRPC = S{RpcStarting, RpcReady} + GroupClientConn = S{ClientConn, ClientDisconn} +) + +// #region boilerplate defs + +// Names of all the states (pkg enum). + +const ( + // Shared + + Start = ss.Start + Ready = ss.Ready + HandshakeDone = ss.HandshakeDone + Handshaking = ss.Handshaking + + // Server + + RpcStarting = "RpcStarting" + RpcReady = "RpcReady" + + ClientConn = "ClientConn" + ClientDisconn = "ClientDisconn" + + // Errors + + ErrNetwork = ss.ErrNetwork + ErrClient = ss.ErrRpc +) + +// Names is an ordered list of all the state names. +var Names = SMerge(ss.Names, S{ + RpcStarting, + RpcReady, + + ClientConn, + ClientDisconn, +}) + +// #endregion diff --git a/pkg/rpc/states/ss_shared.go b/pkg/rpc/states/ss_shared.go new file mode 100644 index 0000000..abae63e --- /dev/null +++ b/pkg/rpc/states/ss_shared.go @@ -0,0 +1,68 @@ +package states + +import am "github.com/pancsta/asyncmachine-go/pkg/machine" + +// S is a type alias for a list of state names. +type S = am.S + +// States map defines relations and properties of states. +var States = am.Struct{ + // Errors + ErrNetwork: {Require: S{am.Exception}}, + ErrRpc: {Require: S{am.Exception}}, + ErrOnClient: {Require: S{am.Exception}}, + + Start: {}, + + // Handshake + Handshaking: { + Require: S{Start}, + Remove: GroupHandshake, + }, + HandshakeDone: { + Require: S{Start}, + Remove: GroupHandshake, + }, +} + +// Groups of mutually exclusive states. + +var ( + GroupHandshake = S{Handshaking, HandshakeDone} +) + +// #region boilerplate defs + +// Names of all the states (pkg enum). + +const ( + ErrNetwork = "ErrNetwork" + ErrRpc = "ErrRpc" + ErrOnClient = "ErrOnClient" + + Start = "Start" + Ready = "Ready" + + HandshakeDone = "HandshakeDone" + Handshaking = "Handshaking" +) + +// Names is an ordered list of all the state names. +var Names = S{ + am.Exception, + + ErrNetwork, + ErrRpc, + + // ErrOnClient indicates that the error happened on the RPC client, and not + // on the remote machine. + ErrOnClient, + + Start, + Ready, + + HandshakeDone, + Handshaking, +} + +// #endregion diff --git a/pkg/rpc/worker.go b/pkg/rpc/worker.go new file mode 100644 index 0000000..703b398 --- /dev/null +++ b/pkg/rpc/worker.go @@ -0,0 +1,1244 @@ +package rpc + +import ( + "context" + "fmt" + "maps" + "slices" + "strings" + "sync" + "sync/atomic" + + amh "github.com/pancsta/asyncmachine-go/pkg/helpers" + am "github.com/pancsta/asyncmachine-go/pkg/machine" + "github.com/pancsta/asyncmachine-go/pkg/rpc/rpcnames" + ss "github.com/pancsta/asyncmachine-go/pkg/rpc/states/client" + "github.com/pancsta/asyncmachine-go/pkg/types" +) + +// Worker is a subset of `pkg/machine#Machine` for RPC. Lacks the queue and +// other local methods. Most methods are clock-based, thus executed locally. +type Worker struct { + ID string + Ctx context.Context + // TODO remote push + Disposed atomic.Bool + + c *Client + err error + states am.Struct + clockTime am.Time + stateNames am.S + activeStatesLock sync.RWMutex + indexStateCtx am.IndexStateCtx + indexWhen am.IndexWhen + indexWhenTime am.IndexWhenTime + // TODO indexWhenArgs + indexWhenArgs am.IndexWhenArgs + whenDisposed chan struct{} +} + +// Worker implements MachineApi +var _ types.MachineApi = &Worker{} + +// ///// RPC methods + +// Sync requests fresh clock values from the remote machine. Useful to call +// after a batch of no-sync methods, eg AddNS. +func (w *Worker) Sync() am.Time { + w.c.Mach.Log("Sync") + + // call rpc + resp := &RespSync{} + if !w.c.call(w.c.Mach.Ctx, rpcnames.Sync.Encode(), w.TimeSum(nil), resp) { + return nil + } + + // validate + if len(resp.Time) > 0 && len(resp.Time) != len(w.stateNames) { + errResponseStr(w.c.Mach, "wrong clock len") + return nil + } + + // process if time is returned + if len(resp.Time) > 0 { + w.c.updateClock(nil, resp.Time) + } + + return w.clockTime +} + +// ///// Mutations (remote) + +// Add activates a list of states in the machine, returning the result of the +// transition (Executed, Queued, Canceled). +// Like every mutation method, it will resolve relations and trigger handlers. +func (w *Worker) Add(states am.S, args am.A) am.Result { + // call rpc + resp := &RespResult{} + rpcArgs := &ArgsMut{States: amh.StatesToIndexes(w, states), Args: args} + if !w.c.call(w.c.Mach.Ctx, rpcnames.Add.Encode(), rpcArgs, resp) { + return am.ResultNoOp + } + + // validate + if resp.Result == 0 { + errResponseStr(w.c.Mach, "no Result") + return am.ResultNoOp + } + + // process + w.c.updateClock(resp.Clock, nil) + + return resp.Result +} + +// Add1 is a shorthand method to add a single state with the passed args. +func (w *Worker) Add1(state string, args am.A) am.Result { + return w.Add(am.S{state}, args) +} + +// AddNS is a NoSync method - an efficient way for adding states, as it +// doesn't wait for, nor transfers a response. Because of which it doesn't +// update the clock. Use Sync() to update the clock after a batch of AddNS +// calls. +func (w *Worker) AddNS(states am.S, args am.A) am.Result { + w.c.log("AddNS") + + // call rpc + rpcArgs := &ArgsMut{States: amh.StatesToIndexes(w, states), Args: args} + if !w.c.notify(w.c.Mach.Ctx, rpcnames.AddNS.Encode(), rpcArgs) { + return am.ResultNoOp + } + + return am.Executed +} + +// Add1NS is a single state version of AddNS. +func (w *Worker) Add1NS(state string, args am.A) am.Result { + return w.AddNS(am.S{state}, args) +} + +// Remove de-activates a list of states in the machine, returning the result of +// the transition (Executed, Queued, Canceled). +// Like every mutation method, it will resolve relations and trigger handlers. +func (w *Worker) Remove(states am.S, args am.A) am.Result { + // call rpc + resp := &RespResult{} + rpcArgs := &ArgsMut{States: amh.StatesToIndexes(w, states), Args: args} + if !w.c.call(w.c.Mach.Ctx, rpcnames.Remove.Encode(), rpcArgs, resp) { + return am.ResultNoOp + } + + // validate + if resp.Result == 0 { + errResponseStr(w.c.Mach, "no Result") + return am.ResultNoOp + } + + // process + w.c.updateClock(resp.Clock, nil) + + return resp.Result +} + +// Remove1 is a shorthand method to remove a single state with the passed args. +// See Remove(). +func (w *Worker) Remove1(state string, args am.A) am.Result { + return w.Remove(am.S{state}, args) +} + +// Set de-activates a list of states in the machine, returning the result of +// the transition (Executed, Queued, Canceled). +// Like every mutation method, it will resolve relations and trigger handlers. +func (w *Worker) Set(states am.S, args am.A) am.Result { + // call rpc + resp := &RespResult{} + rpcArgs := &ArgsMut{States: amh.StatesToIndexes(w, states), Args: args} + if !w.c.call(w.c.Mach.Ctx, rpcnames.Set.Encode(), rpcArgs, resp) { + return am.ResultNoOp + } + + // validate + if resp.Result == 0 { + errResponseStr(w.c.Mach, "no Result") + return am.ResultNoOp + } + + // process + w.c.updateClock(resp.Clock, nil) + + return resp.Result +} + +// AddErr is a dedicated method to add the Exception state with the passed +// error and optional arguments. +// Like every mutation method, it will resolve relations and trigger handlers. +// AddErr produces a stack trace of the error, if LogStackTrace is enabled. +func (w *Worker) AddErr(err error, args am.A) am.Result { + return w.AddErrState(am.Exception, err, args) +} + +// AddErrState adds a dedicated error state, along with the build in Exception +// state. +// Like every mutation method, it will resolve relations and trigger handlers. +// AddErrState produces a stack trace of the error, if LogStackTrace is enabled. +func (w *Worker) AddErrState(state string, err error, args am.A) am.Result { + if w.Disposed.Load() { + return am.Canceled + } + // TODO remove once remote errors are implemented + w.err = err + + // TODO stack traces + // var trace string + // if m.LogStackTrace { + // trace = captureStackTrace() + // } + + // build args + if args == nil { + args = am.A{} + } else { + args = maps.Clone(args) + } + args["err"] = err + // args["err.trace"] = trace + + // mark errors added locally with ErrOnClient + return w.Add(am.S{ss.ErrOnClient, state, am.Exception}, args) +} + +// ///// Checking (local) + +// Is checks if all the passed states are currently active. +// +// machine.StringAll() // ()[Foo:0 Bar:0 Baz:0] +// machine.Add(S{"Foo"}) +// machine.Is(S{"Foo"}) // true +// machine.Is(S{"Foo", "Bar"}) // false +func (w *Worker) Is(states am.S) bool { + w.activeStatesLock.Lock() + defer w.activeStatesLock.Unlock() + + return w.is(states) +} + +// Is1 is a shorthand method to check if a single state is currently active. +// See Is(). +func (w *Worker) Is1(state string) bool { + return w.Is(am.S{state}) +} + +func (w *Worker) is(states am.S) bool { + w.c.clockMx.Lock() + defer w.c.clockMx.Unlock() + + for _, state := range states { + idx := slices.Index(w.stateNames, state) + if !am.IsActiveTick(w.clockTime[idx]) { + return false + } + } + + return true +} + +// IsErr checks if the machine has the Exception state currently active. +func (w *Worker) IsErr() bool { + return w.Is1(am.Exception) +} + +// Not checks if **none** of the passed states are currently active. +// +// machine.StringAll() +// // -> ()[A:0 B:0 C:0 D:0] +// machine.Add(S{"A", "B"}) +// +// // not(A) and not(C) +// machine.Not(S{"A", "C"}) +// // -> false +// +// // not(C) and not(D) +// machine.Not(S{"C", "D"}) +// // -> true +func (w *Worker) Not(states am.S) bool { + w.activeStatesLock.RLock() + defer w.activeStatesLock.RUnlock() + + return slicesNone(w.MustParseStates(states), w.activeStates()) +} + +// Not1 is a shorthand method to check if a single state is currently inactive. +// See Not(). +func (w *Worker) Not1(state string) bool { + return w.Not(am.S{state}) +} + +// Any is group call to Is, returns true if any of the params return true +// from Is. +// +// machine.StringAll() // ()[Foo:0 Bar:0 Baz:0] +// machine.Add(S{"Foo"}) +// // is(Foo, Bar) or is(Bar) +// machine.Any(S{"Foo", "Bar"}, S{"Bar"}) // false +// // is(Foo) or is(Bar) +// machine.Any(S{"Foo"}, S{"Bar"}) // true +func (w *Worker) Any(states ...am.S) bool { + for _, s := range states { + if w.Is(s) { + return true + } + } + return false +} + +// Any1 is group call to Is1(), returns true if any of the params return true +// from Is1(). +func (w *Worker) Any1(states ...string) bool { + for _, s := range states { + if w.Is1(s) { + return true + } + } + return false +} + +// Has return true is passed states are registered in the machine. +func (w *Worker) Has(states am.S) bool { + return slicesEvery(w.StateNames(), states) +} + +// Has1 is a shorthand for Has. It returns true if the passed state is +// registered in the machine. +func (w *Worker) Has1(state string) bool { + return w.Has(am.S{state}) +} + +// IsClock checks if the machine has changed since the passed +// clock. Returns true if at least one state has changed. +func (w *Worker) IsClock(clock am.Clock) bool { + w.activeStatesLock.RLock() + defer w.activeStatesLock.RUnlock() + + for state, tick := range clock { + if w.clockTime[w.Index(state)] != tick { + return false + } + } + + return true +} + +// IsTime checks if the machine has changed since the passed +// time (list of ticks). Returns true if at least one state has changed. The +// states param is optional and can be used to check only a subset of states. +func (w *Worker) IsTime(t am.Time, states am.S) bool { + w.activeStatesLock.RLock() + defer w.activeStatesLock.RUnlock() + + if states == nil { + states = w.stateNames + } + + for i, tick := range t { + if w.clockTime[w.Index(states[i])] != tick { + return false + } + } + + return true +} + +// Switch returns the first state from the passed list that is currently active, +// making it useful for switch statements. +// +// switch mach.Switch(ss.GroupPlaying...) { +// case "Playing": +// case "Paused": +// case "Stopped": +// } +func (w *Worker) Switch(states ...string) string { + activeStates := w.ActiveStates() + + for _, state := range states { + if slices.Contains(activeStates, state) { + return state + } + } + + return "" +} + +// ///// Waiting (local) + +// When returns a channel that will be closed when all the passed states +// become active or the machine gets disposed. +// +// ctx: optional context that will close the channel when done. Useful when +// listening on 2 When() channels within the same `select` to GC the 2nd one. +// TODO re-use channels with the same state set and context +func (w *Worker) When(states am.S, ctx context.Context) <-chan struct{} { + ch := make(chan struct{}) + if w.Disposed.Load() { + close(ch) + return ch + } + + w.activeStatesLock.Lock() + defer w.activeStatesLock.Unlock() + + // if all active, close early + if w.is(states) { + close(ch) + return ch + } + + setMap := am.StateIsActive{} + matched := 0 + for _, s := range states { + setMap[s] = w.is(am.S{s}) + if setMap[s] { + matched++ + } + } + + // add the binding to an index of each state + binding := &am.WhenBinding{ + Ch: ch, + Negation: false, + States: setMap, + Total: len(states), + Matched: matched, + } + + // dispose with context + disposeWithCtx(w, ctx, ch, states, binding, &w.activeStatesLock, w.indexWhen) + + // insert the binding + for _, s := range states { + if _, ok := w.indexWhen[s]; !ok { + w.indexWhen[s] = []*am.WhenBinding{binding} + } else { + w.indexWhen[s] = append(w.indexWhen[s], binding) + } + } + + return ch +} + +// When1 is an alias to When() for a single state. +// See When. +func (w *Worker) When1(state string, ctx context.Context) <-chan struct{} { + return w.When(am.S{state}, ctx) +} + +// WhenNot returns a channel that will be closed when all the passed states +// become inactive or the machine gets disposed. +// +// ctx: optional context that will close the channel when done. Useful when +// listening on 2 WhenNot() channels within the same `select` to GC the 2nd one. +func (w *Worker) WhenNot(states am.S, ctx context.Context) <-chan struct{} { + ch := make(chan struct{}) + if w.Disposed.Load() { + close(ch) + return ch + } + + w.activeStatesLock.Lock() + defer w.activeStatesLock.Unlock() + + // if all inactive, close early + if !w.is(states) { + close(ch) + return ch + } + + setMap := am.StateIsActive{} + matched := 0 + for _, s := range states { + setMap[s] = w.is(am.S{s}) + if !setMap[s] { + matched++ + } + } + + // add the binding to an index of each state + binding := &am.WhenBinding{ + Ch: ch, + Negation: true, + States: setMap, + Total: len(states), + Matched: matched, + } + + // dispose with context + disposeWithCtx(w, ctx, ch, states, binding, &w.activeStatesLock, w.indexWhen) + + // insert the binding + for _, s := range states { + if _, ok := w.indexWhen[s]; !ok { + w.indexWhen[s] = []*am.WhenBinding{binding} + } else { + w.indexWhen[s] = append(w.indexWhen[s], binding) + } + } + + return ch +} + +// WhenNot1 is an alias to WhenNot() for a single state. +// See WhenNot. +func (w *Worker) WhenNot1(state string, ctx context.Context) <-chan struct{} { + return w.WhenNot(am.S{state}, ctx) +} + +// WhenTime returns a channel that will be closed when all the passed states +// have passed the specified time. The time is a logical clock of the state. +// Machine time can be sourced from the Time() method, or Clock() for a specific +// state. +func (w *Worker) WhenTime( + states am.S, times am.Time, ctx context.Context, +) <-chan struct{} { + ch := make(chan struct{}) + valid := len(states) == len(times) + if w.Disposed.Load() || !valid { + if !valid { + // TODO local log + w.log(am.LogDecisions, "[when:time] times for all passed states required") + } + close(ch) + return ch + } + + w.activeStatesLock.Lock() + defer w.activeStatesLock.Unlock() + + // if all times passed, close early + passed := true + for i, s := range states { + if w.tick(s) < times[i] { + passed = false + break + } + } + if passed { + close(ch) + return ch + } + + completed := am.StateIsActive{} + matched := 0 + index := map[string]int{} + for i, s := range states { + completed[s] = w.tick(s) >= times[i] + if completed[s] { + matched++ + } + index[s] = i + } + + // add the binding to an index of each state + binding := &am.WhenTimeBinding{ + Ch: ch, + Index: index, + Completed: completed, + Total: len(states), + Matched: matched, + Times: times, + } + + // dispose with context + disposeWithCtx(w, ctx, ch, states, binding, &w.activeStatesLock, + w.indexWhenTime) + + // insert the binding + for _, s := range states { + if _, ok := w.indexWhenTime[s]; !ok { + w.indexWhenTime[s] = []*am.WhenTimeBinding{binding} + } else { + w.indexWhenTime[s] = append(w.indexWhenTime[s], binding) + } + } + + return ch +} + +// WhenTicks waits N ticks of a single state (relative to now). Uses WhenTime +// underneath. +func (w *Worker) WhenTicks( + state string, ticks int, ctx context.Context, +) <-chan struct{} { + return w.WhenTime(am.S{state}, am.Time{uint64(ticks) + w.Tick(state)}, ctx) +} + +// WhenTicksEq waits till ticks for a single state equal the given absolute +// value (or more). Uses WhenTime underneath. +func (w *Worker) WhenTicksEq( + state string, ticks uint64, ctx context.Context, +) <-chan struct{} { + return w.WhenTime(am.S{state}, am.Time{ticks}, ctx) +} + +// WhenErr returns a channel that will be closed when the machine is in the +// Exception state. +// +// ctx: optional context defaults to the machine's context. +func (w *Worker) WhenErr(ctx context.Context) <-chan struct{} { + return w.When([]string{am.Exception}, ctx) +} + +// ///// Waiting (remote) + +// WhenArgs returns a channel that will be closed when the passed state +// becomes active with all the passed args. Args are compared using the native +// '=='. It's meant to be used with async Multi states, to filter out +// a specific completion. +func (w *Worker) WhenArgs( + state string, args am.A, ctx context.Context, +) <-chan struct{} { + // TODO implement me + panic("implement me") +} + +// ///// Getters (remote) + +// Err returns the last error. +func (w *Worker) Err() error { + // TODO return remote errors + return w.err +} + +// ///// Getters (local) + +// StateNames returns a copy of all the state names. +func (w *Worker) StateNames() am.S { + return w.stateNames +} + +// ActiveStates returns a copy of the currently active states. +func (w *Worker) ActiveStates() am.S { + w.activeStatesLock.RLock() + defer w.activeStatesLock.RUnlock() + + return w.activeStates() +} + +func (w *Worker) activeStates() am.S { + ret := am.S{} + for _, state := range w.stateNames { + if am.IsActiveTick(w.tick(state)) { + ret = append(ret, state) + } + } + + return ret +} + +func (w *Worker) activeStatesUnlocked() am.S { + ret := am.S{} + for _, state := range w.stateNames { + if am.IsActiveTick(w.tickUnlocked(state)) { + ret = append(ret, state) + } + } + + return ret +} + +// Tick return the current tick for a given state. +func (w *Worker) Tick(state string) uint64 { + w.activeStatesLock.RLock() + defer w.activeStatesLock.RUnlock() + + return w.tick(state) +} + +func (w *Worker) tick(state string) uint64 { + w.c.clockMx.Lock() + defer w.c.clockMx.Unlock() + + return w.tickUnlocked(state) +} + +func (w *Worker) tickUnlocked(state string) uint64 { + idx := slices.Index(w.stateNames, state) + + return w.clockTime[idx] +} + +// Clock returns current machine's clock, a state-keyed map of ticks. If states +// are passed, only the ticks of the passed states are returned. +func (w *Worker) Clock(states am.S) am.Clock { + w.activeStatesLock.RLock() + defer w.activeStatesLock.RUnlock() + + return w.clock(states) +} + +func (w *Worker) clock(states am.S) am.Clock { + if states == nil { + states = w.stateNames + } + + ret := am.Clock{} + for _, state := range states { + idx := slices.Index(w.stateNames, state) + ret[state] = w.clockTime[idx] + } + + return ret +} + +// Time returns machine's time, a list of ticks per state. Returned value +// includes the specified states, or all the states if nil. +func (w *Worker) Time(states am.S) am.Time { + w.activeStatesLock.RLock() + defer w.activeStatesLock.RUnlock() + + return w.time(states) +} + +func (w *Worker) time(states am.S) am.Time { + if states == nil { + states = w.stateNames + } + + ret := am.Time{} + for _, state := range states { + idx := slices.Index(w.stateNames, state) + ret = append(ret, w.clockTime[idx]) + } + + return ret +} + +// TimeSum returns the sum of machine's time (ticks per state). +// Returned value includes the specified states, or all the states if nil. +// It's a very inaccurate, yet simple way to measure the machine's +// time. +// TODO handle overflow +func (w *Worker) TimeSum(states am.S) uint64 { + w.activeStatesLock.RLock() + defer w.activeStatesLock.RUnlock() + + if states == nil { + states = w.stateNames + } + + var sum uint64 + for _, state := range states { + idx := slices.Index(w.stateNames, state) + sum += w.clockTime[idx] + } + + return sum +} + +// NewStateCtx returns a new sub-context, bound to the current clock's tick of +// the passed state. +// +// Context cancels when the state has been de-activated, or right away, +// if it isn't currently active. +// +// State contexts are used to check state expirations and should be checked +// often inside goroutines. +// TODO reuse existing ctxs +func (w *Worker) NewStateCtx(state string) context.Context { + w.activeStatesLock.Lock() + defer w.activeStatesLock.Unlock() + + stateCtx, cancel := context.WithCancel(w.c.Mach.Ctx) + + // close early + if !w.is(am.S{state}) { + cancel() + return stateCtx + } + + // add an index + if _, ok := w.indexStateCtx[state]; !ok { + w.indexStateCtx[state] = []context.CancelFunc{cancel} + } else { + w.indexStateCtx[state] = append(w.indexStateCtx[state], cancel) + } + + return stateCtx +} + +// ///// MISC + +// Log logs an [extern] message unless LogNothing is set (default). +// Optionally redirects to a custom logger from SetLogger. +func (w *Worker) Log(msg string, args ...any) { + // call rpc + resp := &RespResult{} + rpcArgs := &ArgsLog{Msg: msg, Args: args} + if !w.c.call(w.c.Mach.Ctx, rpcnames.Log.Encode(), rpcArgs, resp) { + return + } + // TODO local log? +} + +// String returns a one line representation of the currently active states, +// with their clock values. Inactive states are omitted. +// Eg: (Foo:1 Bar:3) +func (w *Worker) String() string { + w.activeStatesLock.RLock() + defer w.activeStatesLock.RUnlock() + + active := w.activeStates() + ret := "(" + for _, state := range w.stateNames { + if !slices.Contains(active, state) { + continue + } + + if ret != "(" { + ret += " " + } + idx := slices.Index(w.stateNames, state) + ret += fmt.Sprintf("%s:%d", state, w.clockTime[idx]) + } + + return ret + ")" +} + +// StringAll returns a one line representation of all the states, with their +// clock values. Inactive states are in square brackets. +// Eg: (Foo:1 Bar:3)[Baz:2] +func (w *Worker) StringAll() string { + w.activeStatesLock.RLock() + defer w.activeStatesLock.RUnlock() + + activeStates := w.activeStates() + ret := "(" + ret2 := "[" + for _, state := range w.stateNames { + idx := slices.Index(w.stateNames, state) + + if slices.Contains(activeStates, state) { + if ret != "(" { + ret += " " + } + ret += fmt.Sprintf("%s:%d", state, w.clockTime[idx]) + continue + } + + if ret2 != "[" { + ret2 += " " + } + ret2 += fmt.Sprintf("%s:%d", state, w.clockTime[idx]) + } + + return ret + ")" + ret2 + "]" +} + +// Inspect returns a multi-line string representation of the machine (states, +// relations, clock). +// states: param for ordered or partial results. +func (w *Worker) Inspect(states am.S) string { + w.activeStatesLock.RLock() + defer w.activeStatesLock.RUnlock() + + if states == nil { + states = w.stateNames + } + + activeStates := w.activeStates() + ret := "" + for _, name := range states { + + state := w.states[name] + active := "false" + if slices.Contains(activeStates, name) { + active = "true" + } + + idx := slices.Index(w.stateNames, name) + ret += name + ":\n" + ret += fmt.Sprintf(" State: %s %d\n", active, w.clockTime[idx]) + if state.Auto { + ret += " Auto: true\n" + } + if state.Multi { + ret += " Multi: true\n" + } + if state.Add != nil { + ret += " Add: " + j(state.Add) + "\n" + } + if state.Require != nil { + ret += " Require: " + j(state.Require) + "\n" + } + if state.Remove != nil { + ret += " Remove: " + j(state.Remove) + "\n" + } + if state.After != nil { + ret += " After: " + j(state.After) + "\n" + } + ret += "\n" + } + + return ret +} + +// log forwards a log msg to the Clients machine, respecting its log level. +func (w *Worker) log(level am.LogLevel, msg string, args ...any) { + if w.c.Mach.GetLogLevel() < level { + return + } + + // TODO get log level from the remote worker + msg = "[worker] " + msg + msg = strings.ReplaceAll(msg, "] [", ":") + // TODO replace {} with [] once #101 is fixed + msg = strings.ReplaceAll(strings.ReplaceAll(msg, "[", "{"), "]", "}") + w.c.Mach.Log(msg, args...) +} + +// MustParseStates parses the states and returns them as a list. +// Panics when a state is not defined. It's an usafe equivalent of VerifyStates. +func (w *Worker) MustParseStates(states am.S) am.S { + // check if all states are defined in m.Struct + for _, s := range states { + if _, ok := w.states[s]; !ok { + panic(fmt.Sprintf("state %s is not defined", s)) + } + } + + return slicesUniq(states) +} + +func (w *Worker) processStateCtxBindings(statesBefore am.S) { + active := w.ActiveStates() + + w.activeStatesLock.RLock() + deactivated := am.DiffStates(statesBefore, active) + + var toCancel []context.CancelFunc + for _, s := range deactivated { + + toCancel = append(toCancel, w.indexStateCtx[s]...) + delete(w.indexStateCtx, s) + } + + w.activeStatesLock.RUnlock() + + // cancel all the state contexts outside the critical zone + for _, cancel := range toCancel { + cancel() + } +} + +func (w *Worker) processWhenBindings(statesBefore am.S) { + active := w.ActiveStates() + + w.activeStatesLock.Lock() + + // calculate activated and deactivated states + activated := am.DiffStates(active, statesBefore) + deactivated := am.DiffStates(statesBefore, active) + + // merge all states + all := am.S{} + all = append(all, activated...) + all = append(all, deactivated...) + + var toClose []chan struct{} + for _, s := range all { + for k, binding := range w.indexWhen[s] { + + if slices.Contains(activated, s) { + + // state activated, check the index + if !binding.Negation { + // match for When( + if !binding.States[s] { + binding.Matched++ + } + } else { + // match for WhenNot( + if !binding.States[s] { + binding.Matched-- + } + } + + // update index: mark as active + binding.States[s] = true + } else { + + // state deactivated + if !binding.Negation { + // match for When( + if binding.States[s] { + binding.Matched-- + } + } else { + // match for WhenNot( + if binding.States[s] { + binding.Matched++ + } + } + + // update index: mark as inactive + binding.States[s] = false + } + + // if not all matched, ignore for now + if binding.Matched < binding.Total { + continue + } + + // completed - close and delete indexes for all involved states + var names []string + for state := range binding.States { + names = append(names, state) + + if len(w.indexWhen[state]) == 1 { + delete(w.indexWhen, state) + continue + } + + if state == s { + w.indexWhen[s] = append(w.indexWhen[s][:k], w.indexWhen[s][k+1:]...) + continue + } + + w.indexWhen[state] = slices.Delete(w.indexWhen[state], k, k+1) + } + + w.log(am.LogDecisions, "[when] match for (%s)", j(names)) + // close outside the critical zone + toClose = append(toClose, binding.Ch) + } + } + w.activeStatesLock.Unlock() + + // notify outside the critical zone + for ch := range toClose { + closeSafe(toClose[ch]) + } +} + +func (w *Worker) processWhenTimeBindings(timeBefore am.Time) { + w.activeStatesLock.Lock() + indexWhenTime := w.indexWhenTime + var toClose []chan struct{} + + // collect all the ticked states + all := am.S{} + for idx, t := range timeBefore { + + // if changed, collect to check + if w.clockTime[idx] != t { + all = append(all, w.stateNames[idx]) + } + } + + // check all the bindings for all the ticked states + for _, s := range all { + + for k, binding := range indexWhenTime[s] { + + // check if the requested time has passed + if !binding.Completed[s] && + w.clockTime[w.Index(s)] >= binding.Times[binding.Index[s]] { + binding.Matched++ + // mark in the index as completed + binding.Completed[s] = true + } + + // if not all matched, ignore for now + if binding.Matched < binding.Total { + continue + } + + // completed - close and delete indexes for all involved states + var names []string + for state := range binding.Index { + names = append(names, state) + if len(indexWhenTime[state]) == 1 { + delete(indexWhenTime, state) + continue + } + if state == s { + indexWhenTime[s] = append(indexWhenTime[s][:k], + indexWhenTime[s][k+1:]...) + continue + } + + indexWhenTime[state] = slices.Delete(indexWhenTime[state], k, k+1) + } + + w.log(am.LogDecisions, "[when:time] match for (%s)", j(names)) + // close outside the critical zone + toClose = append(toClose, binding.Ch) + } + } + w.activeStatesLock.Unlock() + + // notify outside the critical zone + for ch := range toClose { + closeSafe(toClose[ch]) + } +} + +// Index returns the index of a state in the machine's StateNames() list. +func (w *Worker) Index(state string) int { + return slices.Index(w.stateNames, state) +} + +// Dispose disposes the machine and all its emitters. You can wait for the +// completion of the disposal with `<-mach.WhenDisposed`. +func (w *Worker) Dispose() { + if !w.Disposed.CompareAndSwap(false, true) { + return + } + closeSafe(w.whenDisposed) +} + +// WhenDisposed returns a channel that will be closed when the machine is +// disposed. Requires bound handlers. Use Machine.Disposed in case no handlers +// have been bound. +func (w *Worker) WhenDisposed() <-chan struct{} { + return w.whenDisposed +} + +// Export exports the machine state: ID, time and state names. +func (w *Worker) Export() *am.Serialized { + w.activeStatesLock.RLock() + defer w.activeStatesLock.RUnlock() + + w.log(am.LogChanges, "[import] exported at %d ticks", w.time(nil)) + + return &am.Serialized{ + ID: w.ID, + Time: w.time(nil), + StateNames: w.stateNames, + } +} + +func (w *Worker) GetStruct() am.Struct { + return w.states +} + +// ///// ///// ///// + +// ///// UTILS + +// ///// ///// ///// + +// j joins state names into a single string +func j(states []string) string { + return strings.Join(states, " ") +} + +// jw joins state names into a single string with a separator. +func jw(states []string, sep string) string { + return strings.Join(states, sep) +} + +// disposeWithCtx handles early binding disposal caused by a canceled context. +// It's used by most of "when" methods. +func disposeWithCtx[T comparable]( + mach *Worker, ctx context.Context, ch chan struct{}, states am.S, binding T, + lock *sync.RWMutex, index map[string][]T, +) { + if ctx == nil { + return + } + go func() { + select { + case <-ch: + return + case <-mach.Ctx.Done(): + return + case <-ctx.Done(): + } + // GC only if needed + if mach.Disposed.Load() { + return + } + + // TODO track + closeSafe(ch) + + lock.Lock() + defer lock.Unlock() + + for _, s := range states { + if _, ok := index[s]; ok { + if len(index[s]) == 1 { + delete(index, s) + } else { + index[s] = slicesWithout(index[s], binding) + } + } + } + }() +} + +func closeSafe[T any](ch chan T) { + select { + case <-ch: + default: + close(ch) + } +} + +func slicesWithout[S ~[]E, E comparable](coll S, el E) S { + idx := slices.Index(coll, el) + ret := slices.Clone(coll) + if idx == -1 { + return ret + } + return slices.Delete(ret, idx, idx+1) +} + +// slicesNone returns true if none of the elements of coll2 are in coll1. +func slicesNone[S1 ~[]E, S2 ~[]E, E comparable](col1 S1, col2 S2) bool { + for _, el := range col2 { + if slices.Contains(col1, el) { + return false + } + } + return true +} + +// slicesEvery returns true if all elements of coll2 are in coll1. +func slicesEvery[S1 ~[]E, S2 ~[]E, E comparable](col1 S1, col2 S2) bool { + for _, el := range col2 { + if !slices.Contains(col1, el) { + return false + } + } + return true +} + +func slicesFilter[S ~[]E, E any](coll S, fn func(item E, i int) bool) S { + var ret S + for i, el := range coll { + if fn(el, i) { + ret = append(ret, el) + } + } + return ret +} + +func slicesReverse[S ~[]E, E any](coll S) S { + ret := make(S, len(coll)) + for i := range coll { + ret[i] = coll[len(coll)-1-i] + } + return ret +} + +func slicesUniq[S ~[]E, E comparable](coll S) S { + var ret S + for _, el := range coll { + if !slices.Contains(ret, el) { + ret = append(ret, el) + } + } + return ret +}