From f7957a31dafcfc9ff23479229b5bc306956c7c25 Mon Sep 17 00:00:00 2001 From: Jeroen Rinzema Date: Sat, 24 Aug 2024 16:18:52 +0200 Subject: [PATCH] refactor: expose more internal methods within the buffered reader --- pkg/buffer/reader.go | 45 ++++++++++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/pkg/buffer/reader.go b/pkg/buffer/reader.go index 5e981d0..2444789 100644 --- a/pkg/buffer/reader.go +++ b/pkg/buffer/reader.go @@ -54,6 +54,7 @@ func NewReader(logger *slog.Logger, reader io.Reader, bufferSize int) *Reader { func (reader *Reader) reset(size int) { if reader.Msg != nil { reader.Msg = reader.Msg[len(reader.Msg):] + return } if cap(reader.Msg) >= size { @@ -68,12 +69,22 @@ func (reader *Reader) reset(size int) { reader.Msg = make([]byte, size, allocSize) } +// ReadType reads the client message type from the provided reader. +func (reader *Reader) ReadType() (types.ClientMessage, error) { + b, err := reader.Buffer.ReadByte() + if err != nil { + return 0, err + } + + return types.ClientMessage(b), nil +} + // ReadTypedMsg reads a message from the provided reader, returning its type code and body. // It returns the message type, number of bytes read, and an error if there was one. func (reader *Reader) ReadTypedMsg() (types.ClientMessage, int, error) { - b, err := reader.Buffer.ReadByte() + typed, err := reader.ReadType() if err != nil { - return 0, 0, err + return typed, 0, err } n, err := reader.ReadUntypedMsg() @@ -81,7 +92,7 @@ func (reader *Reader) ReadTypedMsg() (types.ClientMessage, int, error) { return 0, 0, err } - return types.ClientMessage(b), n, nil + return typed, n, nil } // Slurp reads the remaining @@ -107,8 +118,22 @@ func (reader *Reader) Slurp(size int) error { return nil } +// ReadMsgSize reads the length of the next message from the provided reader. +func (reader *Reader) ReadMsgSize() (int, error) { + nread, err := io.ReadFull(reader.Buffer, reader.header[:]) + if err != nil { + return nread, err + } + + size := int(binary.BigEndian.Uint32(reader.header[:])) + // size includes itself. + size -= 4 + + return size, nil +} + // ReadUntypedMsg reads a length-prefixed message. It is only used directly -// during the authentication phase of the protocol; ReadTypedMsg is used at all +// during the authentication phase of the protocol; [ReadTypedMsg] is used at all // other times. This returns the number of bytes read and an error, if there // was one. The number of bytes returned can be non-zero even with an error // (e.g. if data was read but didn't validate) so that we can more accurately @@ -117,22 +142,18 @@ func (reader *Reader) Slurp(size int) error { // If the error is related to consuming a buffer that is larger than the // maxMessageSize, the remaining bytes will be read but discarded. func (reader *Reader) ReadUntypedMsg() (int, error) { - nread, err := io.ReadFull(reader.Buffer, reader.header[:]) + size, err := reader.ReadMsgSize() if err != nil { - return nread, err + return 0, err } - size := int(binary.BigEndian.Uint32(reader.header[:])) - // size includes itself. - size -= 4 - if size > reader.MaxMessageSize || size < 0 { - return nread, NewMessageSizeExceeded(reader.MaxMessageSize, size) + return size, NewMessageSizeExceeded(reader.MaxMessageSize, size) } reader.reset(size) n, err := io.ReadFull(reader.Buffer, reader.Msg) - return nread + n, err + return n, err } // GetString reads a null-terminated string.