From d97a4c4391b0f911f2340d32edfa4bcaa56ef342 Mon Sep 17 00:00:00 2001 From: Faizan Qazi Date: Wed, 18 Dec 2024 13:13:00 +0000 Subject: [PATCH] pgwire: rename ReadBuffer.GetString to GetUnsafeString Previously, it was not clear that GetString would not copy data from the ReadBuffer. This could be problematic if the object was long lived, since the entire buffer would have been kept alive. To reduce risk, this patch renames GetString to GetUnsafeString. It also adds a GetSafeString method for cases where a copy is needed. The latter is adopted inside: parseClientProvidedSessionParameters Informs: #137627 Release note: None --- pkg/sql/pgwire/auth_methods.go | 2 +- pkg/sql/pgwire/conn.go | 16 ++++++++-------- pkg/sql/pgwire/pgwirebase/encoding.go | 16 ++++++++++++++-- pkg/sql/pgwire/pre_serve_options.go | 6 ++++-- 4 files changed, 27 insertions(+), 13 deletions(-) diff --git a/pkg/sql/pgwire/auth_methods.go b/pkg/sql/pgwire/auth_methods.go index 90c3a3bff8ee..6729e6b304bc 100644 --- a/pkg/sql/pgwire/auth_methods.go +++ b/pkg/sql/pgwire/auth_methods.go @@ -345,7 +345,7 @@ func scramAuthenticator( // SASLResponse messages contain just the SASL payload. // rb := pgwirebase.ReadBuffer{Msg: resp} - reqMethod, err := rb.GetString() + reqMethod, err := rb.GetUnsafeString() if err != nil { c.LogAuthFailed(ctx, eventpb.AuthFailReason_PRE_HOOK_ERROR, err) return err diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index 676e03af4e06..1b1aa49a3a94 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -345,7 +345,7 @@ func (c *conn) handleSimpleQuery( timeReceived time.Time, unqualifiedIntSize *types.T, ) error { - query, err := buf.GetString() + query, err := buf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -492,11 +492,11 @@ func (c *conn) handleSimpleQuery( // the connection should be considered toast. func (c *conn) handleParse(ctx context.Context, nakedIntSize *types.T) error { telemetry.Inc(sqltelemetry.ParseRequestCounter) - name, err := c.readBuf.GetString() + name, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - query, err := c.readBuf.GetString() + query, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -607,7 +607,7 @@ func (c *conn) handleDescribe(ctx context.Context) error { if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - name, err := c.readBuf.GetString() + name, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -627,7 +627,7 @@ func (c *conn) handleClose(ctx context.Context) error { if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - name, err := c.readBuf.GetString() + name, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -649,11 +649,11 @@ var formatCodesAllText = []pgwirebase.FormatCode{pgwirebase.FormatText} // the connection should be considered toast. func (c *conn) handleBind(ctx context.Context) error { telemetry.Inc(sqltelemetry.BindRequestCounter) - portalName, err := c.readBuf.GetString() + portalName, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } - statementName, err := c.readBuf.GetString() + statementName, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } @@ -775,7 +775,7 @@ func (c *conn) handleExecute( ctx context.Context, timeReceived time.Time, followedBySync bool, ) error { telemetry.Inc(sqltelemetry.ExecuteRequestCounter) - portalName, err := c.readBuf.GetString() + portalName, err := c.readBuf.GetUnsafeString() if err != nil { return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) } diff --git a/pkg/sql/pgwire/pgwirebase/encoding.go b/pkg/sql/pgwire/pgwirebase/encoding.go index 70684525c6e3..40237a6730c4 100644 --- a/pkg/sql/pgwire/pgwirebase/encoding.go +++ b/pkg/sql/pgwire/pgwirebase/encoding.go @@ -213,8 +213,10 @@ func (b *ReadBuffer) ReadTypedMsg(rd BufferedReader) (ClientMessageType, int, er return ClientMessageType(typ), n, err } -// GetString reads a null-terminated string. -func (b *ReadBuffer) GetString() (string, error) { +// GetUnsafeString reads a null-terminated string as a reference. +// Note: The underlying buffer will be prevented from GCing, so long lived +// objects should never use this. +func (b *ReadBuffer) GetUnsafeString() (string, error) { pos := bytes.IndexByte(b.Msg, 0) if pos == -1 { return "", NewProtocolViolationErrorf("NUL terminator not found") @@ -226,6 +228,16 @@ func (b *ReadBuffer) GetString() (string, error) { return s, nil } +// GetSafeString reads a null-terminated string as a copy of the original data +// out. +func (b *ReadBuffer) GetSafeString() (string, error) { + s, err := b.GetUnsafeString() + if err != nil { + return "", err + } + return strings.Clone(s), nil +} + // GetPrepareType returns the buffer's contents as a PrepareType. func (b *ReadBuffer) GetPrepareType() (PrepareType, error) { v, err := b.GetBytes(1) diff --git a/pkg/sql/pgwire/pre_serve_options.go b/pkg/sql/pgwire/pre_serve_options.go index e963d4cdc3f7..1e700d183f45 100644 --- a/pkg/sql/pgwire/pre_serve_options.go +++ b/pkg/sql/pgwire/pre_serve_options.go @@ -54,7 +54,9 @@ func parseClientProvidedSessionParameters( hasTenantSelectOption := false for { // Read a key-value pair from the client. - key, err := buf.GetString() + // Note: GetSafeString is used since the key/value will live well past the + // life of the message. + key, err := buf.GetSafeString() if err != nil { return args, pgerror.Wrap( err, pgcode.ProtocolViolation, @@ -65,7 +67,7 @@ func parseClientProvidedSessionParameters( // End of parameter list. break } - value, err := buf.GetString() + value, err := buf.GetSafeString() if err != nil { return args, pgerror.Wrapf( err, pgcode.ProtocolViolation,