diff --git a/src/control/cmd/daos_agent/security_rpc_test.go b/src/control/cmd/daos_agent/security_rpc_test.go index 43d51f51e069..1c682aff1ca3 100644 --- a/src/control/cmd/daos_agent/security_rpc_test.go +++ b/src/control/cmd/daos_agent/security_rpc_test.go @@ -73,7 +73,7 @@ func setupTestUnixConn(t *testing.T) (*net.UnixConn, func()) { return newConn, cleanup } -func getClientConn(t *testing.T, path string) *drpc.ClientConnection { +func getClientConn(t *testing.T, path string) drpc.DomainSocketClient { client := drpc.NewClientConnection(path) if err := client.Connect(test.Context(t)); err != nil { t.Fatalf("Failed to connect: %v", err) diff --git a/src/control/drpc/drpc_client.go b/src/control/drpc/drpc_client.go index 86291b31a635..f3d324c330ad 100644 --- a/src/control/drpc/drpc_client.go +++ b/src/control/drpc/drpc_client.go @@ -186,7 +186,7 @@ func (c *ClientConnection) GetSocketPath() string { } // NewClientConnection creates a new dRPC client -func NewClientConnection(socket string) *ClientConnection { +func NewClientConnection(socket string) DomainSocketClient { return &ClientConnection{ socketPath: socket, dialer: &clientDialer{}, diff --git a/src/control/drpc/drpc_client_test.go b/src/control/drpc/drpc_client_test.go index 06301c51b44c..10bf5a527a6b 100644 --- a/src/control/drpc/drpc_client_test.go +++ b/src/control/drpc/drpc_client_test.go @@ -72,12 +72,13 @@ func TestNewClientConnection(t *testing.T) { t.Fatal("Expected a real client") return } - test.AssertEqual(t, client.socketPath, testSockPath, + clientConn := client.(*ClientConnection) + test.AssertEqual(t, clientConn.socketPath, testSockPath, "Should match the path we passed in") test.AssertFalse(t, client.IsConnected(), "Shouldn't be connected yet") // Dialer should be the private implementation type - _ = client.dialer.(*clientDialer) + _ = clientConn.dialer.(*clientDialer) } func TestClient_Connect_Success(t *testing.T) { diff --git a/src/control/server/ctl_ranks_rpc_test.go b/src/control/server/ctl_ranks_rpc_test.go index 7819e4994932..f585d51d5b89 100644 --- a/src/control/server/ctl_ranks_rpc_test.go +++ b/src/control/server/ctl_ranks_rpc_test.go @@ -217,7 +217,9 @@ func TestServer_CtlSvc_PrepShutdownRanks(t *testing.T) { cfg.setResponseDelay(tc.responseDelay) } } - srv.setDrpcClient(newMockDrpcClient(cfg)) + srv.getDrpcClientFn = func(s string) drpc.DomainSocketClient { + return newMockDrpcClient(cfg) + } } var cancel context.CancelFunc @@ -580,7 +582,9 @@ func TestServer_CtlSvc_PingRanks(t *testing.T) { cfg.setResponseDelay(tc.responseDelay) } } - srv.setDrpcClient(newMockDrpcClient(cfg)) + srv.getDrpcClientFn = func(string) drpc.DomainSocketClient { + return newMockDrpcClient(cfg) + } } ctx, outerCancel := context.WithCancel(test.Context(t)) @@ -1092,7 +1096,9 @@ func TestServer_CtlSvc_SetEngineLogMasks(t *testing.T) { cfg.setResponseDelay(tc.responseDelay) } } - srv.setDrpcClient(newMockDrpcClient(cfg)) + srv.getDrpcClientFn = func(s string) drpc.DomainSocketClient { + return newMockDrpcClient(cfg) + } } gotResp, gotErr := svc.SetEngineLogMasks(test.Context(t), tc.req) diff --git a/src/control/server/ctl_smd_rpc_test.go b/src/control/server/ctl_smd_rpc_test.go index 8374f64c3756..dad3946d6324 100644 --- a/src/control/server/ctl_smd_rpc_test.go +++ b/src/control/server/ctl_smd_rpc_test.go @@ -737,7 +737,10 @@ func TestServer_CtlSvc_SmdQuery(t *testing.T) { cfg.setSendMsgResponseList(t, mock) } } - srv.setDrpcClient(newMockDrpcClient(cfg)) + cli := newMockDrpcClient(cfg) + srv.getDrpcClientFn = func(s string) drpc.DomainSocketClient { + return cli + } srv.ready.SetTrue() } if tc.harnessStopped { @@ -1627,7 +1630,10 @@ func TestServer_CtlSvc_SmdManage(t *testing.T) { cfg.setSendMsgResponseList(t, mock) } } - srv.setDrpcClient(newMockDrpcClient(cfg)) + cli := newMockDrpcClient(cfg) + srv.getDrpcClientFn = func(s string) drpc.DomainSocketClient { + return cli + } srv.ready.SetTrue() } if tc.harnessStopped { diff --git a/src/control/server/ctl_storage_rpc_test.go b/src/control/server/ctl_storage_rpc_test.go index ecd6c5f66a68..46164188a67c 100644 --- a/src/control/server/ctl_storage_rpc_test.go +++ b/src/control/server/ctl_storage_rpc_test.go @@ -1626,7 +1626,10 @@ func TestServer_CtlSvc_StorageScan_PostEngineStart(t *testing.T) { } else { t.Fatal("drpc response mocks unpopulated") } - te.setDrpcClient(newMockDrpcClient(dcc)) + cli := newMockDrpcClient(dcc) + te.getDrpcClientFn = func(string) drpc.DomainSocketClient { + return cli + } te._superblock.Rank = ranklist.NewRankPtr(uint32(idx + 1)) for _, tc := range te.storage.GetBdevConfigs() { tc.Bdev.DeviceRoles.OptionBits = storage.OptionBits(storage.BdevRoleAll) diff --git a/src/control/server/ctl_svc_test.go b/src/control/server/ctl_svc_test.go index 11995e06671f..afe106435e9c 100644 --- a/src/control/server/ctl_svc_test.go +++ b/src/control/server/ctl_svc_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2019-2023 Intel Corporation. +// (C) Copyright 2019-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -66,6 +66,7 @@ func mockControlService(t *testing.T, log logging.Logger, cfg *config.Server, bm }) if started { ei.ready.SetTrue() + ei.setDrpcSocket("/dontcare") } if err := cs.harness.AddInstance(ei); err != nil { t.Fatal(err) diff --git a/src/control/server/harness_test.go b/src/control/server/harness_test.go index 716bbd8c44e4..7939ba0c3b8a 100644 --- a/src/control/server/harness_test.go +++ b/src/control/server/harness_test.go @@ -294,13 +294,17 @@ func TestServer_Harness_Start(t *testing.T) { } instances := harness.Instances() - + mockDrpcClients := make([]*mockDrpcClient, 0, len(instances)) // set mock dRPC client to record call details for _, e := range instances { ei := e.(*EngineInstance) - ei.setDrpcClient(newMockDrpcClient(&mockDrpcClientConfig{ + cli := newMockDrpcClient(&mockDrpcClientConfig{ SendMsgResponse: &drpc.Response{}, - })) + }) + mockDrpcClients = append(mockDrpcClients, cli) + ei.getDrpcClientFn = func(s string) drpc.DomainSocketClient { + return cli + } } ctx, cancel := context.WithCancel(test.Context(t)) @@ -417,13 +421,10 @@ func TestServer_Harness_Start(t *testing.T) { defer joinMu.Unlock() // verify expected RPCs were made, ranks allocated and // members added to membership - for _, e := range instances { + for i, e := range instances { ei := e.(*EngineInstance) - dc, err := ei.getDrpcClient() - if err != nil { - t.Fatal(err) - } - gotDrpcCalls := dc.(*mockDrpcClient).CalledMethods() + dc := mockDrpcClients[i] + gotDrpcCalls := dc.CalledMethods() AssertEqual(t, tc.expDrpcCalls[ei.Index()], gotDrpcCalls, fmt.Sprintf("%s: unexpected dRPCs for instance %d", name, ei.Index())) diff --git a/src/control/server/instance.go b/src/control/server/instance.go index 14f53cf3b5b8..7ef74a4d1261 100644 --- a/src/control/server/instance.go +++ b/src/control/server/instance.go @@ -58,12 +58,13 @@ type EngineInstance struct { onStorageReady []onStorageReadyFn onReady []onReadyFn onInstanceExit []onInstanceExitFn + getDrpcClientFn func(string) drpc.DomainSocketClient sync.RWMutex // these must be protected by a mutex in order to // avoid racy access. + _drpcSocket string _cancelCtx context.CancelFunc - _drpcClient drpc.DomainSocketClient _superblock *Superblock _lastErr error // populated when harness receives signal } @@ -162,11 +163,7 @@ func (ei *EngineInstance) Index() uint32 { func (ei *EngineInstance) removeSocket() error { fMsg := fmt.Sprintf("removing instance %d socket file", ei.Index()) - dc, err := ei.getDrpcClient() - if err != nil { - return errors.Wrap(err, fMsg) - } - engineSock := dc.GetSocketPath() + engineSock := ei.getDrpcSocket() if err := checkDrpcClientSocketPath(engineSock); err != nil { return errors.Wrap(err, fMsg) diff --git a/src/control/server/instance_drpc.go b/src/control/server/instance_drpc.go index 542b636e2f07..dfd10e8d3e05 100644 --- a/src/control/server/instance_drpc.go +++ b/src/control/server/instance_drpc.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2020-2023 Intel Corporation. +// (C) Copyright 2020-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -28,23 +28,29 @@ import ( ) var ( - errDRPCNotReady = errors.New("no dRPC client set (data plane not started?)") + errDRPCNotReady = errors.New("dRPC socket not ready (data plane not started?)") errEngineNotReady = errors.New("engine not ready yet") ) -func (ei *EngineInstance) setDrpcClient(c drpc.DomainSocketClient) { +func (ei *EngineInstance) setDrpcSocket(sock string) { ei.Lock() defer ei.Unlock() - ei._drpcClient = c + ei._drpcSocket = sock } -func (ei *EngineInstance) getDrpcClient() (drpc.DomainSocketClient, error) { +func (ei *EngineInstance) getDrpcSocket() string { ei.RLock() defer ei.RUnlock() - if ei._drpcClient == nil { - return nil, errDRPCNotReady + return ei._drpcSocket +} + +func (ei *EngineInstance) getDrpcClient() drpc.DomainSocketClient { + ei.Lock() + defer ei.Unlock() + if ei.getDrpcClientFn == nil { + ei.getDrpcClientFn = drpc.NewClientConnection } - return ei._drpcClient, nil + return ei.getDrpcClientFn(ei._drpcSocket) } // NotifyDrpcReady receives a ready message from the running Engine @@ -52,8 +58,7 @@ func (ei *EngineInstance) getDrpcClient() (drpc.DomainSocketClient, error) { func (ei *EngineInstance) NotifyDrpcReady(msg *srvpb.NotifyReadyReq) { ei.log.Debugf("%s instance %d drpc ready: %v", build.DataPlaneName, ei.Index(), msg) - // activate the dRPC client connection to this engine - ei.setDrpcClient(drpc.NewClientConnection(msg.DrpcListenerSock)) + ei.setDrpcSocket(msg.DrpcListenerSock) go func() { ei.drpcReady <- msg @@ -67,11 +72,12 @@ func (ei *EngineInstance) awaitDrpcReady() chan *srvpb.NotifyReadyReq { return ei.drpcReady } +func (ei *EngineInstance) isDrpcSocketReady() bool { + return ei.getDrpcSocket() != "" +} + func (ei *EngineInstance) callDrpc(ctx context.Context, method drpc.Method, body proto.Message) (*drpc.Response, error) { - dc, err := ei.getDrpcClient() - if err != nil { - return nil, err - } + dc := ei.getDrpcClient() rankMsg := "" if sb := ei.getSuperblock(); sb != nil && sb.Rank != nil { @@ -94,6 +100,9 @@ func (ei *EngineInstance) CallDrpc(ctx context.Context, method drpc.Method, body if !ei.IsReady() { return nil, errEngineNotReady } + if !ei.isDrpcSocketReady() { + return nil, errDRPCNotReady + } return ei.callDrpc(ctx, method, body) } diff --git a/src/control/server/instance_drpc_test.go b/src/control/server/instance_drpc_test.go index c42d9ce87f1b..6383ddcf236a 100644 --- a/src/control/server/instance_drpc_test.go +++ b/src/control/server/instance_drpc_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2020-2023 Intel Corporation. +// (C) Copyright 2020-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -7,7 +7,9 @@ package server import ( + "context" "fmt" + "sync" "testing" "time" @@ -15,6 +17,7 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/proto" + "github.com/daos-stack/daos/src/control/common/proto/mgmt" mgmtpb "github.com/daos-stack/daos/src/control/common/proto/mgmt" srvpb "github.com/daos-stack/daos/src/control/common/proto/srv" "github.com/daos-stack/daos/src/control/common/test" @@ -52,10 +55,7 @@ func TestEngineInstance_NotifyDrpcReady(t *testing.T) { instance.NotifyDrpcReady(req) - dc, err := instance.getDrpcClient() - if err != nil || dc == nil { - t.Fatal("Expected a dRPC client connection") - } + test.AssertEqual(t, req.DrpcListenerSock, instance.getDrpcSocket(), "expected socket value set") waitForEngineReady(t, instance) } @@ -64,6 +64,7 @@ func TestEngineInstance_CallDrpc(t *testing.T) { for name, tc := range map[string]struct { notStarted bool notReady bool + noSocket bool noClient bool resp *drpc.Response expErr error @@ -76,8 +77,8 @@ func TestEngineInstance_CallDrpc(t *testing.T) { notReady: true, expErr: errEngineNotReady, }, - "no client configured": { - noClient: true, + "drpc not ready": { + noSocket: true, expErr: errDRPCNotReady, }, "success": { @@ -94,11 +95,15 @@ func TestEngineInstance_CallDrpc(t *testing.T) { instance := NewEngineInstance(log, nil, nil, runner) instance.ready.Store(!tc.notReady) - if !tc.noClient { - cfg := &mockDrpcClientConfig{ - SendMsgResponse: tc.resp, - } - instance.setDrpcClient(newMockDrpcClient(cfg)) + if !tc.noSocket { + instance.setDrpcSocket("/something") + } + + cfg := &mockDrpcClientConfig{ + SendMsgResponse: tc.resp, + } + instance.getDrpcClientFn = func(s string) drpc.DomainSocketClient { + return newMockDrpcClient(cfg) } _, err := instance.CallDrpc(test.Context(t), @@ -108,6 +113,108 @@ func TestEngineInstance_CallDrpc(t *testing.T) { } } +type sendMsgDrpcClient struct { + sync.Mutex + sendMsgFn func(context.Context, *drpc.Call) (*drpc.Response, error) +} + +func (c *sendMsgDrpcClient) IsConnected() bool { + return true +} + +func (c *sendMsgDrpcClient) Connect(_ context.Context) error { + return nil +} + +func (c *sendMsgDrpcClient) Close() error { + return nil +} + +func (c *sendMsgDrpcClient) SendMsg(ctx context.Context, call *drpc.Call) (*drpc.Response, error) { + return c.sendMsgFn(ctx, call) +} + +func (c *sendMsgDrpcClient) GetSocketPath() string { + return "" +} + +func TestEngineInstance_CallDrpc_Parallel(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + + // This test starts with one long-running drpc client that should remain in the SendMsg + // function until all other clients complete, demonstrating a single dRPC call cannot + // block the channel. + + numClients := 100 + numFastClients := numClients - 1 + + doneCh := make(chan struct{}, numFastClients) + longClient := &sendMsgDrpcClient{ + sendMsgFn: func(ctx context.Context, _ *drpc.Call) (*drpc.Response, error) { + numDone := 0 + + for numDone < numFastClients { + select { + case <-ctx.Done(): + t.Fatalf("context done before test finished: %s", ctx.Err()) + case <-doneCh: + numDone++ + t.Logf("%d/%d finished", numDone, numFastClients) + } + } + t.Log("long running client finished") + return &drpc.Response{}, nil + }, + } + + clientCh := make(chan drpc.DomainSocketClient, numClients) + go func(t *testing.T) { + t.Log("starting client producer thread...") + t.Log("adding long-running client") + clientCh <- longClient + for i := 0; i < numFastClients; i++ { + t.Logf("adding client %d", i) + clientCh <- &sendMsgDrpcClient{ + sendMsgFn: func(ctx context.Context, _ *drpc.Call) (*drpc.Response, error) { + doneCh <- struct{}{} + return &drpc.Response{}, nil + }, + } + } + t.Log("closing client channel") + close(clientCh) + }(t) + + t.Log("setting up engine...") + trc := engine.TestRunnerConfig{} + trc.Running.Store(true) + runner := engine.NewTestRunner(&trc, engine.MockConfig()) + instance := NewEngineInstance(log, nil, nil, runner) + instance.ready.Store(true) + + instance.getDrpcClientFn = func(s string) drpc.DomainSocketClient { + t.Log("fetching drpc client") + cli := <-clientCh + t.Log("got drpc client") + return cli + } + + var wg sync.WaitGroup + wg.Add(numClients) + for i := 0; i < numClients; i++ { + go func(t *testing.T, j int) { + t.Logf("%d: CallDrpc", j) + _, err := instance.CallDrpc(test.Context(t), drpc.MethodPoolCreate, &mgmt.PoolCreateReq{}) + if err != nil { + t.Logf("%d: error: %s", j, err.Error()) + } + wg.Done() + }(t, i) + } + wg.Wait() +} + func TestEngineInstance_DrespToRankResult(t *testing.T) { dRank := Rank(1) diff --git a/src/control/server/mgmt_cont_test.go b/src/control/server/mgmt_cont_test.go index 55efb4176494..cc4db24ec3b0 100644 --- a/src/control/server/mgmt_cont_test.go +++ b/src/control/server/mgmt_cont_test.go @@ -91,7 +91,7 @@ func TestMgmt_ListContainers(t *testing.T) { }, "drpc error": { setupDrpc: func(t *testing.T, svc *mgmtSvc) { - setupMockDrpcClient(svc, nil, errors.New("mock drpc")) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(nil, errors.New("mock drpc"))) }, req: validListContReq(), expErr: errors.New("mock drpc"), @@ -99,23 +99,24 @@ func TestMgmt_ListContainers(t *testing.T) { "bad drpc resp": { setupDrpc: func(t *testing.T, svc *mgmtSvc) { badBytes := makeBadBytes(16) - setupMockDrpcClientBytes(svc, badBytes, nil) + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(badBytes, nil)) }, req: validListContReq(), expErr: errors.New("unmarshal"), }, "success; zero containers": { setupDrpc: func(t *testing.T, svc *mgmtSvc) { - setupMockDrpcClient(svc, &mgmtpb.ListContResp{}, nil) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(&mgmtpb.ListContResp{}, nil)) }, req: validListContReq(), expResp: &mgmtpb.ListContResp{}, }, "success; multiple containers": { setupDrpc: func(t *testing.T, svc *mgmtSvc) { - setupMockDrpcClient(svc, &mgmtpb.ListContResp{ - Containers: multiConts, - }, nil) + setupSvcDrpcClient(svc, 0, + getMockDrpcClient(&mgmtpb.ListContResp{ + Containers: multiConts, + }, nil)) }, req: validListContReq(), expResp: &mgmtpb.ListContResp{ @@ -190,7 +191,7 @@ func TestMgmt_ContSetOwner(t *testing.T) { }, "drpc error": { setupDrpc: func(t *testing.T, svc *mgmtSvc) { - setupMockDrpcClient(svc, nil, errors.New("mock drpc")) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(nil, errors.New("mock drpc"))) }, req: validContSetOwnerReq(), expErr: errors.New("mock drpc"), @@ -198,14 +199,14 @@ func TestMgmt_ContSetOwner(t *testing.T) { "bad drpc resp": { setupDrpc: func(t *testing.T, svc *mgmtSvc) { badBytes := makeBadBytes(16) - setupMockDrpcClientBytes(svc, badBytes, nil) + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(badBytes, nil)) }, req: validContSetOwnerReq(), expErr: errors.New("unmarshal"), }, "success": { setupDrpc: func(t *testing.T, svc *mgmtSvc) { - setupMockDrpcClient(svc, &mgmtpb.ContSetOwnerResp{}, nil) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(&mgmtpb.ContSetOwnerResp{}, nil)) }, req: validContSetOwnerReq(), expResp: &mgmtpb.ContSetOwnerResp{ diff --git a/src/control/server/mgmt_pool_test.go b/src/control/server/mgmt_pool_test.go index f5384ff7adeb..931458728d11 100644 --- a/src/control/server/mgmt_pool_test.go +++ b/src/control/server/mgmt_pool_test.go @@ -143,7 +143,8 @@ func TestServer_MgmtSvc_PoolCreateAlreadyExists(t *testing.T) { defer test.ShowBufferOnFailure(t, buf) svc := newTestMgmtSvc(t, log) - setupMockDrpcClient(svc, tc.queryResp, tc.queryErr) + mdc := getMockDrpcClient(tc.queryResp, tc.queryErr) + setupSvcDrpcClient(svc, 0, mdc) if _, err := svc.membership.Add(system.MockMember(t, 1, system.MemberStateJoined)); err != nil { t.Fatal(err) } @@ -377,7 +378,7 @@ func TestServer_MgmtSvc_PoolCreate(t *testing.T) { // dRPC call returns junk in the message body badBytes := makeBadBytes(42) - setupMockDrpcClientBytes(svc, badBytes, err) + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(badBytes, err)) }, expErr: errors.New("unmarshal"), }, @@ -492,6 +493,7 @@ func TestServer_MgmtSvc_PoolCreate(t *testing.T) { mp := storage.NewProvider(log, 0, &engineCfg.Storage, nil, nil, nil, nil) srv := NewEngineInstance(log, mp, nil, r) + srv.setDrpcSocket("/dontcare") srv.ready.SetTrue() harness := NewEngineHarness(log) @@ -563,7 +565,7 @@ func TestServer_MgmtSvc_PoolCreateDownRanks(t *testing.T) { dc := newMockDrpcClient(&mockDrpcClientConfig{IsConnectedBool: true}) dc.cfg.setSendMsgResponse(drpc.Status_SUCCESS, nil, nil) - mgmtSvc.harness.instances[0].(*EngineInstance)._drpcClient = dc + mgmtSvc.harness.instances[0].(*EngineInstance).getDrpcClientFn = func(s string) drpc.DomainSocketClient { return dc } for _, m := range []*system.Member{ system.MockMember(t, 0, system.MemberStateJoined), @@ -955,7 +957,8 @@ func TestServer_MgmtSvc_PoolDestroy(t *testing.T) { setupMockDrpcClient(tc.mgmtSvc, tc.expResp, tc.expErr) } } - tc.setupMockDrpc(tc.mgmtSvc, tc.expErr) + mdc := getMockDrpcClient(tc.expResp, tc.expErr) + setupSvcDrpcClient(tc.mgmtSvc, 0, mdc) if tc.req != nil && tc.req.Sys == "" { tc.req.Sys = build.DefaultSystemName @@ -991,7 +994,7 @@ func TestServer_MgmtSvc_PoolDestroy(t *testing.T) { if tc.expDrpcReq != nil { gotReq := new(mgmtpb.PoolDestroyReq) - if err := proto.Unmarshal(getLastMockCall(tc.mgmtSvc).Body, gotReq); err != nil { + if err := proto.Unmarshal(getLastMockCall(mdc).Body, gotReq); err != nil { t.Fatal(err) } if diff := cmp.Diff(tc.expDrpcReq, gotReq, cmpOpts...); diff != "" { @@ -1000,7 +1003,7 @@ func TestServer_MgmtSvc_PoolDestroy(t *testing.T) { } if tc.expDrpcEvReq != nil { gotReq := new(mgmtpb.PoolEvictReq) - if err := proto.Unmarshal(getLastMockCall(tc.mgmtSvc).Body, gotReq); err != nil { + if err := proto.Unmarshal(getLastMockCall(mdc).Body, gotReq); err != nil { t.Fatal(err) } if diff := cmp.Diff(tc.expDrpcEvReq, gotReq, cmpOpts...); diff != "" { @@ -1009,7 +1012,7 @@ func TestServer_MgmtSvc_PoolDestroy(t *testing.T) { } if tc.expDrpcListContReq != nil { gotReq := new(mgmtpb.ListContReq) - if err := proto.Unmarshal(getLastMockCall(tc.mgmtSvc).Body, gotReq); err != nil { + if err := proto.Unmarshal(getLastMockCall(mdc).Body, gotReq); err != nil { t.Fatal(err) } if diff := cmp.Diff(tc.expDrpcListContReq, gotReq, cmpOpts...); diff != "" { @@ -1072,7 +1075,7 @@ func TestServer_MgmtSvc_PoolExtend(t *testing.T) { // dRPC call returns junk in the message body badBytes := makeBadBytes(42) - setupMockDrpcClientBytes(svc, badBytes, err) + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(badBytes, err)) }, expErr: errors.New("unmarshal"), }, @@ -1098,7 +1101,7 @@ func TestServer_MgmtSvc_PoolExtend(t *testing.T) { if tc.setupMockDrpc == nil { tc.setupMockDrpc = func(svc *mgmtSvc, err error) { - setupMockDrpcClient(tc.mgmtSvc, tc.expResp, tc.expErr) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(tc.expResp, tc.expErr)) } } tc.setupMockDrpc(tc.mgmtSvc, tc.expErr) @@ -1174,7 +1177,7 @@ func TestServer_MgmtSvc_PoolDrain(t *testing.T) { // dRPC call returns junk in the message body badBytes := makeBadBytes(42) - setupMockDrpcClientBytes(svc, badBytes, err) + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(badBytes, err)) }, expErr: errors.New("unmarshal"), }, @@ -1198,7 +1201,7 @@ func TestServer_MgmtSvc_PoolDrain(t *testing.T) { if tc.setupMockDrpc == nil { tc.setupMockDrpc = func(svc *mgmtSvc, err error) { - setupMockDrpcClient(tc.mgmtSvc, tc.expResp, tc.expErr) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(tc.expResp, tc.expErr)) } } tc.setupMockDrpc(tc.mgmtSvc, tc.expErr) @@ -1266,7 +1269,7 @@ func TestServer_MgmtSvc_PoolEvict(t *testing.T) { // dRPC call returns junk in the message body badBytes := makeBadBytes(42) - setupMockDrpcClientBytes(svc, badBytes, err) + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(badBytes, err)) }, expErr: errors.New("unmarshal"), }, @@ -1290,7 +1293,7 @@ func TestServer_MgmtSvc_PoolEvict(t *testing.T) { if tc.setupMockDrpc == nil { tc.setupMockDrpc = func(svc *mgmtSvc, err error) { - setupMockDrpcClient(tc.mgmtSvc, tc.expResp, tc.expErr) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(tc.expResp, tc.expErr)) } } tc.setupMockDrpc(tc.mgmtSvc, tc.expErr) @@ -1425,7 +1428,7 @@ func TestPoolGetACL_Success(t *testing.T) { Entries: []string{"A::OWNER@:rw", "A:g:GROUP@:r"}, }, } - setupMockDrpcClient(svc, expectedResp, nil) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(expectedResp, nil)) resp, err := svc.PoolGetACL(test.Context(t), newTestGetACLReq()) @@ -1446,7 +1449,7 @@ func TestPoolGetACL_DrpcFailed(t *testing.T) { svc := newTestMgmtSvc(t, log) addTestPools(t, svc.sysdb, mockUUID) expectedErr := errors.New("mock error") - setupMockDrpcClient(svc, nil, expectedErr) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(nil, expectedErr)) resp, err := svc.PoolGetACL(test.Context(t), newTestGetACLReq()) @@ -1466,7 +1469,7 @@ func TestPoolGetACL_BadDrpcResp(t *testing.T) { // dRPC call returns junk in the message body badBytes := makeBadBytes(12) - setupMockDrpcClientBytes(svc, badBytes, nil) + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(badBytes, nil)) resp, err := svc.PoolGetACL(test.Context(t), newTestGetACLReq()) @@ -1509,7 +1512,7 @@ func TestPoolOverwriteACL_DrpcFailed(t *testing.T) { svc := newTestMgmtSvc(t, log) addTestPools(t, svc.sysdb, mockUUID) expectedErr := errors.New("mock error") - setupMockDrpcClient(svc, nil, expectedErr) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(nil, expectedErr)) resp, err := svc.PoolOverwriteACL(test.Context(t), newTestModifyACLReq()) @@ -1529,7 +1532,7 @@ func TestPoolOverwriteACL_BadDrpcResp(t *testing.T) { // dRPC call returns junk in the message body badBytes := makeBadBytes(16) - setupMockDrpcClientBytes(svc, badBytes, nil) + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(badBytes, nil)) resp, err := svc.PoolOverwriteACL(test.Context(t), newTestModifyACLReq()) @@ -1553,7 +1556,7 @@ func TestPoolOverwriteACL_Success(t *testing.T) { Entries: []string{"A::OWNER@:rw", "A:g:GROUP@:r"}, }, } - setupMockDrpcClient(svc, expectedResp, nil) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(expectedResp, nil)) resp, err := svc.PoolOverwriteACL(test.Context(t), newTestModifyACLReq()) @@ -1589,7 +1592,7 @@ func TestPoolUpdateACL_DrpcFailed(t *testing.T) { svc := newTestMgmtSvc(t, log) addTestPools(t, svc.sysdb, mockUUID) expectedErr := errors.New("mock error") - setupMockDrpcClient(svc, nil, expectedErr) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(nil, expectedErr)) resp, err := svc.PoolUpdateACL(test.Context(t), newTestModifyACLReq()) @@ -1609,7 +1612,7 @@ func TestPoolUpdateACL_BadDrpcResp(t *testing.T) { // dRPC call returns junk in the message body badBytes := makeBadBytes(16) - setupMockDrpcClientBytes(svc, badBytes, nil) + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(badBytes, nil)) resp, err := svc.PoolUpdateACL(test.Context(t), newTestModifyACLReq()) @@ -1633,7 +1636,7 @@ func TestPoolUpdateACL_Success(t *testing.T) { Entries: []string{"A::OWNER@:rw", "A:g:GROUP@:r"}, }, } - setupMockDrpcClient(svc, expectedResp, nil) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(expectedResp, nil)) resp, err := svc.PoolUpdateACL(test.Context(t), newTestModifyACLReq()) @@ -1677,7 +1680,7 @@ func TestPoolDeleteACL_DrpcFailed(t *testing.T) { svc := newTestMgmtSvc(t, log) addTestPools(t, svc.sysdb, mockUUID) expectedErr := errors.New("mock error") - setupMockDrpcClient(svc, nil, expectedErr) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(nil, expectedErr)) resp, err := svc.PoolDeleteACL(test.Context(t), newTestDeleteACLReq()) @@ -1697,7 +1700,7 @@ func TestPoolDeleteACL_BadDrpcResp(t *testing.T) { // dRPC call returns junk in the message body badBytes := makeBadBytes(16) - setupMockDrpcClientBytes(svc, badBytes, nil) + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(badBytes, nil)) resp, err := svc.PoolDeleteACL(test.Context(t), newTestDeleteACLReq()) @@ -1721,7 +1724,7 @@ func TestPoolDeleteACL_Success(t *testing.T) { Entries: []string{"A::OWNER@:rw", "A:G:readers@:r"}, }, } - setupMockDrpcClient(svc, expectedResp, nil) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(expectedResp, nil)) resp, err := svc.PoolDeleteACL(test.Context(t), newTestDeleteACLReq()) @@ -1787,7 +1790,7 @@ func TestServer_MgmtSvc_PoolQuery(t *testing.T) { // dRPC call returns junk in the message body badBytes := makeBadBytes(42) - setupMockDrpcClientBytes(svc, badBytes, err) + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(badBytes, err)) }, expErr: errors.New("unmarshal"), }, @@ -1818,7 +1821,7 @@ func TestServer_MgmtSvc_PoolQuery(t *testing.T) { if tc.setupMockDrpc == nil { tc.setupMockDrpc = func(svc *mgmtSvc, err error) { - setupMockDrpcClient(tc.mgmtSvc, tc.expResp, tc.expErr) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(tc.expResp, tc.expErr)) } } tc.setupMockDrpc(tc.mgmtSvc, tc.expErr) @@ -1841,22 +1844,17 @@ func TestServer_MgmtSvc_PoolQuery(t *testing.T) { } } -func getLastMockCall(svc *mgmtSvc) *drpc.Call { - mi := svc.harness.instances[0].(*EngineInstance) - if mi == nil || mi._drpcClient == nil { - return nil - } - - return mi._drpcClient.(*mockDrpcClient).SendMsgInputCall +func getLastMockCall(mdc *mockDrpcClient) *drpc.Call { + return mdc.SendMsgInputCall } func TestServer_MgmtSvc_PoolSetProp(t *testing.T) { for name, tc := range map[string]struct { - setupMockDrpc func(_ *mgmtSvc, _ error) - drpcResp *mgmtpb.PoolSetPropResp - req *mgmtpb.PoolSetPropReq - expDrpcReq *mgmtpb.PoolSetPropReq - expErr error + getMockDrpc func(error) *mockDrpcClient + drpcResp *mgmtpb.PoolSetPropResp + req *mgmtpb.PoolSetPropReq + expDrpcReq *mgmtpb.PoolSetPropReq + expErr error }{ "wrong system": { req: &mgmtpb.PoolSetPropReq{Id: mockUUID, Sys: "bad"}, @@ -1872,11 +1870,11 @@ func TestServer_MgmtSvc_PoolSetProp(t *testing.T) { }, }, }, - setupMockDrpc: func(svc *mgmtSvc, err error) { + getMockDrpc: func(err error) *mockDrpcClient { // dRPC call returns junk in the message body badBytes := makeBadBytes(42) - setupMockDrpcClientBytes(svc, badBytes, err) + return getMockDrpcClientBytes(badBytes, err) }, expErr: errors.New("unmarshal"), }, @@ -1961,12 +1959,14 @@ func TestServer_MgmtSvc_PoolSetProp(t *testing.T) { if tc.req.Id != mockUUID { addTestPools(t, ms.sysdb, tc.req.Id) } - if tc.setupMockDrpc == nil { - tc.setupMockDrpc = func(svc *mgmtSvc, err error) { - setupMockDrpcClient(svc, tc.drpcResp, tc.expErr) + + if tc.getMockDrpc == nil { + tc.getMockDrpc = func(err error) *mockDrpcClient { + return getMockDrpcClient(tc.drpcResp, err) } } - tc.setupMockDrpc(ms, tc.expErr) + mdc := tc.getMockDrpc(tc.expErr) + setupSvcDrpcClient(ms, 0, mdc) if tc.req != nil && tc.req.Sys == "" { tc.req.Sys = build.DefaultSystemName @@ -1978,7 +1978,7 @@ func TestServer_MgmtSvc_PoolSetProp(t *testing.T) { } lastReq := new(mgmtpb.PoolSetPropReq) - if err := proto.Unmarshal(getLastMockCall(ms).Body, lastReq); err != nil { + if err := proto.Unmarshal(getLastMockCall(mdc).Body, lastReq); err != nil { t.Fatal(err) } if diff := cmp.Diff(tc.expDrpcReq, lastReq, test.DefaultCmpOpts()...); diff != "" { @@ -2013,7 +2013,7 @@ func TestServer_MgmtSvc_PoolGetProp(t *testing.T) { // dRPC call returns junk in the message body badBytes := makeBadBytes(42) - setupMockDrpcClientBytes(svc, badBytes, err) + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(badBytes, err)) }, expErr: errors.New("unmarshal"), }, @@ -2046,7 +2046,7 @@ func TestServer_MgmtSvc_PoolGetProp(t *testing.T) { } if tc.setupMockDrpc == nil { tc.setupMockDrpc = func(svc *mgmtSvc, err error) { - setupMockDrpcClient(svc, tc.drpcResp, tc.expErr) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(tc.drpcResp, tc.expErr)) } } tc.setupMockDrpc(ms, tc.expErr) @@ -2108,7 +2108,7 @@ func TestServer_MgmtSvc_PoolUpgrade(t *testing.T) { // dRPC call returns junk in the message body badBytes := makeBadBytes(42) - setupMockDrpcClientBytes(svc, badBytes, err) + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(badBytes, err)) }, expErr: errors.New("unmarshal"), }, @@ -2132,7 +2132,7 @@ func TestServer_MgmtSvc_PoolUpgrade(t *testing.T) { if tc.setupMockDrpc == nil { tc.setupMockDrpc = func(svc *mgmtSvc, err error) { - setupMockDrpcClient(tc.mgmtSvc, tc.expResp, tc.expErr) + setupSvcDrpcClient(svc, 0, getMockDrpcClient(tc.expResp, tc.expErr)) } } tc.setupMockDrpc(tc.mgmtSvc, tc.expErr) diff --git a/src/control/server/mgmt_system_test.go b/src/control/server/mgmt_system_test.go index ced317d66526..b274044662cf 100644 --- a/src/control/server/mgmt_system_test.go +++ b/src/control/server/mgmt_system_test.go @@ -199,7 +199,10 @@ func TestServer_MgmtSvc_GetAttachInfo(t *testing.T) { if err := harness.AddInstance(srv); err != nil { t.Fatal(err) } - srv.setDrpcClient(newMockDrpcClient(nil)) + + srv.getDrpcClientFn = func(s string) drpc.DomainSocketClient { + return newMockDrpcClient(nil) + } harness.started.SetTrue() db := raft.MockDatabaseWithAddr(t, log, msReplica.Addr) @@ -2031,7 +2034,8 @@ func TestServer_MgmtSvc_Join(t *testing.T) { } peerCtx := peer.NewContext(test.Context(t), &peer.Peer{Addr: peerAddr}) - setupMockDrpcClient(svc, tc.guResp, nil) + mdc := getMockDrpcClient(tc.guResp, nil) + setupSvcDrpcClient(svc, 0, mdc) gotResp, gotErr := svc.Join(peerCtx, tc.req) test.CmpErr(t, tc.expErr, gotErr) @@ -2047,8 +2051,6 @@ func TestServer_MgmtSvc_Join(t *testing.T) { return } - ei := svc.harness.instances[0].(*EngineInstance) - mdc := ei._drpcClient.(*mockDrpcClient) gotGuReq := new(mgmtpb.GroupUpdateReq) calls := mdc.calls.get() // wait for GroupUpdate diff --git a/src/control/server/util_test.go b/src/control/server/util_test.go index f34c6d16f669..974a9c447326 100644 --- a/src/control/server/util_test.go +++ b/src/control/server/util_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2019-2023 Intel Corporation. +// (C) Copyright 2019-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -171,17 +171,31 @@ func newMockDrpcClient(cfg *mockDrpcClientConfig) *mockDrpcClient { // setupMockDrpcClientBytes sets up the dRPC client for the mgmtSvc to return // a set of bytes as a response. func setupMockDrpcClientBytes(svc *mgmtSvc, respBytes []byte, err error) { - mi := svc.harness.instances[0] + setupSvcDrpcClient(svc, 0, getMockDrpcClientBytes(respBytes, err)) +} + +func getMockDrpcClientBytes(respBytes []byte, err error) *mockDrpcClient { cfg := &mockDrpcClientConfig{} cfg.setSendMsgResponse(drpc.Status_SUCCESS, respBytes, err) - mi.(*EngineInstance).setDrpcClient(newMockDrpcClient(cfg)) + return newMockDrpcClient(cfg) } // setupMockDrpcClient sets up the dRPC client for the mgmtSvc to return // a valid protobuf message as a response. func setupMockDrpcClient(svc *mgmtSvc, resp proto.Message, err error) { + setupSvcDrpcClient(svc, 0, getMockDrpcClient(resp, err)) +} + +// getMockDrpcClient sets up the dRPC client to return a valid protobuf message as a response. +func getMockDrpcClient(resp proto.Message, err error) *mockDrpcClient { respBytes, _ := proto.Marshal(resp) - setupMockDrpcClientBytes(svc, respBytes, err) + return getMockDrpcClientBytes(respBytes, err) +} + +func setupSvcDrpcClient(svc *mgmtSvc, engineIdx int, mdc *mockDrpcClient) { + svc.harness.instances[engineIdx].(*EngineInstance).getDrpcClientFn = func(_ string) drpc.DomainSocketClient { + return mdc + } } // newTestEngine returns an EngineInstance configured for testing. @@ -201,6 +215,7 @@ func newTestEngine(log logging.Logger, isAP bool, provider *storage.Provider, en r := engine.NewTestRunner(rCfg, engineCfg[0]) srv := NewEngineInstance(log, provider, nil, r) + srv.setDrpcSocket("/dontcare") srv.setSuperblock(&Superblock{ Rank: ranklist.NewRankPtr(0), })