From ace13001a5b5a76aeac1c4bb776a8d5e19d7ba7f Mon Sep 17 00:00:00 2001 From: Faizan Qazi Date: Wed, 18 Dec 2024 13:13:00 +0000 Subject: [PATCH] pgwire: rename ReadBuffer.GetString to GetUnsafeString Previously, it was not clear that GetString would not copy data from the ReadBuffer. This could be problematic if the object was long lived, since the entire buffer would have been kept alive. To reduce risk, this patch renames GetString to GetUnsafeString. It also adds a GetSafeString method for cases where a copy is needed. The latter is adopted inside: parseClientProvidedSessionParameters Informs: #137627 Release note: None --- pkg/sql/pgwire/auth_methods.go | 2 +- pkg/sql/pgwire/conn.go | 16 ++++++++-------- pkg/sql/pgwire/pgwirebase/encoding.go | 16 ++++++++++++++-- pkg/sql/pgwire/pre_serve_options.go | 6 ++++-- 4 files changed, 27 insertions(+), 13 deletions(-) diff --git a/pkg/sql/pgwire/auth_methods.go b/pkg/sql/pgwire/auth_methods.go index b9398b010b46..17a041aa6b4f 100644 --- a/pkg/sql/pgwire/auth_methods.go +++ b/pkg/sql/pgwire/auth_methods.go @@ -343,7 +343,7 @@ func scramAuthenticator( // SASLResponse messages contain just the SASL payload. // rb := pgwirebase.ReadBuffer{Msg: resp} - reqMethod, err := rb.GetString() + reqMethod, err := rb.GetUnsafeString() if err != nil { c.LogAuthFailed(ctx, eventpb.AuthFailReason_PRE_HOOK_ERROR, err) return err diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index e78c08cb02dc..3118e743eb3f 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -345,7 +345,7 @@ func (c *conn) handleSimpleQuery( timeReceived time.Time, unqualifiedIntSize *types.T, ) error { - query, err := buf.GetString() + query, err := buf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -492,11 +492,11 @@ func (c *conn) handleSimpleQuery( // the connection should be considered toast. func (c *conn) handleParse(ctx context.Context, nakedIntSize *types.T) error { telemetry.Inc(sqltelemetry.ParseRequestCounter) - name, err := c.readBuf.GetString() + name, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - query, err := c.readBuf.GetString() + query, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -607,7 +607,7 @@ func (c *conn) handleDescribe(ctx context.Context) error { if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - name, err := c.readBuf.GetString() + name, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -627,7 +627,7 @@ func (c *conn) handleClose(ctx context.Context) error { if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - name, err := c.readBuf.GetString() + name, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -649,11 +649,11 @@ var formatCodesAllText = []pgwirebase.FormatCode{pgwirebase.FormatText} // the connection should be considered toast. func (c *conn) handleBind(ctx context.Context) error { telemetry.Inc(sqltelemetry.BindRequestCounter) - portalName, err := c.readBuf.GetString() + portalName, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - statementName, err := c.readBuf.GetString() + statementName, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -775,7 +775,7 @@ func (c *conn) handleExecute( ctx context.Context, timeReceived time.Time, followedBySync bool, ) error { telemetry.Inc(sqltelemetry.ExecuteRequestCounter) - portalName, err := c.readBuf.GetString() + portalName, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } diff --git a/pkg/sql/pgwire/pgwirebase/encoding.go b/pkg/sql/pgwire/pgwirebase/encoding.go index ad1377fd818a..8c24d403d6da 100644 --- a/pkg/sql/pgwire/pgwirebase/encoding.go +++ b/pkg/sql/pgwire/pgwirebase/encoding.go @@ -210,8 +210,10 @@ func (b *ReadBuffer) ReadTypedMsg(rd BufferedReader) (ClientMessageType, int, er return ClientMessageType(typ), n, err } -// GetString reads a null-terminated string. -func (b *ReadBuffer) GetString() (string, error) { +// GetUnsafeString reads a null-terminated string as a reference. +// Note: The underlying buffer will be prevented from GCing, so long lived +// objects should never use this. +func (b *ReadBuffer) GetUnsafeString() (string, error) { pos := bytes.IndexByte(b.Msg, 0) if pos == -1 { return "", NewProtocolViolationErrorf("NUL terminator not found") @@ -224,6 +226,16 @@ func (b *ReadBuffer) GetString() (string, error) { return *((*string)(unsafe.Pointer(&s))), nil } +// GetSafeString reads a null-terminated string as a copy of the original data +// out. +func (b *ReadBuffer) GetSafeString() (string, error) { + s, err := b.GetUnsafeString() + if err != nil { + return "", err + } + return strings.Clone(s), nil +} + // GetPrepareType returns the buffer's contents as a PrepareType. func (b *ReadBuffer) GetPrepareType() (PrepareType, error) { v, err := b.GetBytes(1) diff --git a/pkg/sql/pgwire/pre_serve_options.go b/pkg/sql/pgwire/pre_serve_options.go index 4f9eba518e69..84624366d192 100644 --- a/pkg/sql/pgwire/pre_serve_options.go +++ b/pkg/sql/pgwire/pre_serve_options.go @@ -54,7 +54,9 @@ func parseClientProvidedSessionParameters( hasTenantSelectOption := false for { // Read a key-value pair from the client. - key, err := buf.GetString() + // Note: GetSafeString is used since the key/value will live well past the + // life of the message. + key, err := buf.GetSafeString() if err != nil { return args, pgerror.Wrap( err, pgcode.ProtocolViolation, @@ -65,7 +67,7 @@ func parseClientProvidedSessionParameters( // End of parameter list. break } - value, err := buf.GetString() + value, err := buf.GetSafeString() if err != nil { return args, pgerror.Wrapf( err, pgcode.ProtocolViolation,