diff --git a/pkg/sql/pgwire/auth_methods.go b/pkg/sql/pgwire/auth_methods.go index 90c3a3bff8ee..6729e6b304bc 100644 --- a/pkg/sql/pgwire/auth_methods.go +++ b/pkg/sql/pgwire/auth_methods.go @@ -345,7 +345,7 @@ func scramAuthenticator( // SASLResponse messages contain just the SASL payload. // rb := pgwirebase.ReadBuffer{Msg: resp} - reqMethod, err := rb.GetString() + reqMethod, err := rb.GetUnsafeString() if err != nil { c.LogAuthFailed(ctx, eventpb.AuthFailReason_PRE_HOOK_ERROR, err) return err diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index 676e03af4e06..1b1aa49a3a94 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -345,7 +345,7 @@ func (c *conn) handleSimpleQuery( timeReceived time.Time, unqualifiedIntSize *types.T, ) error { - query, err := buf.GetString() + query, err := buf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -492,11 +492,11 @@ func (c *conn) handleSimpleQuery( // the connection should be considered toast. func (c *conn) handleParse(ctx context.Context, nakedIntSize *types.T) error { telemetry.Inc(sqltelemetry.ParseRequestCounter) - name, err := c.readBuf.GetString() + name, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - query, err := c.readBuf.GetString() + query, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -607,7 +607,7 @@ func (c *conn) handleDescribe(ctx context.Context) error { if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - name, err := c.readBuf.GetString() + name, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -627,7 +627,7 @@ func (c *conn) handleClose(ctx context.Context) error { if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - name, err := c.readBuf.GetString() + name, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -649,11 +649,11 @@ var formatCodesAllText = []pgwirebase.FormatCode{pgwirebase.FormatText} // the connection should be considered toast. func (c *conn) handleBind(ctx context.Context) error { telemetry.Inc(sqltelemetry.BindRequestCounter) - portalName, err := c.readBuf.GetString() + portalName, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - statementName, err := c.readBuf.GetString() + statementName, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -775,7 +775,7 @@ func (c *conn) handleExecute( ctx context.Context, timeReceived time.Time, followedBySync bool, ) error { telemetry.Inc(sqltelemetry.ExecuteRequestCounter) - portalName, err := c.readBuf.GetString() + portalName, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } diff --git a/pkg/sql/pgwire/pgwirebase/encoding.go b/pkg/sql/pgwire/pgwirebase/encoding.go index 70684525c6e3..40237a6730c4 100644 --- a/pkg/sql/pgwire/pgwirebase/encoding.go +++ b/pkg/sql/pgwire/pgwirebase/encoding.go @@ -213,8 +213,10 @@ func (b *ReadBuffer) ReadTypedMsg(rd BufferedReader) (ClientMessageType, int, er return ClientMessageType(typ), n, err } -// GetString reads a null-terminated string. -func (b *ReadBuffer) GetString() (string, error) { +// GetUnsafeString reads a null-terminated string as a reference. +// Note: The underlying buffer will be prevented from GCing, so long lived +// objects should never use this. +func (b *ReadBuffer) GetUnsafeString() (string, error) { pos := bytes.IndexByte(b.Msg, 0) if pos == -1 { return "", NewProtocolViolationErrorf("NUL terminator not found") @@ -226,6 +228,16 @@ func (b *ReadBuffer) GetString() (string, error) { return s, nil } +// GetSafeString reads a null-terminated string as a copy of the original data +// out. +func (b *ReadBuffer) GetSafeString() (string, error) { + s, err := b.GetUnsafeString() + if err != nil { + return "", err + } + return strings.Clone(s), nil +} + // GetPrepareType returns the buffer's contents as a PrepareType. func (b *ReadBuffer) GetPrepareType() (PrepareType, error) { v, err := b.GetBytes(1) diff --git a/pkg/sql/pgwire/pre_serve_options.go b/pkg/sql/pgwire/pre_serve_options.go index e963d4cdc3f7..1e700d183f45 100644 --- a/pkg/sql/pgwire/pre_serve_options.go +++ b/pkg/sql/pgwire/pre_serve_options.go @@ -54,7 +54,9 @@ func parseClientProvidedSessionParameters( hasTenantSelectOption := false for { // Read a key-value pair from the client. - key, err := buf.GetString() + // Note: GetSafeString is used since the key/value will live well past the + // life of the message. + key, err := buf.GetSafeString() if err != nil { return args, pgerror.Wrap( err, pgcode.ProtocolViolation, @@ -65,7 +67,7 @@ func parseClientProvidedSessionParameters( // End of parameter list. break } - value, err := buf.GetString() + value, err := buf.GetSafeString() if err != nil { return args, pgerror.Wrapf( err, pgcode.ProtocolViolation,