diff --git a/authz/audit/audit_logging_test.go b/authz/audit/audit_logging_test.go index ea84db099608..91cad273e959 100644 --- a/authz/audit/audit_logging_test.go +++ b/authz/audit/audit_logging_test.go @@ -24,7 +24,6 @@ import ( "crypto/x509" "encoding/json" "io" - "net" "os" "testing" "time" @@ -240,23 +239,24 @@ func (s) TestAuditLogger(t *testing.T) { wantStreamingCallCode: codes.PermissionDenied, }, } - // Construct the credentials for the tests and the stub server - serverCreds := loadServerCreds(t) - clientCreds := loadClientCreds(t) - ss := &stubserver.StubServer{ - UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { - return &testpb.SimpleResponse{}, nil - }, - FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { - _, err := stream.Recv() - if err != io.EOF { - return err - } - return nil - }, - } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { + // Construct the credentials for the tests and the stub server + serverCreds := loadServerCreds(t) + clientCreds := loadClientCreds(t) + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + _, err := stream.Recv() + if err != io.EOF { + return err + } + return nil + }, + } // Setup test statAuditLogger, gRPC test server with authzPolicy, unary // and stream interceptors. lb := &loggerBuilder{ @@ -266,25 +266,18 @@ func (s) TestAuditLogger(t *testing.T) { audit.RegisterLoggerBuilder(lb) i, _ := authz.NewStatic(test.authzPolicy) - s := grpc.NewServer( - grpc.Creds(serverCreds), - grpc.ChainUnaryInterceptor(i.UnaryInterceptor), - grpc.ChainStreamInterceptor(i.StreamInterceptor)) + s := grpc.NewServer(grpc.Creds(serverCreds), grpc.ChainUnaryInterceptor(i.UnaryInterceptor), grpc.ChainStreamInterceptor(i.StreamInterceptor)) defer s.Stop() - testgrpc.RegisterTestServiceServer(s, ss) - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Error listening: %v", err) - } - go s.Serve(lis) + ss.S = s + stubserver.StartTestService(t, ss) // Setup gRPC test client with certificates containing a SPIFFE Id. - clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(clientCreds)) + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds)) if err != nil { - t.Fatalf("grpc.NewClient(%v) failed: %v", lis.Addr().String(), err) + t.Fatalf("grpc.NewClient(%v) failed: %v", ss.Address, err) } - defer clientConn.Close() - client := testgrpc.NewTestServiceClient(clientConn) + defer cc.Close() + client := testgrpc.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -296,7 +289,7 @@ func (s) TestAuditLogger(t *testing.T) { } stream, err := client.StreamingInputCall(ctx) if err != nil { - t.Fatalf("StreamingInputCall failed:%v", err) + t.Fatalf("StreamingInputCall failed: %v", err) } req := &testpb.StreamingInputCallRequest{ Payload: &testpb.Payload{ @@ -304,7 +297,7 @@ func (s) TestAuditLogger(t *testing.T) { }, } if err := stream.Send(req); err != nil && err != io.EOF { - t.Fatalf("stream.Send failed:%v", err) + t.Fatalf("stream.Send failed: %v", err) } if _, err := stream.CloseAndRecv(); status.Code(err) != test.wantStreamingCallCode { t.Errorf("Unexpected stream.CloseAndRecv fail: got %v want %v", err, test.wantStreamingCallCode) diff --git a/authz/grpc_authz_end2end_test.go b/authz/grpc_authz_end2end_test.go index 4e798f7ca3d7..fc68a6e68e2e 100644 --- a/authz/grpc_authz_end2end_test.go +++ b/authz/grpc_authz_end2end_test.go @@ -23,7 +23,6 @@ import ( "crypto/tls" "crypto/x509" "io" - "net" "os" "testing" "time" @@ -34,6 +33,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/grpc/testdata" @@ -42,26 +42,6 @@ import ( testpb "google.golang.org/grpc/interop/grpc_testing" ) -type testServer struct { - testgrpc.UnimplementedTestServiceServer -} - -func (s *testServer) UnaryCall(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { - return &testpb.SimpleResponse{}, nil -} - -func (s *testServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error { - for { - _, err := stream.Recv() - if err == io.EOF { - return stream.SendAndClose(&testpb.StreamingInputCallResponse{}) - } - if err != nil { - return err - } - } -} - type s struct { grpctest.Tester } @@ -313,25 +293,34 @@ func (s) TestStaticPolicyEnd2End(t *testing.T) { t.Run(name, func(t *testing.T) { // Start a gRPC server with gRPC authz unary and stream server interceptors. i, _ := authz.NewStatic(test.authzPolicy) - s := grpc.NewServer( - grpc.ChainUnaryInterceptor(i.UnaryInterceptor), - grpc.ChainStreamInterceptor(i.StreamInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("error listening: %v", err) + stub := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + StreamingInputCallF: func(stream testgrpc.TestService_StreamingInputCallServer) error { + for { + _, err := stream.Recv() + if err == io.EOF { + return stream.SendAndClose(&testpb.StreamingInputCallResponse{}) + } + if err != nil { + return err + } + } + }, + S: grpc.NewServer(grpc.ChainUnaryInterceptor(i.UnaryInterceptor), grpc.ChainStreamInterceptor(i.StreamInterceptor)), } - go s.Serve(lis) + stubserver.StartTestService(t, stub) + defer stub.Stop() // Establish a connection to the server. - clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - t.Fatalf("grpc.NewClient(%v) failed: %v", lis.Addr().String(), err) + t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err) } - defer clientConn.Close() - client := testgrpc.NewTestServiceClient(clientConn) + defer cc.Close() + client := testgrpc.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -383,29 +372,27 @@ func (s) TestAllowsRPCRequestWithPrincipalsFieldOnTLSAuthenticatedConnection(t * if err != nil { t.Fatalf("failed to generate credentials: %v", err) } - s := grpc.NewServer( - grpc.Creds(creds), - grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("error listening: %v", err) + stub := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + S: grpc.NewServer(grpc.Creds(creds), grpc.ChainUnaryInterceptor(i.UnaryInterceptor)), } - go s.Serve(lis) + stubserver.StartTestService(t, stub) + defer stub.S.Stop() // Establish a connection to the server. creds, err = credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com") if err != nil { t.Fatalf("failed to load credentials: %v", err) } - clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(creds)) + cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(creds)) if err != nil { - t.Fatalf("grpc.NewClient(%v) failed: %v", lis.Addr().String(), err) + t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err) } - defer clientConn.Close() - client := testgrpc.NewTestServiceClient(clientConn) + defer cc.Close() + client := testgrpc.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -448,17 +435,14 @@ func (s) TestAllowsRPCRequestWithPrincipalsFieldOnMTLSAuthenticatedConnection(t Certificates: []tls.Certificate{cert}, ClientCAs: certPool, }) - s := grpc.NewServer( - grpc.Creds(creds), - grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) - - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("error listening: %v", err) + stub := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + S: grpc.NewServer(grpc.Creds(creds), grpc.ChainUnaryInterceptor(i.UnaryInterceptor)), } - go s.Serve(lis) + stubserver.StartTestService(t, stub) + defer stub.Stop() // Establish a connection to the server. cert, err = tls.LoadX509KeyPair(testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) @@ -478,12 +462,12 @@ func (s) TestAllowsRPCRequestWithPrincipalsFieldOnMTLSAuthenticatedConnection(t RootCAs: roots, ServerName: "x.test.example.com", }) - clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(creds)) + cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(creds)) if err != nil { - t.Fatalf("grpc.NewClient(%v) failed: %v", lis.Addr().String(), err) + t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err) } - defer clientConn.Close() - client := testgrpc.NewTestServiceClient(clientConn) + defer cc.Close() + client := testgrpc.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -501,27 +485,34 @@ func (s) TestFileWatcherEnd2End(t *testing.T) { i, _ := authz.NewFileWatcher(file, 1*time.Second) defer i.Close() - // Start a gRPC server with gRPC authz unary and stream server interceptors. - s := grpc.NewServer( - grpc.ChainUnaryInterceptor(i.UnaryInterceptor), - grpc.ChainStreamInterceptor(i.StreamInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) - - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("error listening: %v", err) + stub := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + StreamingInputCallF: func(stream testgrpc.TestService_StreamingInputCallServer) error { + for { + _, err := stream.Recv() + if err == io.EOF { + return stream.SendAndClose(&testpb.StreamingInputCallResponse{}) + } + if err != nil { + return err + } + } + }, + // Start a gRPC server with gRPC authz unary and stream server interceptors. + S: grpc.NewServer(grpc.ChainUnaryInterceptor(i.UnaryInterceptor), grpc.ChainStreamInterceptor(i.StreamInterceptor)), } - defer lis.Close() - go s.Serve(lis) + stubserver.StartTestService(t, stub) + defer stub.Stop() // Establish a connection to the server. - clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - t.Fatalf("grpc.NewClient(%v) failed: %v", lis.Addr().String(), err) + t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err) } - defer clientConn.Close() - client := testgrpc.NewTestServiceClient(clientConn) + defer cc.Close() + client := testgrpc.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -536,7 +527,7 @@ func (s) TestFileWatcherEnd2End(t *testing.T) { // Verifying authorization decision for Streaming RPC. stream, err := client.StreamingInputCall(ctx) if err != nil { - t.Fatalf("failed StreamingInputCall err: %v", err) + t.Fatalf("failed StreamingInputCall : %v", err) } req := &testpb.StreamingInputCallRequest{ Payload: &testpb.Payload{ @@ -544,7 +535,7 @@ func (s) TestFileWatcherEnd2End(t *testing.T) { }, } if err := stream.Send(req); err != nil && err != io.EOF { - t.Fatalf("failed stream.Send err: %v", err) + t.Fatalf("failed stream.Send : %v", err) } _, err = stream.CloseAndRecv() if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() { @@ -571,26 +562,23 @@ func (s) TestFileWatcher_ValidPolicyRefresh(t *testing.T) { i, _ := authz.NewFileWatcher(file, 100*time.Millisecond) defer i.Close() - // Start a gRPC server with gRPC authz unary server interceptor. - s := grpc.NewServer( - grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) - - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("error listening: %v", err) + stub := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + // Start a gRPC server with gRPC authz unary server interceptor. + S: grpc.NewServer(grpc.ChainUnaryInterceptor(i.UnaryInterceptor)), } - defer lis.Close() - go s.Serve(lis) + stubserver.StartTestService(t, stub) + defer stub.Stop() // Establish a connection to the server. - clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - t.Fatalf("grpc.NewClient(%v) failed: %v", lis.Addr().String(), err) + t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err) } - defer clientConn.Close() - client := testgrpc.NewTestServiceClient(clientConn) + defer cc.Close() + client := testgrpc.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -619,26 +607,23 @@ func (s) TestFileWatcher_InvalidPolicySkipReload(t *testing.T) { i, _ := authz.NewFileWatcher(file, 20*time.Millisecond) defer i.Close() - // Start a gRPC server with gRPC authz unary server interceptors. - s := grpc.NewServer( - grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) - - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("error listening: %v", err) + stub := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + // Start a gRPC server with gRPC authz unary server interceptors. + S: grpc.NewServer(grpc.ChainUnaryInterceptor(i.UnaryInterceptor)), } - defer lis.Close() - go s.Serve(lis) + stubserver.StartTestService(t, stub) + defer stub.Stop() // Establish a connection to the server. - clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - t.Fatalf("grpc.NewClient(%v) failed: %v", lis.Addr().String(), err) + t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err) } - defer clientConn.Close() - client := testgrpc.NewTestServiceClient(clientConn) + defer cc.Close() + client := testgrpc.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -670,26 +655,22 @@ func (s) TestFileWatcher_RecoversFromReloadFailure(t *testing.T) { i, _ := authz.NewFileWatcher(file, 100*time.Millisecond) defer i.Close() - // Start a gRPC server with gRPC authz unary server interceptors. - s := grpc.NewServer( - grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) - defer s.Stop() - testgrpc.RegisterTestServiceServer(s, &testServer{}) - - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("error listening: %v", err) + stub := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + S: grpc.NewServer(grpc.ChainUnaryInterceptor(i.UnaryInterceptor)), } - defer lis.Close() - go s.Serve(lis) + stubserver.StartTestService(t, stub) + defer stub.Stop() // Establish a connection to the server. - clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + cc, err := grpc.NewClient(stub.Address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - t.Fatalf("grpc.NewClient(%v) failed: %v", lis.Addr().String(), err) + t.Fatalf("grpc.NewClient(%v) failed: %v", stub.Address, err) } - defer clientConn.Close() - client := testgrpc.NewTestServiceClient(clientConn) + defer cc.Close() + client := testgrpc.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/internal/stubserver/stubserver.go b/internal/stubserver/stubserver.go index 2e404e294bf6..3c3f4fb067f2 100644 --- a/internal/stubserver/stubserver.go +++ b/internal/stubserver/stubserver.go @@ -56,9 +56,11 @@ type StubServer struct { testgrpc.TestServiceServer // Customizable implementations of server handlers. - EmptyCallF func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) - UnaryCallF func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) - FullDuplexCallF func(stream testgrpc.TestService_FullDuplexCallServer) error + EmptyCallF func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) + UnaryCallF func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) + FullDuplexCallF func(stream testgrpc.TestService_FullDuplexCallServer) error + StreamingInputCallF func(stream testgrpc.TestService_StreamingInputCallServer) error + StreamingOutputCallF func(req *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error // A client connected to this service the test may use. Created in Start(). Client testgrpc.TestServiceClient @@ -101,6 +103,16 @@ func (ss *StubServer) FullDuplexCall(stream testgrpc.TestService_FullDuplexCallS return ss.FullDuplexCallF(stream) } +// StreamingInputCall is the handler for testpb.StreamingInputCall +func (ss *StubServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error { + return ss.StreamingInputCallF(stream) +} + +// StreamingOutputCall is the handler for testpb.StreamingOutputCall +func (ss *StubServer) StreamingOutputCall(req *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error { + return ss.StreamingOutputCallF(req, stream) +} + // Start starts the server and creates a client connected to it. func (ss *StubServer) Start(sopts []grpc.ServerOption, dopts ...grpc.DialOption) error { if err := ss.StartServer(sopts...); err != nil {