diff --git a/server/detectionshandler_test.go b/server/detectionshandler_test.go index 832a5c52..dbd13d27 100644 --- a/server/detectionshandler_test.go +++ b/server/detectionshandler_test.go @@ -2178,7 +2178,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) { Name string NewStatus string ReqBody []byte - InitMock func(*testing.T, *Server, *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster) + InitMock func(*testing.T, *Server, *gomock.Controller, *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster) Code int Response any Logs []EntryMatcher @@ -2188,7 +2188,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) { Name: "Sunny Day - IDs", NewStatus: "enable", ReqBody: []byte(`{"ids":["123","456","789"]}`), - InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster) { + InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller, nonAsyncWG *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster) { mDetStore := srv.Detectionstore.(*servermock.MockDetectionstore) mAuth := srv.Authorizer.(*rbac.FakeAuthorizer) mHostAuth := srv.Host.Authorizer.(*rbac.FakeAuthorizer) @@ -2225,7 +2225,11 @@ func TestHandlerBulkUpdateDetection(t *testing.T) { docIndexer := servermock.NewMockBulkIndexer(ctrl) - mDetStore.EXPECT().BuildBulkIndexer(gomock.Any(), gomock.Any()).Return(docIndexer, nil) + mDetStore.EXPECT().BuildBulkIndexer(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, logger *log.Entry) (esutil.BulkIndexer, error) { + nonAsyncWG.Wait() + + return docIndexer, nil + }) engElastAlert.EXPECT().ApplyFilters(gomock.Any()).Return(false, nil) engSuricata.EXPECT().ApplyFilters(gomock.Any()).Return(false, nil) @@ -2341,7 +2345,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) { Name: "Sunny Day - Query", NewStatus: "enable", ReqBody: []byte(`{"query":"severity: low AND ruleset: ETOPEN"}`), - InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster) { + InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller, nonAsyncWG *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster) { mDetStore := srv.Detectionstore.(*servermock.MockDetectionstore) mAuth := srv.Authorizer.(*rbac.FakeAuthorizer) mHostAuth := srv.Host.Authorizer.(*rbac.FakeAuthorizer) @@ -2380,7 +2384,11 @@ func TestHandlerBulkUpdateDetection(t *testing.T) { docIndexer := servermock.NewMockBulkIndexer(ctrl) - mDetStore.EXPECT().BuildBulkIndexer(gomock.Any(), gomock.Any()).Return(docIndexer, nil) + mDetStore.EXPECT().BuildBulkIndexer(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, logger *log.Entry) (esutil.BulkIndexer, error) { + nonAsyncWG.Wait() + + return docIndexer, nil + }) engElastAlert.EXPECT().ApplyFilters(gomock.Any()).Return(false, nil) engSuricata.EXPECT().ApplyFilters(gomock.Any()).Return(false, nil) @@ -2496,7 +2504,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) { Name: "Cannot Delete Community Rules - Ids", NewStatus: "delete", ReqBody: []byte(`{"ids":["123","456","789"]}`), - InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster) { + InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller, _ *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster) { mDetStore := srv.Detectionstore.(*servermock.MockDetectionstore) mAuth := srv.Authorizer.(*rbac.FakeAuthorizer) @@ -2530,7 +2538,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) { Name: "Cannot Delete Community Rules - Query", NewStatus: "delete", ReqBody: []byte(`{"query":"severity: low AND ruleset: ETOPEN"}`), - InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster) { + InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller, _ *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster) { mDetStore := srv.Detectionstore.(*servermock.MockDetectionstore) mAuth := srv.Authorizer.(*rbac.FakeAuthorizer) @@ -2570,7 +2578,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) { Name: "Query Failure", NewStatus: "enable", ReqBody: []byte(`{"query":"severity: low AND ruleset: ETOPEN"}`), - InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster) { + InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller, _ *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster) { mDetStore := srv.Detectionstore.(*servermock.MockDetectionstore) mAuth := srv.Authorizer.(*rbac.FakeAuthorizer) @@ -2591,7 +2599,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) { Name: "Unauthorized", NewStatus: "disable", ReqBody: []byte(`{"query":"severity: low AND ruleset: ETOPEN"}`), - InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster) { + InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller, _ *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster) { return nil, nil }, Code: 401, @@ -2613,7 +2621,10 @@ func TestHandlerBulkUpdateDetection(t *testing.T) { h := NewDetectionHandler(srv) - wg, mb := test.InitMock(t, srv, ctrl) + nonAsyncWG := &sync.WaitGroup{} + nonAsyncWG.Add(1) + + asyncWG, mb := test.InitMock(t, srv, ctrl, nonAsyncWG) if mb != nil { defer mb.Close() } @@ -2632,8 +2643,10 @@ func TestHandlerBulkUpdateDetection(t *testing.T) { r := httptest.NewRequestWithContext(ctx, "PUT", fmt.Sprintf("/detection/bulk/%s", test.NewStatus), bytes.NewReader(test.ReqBody)) h.BulkUpdateDetection(w, r) - if wg != nil { - wg.Wait() + nonAsyncWG.Done() + + if asyncWG != nil { + asyncWG.Wait() } assert.Equal(t, test.Code, w.Code)