Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: isolate statements and portals #75

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ type Statement struct {
columns Columns
}

func DefaultStatementCacheFn() StatementCache {
return &DefaultStatementCache{}
}

type DefaultStatementCache struct {
statements map[string]*Statement
mu sync.RWMutex
Expand Down Expand Up @@ -63,6 +67,10 @@ type Portal struct {
formats []FormatCode
}

func DefaultPortalCacheFn() PortalCache {
return &DefaultPortalCache{}
}

type DefaultPortalCache struct {
portals map[string]*Portal
mu sync.RWMutex
Expand Down
32 changes: 19 additions & 13 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,18 @@ func newErrClientCopyFailed(desc string) error {
return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.Uncategorized), psqlerr.LevelError)
}

type Session struct {
*Server
Statements StatementCache
Portals PortalCache
}

// consumeCommands consumes incoming commands sent over the Postgres wire connection.
// Commands consumed from the connection are returned through a go channel.
// Responses for the given message type are written back to the client.
// This method keeps consuming messages until the client issues a close message
// or the connection is terminated.
func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) error {
srv.logger.Debug("ready for query... starting to consume commands")

// TODO: Include a value to identify unique connections
Expand All @@ -77,7 +83,7 @@ func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *b
}
}

func (srv *Server) consumeSingleCommand(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer, conn net.Conn) error {
func (srv *Session) consumeSingleCommand(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer, conn net.Conn) error {
t, length, err := reader.ReadTypedMsg()
if err == io.EOF {
return nil
Expand Down Expand Up @@ -141,7 +147,7 @@ func handleMessageSizeExceeded(reader *buffer.Reader, writer *buffer.Writer, exc
// message type and reader buffer containing the actual message. The type
// indecates a action executed by the client.
// https://www.postgresql.org/docs/14/protocol-message-formats.html
func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

Expand Down Expand Up @@ -236,7 +242,7 @@ func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.Cli
}
}

func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
if srv.parse == nil {
return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientSimpleQuery))
}
Expand Down Expand Up @@ -287,7 +293,7 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader,
return readyForQuery(writer, types.ServerIdle)
}

func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) handleParse(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
if srv.parse == nil || srv.Statements == nil {
return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientParse))
}
Expand Down Expand Up @@ -337,7 +343,7 @@ func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, write
return writer.End()
}

