Skip to content

Commit

Permalink
Broadcast Testing, Organizing, More Tests
Browse files Browse the repository at this point in the history
While writing tests for DeleteDetection, I noticed the store wasn't checking permissions, this means a user without write/detections permissions could delete a detection and then receive a 403 when attempting to sync (because the sync properly checks permissions).

Added the ability to test for broadcasting and write matchers similar to how we test log messages.

Previously we let 200 responses be delivered automatically by not setting a status code on the response writer, but this skips a log statement that is useful for request timing. As I go, I'm adding a 200 response to any handlers that are missing them for the sake of the log.

Added more tests.
  • Loading branch information
coreyogburn committed Jan 9, 2025
1 parent f9693da commit 78ab44b
Show file tree
Hide file tree
Showing 5 changed files with 880 additions and 127 deletions.
145 changes: 145 additions & 0 deletions server/broadcast_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package server

import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/gorilla/websocket"
)

type MockBroadcaster struct {
Server *httptest.Server
Messages []BroadcastMessage
onmsg func(BroadcastMessage)
}

type BroadcastMessage struct {
Type int
Kind string
Object map[string]interface{}
}

func (mb *MockBroadcaster) collectMessages(w http.ResponseWriter, r *http.Request) {
c, err := (&websocket.Upgrader{}).Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()

for {
mt, message, err := c.ReadMessage()
if err != nil {
break
}

m := map[string]interface{}{}
err = json.Unmarshal(message, &m)
if err != nil {
break
}

kind, ok := m["Kind"].(string)
if !ok {
break
}

obj, ok := m["Object"].(map[string]interface{})
if !ok {
break
}

msg := BroadcastMessage{Type: mt, Kind: kind, Object: obj}
mb.Messages = append(mb.Messages, msg)
if mb.onmsg != nil {
mb.onmsg(msg)
}
}
}

func (mb *MockBroadcaster) Close() {
mb.Server.Close()
}

type BroadcastMatcher struct {
validators []BroadcastValidator
}

type BroadcastValidator func(BroadcastMessage) error

func NewBroadcastMatcher(opts ...BroadcastValidator) BroadcastMatcher {
m := BroadcastMatcher{
validators: opts,
}

return m
}

func (m *BroadcastMatcher) Validate(msg BroadcastMessage) error {
for _, v := range m.validators {
if err := v(msg); err != nil {
return err
}
}

return nil
}

func BroadcastKindEq(kind string) BroadcastValidator {
return func(msg BroadcastMessage) error {
if msg.Kind != kind {
return fmt.Errorf("expected kind %q, got %q", kind, msg.Kind)
}

return nil
}
}

func BroadcastObjectFieldExists(key string) BroadcastValidator {
return func(msg BroadcastMessage) error {
_, ok := msg.Object[key]
if !ok {
return fmt.Errorf("expected field %q to exist", key)
}

return nil
}
}

func BroadcastObjectFieldEq(key string, value interface{}) BroadcastValidator {
switch v := value.(type) {
case int:
value = float64(v)
}

return func(msg BroadcastMessage) error {
if msg.Object[key] != value {
return fmt.Errorf("expected field %q to be %v, got %v", key, value, msg.Object[key])
}

return nil
}
}

