diff --git a/.circleci/config.yml b/.circleci/config.yml index 8d55d8e..59ef17e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -52,7 +52,6 @@ jobs: - run: package_cloud push pantheon/public/fedora/29 ./dist/*.rpm - run: package_cloud push pantheon/public/el/7 ./dist/*.rpm - workflows: version: 2 test-build-release: @@ -66,10 +65,6 @@ workflows: - sig-go-release requires: - test - filters: - branches: - only: - - master - publish-rpm: requires: - build-release @@ -77,4 +72,3 @@ workflows: branches: only: - master - diff --git a/audit.go b/audit.go index e16acf2..006317d 100644 --- a/audit.go +++ b/audit.go @@ -8,6 +8,7 @@ import ( "os" "os/exec" "strings" + "syscall" "github.com/pantheon-systems/pauditd/pkg/marshaller" "github.com/pantheon-systems/pauditd/pkg/metric" @@ -175,7 +176,16 @@ func main() { slog.Error.Fatal(err) } - nlClient, err := NewNetlinkClient(config.GetInt("socket_buffer.receive")) + recvSize := 0 + rmemMax := fetchRmemMax() + // If the value is 0, use the default value from the config + recvSize = rmemMax + if rmemMax == 0 { + recvSize = config.GetInt("socket_buffer.receive") + } + slog.Info.Printf("Setting the receive buffer size to %d\n", recvSize) + + nlClient, err := NewNetlinkClient(recvSize) if err != nil { slog.Error.Fatal(err) } @@ -201,7 +211,6 @@ func main() { //Main loop. Get data from netlink and send it to the json lib for processing for { msg, err := nlClient.Receive() - timing := metric.GetClient().NewTiming() // measure latency from recipt of message if err != nil { if err.Error() == "no buffer space available" { metric.GetClient().Increment("messages.netlink_dropped") @@ -209,13 +218,41 @@ func main() { slog.Error.Printf("Error during message receive: %+v\n", err) continue } - - metric.GetClient().Increment("messages.total") if msg == nil { continue } + // As soon as we have a message, spawn a goroutine to handle it and free up the main loop + go handleMsg(msg, marshaller) + } +} + +// Fetch the max value we can set from /proc/sys/net/core/rmem_max +// This value is mounted in from the host via the kube yaml +func fetchRmemMax() int { + var rmemMax int + file, err := os.Open("/proc/sys/net/core/rmem_max") + if err != nil { + slog.Error.Println(fmt.Sprintf("Error opening rmem_max: [%v]", err)) + } + defer file.Close() - marshaller.Consume(msg) - timing.Send("latency") + _, err = fmt.Fscanf(file, "%d", &rmemMax) + if err != nil { + slog.Error.Println(fmt.Sprintf("Error reading the rmem_max value: [%v]", err)) } + return rmemMax +} + +func handleMsg(msg *syscall.NetlinkMessage, marshaller *marshaller.AuditMarshaller) { + defer func() { + if r := recover(); r != nil { + slog.Error.Printf("Panic occurred in handleMsg: %v", r) + } + }() + + timing := metric.GetClient().NewTiming() // measure latency from recipt of message + metric.GetClient().Increment("messages.total") + + marshaller.Consume(msg) + timing.Send("latency") } diff --git a/client.go b/client.go index e111db0..189bcb6 100644 --- a/client.go +++ b/client.go @@ -20,7 +20,7 @@ const ( MAX_AUDIT_MESSAGE_LENGTH = 8970 ) -//TODO: this should live in a marshaller +// TODO: this should live in a marshaller type AuditStatusPayload struct { Mask uint32 Enabled uint32 @@ -55,7 +55,6 @@ func NewNetlinkClient(recvSize int) (*NetlinkClient, error) { n := &NetlinkClient{ fd: fd, address: &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK, Groups: 0, Pid: 0}, - buf: make([]byte, MAX_AUDIT_MESSAGE_LENGTH), cancelKeepConnection: make(chan struct{}), } @@ -120,7 +119,28 @@ func (n *NetlinkClient) Send(np *NetlinkPacket, a *AuditStatusPayload) error { // Receive will receive a packet from a netlink socket func (n *NetlinkClient) Receive() (*syscall.NetlinkMessage, error) { - nlen, _, err := syscall.Recvfrom(n.fd, n.buf, 0) + // Large message handling + // See https://mdlayher.com/blog/linux-netlink-and-go-part-1-netlink/ + // Use a new buffer every time since this is spawned in a goroutine + buf := make([]byte, MAX_AUDIT_MESSAGE_LENGTH) + for { + // Peek at the buffer to see how many bytes are available. + b, _, err := syscall.Recvfrom(n.fd, buf, syscall.MSG_PEEK) + if err != nil { + return nil, err + } + + // Break when we can read all messages. + if b < len(buf) { + break + } + + // Double in size if not enough bytes. + buf = make([]byte, len(buf)*2) + } + + // Read out all available messages. + nlen, _, err := syscall.Recvfrom(n.fd, buf, 0) if err != nil { return nil, err } @@ -131,13 +151,13 @@ func (n *NetlinkClient) Receive() (*syscall.NetlinkMessage, error) { msg := &syscall.NetlinkMessage{ Header: syscall.NlMsghdr{ - Len: Endianness.Uint32(n.buf[0:4]), - Type: Endianness.Uint16(n.buf[4:6]), - Flags: Endianness.Uint16(n.buf[6:8]), - Seq: Endianness.Uint32(n.buf[8:12]), - Pid: Endianness.Uint32(n.buf[12:16]), + Len: Endianness.Uint32(buf[0:4]), + Type: Endianness.Uint16(buf[4:6]), + Flags: Endianness.Uint16(buf[6:8]), + Seq: Endianness.Uint32(buf[8:12]), + Pid: Endianness.Uint32(buf[12:16]), }, - Data: n.buf[syscall.SizeofNlMsghdr:nlen], + Data: buf[syscall.SizeofNlMsghdr:nlen], } return msg, nil diff --git a/client_test.go b/client_test.go index 689bed6..e084412 100644 --- a/client_test.go +++ b/client_test.go @@ -103,8 +103,6 @@ func TestNewNetlinkClient(t *testing.T) { assert.True(t, (n.fd > 0), "No file descriptor") assert.True(t, (n.address != nil), "Address was nil") assert.Equal(t, uint32(0), n.seq, "Seq should start at 0") - assert.True(t, MAX_AUDIT_MESSAGE_LENGTH >= len(n.buf), "Client buffer is too small") - assert.Equal(t, "Socket receive buffer size: ", lb.String()[:28], "Expected some nice log lines") assert.Equal(t, "", elb.String(), "Did not expect any error messages") }