diff --git a/spop/agent.go b/spop/agent.go index c3770a5..5e6bdac 100644 --- a/spop/agent.go +++ b/spop/agent.go @@ -1,6 +1,7 @@ package spop import ( + "context" "fmt" "log" "net" @@ -9,6 +10,7 @@ import ( type Agent struct { Addr string Handler Handler + Context context.Context } func ListenAndServe(addr string, handler Handler) error { @@ -28,6 +30,9 @@ func (a *Agent) ListenAndServe() error { func (a *Agent) Serve(l net.Listener) error { a.Addr = l.Addr().String() + if a.Context == nil { + a.Context = context.Background() + } for { nc, err := l.Accept() @@ -35,7 +40,7 @@ func (a *Agent) Serve(l net.Listener) error { return fmt.Errorf("accepting conn: %w", err) } - p := NewProtocolClient(nc, a.Handler) + p := newProtocolClient(a.Context, nc, a.Handler) go func() { defer nc.Close() defer p.Close() diff --git a/spop/protocol.go b/spop/protocol.go index c72e67d..19acd97 100644 --- a/spop/protocol.go +++ b/spop/protocol.go @@ -14,10 +14,10 @@ import ( type asyncScheduler struct { // TODO: replace with a circular blocking queue q *queue.Blocking[*frame] - pc *ProtocolClient + pc *protocolClient } -func newAsyncScheduler(pc *ProtocolClient) *asyncScheduler { +func newAsyncScheduler(pc *protocolClient) *asyncScheduler { a := asyncScheduler{ q: queue.NewBlocking[*frame](nil, queue.WithCapacity(runtime.NumCPU()*2)), pc: pc, @@ -44,49 +44,39 @@ func (a *asyncScheduler) schedule(f *frame) { a.q.OfferWait(f) } -// ProtocolClientOption is not used right now, but we want to be able to -// expand the capabilities without breaking the api -type ProtocolClientOption interface { - apply() +func newProtocolClient(ctx context.Context, rw io.ReadWriter, handler Handler) *protocolClient { + var c protocolClient + c.RW = rw + c.Handler = handler + c.Context, c.ctxCancel = context.WithCancel(ctx) + c.as = newAsyncScheduler(&c) + return &c } -func NewProtocolClient(rw io.ReadWriter, handler Handler, opts ...ProtocolClientOption) *ProtocolClient { - ctx, cancel := context.WithCancel(context.Background()) - pc := &ProtocolClient{ - rw: rw, - handler: handler, - ctx: ctx, - ctxCancel: cancel, - } - pc.as = newAsyncScheduler(pc) - - return pc -} +type protocolClient struct { + RW io.ReadWriter + Handler Handler + Context context.Context -type ProtocolClient struct { - rw io.ReadWriter - handler Handler - ctx context.Context ctxCancel context.CancelFunc - - as *asyncScheduler + as *asyncScheduler gotHello bool maxFrameSize uint32 engineID string } -func (c *ProtocolClient) Close() error { +func (c *protocolClient) Close() error { errDisconnect := (&AgentDisconnectFrame{ ErrCode: ErrorUnknown, - }).Write(c.rw) + }).Write(c.RW) c.ctxCancel() - return errors.Join(errDisconnect, c.ctx.Err()) + return errors.Join(errDisconnect, c.Context.Err()) } -func (c *ProtocolClient) frameHandler(f *frame) error { +func (c *protocolClient) frameHandler(f *frame) error { defer releaseFrame(f) switch f.frameType { @@ -101,10 +91,10 @@ func (c *ProtocolClient) frameHandler(f *frame) error { } } -func (c *ProtocolClient) Serve() error { +func (c *protocolClient) Serve() error { for { f := acquireFrame() - if _, err := f.ReadFrom(c.rw); err != nil { + if _, err := f.ReadFrom(c.RW); err != nil { if errors.Is(err, io.EOF) { return nil } @@ -124,7 +114,7 @@ const ( maxFrameSize = 64<<10 - 1 ) -func (c *ProtocolClient) onHAProxyHello(f *frame) error { +func (c *protocolClient) onHAProxyHello(f *frame) error { if c.gotHello { panic("duplicate hello frame") } @@ -160,10 +150,10 @@ func (c *ProtocolClient) onHAProxyHello(f *frame) error { Version: version, MaxFrameSize: c.maxFrameSize, Capabilities: []string{capabilityNamePipelining, capabilityNameAsync}, - }).Write(c.rw) + }).Write(c.RW) } -func (c *ProtocolClient) runHandler(ctx context.Context, w *encoding.ActionWriter, m *encoding.Message, handler HandlerFunc) { +func (c *protocolClient) runHandler(ctx context.Context, w *encoding.ActionWriter, m *encoding.Message, handler HandlerFunc) { didPanic := true defer func() { if didPanic { @@ -180,7 +170,7 @@ func (c *ProtocolClient) runHandler(ctx context.Context, w *encoding.ActionWrite didPanic = false } -func (c *ProtocolClient) onNotify(f *frame) error { +func (c *protocolClient) onNotify(f *frame) error { s := encoding.AcquireMessageScanner(f.buf.ReadBytes()) defer encoding.ReleaseMessageScanner(s) @@ -189,7 +179,7 @@ func (c *ProtocolClient) onNotify(f *frame) error { fn := func(w *encoding.ActionWriter) error { for s.Next(m) { - c.runHandler(c.ctx, w, m, c.handler.HandleSPOE) + c.runHandler(c.Context, w, m, c.Handler.HandleSPOE) if err := m.KV.Discard(); err != nil { return err @@ -203,10 +193,10 @@ func (c *ProtocolClient) onNotify(f *frame) error { FrameID: f.meta.FrameID, StreamID: f.meta.StreamID, ActionWriterCallback: fn, - }).Write(c.rw) + }).Write(c.RW) } -func (c *ProtocolClient) onHAProxyDisconnect(f *frame) error { +func (c *protocolClient) onHAProxyDisconnect(f *frame) error { //TODO: read disconnect reason and return error if required? return nil } diff --git a/spop/server_test.go b/spop/server_test.go index b79c16d..f4def2e 100644 --- a/spop/server_test.go +++ b/spop/server_test.go @@ -46,7 +46,7 @@ func TestFakeCon(t *testing.T) { cancel() }) - pc := NewProtocolClient(pipeConn, handler) + pc := newProtocolClient(context.Background(), pipeConn, handler) defer pc.Close() defer pipe.Close() go pc.Serve()