From 358359b6f31aa71681452f26d24b0648fc2fa367 Mon Sep 17 00:00:00 2001 From: Nathan VanBenschoten Date: Mon, 9 Dec 2024 12:18:58 -0500 Subject: [PATCH 1/2] [DNM] rpc: enable compression even for loopback transport Don't merge, but helpful for benchmarking. --- pkg/rpc/context.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/rpc/context.go b/pkg/rpc/context.go index d54574e1d3c5..adae5c73a564 100644 --- a/pkg/rpc/context.go +++ b/pkg/rpc/context.go @@ -1397,6 +1397,10 @@ func (rpcCtx *Context) dialOptsLocal() ([]grpc.DialOption, error) { return nil, err } + if rpcCtx.rpcCompression { + dialOpts = append(dialOpts, grpc.WithDefaultCallOptions(grpc.UseCompressor((snappyCompressor{}).Name()))) + } + dialOpts = append(dialOpts, grpc.WithContextDialer( func(ctx context.Context, _ string) (net.Conn, error) { return rpcCtx.loopbackDialFn(ctx) From bb87c147845494d647007fa16af8b8a63c55a6a2 Mon Sep 17 00:00:00 2001 From: Nathan VanBenschoten Date: Thu, 19 Dec 2024 19:20:54 -0500 Subject: [PATCH 2/2] [WIP] rpc: recycle request/response memory across RPC calls This commit creates a better `grpc.PreparedMsg` (one that actually recycles memory) and uses it in both directions of the BatchStream RPC. Epic: None Release note: None --- pkg/rpc/mocks_generated_test.go | 86 ++++++++++++ pkg/rpc/stream_pool.go | 12 +- pkg/server/node.go | 8 +- pkg/util/grpcutil/BUILD.bazel | 5 + pkg/util/grpcutil/compressor_info.go | 70 ++++++++++ pkg/util/grpcutil/prepared_message.go | 182 ++++++++++++++++++++++++++ 6 files changed, 361 insertions(+), 2 deletions(-) create mode 100644 pkg/util/grpcutil/compressor_info.go create mode 100644 pkg/util/grpcutil/prepared_message.go diff --git a/pkg/rpc/mocks_generated_test.go b/pkg/rpc/mocks_generated_test.go index efa2c112a466..ea3f1f7bdc2c 100644 --- a/pkg/rpc/mocks_generated_test.go +++ b/pkg/rpc/mocks_generated_test.go @@ -13,6 +13,7 @@ import ( rpcpb "github.com/cockroachdb/cockroach/pkg/rpc/rpcpb" gomock "github.com/golang/mock/gomock" grpc "google.golang.org/grpc" + metadata "google.golang.org/grpc/metadata" ) // MockBatchStreamClient is a mock of BatchStreamClient interface. @@ -38,6 +39,49 @@ func (m *MockBatchStreamClient) EXPECT() *MockBatchStreamClientMockRecorder { return m.recorder } +// CloseSend mocks base method. +func (m *MockBatchStreamClient) CloseSend() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseSend") + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseSend indicates an expected call of CloseSend. +func (mr *MockBatchStreamClientMockRecorder) CloseSend() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseSend", reflect.TypeOf((*MockBatchStreamClient)(nil).CloseSend)) +} + +// Context mocks base method. +func (m *MockBatchStreamClient) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockBatchStreamClientMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockBatchStreamClient)(nil).Context)) +} + +// Header mocks base method. +func (m *MockBatchStreamClient) Header() (metadata.MD, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Header") + ret0, _ := ret[0].(metadata.MD) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Header indicates an expected call of Header. +func (mr *MockBatchStreamClientMockRecorder) Header() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Header", reflect.TypeOf((*MockBatchStreamClient)(nil).Header)) +} + // Recv mocks base method. func (m *MockBatchStreamClient) Recv() (*kvpb.BatchResponse, error) { m.ctrl.T.Helper() @@ -53,6 +97,20 @@ func (mr *MockBatchStreamClientMockRecorder) Recv() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recv", reflect.TypeOf((*MockBatchStreamClient)(nil).Recv)) } +// RecvMsg mocks base method. +func (m *MockBatchStreamClient) RecvMsg(arg0 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RecvMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RecvMsg indicates an expected call of RecvMsg. +func (mr *MockBatchStreamClientMockRecorder) RecvMsg(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockBatchStreamClient)(nil).RecvMsg), arg0) +} + // Send mocks base method. func (m *MockBatchStreamClient) Send(arg0 *kvpb.BatchRequest) error { m.ctrl.T.Helper() @@ -67,6 +125,34 @@ func (mr *MockBatchStreamClientMockRecorder) Send(arg0 interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockBatchStreamClient)(nil).Send), arg0) } +// SendMsg mocks base method. +func (m *MockBatchStreamClient) SendMsg(arg0 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMsg indicates an expected call of SendMsg. +func (mr *MockBatchStreamClientMockRecorder) SendMsg(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockBatchStreamClient)(nil).SendMsg), arg0) +} + +// Trailer mocks base method. +func (m *MockBatchStreamClient) Trailer() metadata.MD { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Trailer") + ret0, _ := ret[0].(metadata.MD) + return ret0 +} + +// Trailer indicates an expected call of Trailer. +func (mr *MockBatchStreamClientMockRecorder) Trailer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Trailer", reflect.TypeOf((*MockBatchStreamClient)(nil).Trailer)) +} + // MockDialbacker is a mock of Dialbacker interface. type MockDialbacker struct { ctrl *gomock.Controller diff --git a/pkg/rpc/stream_pool.go b/pkg/rpc/stream_pool.go index 20773d714f0a..3b66595e9d70 100644 --- a/pkg/rpc/stream_pool.go +++ b/pkg/rpc/stream_pool.go @@ -12,6 +12,7 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/kv/kvpb" + "github.com/cockroachdb/cockroach/pkg/util/grpcutil" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/syncutil" @@ -24,6 +25,7 @@ import ( type streamClient[Req, Resp any] interface { Send(Req) error Recv() (Resp, error) + grpc.ClientStream } // streamConstructor creates a new gRPC stream client over the provided client @@ -74,6 +76,8 @@ type pooledStream[Req, Resp any, Conn comparable] struct { reqC chan Req respC chan result[Resp] + + reqMsg *grpcutil.PreparedMsg } func newPooledStream[Req, Resp any, Conn comparable]( @@ -89,6 +93,7 @@ func newPooledStream[Req, Resp any, Conn comparable]( streamCancel: streamCancel, reqC: make(chan Req), respC: make(chan result[Resp], 1), + reqMsg: new(grpcutil.PreparedMsg), } } @@ -101,7 +106,12 @@ func (s *pooledStream[Req, Resp, Conn]) run(ctx context.Context) { func (s *pooledStream[Req, Resp, Conn]) runOnce(ctx context.Context) (loop bool) { select { case req := <-s.reqC: - err := s.stream.Send(req) + err := s.reqMsg.Encode(s.stream, req) + if err != nil { + s.respC <- result[Resp]{err: err} + return false + } + err = s.stream.SendMsg(s.reqMsg.AsGrpc()) if err != nil { // From grpc.ClientStream.SendMsg: // > On error, SendMsg aborts the stream. diff --git a/pkg/server/node.go b/pkg/server/node.go index a09c5c17ee81..f3bd295862e3 100644 --- a/pkg/server/node.go +++ b/pkg/server/node.go @@ -59,6 +59,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/admission/admissionpb" "github.com/cockroachdb/cockroach/pkg/util/buildutil" "github.com/cockroachdb/cockroach/pkg/util/envutil" + "github.com/cockroachdb/cockroach/pkg/util/grpcutil" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/limit" "github.com/cockroachdb/cockroach/pkg/util/log" @@ -1876,6 +1877,7 @@ func (n *Node) Batch(ctx context.Context, args *kvpb.BatchRequest) (*kvpb.BatchR // BatchStream implements the kvpb.InternalServer interface. func (n *Node) BatchStream(stream kvpb.Internal_BatchStreamServer) error { ctx := stream.Context() + respMsg := new(grpcutil.PreparedMsg) for { argsAlloc := new(struct { args kvpb.BatchRequest @@ -1898,7 +1900,11 @@ func (n *Node) BatchStream(stream kvpb.Internal_BatchStreamServer) error { if err != nil { return err } - err = stream.Send(br) + err = respMsg.Encode(stream, br) + if err != nil { + return err + } + err = stream.SendMsg(respMsg.AsGrpc()) if err != nil { return err } diff --git a/pkg/util/grpcutil/BUILD.bazel b/pkg/util/grpcutil/BUILD.bazel index fb9f16c6a439..7409e3ff7c5b 100644 --- a/pkg/util/grpcutil/BUILD.bazel +++ b/pkg/util/grpcutil/BUILD.bazel @@ -3,11 +3,13 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "grpcutil", srcs = [ + "compressor_info.go", "fast_metadata.go", "grpc_err_redaction.go", "grpc_log.go", "grpc_log_legacy.go", "grpc_util.go", + "prepared_message.go", ], importpath = "github.com/cockroachdb/cockroach/pkg/util/grpcutil", visibility = ["//visibility:public"], @@ -17,11 +19,14 @@ go_library( "//pkg/util/log", "//pkg/util/log/severity", "//pkg/util/netutil", + "//pkg/util/protoutil", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_errors//errbase", "@com_github_cockroachdb_redact//:redact", "@com_github_gogo_status//:status", + "@org_golang_google_grpc//:grpc", "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//encoding", "@org_golang_google_grpc//grpclog", "@org_golang_google_grpc//metadata", "@org_golang_google_grpc//status", diff --git a/pkg/util/grpcutil/compressor_info.go b/pkg/util/grpcutil/compressor_info.go new file mode 100644 index 000000000000..778d770ee442 --- /dev/null +++ b/pkg/util/grpcutil/compressor_info.go @@ -0,0 +1,70 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package grpcutil + +import ( + "context" + "unsafe" + + "google.golang.org/grpc" + "google.golang.org/grpc/encoding" +) + +// RPCInfo exports grpc.rpcInfo. +type RPCInfo struct { + FailFast bool + PreloaderInfo *CompressorInfo +} + +// CompressorInfo exports grpc.compressorInfo. +type CompressorInfo struct { + Codec encoding.Codec + Cp grpc.Compressor + Comp encoding.Compressor +} + +// From runtime/runtime2.go:eface +type eface struct { + typ, data unsafe.Pointer +} + +// RPCInfoFromContext extracts the RPCInfo from the context. +func RPCInfoFromContext(ctx context.Context) (*RPCInfo, bool) { + v := ctx.Value(grpcInfoContextKeyObj) + if v == nil { + return nil, false + } + return (*RPCInfo)((*eface)(unsafe.Pointer(&v)).data), true +} + +// grpcInfoContextKeyObj is a copy of a value with the Go type +// `grpc.rpcInfoContextKey{}`. We cannot construct an object of that type +// directly, but we can "steal" it by forcing the grpc to give it to us: +// `grpc.PreparedMsg.Encode` gives an instance of this object as parameter to +// the `Value` method of the context you give it as argument. We use a custom +// implementation of that to "steal" the argument of type `rpcInfoContextKey{}` +// given to us that way. +// +// This is the same trick that we pull with grpcIncomingKeyObj. +var grpcInfoContextKeyObj = func() interface{} { + var s fakeStream + _ = (*grpc.PreparedMsg)(nil).Encode(&s, nil) + if s.recordedKey == nil { + panic("PreparedMsg.Encode did not request a key") + } + return s.recordedKey +}() + +type fakeStream struct { + fakeContext +} + +var _ grpc.Stream = (*fakeStream)(nil) + +func (s *fakeStream) Context() context.Context { return &s.fakeContext } + +func (*fakeStream) SendMsg(interface{}) error { panic("unused") } +func (*fakeStream) RecvMsg(interface{}) error { panic("unused") } diff --git a/pkg/util/grpcutil/prepared_message.go b/pkg/util/grpcutil/prepared_message.go new file mode 100644 index 000000000000..5aad2e70fe8a --- /dev/null +++ b/pkg/util/grpcutil/prepared_message.go @@ -0,0 +1,182 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package grpcutil + +import ( + "bytes" + "encoding/binary" + "math" + "unsafe" + + "github.com/cockroachdb/cockroach/pkg/util/protoutil" + "github.com/gogo/status" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/encoding" +) + +// PreparedMsg is responsible for creating a Marshalled and Compressed object. +// +// The type is a copy of the grpc.PreparedMsg struct, but with the the ability +// to recycle memory buffers across calls to Encode. +type PreparedMsg struct { + // Fields that mirror the grpc.PreparedMsg struct. + encodedData []byte + hdr []byte + payload []byte + + // Fields for memory reuse. + encodeBuf []byte + compressBuf bytes.Buffer + hdrBuf [headerLen]byte +} + +// AsGrpc returns the PreparedMsg as a *grpc.PreparedMsg. +// +// The returned value is only valid until the next call to Encode. +func (p *PreparedMsg) AsGrpc() *grpc.PreparedMsg { + return (*grpc.PreparedMsg)(unsafe.Pointer(p)) +} + +// Encode marshals and compresses the message using the codec and compressor for +// the stream. +// +// Mirrors the logic in grpc.PreparedMsg.Encode, but with the ability to recycle +// memory buffers across calls. +func (p *PreparedMsg) Encode(s grpc.Stream, msg interface{}) error { + defer p.discardLargeBuffers() + + ctx := s.Context() + rpcInfo, ok := RPCInfoFromContext(ctx) + if !ok { + return status.Errorf(codes.Internal, "grpc: unable to get rpcInfo") + } + + // Check if the context has the relevant information to prepareMsg. + if rpcInfo.PreloaderInfo == nil { + return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo is nil") + } + if rpcInfo.PreloaderInfo.Codec == nil { + return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo.codec is nil") + } + + // Prepare the msg. + data, err := p.encode(rpcInfo.PreloaderInfo.Codec, msg) + if err != nil { + return err + } + p.encodedData = data + compData, err := p.compress(data, rpcInfo.PreloaderInfo.Cp, rpcInfo.PreloaderInfo.Comp) + if err != nil { + return err + } + p.hdr, p.payload = p.msgHeader(data, compData) + return nil +} + +// encode serializes msg and returns a buffer containing the message, or an +// error if it is too large to be transmitted by grpc. If msg is nil, it +// generates an empty message. +func (p *PreparedMsg) encode(c encoding.Codec, msg interface{}) ([]byte, error) { + if msg == nil { + return nil, nil + } + // WIP: assume that the codec wants to use the protobuf encoding. + _ = c + pb, ok := msg.(protoutil.Message) + if !ok { + return nil, status.Errorf(codes.Internal, "expected a protoutil.Message, got %T", msg) + } + size := pb.Size() + if uint(size) > math.MaxUint32 { + return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", size) + } + if cap(p.encodeBuf) < size { + p.encodeBuf = make([]byte, size) + } else { + p.encodeBuf = p.encodeBuf[:size] + } + _, err := protoutil.MarshalToSizedBuffer(pb, p.encodeBuf) + if err != nil { + return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) + } + return p.encodeBuf, nil +} + +// compress returns the input bytes compressed by compressor or cp. If both +// compressors are nil, returns nil. +func (p *PreparedMsg) compress( + in []byte, cp grpc.Compressor, compressor encoding.Compressor, +) ([]byte, error) { + if compressor == nil && cp == nil { + return nil, nil + } + if cp != nil { + return nil, status.Errorf(codes.Internal, "expected a encoding.Compressor, got %T", cp) + } + wrapErr := func(err error) error { + return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) + } + buf := &p.compressBuf + z, err := compressor.Compress(buf) + if err != nil { + return nil, wrapErr(err) + } + if _, err := z.Write(in); err != nil { + return nil, wrapErr(err) + } + if err := z.Close(); err != nil { + return nil, wrapErr(err) + } + return buf.Bytes(), nil +} + +// The format of the payload: compressed or not? +type payloadFormat uint8 + +const ( + compressionNone payloadFormat = 0 // no compression + compressionMade payloadFormat = 1 // compressed +) + +const ( + payloadLen = 1 + sizeLen = 4 + headerLen = payloadLen + sizeLen +) + +// msgHeader returns a 5-byte header for the message being transmitted and the +// payload, which is compData if non-nil or data otherwise. +func (p *PreparedMsg) msgHeader(data, compData []byte) (hdr []byte, payload []byte) { + hdr = p.hdrBuf[:] + if compData != nil { + hdr[0] = byte(compressionMade) + data = compData + } else { + hdr[0] = byte(compressionNone) + } + + // Write length of payload into buf. + binary.BigEndian.PutUint32(hdr[payloadLen:], uint32(len(data))) + return hdr, data +} + +// discardLargeBuffers resets the PreparedMsg for reuse. This prevents the +// PreparedMsg from holding onto excessively large buffers across calls to +// Encode. +func (p *PreparedMsg) discardLargeBuffers() { + const maxRecycleSize = 1 << 16 /* 64KB */ + if cap(p.encodeBuf) > maxRecycleSize { + p.encodeBuf = nil + } else { + p.encodeBuf = p.encodeBuf[:0] + } + if p.compressBuf.Cap() > maxRecycleSize { + p.compressBuf = bytes.Buffer{} + } else { + p.compressBuf.Reset() + } +}