func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) handleDescribe(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
d, err := reader.GetBytes(1)
if err != nil {
return err
Expand Down Expand Up @@ -385,7 +391,7 @@ func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, wr
}

// https://www.postgresql.org/docs/15/protocol-message-formats.html
func (srv *Server) writeParameterDescription(writer *buffer.Writer, parameters []oid.Oid) error {
func (srv *Session) writeParameterDescription(writer *buffer.Writer, parameters []oid.Oid) error {
writer.Start(types.ServerParameterDescription)
writer.AddInt16(int16(len(parameters)))

Expand All @@ -400,7 +406,7 @@ func (srv *Server) writeParameterDescription(writer *buffer.Writer, parameters [
// back to the writer buffer. Information about the returned columns is written
// to the client.
// https://www.postgresql.org/docs/15/protocol-message-formats.html
func (srv *Server) writeColumnDescription(ctx context.Context, writer *buffer.Writer, formats []FormatCode, columns Columns) error {
func (srv *Session) writeColumnDescription(ctx context.Context, writer *buffer.Writer, formats []FormatCode, columns Columns) error {
if len(columns) == 0 {
writer.Start(types.ServerNoData)
return writer.End()
Expand All @@ -409,7 +415,7 @@ func (srv *Server) writeColumnDescription(ctx context.Context, writer *buffer.Wr
return columns.Define(ctx, writer, formats)
}

func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) handleBind(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
name, err := reader.GetString()
if err != nil {
return err
Expand Down Expand Up @@ -451,7 +457,7 @@ func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer
// readParameters attempts to read all incoming parameters from the given
// reader. The parameters are parsed and returned.
// https://www.postgresql.org/docs/14/protocol-message-formats.html
func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([]Parameter, error) {
func (srv *Session) readParameters(ctx context.Context, reader *buffer.Reader) ([]Parameter, error) {
// NOTE: read the total amount of parameter format length that will be send
// by the client. This can be zero to indicate that there are no parameters
// or that the parameters all use the default format (text); or one, in
Expand Down Expand Up @@ -516,7 +522,7 @@ func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([
return parameters, nil
}

func (srv *Server) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error) {
func (srv *Session) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error) {
length, err := reader.GetUint16()
if err != nil {
return nil, err
Expand All @@ -537,7 +543,7 @@ func (srv *Server) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error)
return columns, nil
}

func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) handleExecute(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
if srv.Statements == nil {
return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientExecute))
}
Expand Down Expand Up @@ -565,7 +571,7 @@ func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, wri
return nil
}

func (srv *Server) handleConnTerminate(ctx context.Context) error {
func (srv *Session) handleConnTerminate(ctx context.Context) error {
if srv.TerminateConn == nil {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion examples/session/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

func main() {
srv, err := wire.NewServer(handler, wire.Session(session))
srv, err := wire.NewServer(handler, wire.SessionMiddleware(session))
if err != nil {
panic(err)
}
Expand Down
12 changes: 6 additions & 6 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,18 @@ type OptionFn func(*Server) error

// Statements sets the statement cache used to cache statements for later use. By
// default [DefaultStatementCache] is used.
func Statements(cache StatementCache) OptionFn {
func Statements(handler func() StatementCache) OptionFn {
return func(srv *Server) error {
srv.Statements = cache
srv.Statements = handler
return nil
}
}

// Portals sets the portals cache used to cache statements for later use. By
// default [DefaultPortalCache] is used.
func Portals(cache PortalCache) OptionFn {
func Portals(handler func() PortalCache) OptionFn {
return func(srv *Server) error {
srv.Portals = cache
srv.Portals = handler
return nil
}
}
Expand Down Expand Up @@ -199,10 +199,10 @@ func ExtendTypes(fn func(*pgtype.Map)) OptionFn {
}
}

// Session sets the given session handler within the underlying server. The
// SessionMiddleware sets the given session handler within the underlying server. The
// session handler is called when a new connection is opened and authenticated
// allowing for additional metadata to be wrapped around the connection context.
func Session(fn SessionHandler) OptionFn {
func SessionMiddleware(fn SessionHandler) OptionFn {
return func(srv *Server) error {
if srv.Session == nil {
srv.Session = fn
Expand Down
6 changes: 3 additions & 3 deletions options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ func TestSessionHandler(t *testing.T) {

tests := map[string]test{
"single": {
Session(func(ctx context.Context) (context.Context, error) {
SessionMiddleware(func(ctx context.Context) (context.Context, error) {
return context.WithValue(ctx, mock, value), nil
}),
},
"nested": {
Session(func(ctx context.Context) (context.Context, error) {
SessionMiddleware(func(ctx context.Context) (context.Context, error) {
return ctx, nil
}),
Session(func(ctx context.Context) (context.Context, error) {
SessionMiddleware(func(ctx context.Context) (context.Context, error) {
return context.WithValue(ctx, mock, value), nil
}),
},
Expand Down
16 changes: 11 additions & 5 deletions wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func NewServer(parse ParseFn, options ...OptionFn) (*Server, error) {
logger: slog.Default(),
closer: make(chan struct{}),
types: pgtype.NewMap(),
Statements: &DefaultStatementCache{},
Portals: &DefaultPortalCache{},
Statements: DefaultStatementCacheFn,
Portals: DefaultPortalCacheFn,
Session: func(ctx context.Context) (context.Context, error) { return ctx, nil },
}

Expand All @@ -62,8 +62,8 @@ type Server struct {
TLSConfig *tls.Config
parse ParseFn
Session SessionHandler
Statements StatementCache
Portals PortalCache
Statements func() StatementCache
Portals func() PortalCache
CloseConn CloseFn
TerminateConn CloseFn
Version string
Expand Down Expand Up @@ -162,7 +162,13 @@ func (srv *Server) serve(ctx context.Context, conn net.Conn) error {
return err
}

return srv.consumeCommands(ctx, conn, reader, writer)
session := &Session{
Server: srv,
Statements: srv.Statements(),
Portals: srv.Portals(),
}

return session.consumeCommands(ctx, conn, reader, writer)
}

// Close gracefully closes the underlaying Postgres server.
Expand Down
Loading