Skip to content

Commit

Permalink
pgwire: rename ReadBuffer.GetString to GetUnsafeString
Browse files Browse the repository at this point in the history
Previously, it was not clear that GetString would not copy data from the
ReadBuffer. This could be problematic if the object was long lived,
since the entire buffer would have been kept alive. To reduce risk, this
patch renames GetString to GetUnsafeString. It also adds a GetSafeString
method for cases where a copy is needed. The latter is adopted inside:
parseClientProvidedSessionParameters

Informs: #137627
Release note: None
  • Loading branch information
fqazi committed Dec 18, 2024
1 parent eaf1601 commit ace1300
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pkg/sql/pgwire/auth_methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func scramAuthenticator(
// SASLResponse messages contain just the SASL payload.
//
rb := pgwirebase.ReadBuffer{Msg: resp}
reqMethod, err := rb.GetString()
reqMethod, err := rb.GetUnsafeString()
if err != nil {
c.LogAuthFailed(ctx, eventpb.AuthFailReason_PRE_HOOK_ERROR, err)
return err
Expand Down
16 changes: 8 additions & 8 deletions pkg/sql/pgwire/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ func (c *conn) handleSimpleQuery(
timeReceived time.Time,
unqualifiedIntSize *types.T,
) error {
query, err := buf.GetString()
query, err := buf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand Down Expand Up @@ -492,11 +492,11 @@ func (c *conn) handleSimpleQuery(
// the connection should be considered toast.
func (c *conn) handleParse(ctx context.Context, nakedIntSize *types.T) error {
telemetry.Inc(sqltelemetry.ParseRequestCounter)
name, err := c.readBuf.GetString()
name, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
query, err := c.readBuf.GetString()
query, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand Down Expand Up @@ -607,7 +607,7 @@ func (c *conn) handleDescribe(ctx context.Context) error {
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
name, err := c.readBuf.GetString()
name, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -627,7 +627,7 @@ func (c *conn) handleClose(ctx context.Context) error {
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
name, err := c.readBuf.GetString()
name, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand All @@ -649,11 +649,11 @@ var formatCodesAllText = []pgwirebase.FormatCode{pgwirebase.FormatText}
// the connection should be considered toast.
func (c *conn) handleBind(ctx context.Context) error {
telemetry.Inc(sqltelemetry.BindRequestCounter)
portalName, err := c.readBuf.GetString()
portalName, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
statementName, err := c.readBuf.GetString()
statementName, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand Down Expand Up @@ -775,7 +775,7 @@ func (c *conn) handleExecute(
ctx context.Context, timeReceived time.Time, followedBySync bool,
) error {
telemetry.Inc(sqltelemetry.ExecuteRequestCounter)
portalName, err := c.readBuf.GetString()
portalName, err := c.readBuf.GetUnsafeString()
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
Expand Down
16 changes: 14 additions & 2 deletions pkg/sql/pgwire/pgwirebase/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,10 @@ func (b *ReadBuffer) ReadTypedMsg(rd BufferedReader) (ClientMessageType, int, er
return ClientMessageType(typ), n, err
}

// GetString reads a null-terminated string.
func (b *ReadBuffer) GetString() (string, error) {
// GetUnsafeString reads a null-terminated string as a reference.
// Note: The underlying buffer will be prevented from GCing, so long lived
// objects should never use this.
func (b *ReadBuffer) GetUnsafeString() (string, error) {
pos := bytes.IndexByte(b.Msg, 0)
if pos == -1 {
return "", NewProtocolViolationErrorf("NUL terminator not found")
Expand All @@ -224,6 +226,16 @@ func (b *ReadBuffer) GetString() (string, error) {
return *((*string)(unsafe.Pointer(&s))), nil
}

// GetSafeString reads a null-terminated string as a copy of the original data
// out.
func (b *ReadBuffer) GetSafeString() (string, error) {
s, err := b.GetUnsafeString()
if err != nil {
return "", err
}
return strings.Clone(s), nil
}

// GetPrepareType returns the buffer's contents as a PrepareType.
func (b *ReadBuffer) GetPrepareType() (PrepareType, error) {
v, err := b.GetBytes(1)
Expand Down
6 changes: 4 additions & 2 deletions pkg/sql/pgwire/pre_serve_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ func parseClientProvidedSessionParameters(
hasTenantSelectOption := false
for {
// Read a key-value pair from the client.
key, err := buf.GetString()
// Note: GetSafeString is used since the key/value will live well past the
// life of the message.
key, err := buf.GetSafeString()
if err != nil {
return args, pgerror.Wrap(
err, pgcode.ProtocolViolation,
Expand All @@ -65,7 +67,7 @@ func parseClientProvidedSessionParameters(
// End of parameter list.
break
}
value, err := buf.GetString()
value, err := buf.GetSafeString()
if err != nil {
return args, pgerror.Wrapf(
err, pgcode.ProtocolViolation,
Expand Down

0 comments on commit ace1300

Please sign in to comment.