func MockBroadcast(t *testing.T, srv *Server, onmsg func(BroadcastMessage)) *MockBroadcaster {
mb := &MockBroadcaster{onmsg: onmsg}
s := httptest.NewServer(http.HandlerFunc(mb.collectMessages))

// Convert http://127.0.0.1 to ws://127.0.0.1
u := "ws" + strings.TrimPrefix(s.URL, "http")

// Connect to the server
wsconn, _, err := websocket.DefaultDialer.Dial(u, nil)
if err != nil {
t.Fatalf("%v", err)
}

srv.Host.AddConnection("00000000-0000-0000-0000-000000000000", wsconn, "127.0.0.1")

mb.Server = s

return mb
}
58 changes: 35 additions & 23 deletions server/detectionhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ func RegisterDetectionRoutes(srv *Server, r chi.Router, prefix string) {
r.Post("/convert", h.convertContent)

r.Put("/", h.UpdateDetection)
r.Put("/{id}/override/{overrideIndex}/note", h.updateOverrideNote)
r.Put("/{id}/override/{overrideIndex}/note", h.UpdateOverrideNote)

r.Delete("/{id}", h.deleteDetection)
r.Delete("/{id}", h.DeleteDetection)

r.Post("/bulk/{newStatus}", h.bulkUpdateDetection)
r.Post("/bulk/{newStatus}", h.BulkUpdateDetection)
r.Post("/sync/{engine}/{type}", h.syncEngineDetections)

r.Get("/{engine}/genpublicid", h.genPublicId)
Expand Down Expand Up @@ -532,7 +532,7 @@ func (h *DetectionHandler) UpdateDetection(w http.ResponseWriter, r *http.Reques
// @Failure 403 "Insufficient permissions for this request"
// @Failure 500 "Internal SOC error; review SOC logs"
// @Router /connect/detection/{id}/override/{overrideIndex}/note [put]
func (h *DetectionHandler) updateOverrideNote(w http.ResponseWriter, r *http.Request) {
func (h *DetectionHandler) UpdateOverrideNote(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

detectId := chi.URLParam(r, "id")
Expand Down Expand Up @@ -563,6 +563,8 @@ func (h *DetectionHandler) updateOverrideNote(w http.ResponseWriter, r *http.Req

return
}

web.Respond(w, r, http.StatusOK, nil)
}

// @Summary Delete Detection
Expand All @@ -574,16 +576,22 @@ func (h *DetectionHandler) updateOverrideNote(w http.ResponseWriter, r *http.Req
// @Failure 400 "The provided input object or parameters are malformed or invalid"
// @Failure 401 "Request was not properly authenticated"
// @Failure 403 "Insufficient permissions for this request"
// @Failure 404 "Detection not found"
// @Failure 500 "Internal SOC error; review SOC logs"
// @Router /connect/detection/{id} [delete]
func (h *DetectionHandler) deleteDetection(w http.ResponseWriter, r *http.Request) {
func (h *DetectionHandler) DeleteDetection(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

id := chi.URLParam(r, "id")

det, err := h.server.Detectionstore.GetDetection(ctx, id)
if err != nil {
web.Respond(w, r, http.StatusInternalServerError, err)
if err.Error() == "Object not found" {
web.Respond(w, r, http.StatusNotFound, nil)
} else {
web.Respond(w, r, http.StatusInternalServerError, err)
}

return
}

Expand All @@ -594,7 +602,13 @@ func (h *DetectionHandler) deleteDetection(w http.ResponseWriter, r *http.Reques

old, err := h.server.Detectionstore.DeleteDetection(ctx, id)
if err != nil {
web.Respond(w, r, http.StatusInternalServerError, err)
unauth := &model.Unauthorized{}
if errors.As(err, &unauth) {
web.Respond(w, r, http.StatusForbidden, err)
} else {
web.Respond(w, r, http.StatusInternalServerError, err)
}

return
}

Expand Down Expand Up @@ -622,11 +636,11 @@ func (h *DetectionHandler) deleteDetection(w http.ResponseWriter, r *http.Reques
// @Failure 403 "Insufficient permissions for this request"
// @Failure 500 "Internal SOC error; review SOC logs"
// @Router /connect/detection/bulk/{newStatus} [post]
func (h *DetectionHandler) bulkUpdateDetection(w http.ResponseWriter, r *http.Request) {
func (h *DetectionHandler) BulkUpdateDetection(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
logger := log.FromContext(ctx)

newStatus := chi.URLParam(r, "newStatus") // "enable" or "disable"
newStatus := chi.URLParam(r, "newStatus")

var enabled bool
var delete bool
Expand All @@ -636,7 +650,7 @@ func (h *DetectionHandler) bulkUpdateDetection(w http.ResponseWriter, r *http.Re
case "delete":
delete = true
default:
web.Respond(w, r, http.StatusBadRequest, fmt.Errorf("invalid status; must be 'enable' or 'disable'"))
web.Respond(w, r, http.StatusBadRequest, fmt.Errorf("invalid status; must be 'enable', 'disable', or 'delete'"))
return
}

Expand Down Expand Up @@ -782,11 +796,20 @@ func (h *DetectionHandler) bulkUpdateDetectionAsync(ctx context.Context, body *B
detect := detects[i]
id := detect.Id

engine, ok := h.server.DetectionEngines[detect.Engine]
if !ok {
logger.WithFields(log.Fields{
"publicId": detect.PublicID,
"engine": detect.Engine,
}).Error("detection has unsupported engine, skipping")
errMap[detect.PublicID] = "unsupported engine"

continue
}

if !body.Delete {
detect.IsEnabled = body.NewStatus

engine := h.server.DetectionEngines[detect.Engine]

filterApplied, err := engine.ApplyFilters(detect)
if err != nil {
logger.WithError(err).WithFields(log.Fields{
Expand All @@ -802,17 +825,6 @@ func (h *DetectionHandler) bulkUpdateDetectionAsync(ctx context.Context, body *B
}
}

engine, ok := h.server.DetectionEngines[detect.Engine]
if !ok {
logger.WithFields(log.Fields{
"publicId": detect.PublicID,
"engine": detect.Engine,
}).Error("detection has unsupported engine, skipping")
errMap[detect.PublicID] = "unsupported engine"

continue
}

exErr := engine.ExtractDetails(detect)
if exErr != nil {
logger.WithField("publicId", detect.PublicID).WithError(exErr).Warn("unable to extract details from detection, skipping")
Expand Down
Loading

0 comments on commit 78ab44b

Please sign in to comment.