From 28db9fc297b7a789afba442ede9d7bbd865ba642 Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Fri, 25 Oct 2024 00:01:53 +0200 Subject: [PATCH 1/3] add cause to canceled context Closes #7541  Conflicts:  internal/transport/transport.go --- internal/transport/context.go | 61 ++++++++++++++++++++++ internal/transport/context_test.go | 76 ++++++++++++++++++++++++++++ internal/transport/handler_server.go | 11 ++-- internal/transport/http2_server.go | 24 ++++----- 4 files changed, 151 insertions(+), 21 deletions(-) create mode 100644 internal/transport/context.go create mode 100644 internal/transport/context_test.go diff --git a/internal/transport/context.go b/internal/transport/context.go new file mode 100644 index 000000000000..907e8fe8689d --- /dev/null +++ b/internal/transport/context.go @@ -0,0 +1,61 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "context" + "errors" + "time" + + "golang.org/x/net/http2" + "google.golang.org/grpc/status" +) + +var ErrGrpcTimeout = errors.New("grpc-timeout") +var ErrRequestDone = errors.New("request is done processing") +var ErrServerTransportClosed = errors.New("server transport closed") +var ErrUnreachable = errors.New("unreachable") + +type RstCodeError struct { + RstCode http2.ErrCode +} + +func (e RstCodeError) Error() string { + return e.RstCode.String() +} + +type StatusError struct { + Status *status.Status +} + +func (e StatusError) Error() string { + return e.Status.String() +} + +func createContext(ctx context.Context, timeoutSet bool, timeout time.Duration) (context.Context, context.CancelCauseFunc) { + var timoutCancel context.CancelFunc = nil + if timeoutSet { + ctx, timoutCancel = context.WithTimeoutCause(ctx, timeout, ErrGrpcTimeout) + } + ctx, cancel := context.WithCancelCause(ctx) + if timoutCancel != nil { + context.AfterFunc(ctx, timoutCancel) + } + return ctx, cancel +} diff --git a/internal/transport/context_test.go b/internal/transport/context_test.go new file mode 100644 index 000000000000..3450c300eae7 --- /dev/null +++ b/internal/transport/context_test.go @@ -0,0 +1,76 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "context" + "reflect" + "testing" + "time" +) + +func Test_createContext(t *testing.T) { + tests := []struct { + name string + f func() context.Context + err error + cause error + }{ + {"cause when cancelled", + func() context.Context { + ctx, cancel := createContext(context.Background(), false, 0) + cancel(ErrRequestDone) + return ctx + }, + context.Canceled, + ErrRequestDone, + }, + {"cause when cancelled after deadline exceeded", + func() context.Context { + ctx, cancel := createContext(context.Background(), true, 0) + cancel(ErrRequestDone) + return ctx + }, + context.DeadlineExceeded, + ErrGrpcTimeout, + }, + {"cause when cancelled before deadline exceeded", + func() context.Context { + ctx, cancel := createContext(context.Background(), true, 1*time.Second) + cancel(ErrRequestDone) + return ctx + }, + context.Canceled, + ErrRequestDone, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.f() + err := ctx.Err() + if !reflect.DeepEqual(err, tt.err) { + t.Errorf("ctx.Err() got %v, want %v", err, tt.cause) + } + cause := context.Cause(ctx) + if !reflect.DeepEqual(cause, tt.cause) { + t.Errorf("context.Cause(ctx) got = %v, want %v", cause, tt.cause) + } + }) + } +} diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 0ebe4a71cb9b..72efc67a5797 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -387,12 +387,7 @@ func (ht *serverHandlerTransport) WriteHeader(s *ServerStream, md metadata.MD) e func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*ServerStream)) { // With this transport type there will be exactly 1 stream: this HTTP request. - var cancel context.CancelFunc - if ht.timeoutSet { - ctx, cancel = context.WithTimeout(ctx, ht.timeout) - } else { - ctx, cancel = context.WithCancel(ctx) - } + ctx, cancel := createContext(ctx, ht.timeoutSet, ht.timeout) // requestOver is closed when the status has been written via WriteStatus. requestOver := make(chan struct{}) @@ -402,8 +397,8 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream case <-ht.closedCh: case <-ht.req.Context().Done(): } - cancel() - ht.Close(errors.New("request is done processing")) + cancel(ErrRequestDone) + ht.Close(ErrRequestDone) }() ctx = metadata.NewIncomingContext(ctx, ht.headerMD) diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index f6faa29b9520..24052eebf3ae 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -530,11 +530,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade // s is just created by the caller. No lock needed. s.state = streamReadDone } - if timeoutSet { - s.ctx, s.cancel = context.WithTimeout(ctx, timeout) - } else { - s.ctx, s.cancel = context.WithCancel(ctx) - } + s.ctx, s.cancel = createContext(ctx, timeoutSet, timeout) // Attach the received metadata to the context. if len(mdata) > 0 { @@ -549,18 +545,19 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade t.mu.Lock() if t.state != reachable { t.mu.Unlock() - s.cancel() + s.cancel(ErrUnreachable) return nil } if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() + rstCode := http2.ErrCodeRefusedStream t.controlBuf.put(&cleanupStream{ streamID: streamID, rst: true, - rstCode: http2.ErrCodeRefusedStream, + rstCode: rstCode, onWrite: func() {}, }) - s.cancel() + s.cancel(RstCodeError{rstCode}) return nil } if httpMethod != http.MethodPost { @@ -569,14 +566,15 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade if t.logger.V(logLevel) { t.logger.Infof("Aborting the stream early: %v", errMsg) } + status := status.New(codes.Internal, errMsg) t.controlBuf.put(&earlyAbortStream{ httpStatus: 405, streamID: streamID, contentSubtype: s.contentSubtype, - status: status.New(codes.Internal, errMsg), + status: status, rst: !frame.StreamEnded(), }) - s.cancel() + s.cancel(StatusError{status}) return nil } if t.inTapHandle != nil { @@ -1273,7 +1271,7 @@ func (t *http2Server) Close(err error) { channelz.RemoveEntry(t.channelz.ID) // Cancel all active streams. for _, s := range streams { - s.cancel() + s.cancel(ErrServerTransportClosed) } } @@ -1303,7 +1301,7 @@ func (t *http2Server) finishStream(s *ServerStream, rst bool, rstCode http2.ErrC // In case stream sending and receiving are invoked in separate // goroutines (e.g., bi-directional streaming), cancel needs to be // called to interrupt the potential blocking on other goroutines. - s.cancel() + s.cancel(RstCodeError{rstCode}) oldState := s.swapState(streamDone) if oldState == streamDone { @@ -1327,7 +1325,7 @@ func (t *http2Server) closeStream(s *ServerStream, rst bool, rstCode http2.ErrCo // In case stream sending and receiving are invoked in separate // goroutines (e.g., bi-directional streaming), cancel needs to be // called to interrupt the potential blocking on other goroutines. - s.cancel() + s.cancel(RstCodeError{rstCode}) s.swapState(streamDone) t.deleteStream(s, eosReceived) From 77899b4774398067e7f3367d3ff4baeaac6bd303 Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Thu, 14 Nov 2024 16:04:15 +0100 Subject: [PATCH 2/3] set cause for context cancellation only in case of closing the stream with a http2 error code. Signed-off-by: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> --- internal/transport/context.go | 31 +++----------------------- internal/transport/context_test.go | 23 ++++++++++--------- internal/transport/handler_server.go | 6 ++--- internal/transport/http2_server.go | 33 ++++++++++++++++++---------- internal/transport/server_stream.go | 4 ++-- 5 files changed, 42 insertions(+), 55 deletions(-) diff --git a/internal/transport/context.go b/internal/transport/context.go index 907e8fe8689d..ebb0399744e5 100644 --- a/internal/transport/context.go +++ b/internal/transport/context.go @@ -20,38 +20,13 @@ package transport import ( "context" - "errors" "time" - - "golang.org/x/net/http2" - "google.golang.org/grpc/status" ) -var ErrGrpcTimeout = errors.New("grpc-timeout") -var ErrRequestDone = errors.New("request is done processing") -var ErrServerTransportClosed = errors.New("server transport closed") -var ErrUnreachable = errors.New("unreachable") - -type RstCodeError struct { - RstCode http2.ErrCode -} - -func (e RstCodeError) Error() string { - return e.RstCode.String() -} - -type StatusError struct { - Status *status.Status -} - -func (e StatusError) Error() string { - return e.Status.String() -} - -func createContext(ctx context.Context, timeoutSet bool, timeout time.Duration) (context.Context, context.CancelCauseFunc) { - var timoutCancel context.CancelFunc = nil +func createContextWithTimeout(ctx context.Context, timeoutSet bool, timeout time.Duration) (context.Context, context.CancelCauseFunc) { + var timoutCancel context.CancelFunc if timeoutSet { - ctx, timoutCancel = context.WithTimeoutCause(ctx, timeout, ErrGrpcTimeout) + ctx, timoutCancel = context.WithTimeout(ctx, timeout) } ctx, cancel := context.WithCancelCause(ctx) if timoutCancel != nil { diff --git a/internal/transport/context_test.go b/internal/transport/context_test.go index 3450c300eae7..dc90d80f5c19 100644 --- a/internal/transport/context_test.go +++ b/internal/transport/context_test.go @@ -20,12 +20,15 @@ package transport import ( "context" + "errors" "reflect" "testing" "time" ) -func Test_createContext(t *testing.T) { +var errRequestDone = errors.New("request is done processing") + +func Test_createContextWithTimeout(t *testing.T) { tests := []struct { name string f func() context.Context @@ -34,30 +37,30 @@ func Test_createContext(t *testing.T) { }{ {"cause when cancelled", func() context.Context { - ctx, cancel := createContext(context.Background(), false, 0) - cancel(ErrRequestDone) + ctx, cancel := createContextWithTimeout(context.Background(), false, 0) + cancel(errRequestDone) return ctx }, context.Canceled, - ErrRequestDone, + errRequestDone, }, {"cause when cancelled after deadline exceeded", func() context.Context { - ctx, cancel := createContext(context.Background(), true, 0) - cancel(ErrRequestDone) + ctx, cancel := createContextWithTimeout(context.Background(), true, 0) + cancel(errRequestDone) return ctx }, context.DeadlineExceeded, - ErrGrpcTimeout, + context.DeadlineExceeded, }, {"cause when cancelled before deadline exceeded", func() context.Context { - ctx, cancel := createContext(context.Background(), true, 1*time.Second) - cancel(ErrRequestDone) + ctx, cancel := createContextWithTimeout(context.Background(), true, 1*time.Second) + cancel(errRequestDone) return ctx }, context.Canceled, - ErrRequestDone, + errRequestDone, }, } for _, tt := range tests { diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 72efc67a5797..1b020c94c7f6 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -387,7 +387,7 @@ func (ht *serverHandlerTransport) WriteHeader(s *ServerStream, md metadata.MD) e func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*ServerStream)) { // With this transport type there will be exactly 1 stream: this HTTP request. - ctx, cancel := createContext(ctx, ht.timeoutSet, ht.timeout) + ctx, cancel := createContextWithTimeout(ctx, ht.timeoutSet, ht.timeout) // requestOver is closed when the status has been written via WriteStatus. requestOver := make(chan struct{}) @@ -397,8 +397,8 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream case <-ht.closedCh: case <-ht.req.Context().Done(): } - cancel(ErrRequestDone) - ht.Close(ErrRequestDone) + cancel(nil) + ht.Close(errors.New("request is done processing")) }() ctx = metadata.NewIncomingContext(ctx, ht.headerMD) diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 24052eebf3ae..4e9fa7652076 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -530,7 +530,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade // s is just created by the caller. No lock needed. s.state = streamReadDone } - s.ctx, s.cancel = createContext(ctx, timeoutSet, timeout) + s.ctx, s.cancel = createContextWithTimeout(ctx, timeoutSet, timeout) // Attach the received metadata to the context. if len(mdata) > 0 { @@ -545,19 +545,18 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade t.mu.Lock() if t.state != reachable { t.mu.Unlock() - s.cancel(ErrUnreachable) + s.cancel(nil) return nil } if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() - rstCode := http2.ErrCodeRefusedStream t.controlBuf.put(&cleanupStream{ streamID: streamID, rst: true, - rstCode: rstCode, + rstCode: http2.ErrCodeRefusedStream, onWrite: func() {}, }) - s.cancel(RstCodeError{rstCode}) + s.cancel(nil) return nil } if httpMethod != http.MethodPost { @@ -566,15 +565,14 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade if t.logger.V(logLevel) { t.logger.Infof("Aborting the stream early: %v", errMsg) } - status := status.New(codes.Internal, errMsg) t.controlBuf.put(&earlyAbortStream{ httpStatus: 405, streamID: streamID, contentSubtype: s.contentSubtype, - status: status, + status: status.New(codes.Internal, errMsg), rst: !frame.StreamEnded(), }) - s.cancel(StatusError{status}) + s.cancel(nil) return nil } if t.inTapHandle != nil { @@ -832,7 +830,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) { func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { // If the stream is not deleted from the transport's active streams map, then do a regular close stream. if s, ok := t.getStream(f); ok { - t.closeStream(s, false, 0, false) + t.closeStream(s, false, http2.ErrCodeNo, false) return } // If the stream is already deleted from the active streams map, then put a cleanupStream item into controlbuf to delete the stream from loopy writer's established streams map. @@ -1271,7 +1269,7 @@ func (t *http2Server) Close(err error) { channelz.RemoveEntry(t.channelz.ID) // Cancel all active streams. for _, s := range streams { - s.cancel(ErrServerTransportClosed) + s.cancel(nil) } } @@ -1301,7 +1299,7 @@ func (t *http2Server) finishStream(s *ServerStream, rst bool, rstCode http2.ErrC // In case stream sending and receiving are invoked in separate // goroutines (e.g., bi-directional streaming), cancel needs to be // called to interrupt the potential blocking on other goroutines. - s.cancel(RstCodeError{rstCode}) + s.cancel(nil) oldState := s.swapState(streamDone) if oldState == streamDone { @@ -1325,7 +1323,7 @@ func (t *http2Server) closeStream(s *ServerStream, rst bool, rstCode http2.ErrCo // In case stream sending and receiving are invoked in separate // goroutines (e.g., bi-directional streaming), cancel needs to be // called to interrupt the potential blocking on other goroutines. - s.cancel(RstCodeError{rstCode}) + s.cancel(HTTP2CodeError{rstCode}) s.swapState(streamDone) t.deleteStream(s, eosReceived) @@ -1473,3 +1471,14 @@ func GetConnection(ctx context.Context) net.Conn { func SetConnection(ctx context.Context, conn net.Conn) context.Context { return context.WithValue(ctx, connectionKey{}, conn) } + +// HTTP2CodeError represents an error with an HTTP/2 error code. +type HTTP2CodeError struct { + // Code is the HTTP/2 error code associated with the error. + Code http2.ErrCode +} + +// Error returns the string representation of the HTTP/2 error code. +func (e HTTP2CodeError) Error() string { + return e.Code.String() +} diff --git a/internal/transport/server_stream.go b/internal/transport/server_stream.go index acbf014900bc..4eb94dfc59cf 100644 --- a/internal/transport/server_stream.go +++ b/internal/transport/server_stream.go @@ -33,8 +33,8 @@ type ServerStream struct { *Stream // Embed for common stream functionality. st ServerTransport - ctxDone <-chan struct{} // closed at the end of stream. Cache of ctx.Done() (for performance) - cancel context.CancelFunc // invoked at the end of stream to cancel ctx. + ctxDone <-chan struct{} // closed at the end of stream. Cache of ctx.Done() (for performance) + cancel context.CancelCauseFunc // invoked at the end of stream to cancel ctx. // Holds compressor names passed in grpc-accept-encoding metadata from the // client. From a9a5167b88abb5fc514409836f936e2749b32ae1 Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Thu, 14 Nov 2024 16:06:33 +0100 Subject: [PATCH 3/3] add example for context propagation in a server with goroutines. Signed-off-by: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> --- examples/features/context/README.md | 12 ++ examples/features/context/client/main.go | 161 +++++++++++++++++++++++ examples/features/context/server/main.go | 96 ++++++++++++++ examples/go.mod | 2 +- 4 files changed, 270 insertions(+), 1 deletion(-) create mode 100644 examples/features/context/README.md create mode 100644 examples/features/context/client/main.go create mode 100644 examples/features/context/server/main.go diff --git a/examples/features/context/README.md b/examples/features/context/README.md new file mode 100644 index 000000000000..815201b48e91 --- /dev/null +++ b/examples/features/context/README.md @@ -0,0 +1,12 @@ +# Context + +This example shows how servers can process requests in separate goroutines and +handle context cancellation. + +``` +go run server/main.go +``` + +``` +go run client/main.go +``` diff --git a/examples/features/context/client/main.go b/examples/features/context/client/main.go new file mode 100644 index 000000000000..8d20e8e5ffd4 --- /dev/null +++ b/examples/features/context/client/main.go @@ -0,0 +1,161 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Binary client demonstrates how to cancel in-flight RPCs by canceling the +// context passed to the RPC. +package main + +import ( + "context" + "flag" + "fmt" + "log" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + pb "google.golang.org/grpc/examples/features/proto/echo" + "google.golang.org/grpc/status" +) + +var addr = flag.String("addr", "localhost:50051", "the address to connect to") + +func sendMessage(stream pb.Echo_BidirectionalStreamingEchoClient, msg string) error { + fmt.Printf("sending message %q\n", msg) + return stream.Send(&pb.EchoRequest{Message: msg}) +} + +func recvMessage(stream pb.Echo_BidirectionalStreamingEchoClient, wantErrCode codes.Code) { + res, err := stream.Recv() + if status.Code(err) != wantErrCode { + log.Fatalf("stream.Recv() = %v, %v; want _, status.Code(err)=%v", res, err, wantErrCode) + } + if err != nil { + fmt.Printf("stream.Recv() returned expected error %v\n", err) + return + } + fmt.Printf("received message %q\n", res.GetMessage()) +} + +func cancelStream() { + fmt.Println("sending two messages and then canceling") + conn, err := grpc.NewClient(*addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + log.Fatalf("did not connect: %v", err) + } + defer conn.Close() + + c := pb.NewEchoClient(conn) + + // Initiate the stream with a context that supports cancellation. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + stream, err := c.BidirectionalStreamingEcho(ctx) + if err != nil { + log.Fatalf("error creating stream: %v", err) + } + + // Send some test messages. + if err := sendMessage(stream, "hello"); err != nil { + log.Fatalf("error sending on stream: %v", err) + } + if err := sendMessage(stream, "world"); err != nil { + log.Fatalf("error sending on stream: %v", err) + } + + // Ensure the RPC is working. + recvMessage(stream, codes.OK) + recvMessage(stream, codes.OK) + + fmt.Println("canceling context") + cancel() +} + +func closeConnection() { + fmt.Println("sending two messages and then closing connection") + conn, err := grpc.NewClient(*addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + log.Fatalf("did not connect: %v", err) + } + defer func() { + fmt.Println("closing connection") + conn.Close() + }() + + c := pb.NewEchoClient(conn) + + stream, err := c.BidirectionalStreamingEcho(context.Background()) + if err != nil { + log.Fatalf("error creating stream: %v", err) + } + + // Send some test messages. + if err := sendMessage(stream, "hello"); err != nil { + log.Fatalf("error sending on stream: %v", err) + } + if err := sendMessage(stream, "world"); err != nil { + log.Fatalf("error sending on stream: %v", err) + } + + // Ensure the RPC is working. + recvMessage(stream, codes.OK) + recvMessage(stream, codes.OK) +} + +func timout() { + fmt.Println("sending two messages and then waiting for timeout") + conn, err := grpc.NewClient(*addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + log.Fatalf("did not connect: %v", err) + } + defer conn.Close() + + c := pb.NewEchoClient(conn) + + // Initiate the stream with a context that supports cancellation. + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + stream, err := c.BidirectionalStreamingEcho(ctx) + if err != nil { + log.Fatalf("error creating stream: %v", err) + } + + // Send some test messages. + if err := sendMessage(stream, "hello"); err != nil { + log.Fatalf("error sending on stream: %v", err) + } + if err := sendMessage(stream, "world"); err != nil { + log.Fatalf("error sending on stream: %v", err) + } + + recvMessage(stream, codes.OK) + recvMessage(stream, codes.OK) + + fmt.Println("waiting for timeout") + <-ctx.Done() +} + +func main() { + flag.Parse() + + // simulate some client behaviors which terminate with different reasons + cancelStream() + closeConnection() + timout() +} diff --git a/examples/features/context/server/main.go b/examples/features/context/server/main.go new file mode 100644 index 000000000000..f02752fd1f85 --- /dev/null +++ b/examples/features/context/server/main.go @@ -0,0 +1,96 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Binary server demonstrates how to handle canceled contexts when a client +// cancels an in-flight RPC. +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "io" + "log" + "net" + + "golang.org/x/net/http2" + "google.golang.org/grpc" + "google.golang.org/grpc/internal/transport" + + pb "google.golang.org/grpc/examples/features/proto/echo" +) + +var port = flag.Int("port", 50051, "the port to serve on") + +type server struct { + pb.UnimplementedEchoServer +} + +func workOnMessage(ctx context.Context, msg string) { + fmt.Printf("starting work on message: %q\n", msg) + i := 0 + for { + if err := ctx.Err(); err != nil { + cause := context.Cause(ctx) + fmt.Printf("'%v' with cause '%v', message worker for message %q stopping.\n", err, cause, msg) + var httpErr *transport.HTTP2CodeError + if errors.As(cause, &httpErr) { + switch httpErr.Code { + case http2.ErrCodeNo: + return + default: + fmt.Printf("unexpected HTTP/2 error: %v", httpErr) + } + } + return + } + // simulate work on message but don't flood the logs + i++ + } +} + +func (s *server) BidirectionalStreamingEcho(stream pb.Echo_BidirectionalStreamingEchoServer) error { + ctx := stream.Context() + for { + recv, err := stream.Recv() + if err != nil { + fmt.Printf("server: error receiving from stream: %v\n", err) + if err == io.EOF { + return nil + } + return err + } + msg := recv.Message + go workOnMessage(ctx, msg) + stream.Send(&pb.EchoResponse{Message: msg}) + } +} + +func main() { + flag.Parse() + + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port)) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + fmt.Printf("server listening at port %v\n", lis.Addr()) + s := grpc.NewServer() + pb.RegisterEchoServer(s, &server{}) + s.Serve(lis) +} diff --git a/examples/go.mod b/examples/go.mod index ed63e04cf9fd..48fb49c423e0 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -7,6 +7,7 @@ require ( github.com/prometheus/client_golang v1.20.5 go.opentelemetry.io/otel/exporters/prometheus v0.53.0 go.opentelemetry.io/otel/sdk/metric v1.31.0 + golang.org/x/net v0.30.0 golang.org/x/oauth2 v0.23.0 google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 google.golang.org/grpc v1.67.1 @@ -69,7 +70,6 @@ require ( go.opentelemetry.io/otel/sdk v1.31.0 // indirect go.opentelemetry.io/otel/trace v1.31.0 // indirect golang.org/x/crypto v0.28.0 // indirect - golang.org/x/net v0.30.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.26.0 // indirect golang.org/x/text v0.19.0 // indirect