diff --git a/cmd/zoekt-sourcegraph-indexserver/main.go b/cmd/zoekt-sourcegraph-indexserver/main.go index aff3d2c95..f720cb7a2 100644 --- a/cmd/zoekt-sourcegraph-indexserver/main.go +++ b/cmd/zoekt-sourcegraph-indexserver/main.go @@ -1482,6 +1482,7 @@ func dialGRPCClient(addr string, logger sglog.Logger, additionalOpts ...grpc.Dia grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithChainStreamInterceptor( metrics.StreamClientInterceptor(), + messagesize.StreamClientInterceptor, internalActorStreamInterceptor(), internalerrs.LoggingStreamClientInterceptor(logger), internalerrs.PrometheusStreamClientInterceptor, @@ -1489,6 +1490,7 @@ func dialGRPCClient(addr string, logger sglog.Logger, additionalOpts ...grpc.Dia ), grpc.WithChainUnaryInterceptor( metrics.UnaryClientInterceptor(), + messagesize.UnaryClientInterceptor, internalActorUnaryInterceptor(), internalerrs.LoggingUnaryClientInterceptor(logger), internalerrs.PrometheusUnaryClientInterceptor, diff --git a/cmd/zoekt-webserver/main.go b/cmd/zoekt-webserver/main.go index d406d5b3b..0181ae3f0 100644 --- a/cmd/zoekt-webserver/main.go +++ b/cmd/zoekt-webserver/main.go @@ -648,11 +648,13 @@ func newGRPCServer(logger sglog.Logger, streamer zoekt.Streamer, additionalOpts grpc.ChainStreamInterceptor( otelgrpc.StreamServerInterceptor(), metrics.StreamServerInterceptor(), + messagesize.StreamServerInterceptor, internalerrs.LoggingStreamServerInterceptor(logger), ), grpc.ChainUnaryInterceptor( otelgrpc.UnaryServerInterceptor(), metrics.UnaryServerInterceptor(), + messagesize.UnaryServerInterceptor, internalerrs.LoggingUnaryServerInterceptor(logger), ), } diff --git a/grpc/messagesize/prometheus.go b/grpc/messagesize/prometheus.go new file mode 100644 index 000000000..881f8ffb3 --- /dev/null +++ b/grpc/messagesize/prometheus.go @@ -0,0 +1,321 @@ +package messagesize + +import ( + "context" + "sync" + "sync/atomic" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/sourcegraph/zoekt/grpc/grpcutil" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" +) + +var ( + metricServerSingleMessageSize = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "grpc_server_sent_individual_message_size_bytes_per_rpc", + Help: "Size of individual messages sent by the server per RPC.", + Buckets: sizeBuckets, + }, []string{ + "grpc_service", // e.g. "gitserver.v1.GitserverService" + "grpc_method", // e.g. "Exec" + }) + + metricServerTotalSentPerRPCBytes = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "grpc_server_sent_bytes_per_rpc", + Help: "Total size of all the messages sent by the server during the course of a single RPC call", + Buckets: sizeBuckets, + }, []string{ + "grpc_service", // e.g. "gitserver.v1.GitserverService" + "grpc_method", // e.g. "Exec" + }) + + metricClientSingleMessageSize = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "grpc_client_sent_individual_message_size_per_rpc_bytes", + Help: "Size of individual messages sent by the client per RPC.", + Buckets: sizeBuckets, + }, []string{ + "grpc_service", // e.g. "gitserver.v1.GitserverService" + "grpc_method", // e.g. "Exec" + }) + + metricClientTotalSentPerRPCBytes = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "grpc_client_sent_bytes_per_rpc", + Help: "Total size of all the messages sent by the client during the course of a single RPC call", + Buckets: sizeBuckets, + }, []string{ + "grpc_service", // e.g. "gitserver.v1.GitserverService" + "grpc_method", // e.g. "Exec" + }) +) + +const ( + B = 1 + KB = 1024 * B + MB = 1024 * KB + GB = 1024 * MB +) + +var sizeBuckets = []float64{ + 0, + 1 * KB, + 10 * KB, + 50 * KB, + 100 * KB, + 500 * KB, + 1 * MB, + 5 * MB, + 10 * MB, + 50 * MB, + 100 * MB, + 500 * MB, + 1 * GB, + 5 * GB, + 10 * GB, +} + +// UnaryServerInterceptor is a grpc.UnaryServerInterceptor that records Prometheus metrics that observe the size of +// the response message sent back by the server for a single RPC call. +func UnaryServerInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + observer := newServerMessageSizeObserver(info.FullMethod) + + return unaryServerInterceptor(observer, req, ctx, info, handler) +} + +func unaryServerInterceptor(observer *messageSizeObserver, req any, ctx context.Context, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + defer observer.FinishRPC() + + r, err := handler(ctx, req) + if err != nil { + return r, err + } + + response, ok := r.(proto.Message) + if !ok { + return r, nil + } + + observer.Observe(response) + return response, nil +} + +// StreamServerInterceptor is a grpc.StreamServerInterceptor that records Prometheus metrics that observe both the sizes of the +// individual response messages and the cumulative response size of all the message sent back by the server over the course +// of a single RPC call. +func StreamServerInterceptor(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + observer := newServerMessageSizeObserver(info.FullMethod) + + return streamServerInterceptor(observer, srv, ss, info, handler) +} + +func streamServerInterceptor(observer *messageSizeObserver, srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + defer observer.FinishRPC() + + wrappedStream := newObservingServerStream(ss, observer) + + return handler(srv, wrappedStream) +} + +// UnaryClientInterceptor is a grpc.UnaryClientInterceptor that records Prometheus metrics that observe the size of +// the request message sent by client for a single RPC call. +func UnaryClientInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + o := newClientMessageSizeObserver(method) + return unaryClientInterceptor(o, ctx, method, req, reply, cc, invoker, opts...) +} + +func unaryClientInterceptor(observer *messageSizeObserver, ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + defer observer.FinishRPC() + + err := invoker(ctx, method, req, reply, cc, opts...) + if err != nil { + // Don't record the size of the message if there was an error sending it, since it may not have been sent. + return err + } + + // Observe the size of the request message. + request, ok := req.(proto.Message) + if !ok { + return nil + } + + observer.Observe(request) + return nil +} + +// StreamClientInterceptor is a grpc.StreamClientInterceptor that records Prometheus metrics that observe both the sizes of the +// individual request messages and the cumulative request size of all the message sent by the client over the course +// of a single RPC call. +func StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + observer := newClientMessageSizeObserver(method) + + return streamClientInterceptor(observer, ctx, desc, cc, method, streamer, opts...) +} + +func streamClientInterceptor(observer *messageSizeObserver, ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + s, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, err + } + + wrappedStream := newObservingClientStream(s, observer) + return wrappedStream, nil +} + +type observingServerStream struct { + grpc.ServerStream + + observer *messageSizeObserver +} + +func newObservingServerStream(s grpc.ServerStream, observer *messageSizeObserver) grpc.ServerStream { + return &observingServerStream{ + ServerStream: s, + observer: observer, + } +} + +func (s *observingServerStream) SendMsg(m any) error { + err := s.ServerStream.SendMsg(m) + if err != nil { + // Don't record the size of the message if there was an error sending it, since it may not have been sent. + // + // However, the stream aborts on an error, + // so we need to record the total size of the messages sent during the course of the RPC call. + s.observer.FinishRPC() + return err + } + + // Observe the size of the sent message. + message, ok := m.(proto.Message) + if !ok { + return nil + } + + s.observer.Observe(message) + return nil +} + +type observingClientStream struct { + grpc.ClientStream + + observer *messageSizeObserver +} + +func newObservingClientStream(s grpc.ClientStream, observer *messageSizeObserver) grpc.ClientStream { + return &observingClientStream{ + ClientStream: s, + observer: observer, + } +} + +func (s *observingClientStream) SendMsg(m any) error { + err := s.ClientStream.SendMsg(m) + if err != nil { + // Don't record the size of the message if there was an error sending it, since it may not have been sent. + // + // However, the stream aborts on an error, + // so we need to record the total size of the messages sent during the course of the RPC call. + s.observer.FinishRPC() + return err + } + + // Observe the size of the sent message. + message, ok := m.(proto.Message) + if !ok { + return nil + } + + s.observer.Observe(message) + return nil +} + +func (s *observingClientStream) CloseSend() error { + err := s.ClientStream.CloseSend() + + s.observer.FinishRPC() + return err +} + +func (s *observingClientStream) RecvMsg(m any) error { + err := s.ClientStream.RecvMsg(m) + if err != nil { + // Record the total size of the messages sent during the course of the RPC call, even if there was an error. + s.observer.FinishRPC() + } + + return err +} + +func newServerMessageSizeObserver(fullMethod string) *messageSizeObserver { + serviceName, methodName := grpcutil.SplitMethodName(fullMethod) + + onSingle := func(messageSize uint64) { + metricServerSingleMessageSize.WithLabelValues(serviceName, methodName).Observe(float64(messageSize)) + } + + onFinish := func(messageSize uint64) { + metricServerTotalSentPerRPCBytes.WithLabelValues(serviceName, methodName).Observe(float64(messageSize)) + } + + return &messageSizeObserver{ + onSingleFunc: onSingle, + onFinishFunc: onFinish, + } +} + +func newClientMessageSizeObserver(fullMethod string) *messageSizeObserver { + serviceName, methodName := grpcutil.SplitMethodName(fullMethod) + + onSingle := func(messageSize uint64) { + metricClientSingleMessageSize.WithLabelValues(serviceName, methodName).Observe(float64(messageSize)) + } + + onFinish := func(messageSize uint64) { + metricClientTotalSentPerRPCBytes.WithLabelValues(serviceName, methodName).Observe(float64(messageSize)) + } + + return &messageSizeObserver{ + onSingleFunc: onSingle, + onFinishFunc: onFinish, + } +} + +// messageSizeObserver is a utility that records Prometheus metrics that observe the size of each sent message and the +// cumulative size of all sent messages during the course of a single RPC call. +type messageSizeObserver struct { + onSingleFunc func(messageSizeBytes uint64) + + finishOnce sync.Once + onFinishFunc func(totalSizeBytes uint64) + + totalSizeBytes atomic.Uint64 +} + +// Observe records the size of a single message. +func (o *messageSizeObserver) Observe(message proto.Message) { + s := uint64(proto.Size(message)) + o.onSingleFunc(s) + + o.totalSizeBytes.Add(s) +} + +// FinishRPC records the total size of all sent messages during the course of a single RPC call. +// This function should only be called once the RPC call has completed. +func (o *messageSizeObserver) FinishRPC() { + o.finishOnce.Do(func() { + o.onFinishFunc(o.totalSizeBytes.Load()) + }) +} + +var ( + _ grpc.ServerStream = &observingServerStream{} + _ grpc.ClientStream = &observingClientStream{} +) + +var ( + _ grpc.UnaryServerInterceptor = UnaryServerInterceptor + _ grpc.StreamServerInterceptor = StreamServerInterceptor + _ grpc.UnaryClientInterceptor = UnaryClientInterceptor + _ grpc.StreamClientInterceptor = StreamClientInterceptor +) diff --git a/grpc/messagesize/prometheus_test.go b/grpc/messagesize/prometheus_test.go new file mode 100644 index 000000000..739dda598 --- /dev/null +++ b/grpc/messagesize/prometheus_test.go @@ -0,0 +1,758 @@ +package messagesize + +import ( + "bytes" + "context" + "errors" + "io" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/timestamppb" + + newspb "github.com/sourcegraph/zoekt/grpc/testprotos/news/v1" +) + +var ( + binaryMessage = &newspb.BinaryAttachment{ + Name: "data", + Data: []byte(strings.Repeat("x", 1*1024*1024)), + } + + keyValueMessage = &newspb.KeyValueAttachment{ + Name: "data", + Data: map[string]string{ + "key1": strings.Repeat("x", 1*1024*1024), + "key2": "value2", + }, + } + + articleMessage = &newspb.Article{ + Author: "author", + Date: ×tamppb.Timestamp{Seconds: 1234567890}, + Title: "title", + Content: "content", + Status: newspb.Article_STATUS_PUBLISHED, + Attachments: []*newspb.Attachment{ + {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: keyValueMessage}}, + {Contents: &newspb.Attachment_KeyValueAttachment{KeyValueAttachment: keyValueMessage}}, + {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: binaryMessage}}, + {Contents: &newspb.Attachment_BinaryAttachment{BinaryAttachment: binaryMessage}}, + }, + } +) + +func BenchmarkObserverBinary(b *testing.B) { + o := messageSizeObserver{ + onSingleFunc: func(messageSizeBytes uint64) {}, + onFinishFunc: func(totalSizeBytes uint64) {}, + } + + benchmarkObserver(b, &o, binaryMessage) +} + +func BenchmarkObserverKeyValue(b *testing.B) { + o := messageSizeObserver{ + onSingleFunc: func(messageSizeBytes uint64) {}, + onFinishFunc: func(totalSizeBytes uint64) {}, + } + + benchmarkObserver(b, &o, keyValueMessage) +} + +func BenchmarkObserverArticle(b *testing.B) { + o := messageSizeObserver{ + onSingleFunc: func(messageSizeBytes uint64) {}, + onFinishFunc: func(totalSizeBytes uint64) {}, + } + + benchmarkObserver(b, &o, articleMessage) +} + +func benchmarkObserver(b *testing.B, observer *messageSizeObserver, message proto.Message) { + b.ReportAllocs() + + for n := 0; n < b.N; n++ { + observer.Observe(message) + } + + observer.FinishRPC() +} + +func TestUnaryServerInterceptor(t *testing.T) { + ctx := context.Background() + + request := &newspb.BinaryAttachment{ + Data: bytes.Repeat([]byte("request"), 3), + } + + response := &newspb.BinaryAttachment{ + Data: bytes.Repeat([]byte("response"), 7), + } + + info := &grpc.UnaryServerInfo{ + FullMethod: "news.v1.NewsService/GetArticle", + } + + sentinelError := errors.New("expected error") + + tests := []struct { + name string + handler func(ctx context.Context, req any) (any, error) + expectedError error + expectedResult any + expectedSize uint64 + }{ + { + name: "invoker successful - observe response", + handler: func(ctx context.Context, req any) (any, error) { + return response, nil + }, + expectedError: nil, + expectedResult: response, + expectedSize: uint64(proto.Size(response)), + }, + { + name: "invoker error - observe a zero-sized response", + handler: func(ctx context.Context, req any) (any, error) { + return nil, sentinelError + }, + expectedError: sentinelError, + expectedResult: nil, + expectedSize: uint64(0), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + onFinishCalledCount := 0 + + observer := messageSizeObserver{ + onSingleFunc: func(messageSizeBytes uint64) {}, + onFinishFunc: func(totalSizeBytes uint64) { + onFinishCalledCount++ + + if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" { + t.Error("totalSizeBytes mismatch (-want +got):\n", diff) + } + }, + } + + actualResult, err := unaryServerInterceptor(&observer, request, ctx, info, test.handler) + if err != test.expectedError { + t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err) + } + + if diff := cmp.Diff(test.expectedResult, actualResult, protocmp.Transform()); diff != "" { + t.Error("response mismatch (-want +got):\n", diff) + } + + if diff := cmp.Diff(1, onFinishCalledCount); diff != "" { + t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) + } + }) + } +} + +func TestStreamServerInterceptor(t *testing.T) { + + response1 := &newspb.BinaryAttachment{ + Name: "", + Data: []byte("response"), + } + response2 := &newspb.BinaryAttachment{ + Name: "", + Data: bytes.Repeat([]byte("response"), 3), + } + response3 := &newspb.BinaryAttachment{ + Name: "", + Data: bytes.Repeat([]byte("response"), 7), + } + + info := &grpc.StreamServerInfo{ + FullMethod: "news.v1.NewsService/GetArticle", + } + + sentinelError := errors.New("expected error") + + tests := []struct { + name string + + mockSendMsg func(m any) error + handler func(srv any, stream grpc.ServerStream) error + + expectedError error + expectedResponses []any + expectedSize uint64 + }{ + { + name: "invoker successful - observe all 3 responses", + + mockSendMsg: func(m any) error { + return nil // no error + }, + + handler: func(srv any, stream grpc.ServerStream) error { + for _, r := range []proto.Message{response1, response2, response3} { + if err := stream.SendMsg(r); err != nil { + return err + } + } + + return nil + }, + + expectedError: nil, + expectedResponses: []any{response1, response2, response3}, + expectedSize: uint64(proto.Size(response1) + proto.Size(response2) + proto.Size(response3)), + }, + + { + name: "invoker fails on 3rd response - only observe first 2", + + mockSendMsg: func(m any) error { + if m == response3 { + return sentinelError + } + + return nil + }, + handler: func(srv any, stream grpc.ServerStream) error { + for _, r := range []proto.Message{response1, response2, response3} { + if err := stream.SendMsg(r); err != nil { + return err + } + } + + return nil + }, + + expectedError: sentinelError, + expectedResponses: []any{response1, response2, response3}, // response 3 should still be attempted to be sent + expectedSize: uint64(proto.Size(response1) + proto.Size(response2)), // response 3 should not be counted since an error occurred while sending it + }, + + { + name: "invoker fails immediately - should still observe a zero-sized response", + + mockSendMsg: func(m any) error { + return errors.New("should not be called") + }, + + handler: func(srv any, stream grpc.ServerStream) error { + return sentinelError + }, + + expectedError: sentinelError, + expectedResponses: []any{}, // there are no responses + expectedSize: uint64(0), // there are no responses, so the size is 0 + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + onFinishCallCount := 0 + + observer := messageSizeObserver{ + onSingleFunc: func(messageSizeBytes uint64) {}, + onFinishFunc: func(totalSizeBytes uint64) { + onFinishCallCount++ + + if totalSizeBytes != test.expectedSize { + t.Errorf("totalSizeBytes mismatch (wanted: %d, got: %d)", test.expectedSize, totalSizeBytes) + } + }, + } + + var actualResponses []any + + ss := &mockServerStream{ + mockSendMsg: func(m any) error { + actualResponses = append(actualResponses, m) + + return test.mockSendMsg(m) + }, + } + + err := streamServerInterceptor(&observer, nil, ss, info, test.handler) + if err != test.expectedError { + t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err) + } + + if diff := cmp.Diff(test.expectedResponses, actualResponses, protocmp.Transform(), cmpopts.EquateEmpty()); diff != "" { + t.Error("responses mismatch (-want +got):\n", diff) + } + + if diff := cmp.Diff(1, onFinishCallCount); diff != "" { + t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) + } + }) + } +} + +func TestUnaryClientInterceptor(t *testing.T) { + ctx := context.Background() + + request := &newspb.BinaryAttachment{ + Name: "data", + Data: bytes.Repeat([]byte("request"), 3), + } + + method := "news.v1.NewsService/GetArticle" + + sentinelError := errors.New("expected error") + + tests := []struct { + name string + invoker grpc.UnaryInvoker + + expectedError error + expectedRequest any + expectedSize uint64 + }{ + { + name: "invoker successful - observe request size", + invoker: func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + return nil + }, + + expectedError: nil, + expectedRequest: request, + expectedSize: uint64(proto.Size(request)), + }, + + { + name: "invoker error - observe a zero-sized response", + invoker: func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + return sentinelError + }, + + expectedError: sentinelError, + expectedRequest: request, + expectedSize: uint64(0), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + onFinishCallCount := 0 + + observer := messageSizeObserver{ + onSingleFunc: func(messageSizeBytes uint64) {}, + onFinishFunc: func(totalSizeBytes uint64) { + onFinishCallCount++ + + if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" { + t.Error("totalSizeBytes mismatch (-want +got):\n", diff) + } + }, + } + + var actualRequest any + + invokerCalled := false + invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + invokerCalled = true + + actualRequest = req + return test.invoker(ctx, method, req, reply, cc, opts...) + } + + err := unaryClientInterceptor(&observer, ctx, method, request, nil, nil, invoker) + if err != test.expectedError { + t.Errorf("error mismatch (wanted: %q, got: %q)", test.expectedError, err) + } + + if !invokerCalled { + t.Fatal("invoker not called") + } + + if diff := cmp.Diff(test.expectedRequest, actualRequest, protocmp.Transform()); diff != "" { + t.Error("request mismatch (-want +got):\n", diff) + } + + if diff := cmp.Diff(1, onFinishCallCount); diff != "" { + t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) + } + }) + } +} + +func TestStreamingClientInterceptor(t *testing.T) { + ctx := context.Background() + + request1 := &newspb.BinaryAttachment{ + Name: "data", + Data: bytes.Repeat([]byte("request"), 3), + } + + request2 := &newspb.BinaryAttachment{ + Name: "data", + Data: bytes.Repeat([]byte("request"), 7), + } + + request3 := &newspb.BinaryAttachment{ + Name: "data", + Data: bytes.Repeat([]byte("request"), 13), + } + + method := "news.v1.NewsService/GetArticle" + + sentinelError := errors.New("expected error") + + type stepType int + + const ( + stepSend stepType = iota + stepRecv + stepCloseSend + ) + + type step struct { + stepType stepType + + message any + streamErr error + } + + tests := []struct { + name string + + steps []step + expectedSize uint64 + }{ + { + name: "invoker successful - observe request size", + steps: []step{ + { + stepType: stepSend, + + message: request1, + streamErr: nil, + }, + { + stepType: stepSend, + + message: request2, + streamErr: nil, + }, + { + stepType: stepSend, + + message: request3, + streamErr: nil, + }, + { + stepType: stepRecv, + + message: nil, + streamErr: io.EOF, // end of stream + }, + }, + + expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)), + }, + { + name: "2nd send failed - stream aborts and should only observe first request", + steps: []step{ + { + stepType: stepSend, + message: request1, + streamErr: nil, + }, + { + stepType: stepSend, + message: request2, + streamErr: sentinelError, + }, + }, + + expectedSize: uint64(proto.Size(request1)), + }, + { + name: "recv message fails with non io.EOF error - should still observe all requests", + steps: []step{ + { + stepType: stepSend, + + message: request1, + streamErr: nil, + }, + { + stepType: stepSend, + + message: request2, + streamErr: nil, + }, + { + stepType: stepSend, + + message: request3, + streamErr: nil, + }, + { + stepType: stepRecv, + + message: nil, + streamErr: sentinelError, + }, + }, + + expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)), + }, + + { + name: "close send called - should observe all requests", + steps: []step{ + { + stepType: stepSend, + + message: request1, + streamErr: nil, + }, + { + stepType: stepSend, + + message: request2, + streamErr: nil, + }, + { + stepType: stepSend, + + message: request3, + streamErr: nil, + }, + { + stepType: stepCloseSend, + + message: nil, + streamErr: nil, + }, + }, + + expectedSize: uint64(proto.Size(request1) + proto.Size(request2) + proto.Size(request3)), + }, + { + name: "close send called immediately - should observe zero-sized response", + steps: []step{ + { + stepType: stepCloseSend, + + message: nil, + streamErr: nil, + }, + }, + + expectedSize: uint64(0), + }, + { + name: "first send fails - stream should abort and observe zero-sized response", + steps: []step{ + { + stepType: stepSend, + + message: request1, + streamErr: sentinelError, + }, + }, + + expectedSize: uint64(0), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + onFinishCallCount := 0 + + observer := messageSizeObserver{ + onSingleFunc: func(messageSizeBytes uint64) {}, + onFinishFunc: func(totalSizeBytes uint64) { + onFinishCallCount++ + + if diff := cmp.Diff(totalSizeBytes, test.expectedSize); diff != "" { + t.Error("totalSizeBytes mismatch (-want +got):\n", diff) + } + }, + } + + baseStream := &mockClientStream{} + streamerCalled := false + streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + streamerCalled = true + + return baseStream, nil + } + + ss, err := streamClientInterceptor(&observer, ctx, nil, nil, method, streamer) + require.NoError(t, err) + + // Run through all the steps, preparing the mockClientStream to return the expected errors + for _, step := range test.steps { + baseStreamCalled := false + var streamErr error + + switch step.stepType { + case stepSend: + baseStream.mockSendMsg = func(m any) error { + baseStreamCalled = true + return step.streamErr + } + + streamErr = ss.SendMsg(step.message) + case stepRecv: + baseStream.mockRecvMsg = func(_ any) error { + baseStreamCalled = true + return step.streamErr + } + + streamErr = ss.RecvMsg(step.message) + + case stepCloseSend: + baseStream.mockCloseSend = func() error { + baseStreamCalled = true + return step.streamErr + } + + streamErr = ss.CloseSend() + default: + t.Fatalf("unknown step type: %v", step.stepType) + } + + // ensure that the baseStream was called and errors are propagated + require.True(t, baseStreamCalled) + require.Equal(t, step.streamErr, streamErr) + } + + if !streamerCalled { + t.Fatal("streamer not called") + } + + if diff := cmp.Diff(1, onFinishCallCount); diff != "" { + t.Error("onFinishFunc not called expected number of times (-want +got):\n", diff) + } + }) + } +} + +func TestObserver(t *testing.T) { + testCases := []struct { + name string + messages []proto.Message + }{ + { + name: "single message", + messages: []proto.Message{&newspb.BinaryAttachment{ + Name: "data1", + Data: []byte("sample data"), + }}, + }, + { + name: "multiple messages", + messages: []proto.Message{ + &newspb.BinaryAttachment{ + Name: "data1", + Data: []byte("sample data"), + }, + &newspb.KeyValueAttachment{ + Name: "data2", + Data: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + }, + }}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var singleMessageSizes []uint64 + var totalSize uint64 + + // Create a new observer with custom onSingleFunc and onFinishFunc + obs := &messageSizeObserver{ + onSingleFunc: func(messageSizeBytes uint64) { + singleMessageSizes = append(singleMessageSizes, messageSizeBytes) + }, + onFinishFunc: func(totalSizeBytes uint64) { + totalSize = totalSizeBytes + }, + } + + // Call ObserveSingle for each message + for _, msg := range tc.messages { + obs.Observe(msg) + } + + // Check that the singleMessageSizes are correct + for i, msg := range tc.messages { + expectedSize := uint64(proto.Size(msg)) + require.Equal(t, expectedSize, singleMessageSizes[i]) + } + + // Call FinishRPC + obs.FinishRPC() + + // Check that the totalSize is correct + expectedTotalSize := uint64(0) + for _, size := range singleMessageSizes { + expectedTotalSize += size + } + require.EqualValues(t, expectedTotalSize, totalSize) + }) + } +} + +type mockServerStream struct { + mockSendMsg func(m any) error + + grpc.ServerStream +} + +func (s *mockServerStream) SendMsg(m any) error { + if s.mockSendMsg != nil { + return s.mockSendMsg(m) + } + + return errors.New("send msg not implemented") +} + +type mockClientStream struct { + mockRecvMsg func(m any) error + mockSendMsg func(m any) error + mockCloseSend func() error + + grpc.ClientStream +} + +func (s *mockClientStream) SendMsg(m any) error { + if s.mockSendMsg != nil { + return s.mockSendMsg(m) + } + + return errors.New("send msg not implemented") +} + +func (s *mockClientStream) RecvMsg(m any) error { + if s.mockRecvMsg != nil { + return s.mockRecvMsg(m) + } + + return errors.New("recv msg not implemented") +} + +func (s *mockClientStream) CloseSend() error { + if s.mockCloseSend != nil { + return s.mockCloseSend() + } + + return errors.New("close send not implemented") +} + +var _ grpc.ServerStream = &mockServerStream{} +var _ grpc.ClientStream = &mockClientStream{}