Skip to content

Commit

Permalink
Merge pull request #137770 from cockroachdb/blathers/backport-release…
Browse files Browse the repository at this point in the history
…-24.1-137682

release-24.1: pgwire: rename ReadBuffer.GetString to GetUnsafeString
  • Loading branch information
fqazi authored Dec 19, 2024
2 parents eaf1601 + ace1300 commit 615d2c0
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pkg/sql/pgwire/auth_methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func scramAuthenticator(
// SASLResponse messages contain just the SASL payload.
//
rb := pgwirebase.ReadBuffer{Msg: resp}
reqMethod, err := rb.GetString()
reqMethod, err := rb.GetUnsafeString()
if err != nil {
c.LogAuthFailed(ctx, eventpb.AuthFailReason_PRE_HOOK_ERROR, err)
return err
Expand Down
16 changes: 8 additions & 8 deletions pkg/sql/pgwire/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ func (c *conn) handleSimpleQuery(
timeReceived time.Time,
unqualifiedIntSize *types.T,
) error {
query, err := buf.GetString()
query, err := buf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand Down Expand Up @@ -492,11 +492,11 @@ func (c *conn) handleSimpleQuery(
// the connection should be considered toast.
func (c *conn) handleParse(ctx context.Context, nakedIntSize *types.T) error {
telemetry.Inc(sqltelemetry.ParseRequestCounter)
name, err := c.readBuf.GetString()
name, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
query, err := c.readBuf.GetString()
query, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand Down Expand Up @@ -607,7 +607,7 @@ func (c *conn) handleDescribe(ctx context.Context) error {
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
name, err := c.readBuf.GetString()
name, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -627,7 +627,7 @@ func (c *conn) handleClose(ctx context.Context) error {
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
name, err := c.readBuf.GetString()
name, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -649,11 +649,11 @@ var formatCodesAllText = []pgwirebase.FormatCode{pgwirebase.FormatText}
// the connection should be considered toast.
func (c *conn) handleBind(ctx context.Context) error {
telemetry.Inc(sqltelemetry.BindRequestCounter)
portalName, err := c.readBuf.GetString()
portalName, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
statementName, err := c.readBuf.GetString()
statementName, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand Down Expand Up @@ -775,7 +775,7 @@ func (c *conn) handleExecute(
ctx context.Context, timeReceived time.Time, followedBySync bool,
) error {
telemetry.Inc(sqltelemetry.ExecuteRequestCounter)
portalName, err := c.readBuf.GetString()
portalName, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand Down
16 changes: 14 additions & 2 deletions pkg/sql/pgwire/pgwirebase/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,10 @@ func (b *ReadBuffer) ReadTypedMsg(rd BufferedReader) (ClientMessageType, int, er
return ClientMessageType(typ), n, err
}

// GetString reads a null-terminated string.
func (b *ReadBuffer) GetString() (string, error) {
// GetUnsafeString reads a null-terminated string as a reference.
// Note: The underlying buffer will be prevented from GCing, so long lived
// objects should never use this.
func (b *ReadBuffer) GetUnsafeString() (string, error) {
pos := bytes.IndexByte(b.Msg, 0)
if pos == -1 {
return "", NewProtocolViolationErrorf("NUL terminator not found")
Expand All @@ -224,6 +226,16 @@ func (b *ReadBuffer) GetString() (string, error) {
return *((*string)(unsafe.Pointer(&s))), nil
}

// GetSafeString reads a null-terminated string as a copy of the original data
// out.
func (b *ReadBuffer) GetSafeString() (string, error) {
s, err := b.GetUnsafeString()
if err != nil {
return "", err
}
return strings.Clone(s), nil
}

// GetPrepareType returns the buffer's contents as a PrepareType.
func (b *ReadBuffer) GetPrepareType() (PrepareType, error) {
v, err := b.GetBytes(1)
Expand Down
6 changes: 4 additions & 2 deletions pkg/sql/pgwire/pre_serve_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ func parseClientProvidedSessionParameters(
hasTenantSelectOption := false
for {
// Read a key-value pair from the client.
key, err := buf.GetString()
// Note: GetSafeString is used since the key/value will live well past the
// life of the message.
key, err := buf.GetSafeString()
if err != nil {
return args, pgerror.Wrap(
err, pgcode.ProtocolViolation,
Expand All @@ -65,7 +67,7 @@ func parseClientProvidedSessionParameters(
// End of parameter list.
break
}
value, err := buf.GetString()
value, err := buf.GetSafeString()
if err != nil {
return args, pgerror.Wrapf(
err, pgcode.ProtocolViolation,
Expand Down

0 comments on commit 615d2c0

Please sign in to comment.