Skip to content

Commit

Permalink
Test the addRule function, remove cpuprofile
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus committed Dec 2, 2016
1 parent 3478433 commit 434a0db
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 48 deletions.
57 changes: 31 additions & 26 deletions audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"flag"
"fmt"
"github.com/pkg/profile"
"github.com/spf13/viper"
"log"
"log/syslog"
Expand All @@ -19,7 +18,16 @@ import (
var l = log.New(os.Stdout, "", 0)
var el = log.New(os.Stderr, "", 0)

func loadConfig(config *viper.Viper) {
type executor func (string, ...string) error

func lExec (s string, a ...string) error {
return exec.Command(s, a...).Run()
}

func loadConfig(configFile string) (*viper.Viper, error) {
config := viper.New()
config.SetConfigFile(configFile)

config.SetDefault("message_tracking.enabled", true)
config.SetDefault("message_tracking.log_out_of_order", false)
config.SetDefault("message_tracking.max_out_of_order", 500)
Expand All @@ -29,18 +37,20 @@ func loadConfig(config *viper.Viper) {
config.SetDefault("output.syslog.attempts", "3")
config.SetDefault("log.flags", 0)

err := config.ReadInConfig() // Find and read the config file
if err != nil { // Handle errors reading the config file
el.Printf("Config file has an error: %s\n", err)
os.Exit(1)
if err := config.ReadInConfig(); err != nil {
return nil, err
}

l.SetFlags(config.GetInt("log.flags"))
el.SetFlags(config.GetInt("log.flags"))

return config, nil
}

func setRules(config *viper.Viper) {
func setRules(config *viper.Viper, e executor) error {
// Clear existing rules
err := exec.Command("auditctl", "-D").Run()
if err != nil {
el.Fatalf("Failed to flush existing audit rules. Error: %s\n", err)
if err := e("auditctl", "-D"); err != nil {
return errors.New(fmt.Sprintf("Failed to flush existing audit rules. Error: %s", err))
}

l.Println("Flushed existing audit rules")
Expand All @@ -53,16 +63,17 @@ func setRules(config *viper.Viper) {
continue
}

err := exec.Command("auditctl", strings.Fields(v)...).Run()
if err != nil {
el.Fatalf("Failed to add rule #%d. Error: %s \n", i+1, err)
if err := e("auditctl", strings.Fields(v)...); err != nil {
return errors.New(fmt.Sprintf("Failed to add rule #%d. Error: %s", i+1, err))
}

l.Printf("Added audit rule #%d\n", i+1)
}
} else {
el.Fatalln("No audit rules found. exiting")
return errors.New("No audit rules found.")
}

return nil
}

func createOutput(config *viper.Viper) (*AuditWriter, error) {
Expand Down Expand Up @@ -246,9 +257,7 @@ func createFilters(config *viper.Viper) []AuditFilter {
}

func main() {
config := viper.New()
configFile := flag.String("config", "", "Config file location")
cpuProfile := flag.Bool("cpuprofile", false, "Enable cpu profiling")

flag.Parse()

Expand All @@ -258,17 +267,13 @@ func main() {
os.Exit(1)
}

config.SetConfigFile(*configFile)
loadConfig(config)

l.SetFlags(config.GetInt("log.flags"))
el.SetFlags(config.GetInt("log.flags"))

setRules(config)
config, err := loadConfig(*configFile)
if err != nil {
el.Fatal(err)
}

if *cpuProfile {
l.Println("Enabling CPU profile ./cpu.pprof")
defer profile.Start(profile.Quiet, profile.ProfilePath(".")).Stop()
if err := setRules(config, lExec); err != nil {
el.Fatal(err)
}

writer, err := createOutput(config)
Expand Down
85 changes: 63 additions & 22 deletions audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@ import (
"strconv"
"syscall"
"testing"
"errors"
)

func Test_loadConfig(t *testing.T) {
config := viper.New()

file := createTempFile(t, "defaultValues.test.yaml", "")
defer os.Remove(file)

config.SetConfigFile(file)
loadConfig(config)
// defaults
config, err := loadConfig(file)
assert.Equal(t, true, config.GetBool("message_tracking.enabled"), "message_tracking.enabled should default to true")
assert.Equal(t, false, config.GetBool("message_tracking.log_out_of_order"), "message_tracking.log_out_of_order should default to false")
assert.Equal(t, 500, config.GetInt("message_tracking.max_out_of_order"), "message_tracking.max_out_of_order should default to 500")
Expand All @@ -30,28 +29,70 @@ func Test_loadConfig(t *testing.T) {
assert.Equal(t, "go-audit", config.GetString("output.syslog.tag"), "output.syslog.tag should default to go-audit")
assert.Equal(t, 3, config.GetInt("output.syslog.attempts"), "output.syslog.attempts should default to 3")
assert.Equal(t, 0, config.GetInt("log.flags"), "log.flags should default to 0")
assert.Equal(t, 0, l.Flags(), "stdout log flags was wrong")
assert.Equal(t, 0, el.Flags(), "stderr log flags was wrong")
assert.Nil(t, err)

//TODO: this doesn't work because loadConfig calls os.Exit
//lb, elb := hookLogger()
//defer resetLogger()
//
//file = createTempFile(t, "defaultValues.test.yaml", "this is bad")
//loadConfig(config, file)
//assert.Equal(t, "", lb.String(), "Got some log lines we did not expect")
//assert.Equal(t, "Error occurred while trying to keep the connection: bad file descriptor\n", elb.String(), "Figured we would have an error")
}

func Test_loadConfig_fail(t *testing.T) {
//TODO: test that we exit if the config file doesn't exist or is poorly formed
t.Skip("Not implemented")
// parse error
file = createTempFile(t, "defaultValues.test.yaml", "this is bad")
config, err = loadConfig(file)
assert.EqualError(t, err, "While parsing config: yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `this is...` into map[string]interface {}")
assert.Nil(t, config)
}

func Test_setRules(t *testing.T) {
//TODO: Test rules are flushed first (success/fail)
//TODO: Test rules are added (success/fail)
//TODO: Test empty rule lines are skipped
//TODO: Test fatal if no rules
t.Skip("Not implemented")
defer resetLogger()

// fail to flush rules
config := viper.New()

err := setRules(config, func (s string, a ...string) error {
if s == "auditctl" && a[0] == "-D" {
return errors.New("testing")
}

return nil
})

assert.EqualError(t, err, "Failed to flush existing audit rules. Error: testing")

// fail on 0 rules
err = setRules(config, func (s string, a ...string) error { return nil })
assert.EqualError(t, err, "No audit rules found.")

// failure to set rule
r := 0
config.Set("rules", []string{"-a -1 -2", "", "-a -3 -4"})
err = setRules(config, func (s string, a ...string) error {
if a[0] != "-D" {
return errors.New("testing rule")
}

r++

return nil
})

assert.Equal(t, 1, r, "Wrong number of rule set attempts")
assert.EqualError(t, err, "Failed to add rule #1. Error: testing rule")

// properly set rules
r = 0
err = setRules(config, func (s string, a ...string) error {
// Skip the flush rules
if a[0] != "-a" {
return nil
}

if (a[1] == "-1" && a[2] == "-2") || (a[1] == "-3" && a[2] == "-4") {
r++
}

return nil
})

assert.Equal(t, 2, r, "Wrong number of correct rule set attempts")
assert.Nil(t, err)
}

func Test_createFileOutput(t *testing.T) {
Expand Down

0 comments on commit 434a0db

Please sign in to comment.