diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index c0385d2..6a6d891 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -24,12 +24,12 @@ jobs: with: go-version: 1.21 - - name: Vet - run: go vet ./... - - name: Run generators run: go generate ./... + - name: Vet + run: go vet ./... + - name: Build run: go build -ldflags="-s -w" -o build/shipshape . && ls -lh build/shipshape diff --git a/.gitignore b/.gitignore index a67215f..feff60a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ +# Generated files +registry_gen.go +pkg/result/breach_gen.go + build dist/ docker-compose.override.yml node_modules -registry_gen.go venom-output diff --git a/cmd/gen.go b/cmd/gen.go index 44c27ee..719cf1f 100644 --- a/cmd/gen.go +++ b/cmd/gen.go @@ -12,6 +12,7 @@ import ( var ( arg string checkpackage string + breachTypes []string ) func main() { @@ -25,11 +26,18 @@ func main() { } gen.Registry(checkpackage) break + case "breach-type": + if len(breachTypes) == 0 { + log.Fatal("missing flags; struct is required") + } + gen.BreachType(breachTypes) + break } } func parseFlags() { pflag.StringVar(&checkpackage, "checkpackage", "", "The package to which the check belongs") + pflag.StringSliceVar(&breachTypes, "type", []string{}, "The breach type") pflag.Parse() } diff --git a/cmd/gen/breachtype.go b/cmd/gen/breachtype.go new file mode 100644 index 0000000..c89a027 --- /dev/null +++ b/cmd/gen/breachtype.go @@ -0,0 +1,77 @@ +package gen + +import ( + "bytes" + "log" + "os" + "path/filepath" + "strings" + "text/template" +) + +func BreachType(breachTypes []string) { + log.Println("Generating breach type funcs -", strings.Join(breachTypes, ",")) + + breachTypeFile := "breach_gen.go" + breachTypeFullFilePath := filepath.Join(getScriptPath(), "..", "..", "pkg", "result", breachTypeFile) + if err := os.Remove(breachTypeFullFilePath); err != nil && !os.IsNotExist(err) { + log.Fatalln(err) + } + createFile(breachTypeFullFilePath, "package result\n") + + for _, bt := range breachTypes { + appendFileContent(breachTypeFullFilePath, breachTypeFuncs(bt)) + } +} + +func breachTypeFuncs(bt string) string { + tmplStr := ` +/* + * {{.BreachType}}Breach + */ +func (b *{{.BreachType}}Breach) GetCheckName() string { + return b.CheckName +} + +func (b *{{.BreachType}}Breach) GetCheckType() string { + return b.CheckType +} + +func (b *{{.BreachType}}Breach) GetRemediation() *Remediation { + return &b.Remediation +} + +func (b *{{.BreachType}}Breach) GetSeverity() string { + return b.Severity +} + +func (b *{{.BreachType}}Breach) GetType() BreachType { + return BreachType{{.BreachType}} +} + +func (b *{{.BreachType}}Breach) SetCommonValues(checkType string, checkName string, severity string) { + b.BreachType = b.GetType() + b.CheckType = checkType + b.CheckName = checkName + b.Severity = severity +} + +func (b *{{.BreachType}}Breach) SetRemediation(status RemediationStatus, msg string) { + b.Remediation.Status = status + if msg != "" { + b.Remediation.Messages = []string{msg} + } +} +` + tmpl, err := template.New("breachTypeFuncs").Parse(tmplStr) + if err != nil { + log.Fatalln(err) + } + + buf := &bytes.Buffer{} + err = tmpl.Execute(buf, struct{ BreachType string }{bt}) + if err != nil { + log.Fatalln(err) + } + return buf.String() +} diff --git a/cmd/gen/helpers.go b/cmd/gen/helpers.go new file mode 100644 index 0000000..f546452 --- /dev/null +++ b/cmd/gen/helpers.go @@ -0,0 +1,79 @@ +package gen + +import ( + "log" + "os" + "path/filepath" + "runtime" + "strings" +) + +func getScriptPath() string { + _, b, _, _ := runtime.Caller(0) + return filepath.Dir(b) +} + +func createFile(fullpath string, firstTimeContent string) { + if f, err := os.Stat(fullpath); err == nil && !f.IsDir() { + return + } else if !os.IsNotExist(err) { + log.Fatalln(err) + } + + f, err := os.OpenFile(fullpath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Fatal(err) + } + defer func() { + f.Close() + }() + + if firstTimeContent == "" { + return + } + + if _, err := f.Write([]byte(firstTimeContent)); err != nil { + log.Fatal(err) + } +} + +func getFileLines(fullpath string) []string { + input, err := os.ReadFile(fullpath) + if err != nil { + log.Fatalln(err) + } + return strings.Split(string(input), "\n") +} + +func writeFileContent(fullpath string, content string) { + err := os.WriteFile(fullpath, []byte(content), 0644) + if err != nil { + log.Fatalln(err) + } +} + +func appendFileContent(fullpath string, content string) { + input, err := os.ReadFile(fullpath) + if err != nil { + log.Fatalln(err) + } + output := string(input) + content + writeFileContent(fullpath, output) +} + +func writeFileLines(fullpath string, lines []string) { + output := strings.Join(lines, "\n") + err := os.WriteFile(fullpath, []byte(output), 0644) + if err != nil { + log.Fatalln(err) + } +} + +func stringSliceMatch(slice []string, item string) bool { + for _, s := range slice { + if strings.Contains(s, item) { + return true + } + } + return false +} diff --git a/cmd/gen/registry.go b/cmd/gen/registry.go index eb0f793..29acc61 100644 --- a/cmd/gen/registry.go +++ b/cmd/gen/registry.go @@ -3,68 +3,20 @@ package gen import ( "fmt" "log" - "os" "path/filepath" - "runtime" - "strings" ) -var registryFile = "registry_gen.go" -var fullRegistryFilePath string - // Registry adds the checks for a package to the registry. func Registry(chkPkg string) { - fullRegistryFilePath = filepath.Join(getScriptPath(), "../../", registryFile) - createFile() - addImportLine(chkPkg) -} - -func getScriptPath() string { - _, b, _, _ := runtime.Caller(0) - return filepath.Dir(b) -} - -func createFile() { - if f, err := os.Stat(fullRegistryFilePath); err == nil && !f.IsDir() { - return - } else if !os.IsNotExist(err) { - log.Fatal(err) - } - - f, err := os.OpenFile(fullRegistryFilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Fatal(err) - } - defer func() { - f.Close() - }() - - if _, err := f.Write([]byte("package main\n\n")); err != nil { - log.Fatal(err) - } -} - -func getFileLines() []string { - input, err := os.ReadFile(fullRegistryFilePath) - if err != nil { - log.Fatalln(err) - } + log.Println("Generating checks registry - adding", chkPkg) - return strings.Split(string(input), "\n") -} + registryFile := "registry_gen.go" + registryFullFilePath := filepath.Join(getScriptPath(), "..", "..", registryFile) + createFile(registryFullFilePath, "package main\n\n") -func writeFileLines(lines []string) { - output := strings.Join(lines, "\n") - err := os.WriteFile(fullRegistryFilePath, []byte(output), 0644) - if err != nil { - log.Fatalln(err) - } -} - -func addImportLine(chkPkg string) { pkgFullName := fmt.Sprintf("github.com/salsadigitalauorg/shipshape/pkg/checks/%s", chkPkg) - fileLines := getFileLines() + fileLines := getFileLines(registryFullFilePath) if stringSliceMatch(fileLines, pkgFullName) { return } @@ -72,19 +24,12 @@ func addImportLine(chkPkg string) { importLine := fmt.Sprintf("import _ \"%s\"", pkgFullName) newFileLines := []string{} for i, line := range fileLines { - if i == 2 { + // Add the import line before the last line, + // so that the last line is always a newline. + if i == len(fileLines)-1 { newFileLines = append(newFileLines, importLine) } newFileLines = append(newFileLines, line) } - writeFileLines(newFileLines) -} - -func stringSliceMatch(slice []string, item string) bool { - for _, s := range slice { - if strings.Contains(s, item) { - return true - } - } - return false + writeFileLines(registryFullFilePath, newFileLines) } diff --git a/pkg/checks/crawler/crawlercheck.go b/pkg/checks/crawler/crawlercheck.go index 393aafe..f2c1e30 100644 --- a/pkg/checks/crawler/crawlercheck.go +++ b/pkg/checks/crawler/crawlercheck.go @@ -79,7 +79,7 @@ func (c *CrawlerCheck) RunCheck() { crawler.OnError(func(r *colly.Response, err error) { c.Result.Status = result.Fail - c.AddBreach(result.KeyValueBreach{ + c.AddBreach(&result.KeyValueBreach{ Key: fmt.Sprintf("%v", r.Request.URL), ValueLabel: "invalid response", Value: fmt.Sprintf("%d", r.StatusCode), diff --git a/pkg/checks/crawler/crawlercheck_test.go b/pkg/checks/crawler/crawlercheck_test.go index 2015ee0..6f3ed1d 100644 --- a/pkg/checks/crawler/crawlercheck_test.go +++ b/pkg/checks/crawler/crawlercheck_test.go @@ -59,7 +59,7 @@ func TestCrawlerCheck(t *testing.T) { c.Init(Crawler) c.RunCheck() assert.ElementsMatch( - []result.Breach{result.KeyValueBreach{ + []result.Breach{&result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, CheckType: "crawler", Severity: "normal", diff --git a/pkg/checks/docker/baseimagecheck.go b/pkg/checks/docker/baseimagecheck.go index ae5bb42..69b2dce 100644 --- a/pkg/checks/docker/baseimagecheck.go +++ b/pkg/checks/docker/baseimagecheck.go @@ -87,7 +87,7 @@ func (c *BaseImageCheck) RunCheck() { } if len(c.Allowed) > 0 && !utils.PackageCheckString(c.Allowed, match[1], match[2]) { - c.AddBreach(result.KeyValueBreach{ + c.AddBreach(&result.KeyValueBreach{ Key: name, ValueLabel: "invalid base image", Value: match[1], @@ -108,7 +108,7 @@ func (c *BaseImageCheck) RunCheck() { } if !utils.PackageCheckString(c.Allowed, match[1], match[2]) { - c.AddBreach(result.KeyValueBreach{ + c.AddBreach(&result.KeyValueBreach{ Key: name, ValueLabel: "invalid base image", Value: def.Image, diff --git a/pkg/checks/docker/baseimagecheck_test.go b/pkg/checks/docker/baseimagecheck_test.go index dd2305c..1a19e9b 100644 --- a/pkg/checks/docker/baseimagecheck_test.go +++ b/pkg/checks/docker/baseimagecheck_test.go @@ -53,9 +53,10 @@ func TestInvalidDockerfileCheck(t *testing.T) { Paths: []string{"./fixtures/compose-dockerfile"}, } c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.KeyValueBreach{ + []result.Breach{&result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, Key: "service1", ValueLabel: "invalid base image", @@ -86,9 +87,10 @@ func TestInvalidDockerfileImageVersion(t *testing.T) { Paths: []string{"./fixtures/compose-dockerfile"}, } c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.KeyValueBreach{ + []result.Breach{&result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, Key: "service1", ValueLabel: "invalid base image", @@ -157,9 +159,10 @@ func TestInvalidImageCheck(t *testing.T) { Paths: []string{"./fixtures/compose-image"}, } c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.KeyValueBreach{ + []result.Breach{&result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, Key: "service4", ValueLabel: "invalid base image", @@ -180,16 +183,17 @@ func TestInvalidImageVersions(t *testing.T) { Paths: []string{"./fixtures/compose-image"}, } c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.ElementsMatch( []result.Breach{ - result.KeyValueBreach{ + &result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, Key: "service2", ValueLabel: "invalid base image", Value: "bitnami/postgresql@16", }, - result.KeyValueBreach{ + &result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, Key: "service4", ValueLabel: "invalid base image", diff --git a/pkg/checks/drupal/dbadmincheck.go b/pkg/checks/drupal/dbadmincheck.go index 5a3b6b0..b49fffa 100644 --- a/pkg/checks/drupal/dbadmincheck.go +++ b/pkg/checks/drupal/dbadmincheck.go @@ -8,6 +8,7 @@ import ( "os/exec" "strings" + "github.com/salsadigitalauorg/shipshape/pkg/command" "github.com/salsadigitalauorg/shipshape/pkg/config" "github.com/salsadigitalauorg/shipshape/pkg/result" "github.com/salsadigitalauorg/shipshape/pkg/utils" @@ -57,18 +58,18 @@ func (c *AdminUserCheck) getActiveRoles() map[string]string { activeRoles, err := Drush(c.DrushPath, c.Alias, cmd).Exec() var pathErr *fs.PathError if err != nil && errors.As(err, &pathErr) { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ Value: pathErr.Path + ": " + pathErr.Err.Error()}) } else if err != nil { msg := string(err.(*exec.ExitError).Stderr) - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ Value: strings.ReplaceAll(strings.TrimSpace(msg), " \n ", "")}) } else { // Unmarshal roles JSON. err = json.Unmarshal(activeRoles, &rolesListMap) var synErr *json.SyntaxError if err != nil && errors.As(err, &synErr) { - c.AddBreach(result.ValueBreach{Value: err.Error()}) + c.AddBreach(&result.ValueBreach{Value: err.Error()}) } } @@ -94,7 +95,7 @@ func (c *AdminUserCheck) FetchData() { if err != nil { msg := string(err.(*exec.ExitError).Stderr) - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ Value: strings.ReplaceAll(strings.TrimSpace(msg), " \n ", "")}) } } @@ -103,7 +104,7 @@ func (c *AdminUserCheck) FetchData() { // into the roleConfigs for further processing. func (c *AdminUserCheck) UnmarshalDataMap() { if len(c.DataMap) == 0 { - c.AddBreach(result.ValueBreach{Value: "no data provided"}) + c.AddBreach(&result.ValueBreach{Value: "no data provided"}) return } @@ -113,7 +114,7 @@ func (c *AdminUserCheck) UnmarshalDataMap() { err := json.Unmarshal([]byte(element), &role) var synErr *json.SyntaxError if err != nil && errors.As(err, &synErr) { - c.AddBreach(result.ValueBreach{Value: err.Error()}) + c.AddBreach(&result.ValueBreach{Value: err.Error()}) return } // Collect role config. @@ -130,40 +131,31 @@ func (c *AdminUserCheck) RunCheck() { } if isAdmin { - if c.PerformRemediation { - if err := c.Remediate(roleName); err != nil { - c.AddBreach(result.KeyValueBreach{ - Key: "failed to set is_admin to false", - ValueLabel: "role", - Value: roleName, - }) - } else { - c.AddRemediation(fmt.Sprintf( - "Fixed disallowed admin setting for role [%s]", roleName)) - } - } else { - c.AddBreach(result.KeyValueBreach{ - Key: "is_admin: true", - ValueLabel: "role", - Value: roleName, - }) - } + c.AddBreach(&result.KeyValueBreach{ + Key: "is_admin: true", + ValueLabel: "role", + Value: roleName, + }) } } - - if len(c.Result.Breaches) == 0 { - c.Result.Status = result.Pass - } } // Remediate attempts to fix a breach. -func (c *AdminUserCheck) Remediate(breachIfc interface{}) error { - // A breach is expected to be a string. - if b, ok := breachIfc.(string); ok { - _, err := Drush(c.DrushPath, c.Alias, []string{"config:set", "user.role." + b, "is_admin", "0"}).Exec() +func (c *AdminUserCheck) Remediate() { + for _, b := range c.Result.Breaches { + b, ok := b.(*result.KeyValueBreach) + if !ok { + continue + } + + _, err := Drush(c.DrushPath, c.Alias, []string{"config:set", "user.role." + b.Value, "is_admin", "0"}).Exec() if err != nil { - return err + b.SetRemediation(result.RemediationStatusFailed, fmt.Sprintf( + "failed to set is_admin to false for role '%s' due to error: %s", + b.Value, command.GetMsgFromCommandError(err))) + } else { + b.SetRemediation(result.RemediationStatusSuccess, fmt.Sprintf( + "Fixed disallowed admin setting for role [%s]", b.Value)) } } - return nil } diff --git a/pkg/checks/drupal/dbadmincheck_test.go b/pkg/checks/drupal/dbadmincheck_test.go index 4db8c7d..c46e4fd 100644 --- a/pkg/checks/drupal/dbadmincheck_test.go +++ b/pkg/checks/drupal/dbadmincheck_test.go @@ -50,9 +50,8 @@ func TestAdminUserFetchData(t *testing.T) { t.Run("drushNotFound", func(t *testing.T) { c := AdminUserCheck{} c.FetchData() - assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", Value: "vendor/drush/drush/drush: no such file or directory", }}, @@ -70,9 +69,8 @@ func TestAdminUserFetchData(t *testing.T) { ) c := AdminUserCheck{} c.FetchData() - assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", Value: "unable to run drush command", }}, @@ -103,9 +101,8 @@ func TestAdminUserUnmarshalData(t *testing.T) { // Empty datamap. t.Run("emptyDataMap", func(t *testing.T) { c.UnmarshalDataMap() - assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", Value: "no data provided", }}, @@ -122,9 +119,8 @@ func TestAdminUserUnmarshalData(t *testing.T) { }, } c.UnmarshalDataMap() - assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", Value: "invalid character ']' after object key:value pair", }}, @@ -175,7 +171,7 @@ func TestAdminUserRunCheck(t *testing.T) { AllowedRoles: []string{"content-admin"}, }, ExpectStatus: result.Fail, - ExpectFails: []result.Breach{result.KeyValueBreach{ + ExpectFails: []result.Breach{&result.KeyValueBreach{ BreachType: "key-value", Key: "is_admin: true", ValueLabel: "role", @@ -195,59 +191,15 @@ func TestAdminUserRunCheck(t *testing.T) { }, ExpectStatus: result.Pass, }, - - // Role has is_admin:true, with remediation. - { - Name: "roleAdminWithRemediation", - Check: &AdminUserCheck{ - CheckBase: config.CheckBase{ - DataMap: map[string][]byte{ - "anonymous": []byte(`{"is_admin":true, "id": "anonymous"}`)}, - PerformRemediation: true, - }, - AllowedRoles: []string{"content-admin"}, - }, - PreRun: func(t *testing.T) { - command.ShellCommander = internal.ShellCommanderMaker(nil, nil, nil) - }, - ExpectStatus: result.Pass, - ExpectRemediations: []string{"Fixed disallowed admin setting for role [anonymous]"}, - }, - - // Role has is_admin:true, with remediation error. - { - Name: "roleAdminWithRemediationError", - Check: &AdminUserCheck{ - CheckBase: config.CheckBase{ - DataMap: map[string][]byte{ - "anonymous": []byte(`{"is_admin":true, "id": "anonymous"}`)}, - PerformRemediation: true, - }, - AllowedRoles: []string{"content-admin"}, - }, - PreRun: func(t *testing.T) { - command.ShellCommander = internal.ShellCommanderMaker( - nil, - &exec.ExitError{Stderr: []byte("unable to run drush command")}, - nil, - ) - }, - ExpectStatus: result.Fail, - ExpectFails: []result.Breach{result.KeyValueBreach{ - BreachType: "key-value", - Key: "failed to set is_admin to false", - ValueLabel: "role", - Value: "anonymous", - }}, - }, } - curShellCommander := command.ShellCommander - defer func() { command.ShellCommander = curShellCommander }() - for _, test := range tests { t.Run(test.Name, func(t *testing.T) { test.Check.UnmarshalDataMap() + + curShellCommander := command.ShellCommander + defer func() { command.ShellCommander = curShellCommander }() + internal.TestRunCheck(t, test) }) } @@ -266,8 +218,23 @@ func TestAdminUserRemediate(t *testing.T) { nil) c := AdminUserCheck{} - err := c.Remediate("foo") - assert.Error(err, "unable to run drush command") + c.AddBreach(&result.KeyValueBreach{ + Key: "is_admin: true", + ValueLabel: "role", + Value: "foo", + }) + c.Remediate() + assert.EqualValues([]result.Breach{&result.KeyValueBreach{ + BreachType: "key-value", + Key: "is_admin: true", + ValueLabel: "role", + Value: "foo", + Remediation: result.Remediation{ + Status: "failed", + Messages: []string{"failed to set is_admin to false for role 'foo' " + + "due to error: unable to run drush command"}, + }, + }}, c.Result.Breaches) }) t.Run("drushCommandIsCorrect", func(t *testing.T) { @@ -275,7 +242,12 @@ func TestAdminUserRemediate(t *testing.T) { command.ShellCommander = internal.ShellCommanderMaker(nil, nil, &generatedCommand) c := AdminUserCheck{} - c.Remediate("foo") + c.AddBreach(&result.KeyValueBreach{ + Key: "is_admin: true", + ValueLabel: "role", + Value: "foo", + }) + c.Remediate() assert.Equal("vendor/drush/drush/drush config:set user.role.foo is_admin 0", generatedCommand) }) } diff --git a/pkg/checks/drupal/dbmodulecheck_test.go b/pkg/checks/drupal/dbmodulecheck_test.go index 5a02fe0..4c3e713 100644 --- a/pkg/checks/drupal/dbmodulecheck_test.go +++ b/pkg/checks/drupal/dbmodulecheck_test.go @@ -81,6 +81,7 @@ node: c.Init(DbModule) c.UnmarshalDataMap() c.RunCheck() + c.Result.DetermineResultStatus(false) return c } @@ -109,14 +110,14 @@ views_ui: }) assert.ElementsMatch( []result.Breach{ - result.KeyValuesBreach{ + &result.KeyValuesBreach{ BreachType: "key-values", CheckType: "drupal-db-module", Severity: "normal", Key: "required modules are not enabled", Values: []string{"block"}, }, - result.KeyValuesBreach{ + &result.KeyValuesBreach{ BreachType: "key-values", CheckType: "drupal-db-module", Severity: "normal", diff --git a/pkg/checks/drupal/dbpermissionscheck.go b/pkg/checks/drupal/dbpermissionscheck.go index 6959db2..ac843e9 100644 --- a/pkg/checks/drupal/dbpermissionscheck.go +++ b/pkg/checks/drupal/dbpermissionscheck.go @@ -41,7 +41,7 @@ func (c *DbPermissionsCheck) Merge(mergeCheck config.Check) error { // type for further processing. func (c *DbPermissionsCheck) UnmarshalDataMap() { if len(c.DataMap[c.ConfigName]) == 0 { - c.AddBreach(result.ValueBreach{Value: "no data provided"}) + c.AddBreach(&result.ValueBreach{Value: "no data provided"}) } c.Permissions = map[string]DrushRole{} @@ -51,7 +51,7 @@ func (c *DbPermissionsCheck) UnmarshalDataMap() { // RunCheck implements the Check logic for Drupal Permissions in database config. func (c *DbPermissionsCheck) RunCheck() { if len(c.Disallowed) == 0 { - c.AddBreach(result.ValueBreach{Value: "list of disallowed perms not provided"}) + c.AddBreach(&result.ValueBreach{Value: "list of disallowed perms not provided"}) } for r, perms := range c.Permissions { @@ -70,40 +70,33 @@ func (c *DbPermissionsCheck) RunCheck() { return fails[i] < fails[j] }) - if c.PerformRemediation { - if err := c.Remediate(DbPermissionsBreach{Role: r, Perms: strings.Join(fails, ",")}); err != nil { - c.AddBreach(result.KeyValueBreach{ - KeyLabel: "role", - Key: r, - ValueLabel: "failed to fix disallowed permissions due to error", - Value: command.GetMsgFromCommandError(err), - }) - } else { - c.AddRemediation(fmt.Sprintf( - "[%s] fixed disallowed permissions: [%s]", - r, strings.Join(fails, ", "))) - } - } else { - c.AddBreach(result.KeyValuesBreach{ - KeyLabel: "role", - Key: r, - ValueLabel: "permissions", - Values: fails, - }) - } - } - - if len(c.Result.Breaches) == 0 { - c.Result.Status = result.Pass + c.AddBreach(&result.KeyValuesBreach{ + KeyLabel: "role", + Key: r, + ValueLabel: "permissions", + Values: fails, + }) } } -// Remediate attempts to fix a breach. -func (c *DbPermissionsCheck) Remediate(breachIfc interface{}) error { - b := breachIfc.(DbPermissionsBreach) - _, err := Drush(c.DrushPath, c.Alias, []string{"role:perm:remove", b.Role, b.Perms}).Exec() - if err != nil { - return err +// Remediate attempts to remove any disallowed permissions detected. +func (c *DbPermissionsCheck) Remediate() { + for _, b := range c.Result.Breaches { + b, ok := b.(*result.KeyValuesBreach) + if !ok { + continue + } + _, err := Drush( + c.DrushPath, c.Alias, + []string{"role:perm:remove", b.Key, strings.Join(b.Values, ",")}).Exec() + if err != nil { + b.SetRemediation(result.RemediationStatusFailed, fmt.Sprintf( + "failed to fix disallowed permissions for role '%s' due to error: %s", + b.Key, command.GetMsgFromCommandError(err))) + } else { + b.SetRemediation(result.RemediationStatusSuccess, fmt.Sprintf( + "[%s] fixed disallowed permissions: [%s]", + b.Key, strings.Join(b.Values, ", "))) + } } - return nil } diff --git a/pkg/checks/drupal/dbpermissionscheck_test.go b/pkg/checks/drupal/dbpermissionscheck_test.go index a40cf7d..df798da 100644 --- a/pkg/checks/drupal/dbpermissionscheck_test.go +++ b/pkg/checks/drupal/dbpermissionscheck_test.go @@ -7,7 +7,6 @@ import ( . "github.com/salsadigitalauorg/shipshape/pkg/checks/drupal" "github.com/salsadigitalauorg/shipshape/pkg/checks/yaml" "github.com/salsadigitalauorg/shipshape/pkg/command" - "github.com/salsadigitalauorg/shipshape/pkg/config" "github.com/salsadigitalauorg/shipshape/pkg/internal" "github.com/salsadigitalauorg/shipshape/pkg/result" @@ -67,10 +66,11 @@ func TestDbPermissionsUnmarshalDataMap(t *testing.T) { t.Run("noDataProvided", func(t *testing.T) { c := DbPermissionsCheck{} c.UnmarshalDataMap() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.Empty(c.Result.Passes) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", Value: "no data provided", }}, @@ -132,7 +132,7 @@ func TestDbPermissionsRunCheck(t *testing.T) { Init: true, ExpectStatus: result.Fail, ExpectNoPass: true, - ExpectFails: []result.Breach{result.ValueBreach{ + ExpectFails: []result.Breach{&result.ValueBreach{ BreachType: "value", Severity: "normal", Value: "list of disallowed perms not provided", @@ -208,7 +208,7 @@ func TestDbPermissionsRunCheck(t *testing.T) { "[authenticated] no disallowed permissions", }, ExpectFails: []result.Breach{ - result.KeyValuesBreach{ + &result.KeyValuesBreach{ BreachType: "key-values", Severity: "normal", KeyLabel: "role", @@ -216,7 +216,7 @@ func TestDbPermissionsRunCheck(t *testing.T) { ValueLabel: "permissions", Values: []string{"administer modules", "administer permissions"}, }, - result.KeyValuesBreach{ + &result.KeyValuesBreach{ BreachType: "key-values", Severity: "normal", KeyLabel: "role", @@ -226,126 +226,6 @@ func TestDbPermissionsRunCheck(t *testing.T) { }, }, }, - { - Name: "breachRemediation", - Check: &DbPermissionsCheck{ - DrushYamlCheck: DrushYamlCheck{ - YamlBase: yaml.YamlBase{ - CheckBase: config.CheckBase{ - PerformRemediation: true, - }, - }, - }, - Permissions: map[string]DrushRole{ - "anonymous": { - Label: "Anonymous user", - Perms: []string{"access content", "view media"}, - }, - "authenticated": { - Label: "Authenticated user", - Perms: []string{"access content", "view media"}, - }, - "site_administrator": { - Label: "Site Administrator", - Perms: []string{"administer modules", "administer permissions"}, - }, - "another_site_administrator": { - Label: "Site Administrator", - Perms: []string{"administer modules", "administer permissions"}, - }, - "site_editor": { - Label: "Site Editor", - Perms: []string{"administer modules"}, - }, - }, - Disallowed: []string{"administer modules", "administer permissions"}, - ExcludeRoles: []string{"another_site_administrator"}, - }, - Init: true, - PreRun: func(t *testing.T) { - command.ShellCommander = internal.ShellCommanderMaker(nil, nil, nil) - }, - Sort: true, - ExpectStatus: result.Pass, - ExpectPasses: []string{ - "[anonymous] no disallowed permissions", - "[authenticated] no disallowed permissions", - }, - ExpectNoFail: true, - ExpectRemediations: []string{ - "[site_administrator] fixed disallowed permissions: [administer modules, administer permissions]", - "[site_editor] fixed disallowed permissions: [administer modules]", - }, - }, - { - Name: "breachRemediationExitError", - Check: &DbPermissionsCheck{ - DrushYamlCheck: DrushYamlCheck{ - YamlBase: yaml.YamlBase{ - CheckBase: config.CheckBase{ - PerformRemediation: true, - }, - }, - }, - Permissions: map[string]DrushRole{ - "anonymous": { - Label: "Anonymous user", - Perms: []string{"access content", "view media"}, - }, - "authenticated": { - Label: "Authenticated user", - Perms: []string{"access content", "view media"}, - }, - "site_administrator": { - Label: "Site Administrator", - Perms: []string{"administer modules", "administer permissions"}, - }, - "another_site_administrator": { - Label: "Site Administrator", - Perms: []string{"administer modules", "administer permissions"}, - }, - "site_editor": { - Label: "Site Editor", - Perms: []string{"administer modules"}, - }, - }, - Disallowed: []string{"administer modules", "administer permissions"}, - ExcludeRoles: []string{"another_site_administrator"}, - }, - Init: true, - PreRun: func(t *testing.T) { - command.ShellCommander = internal.ShellCommanderMaker( - nil, - &exec.ExitError{Stderr: []byte("unable to run drush command")}, - nil, - ) - }, - Sort: true, - ExpectStatus: result.Fail, - ExpectPasses: []string{ - "[anonymous] no disallowed permissions", - "[authenticated] no disallowed permissions", - }, - ExpectFails: []result.Breach{ - result.KeyValueBreach{ - BreachType: "key-value", - Severity: "normal", - KeyLabel: "role", - Key: "site_administrator", - ValueLabel: "failed to fix disallowed permissions due to error", - Value: "unable to run drush command", - }, - result.KeyValueBreach{ - BreachType: "key-value", - Severity: "normal", - KeyLabel: "role", - Key: "site_editor", - ValueLabel: "failed to fix disallowed permissions due to error", - Value: "unable to run drush command", - }, - }, - ExpectNoRemediations: true, - }, } curShellCommander := command.ShellCommander @@ -371,8 +251,25 @@ func TestDbPermissionsRemediate(t *testing.T) { nil) c := DbPermissionsCheck{} - err := c.Remediate(DbPermissionsBreach{Role: "foo", Perms: "bar,baz"}) - assert.Error(err, "unable to run drush command") + c.AddBreach(&result.KeyValuesBreach{ + BreachType: "key-values", + KeyLabel: "role", + Key: "foo", + ValueLabel: "permissions", + Values: []string{"bar", "baz"}, + }) + c.Remediate() + assert.EqualValues([]result.Breach{&result.KeyValuesBreach{ + BreachType: "key-values", + KeyLabel: "role", + Key: "foo", + ValueLabel: "permissions", + Values: []string{"bar", "baz"}, + Remediation: result.Remediation{ + Status: result.RemediationStatusFailed, + Messages: []string{"failed to fix disallowed permissions for role 'foo' due to error: unable to run drush command"}, + }, + }}, c.Result.Breaches) }) t.Run("drushCommandIsCorrect", func(t *testing.T) { @@ -380,7 +277,13 @@ func TestDbPermissionsRemediate(t *testing.T) { command.ShellCommander = internal.ShellCommanderMaker(nil, nil, &generatedCommand) c := DbPermissionsCheck{} - c.Remediate(DbPermissionsBreach{Role: "foo", Perms: "bar,baz"}) + c.AddBreach(&result.KeyValuesBreach{ + KeyLabel: "role", + Key: "foo", + ValueLabel: "permissions", + Values: []string{"bar", "baz"}, + }) + c.Remediate() assert.Equal("vendor/drush/drush/drush role:perm:remove foo bar,baz", generatedCommand) }) } diff --git a/pkg/checks/drupal/dbtfausercheck.go b/pkg/checks/drupal/dbtfausercheck.go index 8e7153b..459e820 100644 --- a/pkg/checks/drupal/dbtfausercheck.go +++ b/pkg/checks/drupal/dbtfausercheck.go @@ -48,7 +48,7 @@ func (c *DbUserTfaCheck) FetchData() { res, err := Drush(c.DrushPath, c.Alias, cmd).Exec() if err != nil { c.Result.Status = result.Fail - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: "error fetching drush user info", Value: command.GetMsgFromCommandError(err), }) @@ -73,7 +73,7 @@ func (c *DbUserTfaCheck) RunCheck() { for _, user := range users { tfaDisabled = append(tfaDisabled, fmt.Sprintf("%s:%s", user.Name, user.UID)) } - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: "users with TFA disabled", Value: strings.Join(tfaDisabled, ", "), }) diff --git a/pkg/checks/drupal/dbtfausercheck_test.go b/pkg/checks/drupal/dbtfausercheck_test.go index 5e817f9..6205cd0 100644 --- a/pkg/checks/drupal/dbtfausercheck_test.go +++ b/pkg/checks/drupal/dbtfausercheck_test.go @@ -29,7 +29,7 @@ func TestDbTfaUserCheck(t *testing.T) { assert.Equal(result.Fail, c.Result.Status) assert.Empty(c.Result.Passes) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", CheckType: "drupal-db-user-tfa", Severity: "normal", @@ -62,7 +62,7 @@ func TestDbTfaUserCheck(t *testing.T) { assert.Equal(result.Fail, c.Result.Status) assert.Empty(c.Result.Passes) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", CheckType: "drupal-db-user-tfa", Severity: "normal", @@ -99,7 +99,7 @@ func TestDbTfaUserCheck(t *testing.T) { assert.Equal(result.Fail, c.Result.Status) assert.Empty(c.Result.Passes) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", CheckType: "drupal-db-user-tfa", Severity: "normal", diff --git a/pkg/checks/drupal/drupal.go b/pkg/checks/drupal/drupal.go index 28373ea..9c04af2 100644 --- a/pkg/checks/drupal/drupal.go +++ b/pkg/checks/drupal/drupal.go @@ -73,12 +73,12 @@ func CheckModulesInYaml(c *yaml.YamlBase, ct config.CheckType, configName string required_errored, required_disabled := DetermineModuleStatus(c.NodeMap[configName], ct, required) if len(required_errored) > 0 { - c.AddBreach(result.KeyValuesBreach{ + c.AddBreach(&result.KeyValuesBreach{ Key: "error verifying status for required modules", Values: required_errored}) } if len(required_disabled) > 0 { - c.AddBreach(result.KeyValuesBreach{ + c.AddBreach(&result.KeyValuesBreach{ Key: "required modules are not enabled", Values: required_disabled}) } @@ -94,12 +94,12 @@ func CheckModulesInYaml(c *yaml.YamlBase, ct config.CheckType, configName string disallowed_errored, disallowed_disabled := DetermineModuleStatus(c.NodeMap[configName], ct, disallowed) if len(disallowed_errored) > 0 { - c.AddBreach(result.KeyValuesBreach{ + c.AddBreach(&result.KeyValuesBreach{ Key: "error verifying status for disallowed modules", Values: disallowed_errored}) } if len(disallowed_enabled) > 0 { - c.AddBreach(result.KeyValuesBreach{ + c.AddBreach(&result.KeyValuesBreach{ Key: "disallowed modules are enabled", Values: disallowed_enabled}) } diff --git a/pkg/checks/drupal/drupal_test.go b/pkg/checks/drupal/drupal_test.go index ceaf6ad..9ef3ae2 100644 --- a/pkg/checks/drupal/drupal_test.go +++ b/pkg/checks/drupal/drupal_test.go @@ -241,7 +241,7 @@ module: "all disallowed modules are disabled", }) assert.EqualValues( - []result.Breach{result.KeyValuesBreach{ + []result.Breach{&result.KeyValuesBreach{ BreachType: "key-values", Key: "disallowed modules are enabled", Values: []string{"dblog"}, @@ -265,19 +265,18 @@ func TestCheckModulesInYaml(t *testing.T) { } c.UnmarshalDataMap() CheckModulesInYaml(&c, FileModule, "shipshape.extension.yml", required, disallowed) - assert.Equal(result.Fail, c.Result.Status) assert.ElementsMatch(c.Result.Passes, []string{ "some required modules are enabled: block", "some disallowed modules are disabled: views_ui", }) assert.ElementsMatch( []result.Breach{ - result.KeyValuesBreach{ + &result.KeyValuesBreach{ BreachType: "key-values", Key: "error verifying status for required modules", Values: []string{"invalid character '&' at position 11, following \".node\""}, }, - result.KeyValuesBreach{ + &result.KeyValuesBreach{ BreachType: "key-values", Key: "error verifying status for disallowed modules", Values: []string{"invalid character '&' at position 15, following \".field_ui\""}, @@ -304,19 +303,18 @@ module: } c.UnmarshalDataMap() CheckModulesInYaml(&c, FileModule, "shipshape.extension.yml", required, disallowed) - assert.Equal(result.Fail, c.Result.Status) assert.ElementsMatch(c.Result.Passes, []string{ "some required modules are enabled: block", "some disallowed modules are disabled: field_ui", }) assert.ElementsMatch( []result.Breach{ - result.KeyValuesBreach{ + &result.KeyValuesBreach{ BreachType: "key-values", Key: "required modules are not enabled", Values: []string{"node"}, }, - result.KeyValuesBreach{ + &result.KeyValuesBreach{ BreachType: "key-values", Key: "disallowed modules are enabled", Values: []string{"views_ui"}, diff --git a/pkg/checks/drupal/drush_test.go b/pkg/checks/drupal/drush_test.go index ddc7a15..5ab24b0 100644 --- a/pkg/checks/drupal/drush_test.go +++ b/pkg/checks/drupal/drush_test.go @@ -61,5 +61,5 @@ func TestDrushQuery(t *testing.T) { _, err := drupal.Drush("", "", []string{}).Query("SELECT uid FROM users") assert.NoError(t, err) - assert.Equal(t, "vendor/drush/drush/drush sql:query SELECT uid FROM users", generatedCommand) + assert.Equal(t, "vendor/drush/drush/drush sql:query 'SELECT uid FROM users'", generatedCommand) } diff --git a/pkg/checks/drupal/drushyamlcheck.go b/pkg/checks/drupal/drushyamlcheck.go index 82afc85..9351a81 100644 --- a/pkg/checks/drupal/drushyamlcheck.go +++ b/pkg/checks/drupal/drushyamlcheck.go @@ -1,13 +1,17 @@ package drupal import ( + "fmt" "io/fs" "os/exec" "strings" + "github.com/salsadigitalauorg/shipshape/pkg/command" "github.com/salsadigitalauorg/shipshape/pkg/config" "github.com/salsadigitalauorg/shipshape/pkg/result" "github.com/salsadigitalauorg/shipshape/pkg/utils" + + log "github.com/sirupsen/logrus" ) // Init implementation for the drush-based yaml check. @@ -39,13 +43,44 @@ func (c *DrushYamlCheck) FetchData() { c.DataMap[c.ConfigName], err = Drush(c.DrushPath, c.Alias, c.DrushCommand.Args).Exec() if err != nil { if pathErr, ok := err.(*fs.PathError); ok { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ Value: pathErr.Path + ": " + pathErr.Err.Error()}) } else { msg := string(err.(*exec.ExitError).Stderr) - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: c.ConfigName, Value: strings.ReplaceAll(strings.TrimSpace(msg), " \n ", "")}) } } } + +// Remediate attempts to remediate a breach by running the drush command +// specified in the check. +func (c *DrushYamlCheck) Remediate() { + for _, b := range c.Result.Breaches { + contextLogger := log.WithFields(log.Fields{ + "check-type": c.GetType(), + "check-name": c.GetName(), + "breach": b, + }) + if c.RemediateCommand == "" { + contextLogger.Print("no remediation command specified - failing") + b.SetRemediation(result.RemediationStatusNoSupport, "") + return + } + + contextLogger.Print("running remediation command") + _, err := command.ShellCommander("sh", "-c", c.RemediateCommand).Output() + if err != nil { + b.SetRemediation(result.RemediationStatusFailed, fmt.Sprintf( + "error running remediation command for config '%s' due to error: %s", + c.ConfigName, command.GetMsgFromCommandError(err))) + } else { + if c.RemediateMsg == "" { + c.RemediateMsg = fmt.Sprintf( + "remediation command for config '%s' ran successfully", c.ConfigName) + } + b.SetRemediation(result.RemediationStatusSuccess, c.RemediateMsg) + } + } +} diff --git a/pkg/checks/drupal/drushyamlcheck_test.go b/pkg/checks/drupal/drushyamlcheck_test.go index faf52c1..9719395 100644 --- a/pkg/checks/drupal/drushyamlcheck_test.go +++ b/pkg/checks/drupal/drushyamlcheck_test.go @@ -7,12 +7,21 @@ import ( . "github.com/salsadigitalauorg/shipshape/pkg/checks/drupal" "github.com/salsadigitalauorg/shipshape/pkg/checks/yaml" "github.com/salsadigitalauorg/shipshape/pkg/command" + "github.com/salsadigitalauorg/shipshape/pkg/config" "github.com/salsadigitalauorg/shipshape/pkg/internal" "github.com/salsadigitalauorg/shipshape/pkg/result" "github.com/stretchr/testify/assert" ) +func TestDrushYamlCheckInit(t *testing.T) { + assert := assert.New(t) + + c := DrushYamlCheck{} + c.Init(DrushYaml) + assert.True(c.RequiresDb) +} + func TestDrushYamlMerge(t *testing.T) { assert := assert.New(t) @@ -59,81 +68,199 @@ func TestDrushYamlMerge(t *testing.T) { }, c) } -func TestDrushYamlCheck(t *testing.T) { - assert := assert.New(t) - - t.Run("drushNotFound", func(t *testing.T) { - c := DrushYamlCheck{ - Command: "status", - ConfigName: "core.extension", - } - - c.Init(DrushYaml) - assert.True(c.RequiresDb) - - c.FetchData() - assert.Equal(result.Fail, c.Result.Status) - assert.Empty(c.Result.Passes) - assert.ElementsMatch( - []result.Breach{result.ValueBreach{ +func TestDrushYamlCheckFetchData(t *testing.T) { + tt := []internal.FetchDataTest{ + { + Name: "drushNotFound", + Check: &DrushYamlCheck{ + Command: "status", + ConfigName: "core.extension", + }, + ExpectBreaches: []result.Breach{&result.ValueBreach{ BreachType: "value", CheckType: "drush-yaml", Severity: "normal", Value: "vendor/drush/drush/drush: no such file or directory", }}, - c.Result.Breaches, - ) - }) - - curShellCommander := command.ShellCommander - defer func() { command.ShellCommander = curShellCommander }() + }, - t.Run("drushError", func(t *testing.T) { - c := DrushYamlCheck{ - Command: "status", - ConfigName: "core.extension", - } - - command.ShellCommander = internal.ShellCommanderMaker( - nil, - &exec.ExitError{Stderr: []byte("unable to run drush command")}, - nil, - ) - - c.FetchData() - assert.Equal(result.Fail, c.Result.Status) - assert.Empty(c.Result.Passes) - assert.ElementsMatch( - []result.Breach{result.ValueBreach{ + { + Name: "drushError", + Check: &DrushYamlCheck{ + Command: "status", + ConfigName: "core.extension", + }, + PreFetch: func(t *testing.T) { + command.ShellCommander = internal.ShellCommanderMaker( + nil, + &exec.ExitError{Stderr: []byte("unable to run drush command")}, + nil, + ) + }, + ExpectBreaches: []result.Breach{&result.ValueBreach{ BreachType: "value", + CheckType: "drush-yaml", + Severity: "normal", ValueLabel: "core.extension", Value: "unable to run drush command", }}, - c.Result.Breaches, - ) - }) + }, - t.Run("drushOK", func(t *testing.T) { - stdout := ` + { + Name: "drushOK", + Check: &DrushYamlCheck{ + Command: "status", + ConfigName: "core.extension", + }, + PreFetch: func(t *testing.T) { + stdout := ` module: block: 0 - views_ui: 0 + views_ui: 0 ` + command.ShellCommander = internal.ShellCommanderMaker( + &stdout, + nil, + nil, + ) + }, + }, + } + + curShellCommander := command.ShellCommander + defer func() { command.ShellCommander = curShellCommander }() + + for _, test := range tt { + t.Run(test.Name, func(t *testing.T) { + test.Check.Init(DrushYaml) + internal.TestFetchData(t, test) + }) + } +} + +func TestDrushYamlCheckRunCheck(t *testing.T) { + tests := []internal.RunCheckTest{ + { + Name: "pass", + Check: &DrushYamlCheck{ + YamlBase: yaml.YamlBase{ + CheckBase: config.CheckBase{ + DataMap: map[string][]byte{ + "core.extension": []byte(`{"profile":"standard"}`)}, + }, + Values: []yaml.KeyValue{ + {Key: "profile", Value: "standard"}, + }, + }, + ConfigName: "core.extension", + }, + ExpectStatus: result.Pass, + ExpectPasses: []string{"[core.extension] 'profile' equals 'standard'"}, + }, + { + Name: "breach", + Check: &DrushYamlCheck{ + YamlBase: yaml.YamlBase{ + CheckBase: config.CheckBase{ + DataMap: map[string][]byte{ + "core.extension": []byte(`{"profile":"minimal"}`)}, + }, + Values: []yaml.KeyValue{ + {Key: "profile", Value: "standard"}, + }, + }, + ConfigName: "core.extension", + }, + PreRun: func(t *testing.T) { + command.ShellCommander = internal.ShellCommanderMaker(nil, nil, nil) + }, + ExpectFails: []result.Breach{&result.KeyValueBreach{ + BreachType: "key-value", + KeyLabel: "core.extension", + Key: "profile", + ValueLabel: "actual", + Value: "minimal", + ExpectedValue: "standard", + }}, + ExpectStatus: result.Fail, + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + curShellCommander := command.ShellCommander + defer func() { command.ShellCommander = curShellCommander }() + test.Check.UnmarshalDataMap() + internal.TestRunCheck(t, test) + }) + } +} + +func TestDrushYamlCheckRemediate(t *testing.T) { + tt := []internal.RemediateTest{ + { + Name: "noCommand", + Check: &DrushYamlCheck{ + YamlBase: yaml.YamlBase{ + CheckBase: config.CheckBase{ + DataMap: map[string][]byte{ + "core.extension": []byte(`{"profile":"minimal"}`)}, + Result: result.Result{ + Breaches: []result.Breach{&result.ValueBreach{}}}}, + Values: []yaml.KeyValue{{Key: "profile", Value: "standard"}}}, + RemediateCommand: "", + ConfigName: "core.extension"}, + ExpectGeneratedCommand: "", + ExpectBreaches: []result.Breach{&result.ValueBreach{ + Remediation: result.Remediation{Status: "no-support"}}}, + ExpectStatusFail: true, + ExpectRemediationStatus: result.RemediationStatusNoSupport, + }, + { + Name: "simpleCommand", + Check: &DrushYamlCheck{ + YamlBase: yaml.YamlBase{ + CheckBase: config.CheckBase{ + Result: result.Result{ + Breaches: []result.Breach{&result.ValueBreach{}}}}}, + RemediateCommand: "drush config:set clamav.settings enabled 1"}, + ExpectGeneratedCommand: "sh -c 'drush config:set clamav.settings enabled 1'", + ExpectBreaches: []result.Breach{&result.ValueBreach{ + Remediation: result.Remediation{ + Status: "success", + Messages: []string{ + "remediation command for config '' ran successfully"}}}}, + ExpectRemediationStatus: result.RemediationStatusSuccess, + }, + { + Name: "multilineCommand", + Check: &DrushYamlCheck{ + YamlBase: yaml.YamlBase{ + CheckBase: config.CheckBase{ + Result: result.Result{ + Breaches: []result.Breach{&result.ValueBreach{}}}}}, + RemediateCommand: `#!/bin/bash +set -eu +drush config:set clamav.settings enabled true +`}, + ExpectGeneratedCommand: `sh -c '#!/bin/bash +set -eu +drush config:set clamav.settings enabled true +'`, + ExpectBreaches: []result.Breach{&result.ValueBreach{ + Remediation: result.Remediation{ + Status: "success", + Messages: []string{ + "remediation command for config '' ran successfully"}}}}, + ExpectRemediationStatus: result.RemediationStatusSuccess, + }, + } + + for _, tc := range tt { + t.Run(tc.Name, func(t *testing.T) { + internal.TestRemediate(t, tc) + }) + } - command.ShellCommander = internal.ShellCommanderMaker( - &stdout, - nil, - nil, - ) - - c := DrushYamlCheck{ - Command: "status", - ConfigName: "core.extension", - } - c.FetchData() - assert.NotEqual(result.Fail, c.Result.Status) - assert.Empty(c.Result.Passes) - assert.Empty(c.Result.Breaches) - }) } diff --git a/pkg/checks/drupal/forbiddenusercheck.go b/pkg/checks/drupal/forbiddenusercheck.go index f6a9e53..84a5492 100644 --- a/pkg/checks/drupal/forbiddenusercheck.go +++ b/pkg/checks/drupal/forbiddenusercheck.go @@ -40,11 +40,11 @@ func (c *ForbiddenUserCheck) CheckUserStatus() bool { userStatus, err := Drush(c.DrushPath, c.Alias, cmd).Exec() var pathError *fs.PathError if err != nil && errors.As(err, &pathError) { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ Value: pathError.Path + ": " + pathError.Err.Error()}) } else if err != nil { msg := string(err.(*exec.ExitError).Stderr) - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ Value: strings.ReplaceAll(strings.TrimSpace(msg), " \n ", "")}) } else { // Unmarshal user:info JSON. @@ -57,7 +57,7 @@ func (c *ForbiddenUserCheck) CheckUserStatus() bool { err = json.Unmarshal(userStatus, &userStatusMap) var syntaxError *json.SyntaxError if err != nil && errors.As(err, &syntaxError) { - c.AddBreach(result.ValueBreach{Value: err.Error()}) + c.AddBreach(&result.ValueBreach{Value: err.Error()}) } if userStatusMap[c.UserId]["user_status"] == "1" { @@ -69,18 +69,23 @@ func (c *ForbiddenUserCheck) CheckUserStatus() bool { } // Remediate attempts to block an active forbidden user. -func (c *ForbiddenUserCheck) Remediate(breachIfc interface{}) error { - _, err := Drush(c.DrushPath, c.Alias, []string{"user:block", "--uid=" + c.UserId}).Exec() - if err != nil { - c.AddBreach(result.KeyValueBreach{ - KeyLabel: "user", - Key: c.UserId, - ValueLabel: "error blocking forbidden user", - Value: command.GetMsgFromCommandError(err), - }) - return err +func (c *ForbiddenUserCheck) Remediate() { + for _, b := range c.Result.Breaches { + if _, ok := b.(*result.KeyValueBreach); !ok { + continue + } + + _, err := Drush(c.DrushPath, c.Alias, []string{"user:block", "--uid=" + c.UserId}).Exec() + if err != nil { + b.SetRemediation(result.RemediationStatusFailed, fmt.Sprintf( + "error blocking forbidden user '%s' due to error: %s", + c.UserId, command.GetMsgFromCommandError(err))) + } else { + b.SetRemediation(result.RemediationStatusSuccess, fmt.Sprintf( + "Blocked the forbidden user [%s]", c.UserId)) + } } - return nil + } // Merge implementation for ForbiddenUserCheck check. @@ -97,16 +102,10 @@ func (c *ForbiddenUserCheck) HasData(failCheck bool) bool { func (c *ForbiddenUserCheck) RunCheck() { userActive := c.CheckUserStatus() if userActive { - if c.PerformRemediation { - if err := c.Remediate(nil); err == nil { - c.AddRemediation(fmt.Sprintf("Blocked the forbidden user [%s]", c.UserId)) - } - } else { - c.AddBreach(result.KeyValueBreach{ - Key: "forbidden user is active", - Value: c.UserId, - }) - } + c.AddBreach(&result.KeyValueBreach{ + Key: "forbidden user is active", + Value: c.UserId, + }) } if len(c.Result.Breaches) == 0 { diff --git a/pkg/checks/drupal/forbiddenusercheck_test.go b/pkg/checks/drupal/forbiddenusercheck_test.go index 1e4b916..17fd2f1 100644 --- a/pkg/checks/drupal/forbiddenusercheck_test.go +++ b/pkg/checks/drupal/forbiddenusercheck_test.go @@ -6,7 +6,6 @@ import ( "github.com/salsadigitalauorg/shipshape/pkg/checks/drupal" "github.com/salsadigitalauorg/shipshape/pkg/command" - "github.com/salsadigitalauorg/shipshape/pkg/config" "github.com/salsadigitalauorg/shipshape/pkg/internal" "github.com/salsadigitalauorg/shipshape/pkg/result" "github.com/stretchr/testify/assert" @@ -46,14 +45,14 @@ func TestForbiddenUserCheck_RunCheck(t *testing.T) { t.Run("failOnDrushNotFound", func(t *testing.T) { c := drupal.ForbiddenUserCheck{} c.RunCheck() + c.Result.DetermineResultStatus(false) assertions.Equal(result.Fail, c.Result.Status) assertions.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", Value: "vendor/drush/drush/drush: no such file or directory", }}, - c.Result.Breaches, - ) + c.Result.Breaches) }) t.Run("failOnDrushError", func(t *testing.T) { @@ -69,7 +68,7 @@ func TestForbiddenUserCheck_RunCheck(t *testing.T) { c.RunCheck() assertions.Empty(c.Result.Passes) assertions.ElementsMatch( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", CheckType: "drupal-user-forbidden", Severity: "normal", @@ -91,10 +90,11 @@ func TestForbiddenUserCheck_RunCheck(t *testing.T) { nil, ) c.RunCheck() + c.Result.DetermineResultStatus(false) assertions.Equal(result.Fail, c.Result.Status) assertions.Empty(c.Result.Passes) assertions.ElementsMatch( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", CheckType: "drupal-user-forbidden", Severity: "normal", @@ -121,10 +121,11 @@ func TestForbiddenUserCheck_RunCheck(t *testing.T) { nil, ) c.RunCheck() + c.Result.DetermineResultStatus(false) assertions.Equal(result.Fail, c.Result.Status) assertions.Empty(c.Result.Passes) assertions.ElementsMatch( - []result.Breach{result.KeyValueBreach{ + []result.Breach{&result.KeyValueBreach{ BreachType: "key-value", CheckType: "drupal-user-forbidden", Severity: "normal", @@ -167,24 +168,38 @@ func TestForbiddenUserCheck_Remediate(t *testing.T) { defer func() { command.ShellCommander = curShellCommander }() t.Run("failOnDrushError", func(t *testing.T) { - c := drupal.ForbiddenUserCheck{} - c.Init(drupal.ForbiddenUser) - assertions.True(c.RequiresDb) + c := drupal.ForbiddenUserCheck{UserId: "1"} + c.AddBreach(&result.KeyValueBreach{ + BreachType: "key-value", + Key: "forbidden user is active", + Value: c.UserId, + }) command.ShellCommander = internal.ShellCommanderMaker( nil, &exec.ExitError{Stderr: []byte("Unable to find a matching user")}, nil, ) - err := c.Remediate(nil) - assertions.NotNil(err) - msg := string(err.(*exec.ExitError).Stderr) - assertions.Equal("Unable to find a matching user", msg) + c.Remediate() + assertions.EqualValues([]result.Breach{&result.KeyValueBreach{ + BreachType: "key-value", + Key: "forbidden user is active", + Value: c.UserId, + Remediation: result.Remediation{ + Status: result.RemediationStatusFailed, + Messages: []string{"error blocking forbidden user '1' due to error: " + + "Unable to find a matching user"}}}, + }, c.Result.Breaches) + c.Result.DetermineResultStatus(true) }) t.Run("passOnBlockingInactiveUser", func(t *testing.T) { - c := drupal.ForbiddenUserCheck{CheckBase: config.CheckBase{PerformRemediation: true}} - c.Init(drupal.ForbiddenUser) + c := drupal.ForbiddenUserCheck{UserId: "1"} + c.AddBreach(&result.KeyValueBreach{ + BreachType: "key-value", + Key: "forbidden user is active", + Value: c.UserId, + }) stdout := ` { @@ -198,16 +213,16 @@ func TestForbiddenUserCheck_Remediate(t *testing.T) { nil, nil, ) - c.RunCheck() + c.Remediate() + assertions.EqualValues([]result.Breach{&result.KeyValueBreach{ + BreachType: "key-value", + Key: "forbidden user is active", + Value: c.UserId, + Remediation: result.Remediation{ + Status: result.RemediationStatusSuccess, + Messages: []string{"Blocked the forbidden user [1]"}}}, + }, c.Result.Breaches) + c.Result.DetermineResultStatus(true) assertions.Equal(result.Pass, c.Result.Status) - assertions.Empty(c.Result.Breaches) - assertions.ElementsMatch( - []string{"Blocked the forbidden user [1]"}, - c.Result.Remediations, - ) - assertions.ElementsMatch( - []string{"No forbidden user is active."}, - c.Result.Passes, - ) }) } diff --git a/pkg/checks/drupal/rolepermissionscheck.go b/pkg/checks/drupal/rolepermissionscheck.go index d86974b..b82231f 100644 --- a/pkg/checks/drupal/rolepermissionscheck.go +++ b/pkg/checks/drupal/rolepermissionscheck.go @@ -3,10 +3,9 @@ package drupal import ( "encoding/json" "errors" - "io/fs" - "os/exec" "strings" + "github.com/salsadigitalauorg/shipshape/pkg/command" "github.com/salsadigitalauorg/shipshape/pkg/config" "github.com/salsadigitalauorg/shipshape/pkg/result" "github.com/salsadigitalauorg/shipshape/pkg/utils" @@ -32,6 +31,16 @@ func (c *RolePermissionsCheck) Init(ct config.CheckType) { c.RequiresDb = true } +// Merge implementation for RolePermissionsCheck check. +func (c *RolePermissionsCheck) Merge(mergeCheck config.Check) error { + return nil +} + +// HasData implementation for RolePermissionsCheck check. +func (c *RolePermissionsCheck) HasData(failCheck bool) bool { + return true +} + // GetRolePermissions get the permissions of the role. func (c *RolePermissionsCheck) GetRolePermissions() []string { // Command: drush role:list --filter=id=anonymous --fields=perms --format=json @@ -39,14 +48,8 @@ func (c *RolePermissionsCheck) GetRolePermissions() []string { drushOutput, err := Drush(c.DrushPath, c.Alias, cmd).Exec() - var pathError *fs.PathError - if err != nil && errors.As(err, &pathError) { - c.AddBreach(result.ValueBreach{ - Value: pathError.Path + ": " + pathError.Err.Error()}) - } else if err != nil { - msg := string(err.(*exec.ExitError).Stderr) - c.AddBreach(result.ValueBreach{ - Value: strings.ReplaceAll(strings.TrimSpace(msg), " \n ", "")}) + if err != nil { + c.AddBreach(&result.ValueBreach{Value: command.GetMsgFromCommandError(err)}) } else { // Unmarshal role:list JSON. // { @@ -63,7 +66,7 @@ func (c *RolePermissionsCheck) GetRolePermissions() []string { err = json.Unmarshal(drushOutput, &rolePermissionsMap) var syntaxError *json.SyntaxError if err != nil && errors.As(err, &syntaxError) { - c.AddBreach(result.ValueBreach{Value: err.Error()}) + c.AddBreach(&result.ValueBreach{Value: err.Error()}) } if len(rolePermissionsMap[c.RoleId]["perms"]) > 0 { @@ -74,20 +77,10 @@ func (c *RolePermissionsCheck) GetRolePermissions() []string { return nil } -// HasData implementation for RolePermissionsCheck check. -func (c *RolePermissionsCheck) HasData(failCheck bool) bool { - return true -} - -// Merge implementation for RolePermissionsCheck check. -func (c *RolePermissionsCheck) Merge(mergeCheck config.Check) error { - return nil -} - // RunCheck implements the Check logic for role permissions. func (c *RolePermissionsCheck) RunCheck() { if c.RoleId == "" { - c.AddBreach(result.ValueBreach{Value: "no role ID provided"}) + c.AddBreach(&result.ValueBreach{Value: "no role ID provided"}) return } @@ -95,7 +88,7 @@ func (c *RolePermissionsCheck) RunCheck() { // Check for required permissions. diff := utils.StringSlicesInterdiffUnique(rolePermissions, c.RequiredPermissions) if len(diff) > 0 { - c.AddBreach(result.KeyValueBreach{ + c.AddBreach(&result.KeyValueBreach{ KeyLabel: "role", Key: c.RoleId, ValueLabel: "missing permissions", @@ -106,7 +99,7 @@ func (c *RolePermissionsCheck) RunCheck() { // Check for disallowed permissions. diff = utils.StringSlicesIntersectUnique(rolePermissions, c.DisallowedPermissions) if len(diff) > 0 { - c.AddBreach(result.KeyValueBreach{ + c.AddBreach(&result.KeyValueBreach{ KeyLabel: "role", Key: c.RoleId, ValueLabel: "disallowed permissions", diff --git a/pkg/checks/drupal/rolepermissionscheck_test.go b/pkg/checks/drupal/rolepermissionscheck_test.go index 82a3919..ce653ea 100644 --- a/pkg/checks/drupal/rolepermissionscheck_test.go +++ b/pkg/checks/drupal/rolepermissionscheck_test.go @@ -42,9 +42,10 @@ func TestRolePermissionsCheck_RunCheck(t *testing.T) { t.Run("failOnNoRoleProvided", func(t *testing.T) { c := drupal.RolePermissionsCheck{} c.RunCheck() + c.Result.DetermineResultStatus(false) assertions.Equal(result.Fail, c.Result.Status) assertions.ElementsMatch( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: result.BreachTypeValue, Value: "no role ID provided"}}, c.Result.Breaches, @@ -56,9 +57,10 @@ func TestRolePermissionsCheck_RunCheck(t *testing.T) { RoleId: "authenticated", } c.RunCheck() + c.Result.DetermineResultStatus(false) assertions.Equal(result.Fail, c.Result.Status) assertions.ElementsMatch( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: result.BreachTypeValue, Value: "vendor/drush/drush/drush: no such file or directory"}}, c.Result.Breaches) @@ -77,9 +79,10 @@ func TestRolePermissionsCheck_RunCheck(t *testing.T) { nil, ) c.RunCheck() + c.Result.DetermineResultStatus(false) assertions.Empty(c.Result.Passes) assertions.ElementsMatch( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: result.BreachTypeValue, CheckType: "drupal-role-permissions", Severity: "normal", @@ -101,10 +104,11 @@ func TestRolePermissionsCheck_RunCheck(t *testing.T) { nil, ) c.RunCheck() + c.Result.DetermineResultStatus(false) assertions.Equal(result.Fail, c.Result.Status) assertions.Empty(c.Result.Passes) assertions.ElementsMatch( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: result.BreachTypeValue, CheckType: "drupal-role-permissions", Severity: "normal", @@ -142,11 +146,12 @@ func TestRolePermissionsCheck_RunCheck(t *testing.T) { nil, ) c.RunCheck() + c.Result.DetermineResultStatus(false) assertions.Equal(result.Fail, c.Result.Status) assertions.Empty(c.Result.Passes) assertions.ElementsMatch( []result.Breach{ - result.KeyValueBreach{ + &result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, CheckType: "drupal-role-permissions", Severity: "normal", @@ -155,7 +160,7 @@ func TestRolePermissionsCheck_RunCheck(t *testing.T) { ValueLabel: "missing permissions", Value: "[setup own tfa]", }, - result.KeyValueBreach{ + &result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, CheckType: "drupal-role-permissions", Severity: "normal", @@ -197,6 +202,7 @@ func TestRolePermissionsCheck_RunCheck(t *testing.T) { nil, ) c.RunCheck() + c.Result.DetermineResultStatus(false) assertions.Equal(result.Pass, c.Result.Status) assertions.Empty(c.Result.Breaches) }) diff --git a/pkg/checks/drupal/trackingcodecheck.go b/pkg/checks/drupal/trackingcodecheck.go index dc9f0ed..8cab389 100644 --- a/pkg/checks/drupal/trackingcodecheck.go +++ b/pkg/checks/drupal/trackingcodecheck.go @@ -35,14 +35,14 @@ func (c *TrackingCodeCheck) Merge(mergeCheck config.Check) error { // type for further processing. func (c *TrackingCodeCheck) UnmarshalDataMap() { if len(c.DataMap[c.ConfigName]) == 0 { - c.AddBreach(result.ValueBreach{Value: "no data provided"}) + c.AddBreach(&result.ValueBreach{Value: "no data provided"}) } c.DrushStatus = DrushStatus{} err := yaml.Unmarshal(c.DataMap[c.ConfigName], &c.DrushStatus) if err != nil { if _, ok := err.(*yaml.TypeError); !ok { - c.AddBreach(result.ValueBreach{Value: err.Error()}) + c.AddBreach(&result.ValueBreach{Value: err.Error()}) return } } @@ -52,7 +52,7 @@ func (c *TrackingCodeCheck) RunCheck() { resp, err := http.Get(c.DrushStatus.Uri) if err != nil { - c.AddBreach(result.ValueBreach{Value: "could not determine site uri"}) + c.AddBreach(&result.ValueBreach{Value: "could not determine site uri"}) return } @@ -65,7 +65,7 @@ func (c *TrackingCodeCheck) RunCheck() { c.AddPass(fmt.Sprintf("tracking code [%s] present", c.Code)) c.Result.Status = result.Pass } else { - c.AddBreach(result.KeyValueBreach{ + c.AddBreach(&result.KeyValueBreach{ Key: "required tracking code not present", Value: c.Code, }) diff --git a/pkg/checks/drupal/trackingcodecheck_test.go b/pkg/checks/drupal/trackingcodecheck_test.go index f497688..323a820 100644 --- a/pkg/checks/drupal/trackingcodecheck_test.go +++ b/pkg/checks/drupal/trackingcodecheck_test.go @@ -86,9 +86,10 @@ func TestTrackingCodeCheckFails(t *testing.T) { Uri: "https://google.com", } c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.ElementsMatch( - []result.Breach{result.KeyValueBreach{ + []result.Breach{&result.KeyValueBreach{ BreachType: "key-value", CheckType: "drupal-tracking-code", Severity: "normal", @@ -111,6 +112,7 @@ func TestTrackingCodeCheckPass(t *testing.T) { Uri: "https://gist.github.com/Pominova/cf7884e7418f6ebfa412d2d3dc472a97", } c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Pass, c.Result.Status) assert.ElementsMatch( []string{"tracking code [UA-xxxxxx-1] present"}, diff --git a/pkg/checks/drupal/types.go b/pkg/checks/drupal/types.go index d9cb823..b88c697 100644 --- a/pkg/checks/drupal/types.go +++ b/pkg/checks/drupal/types.go @@ -22,10 +22,12 @@ type DrushCommand struct { } type DrushYamlCheck struct { - yaml.YamlBase `yaml:",inline"` - DrushCommand `yaml:",inline"` - Command string `yaml:"command"` - ConfigName string `yaml:"config-name"` + yaml.YamlBase `yaml:",inline"` + DrushCommand `yaml:",inline"` + Command string `yaml:"command"` + ConfigName string `yaml:"config-name"` + RemediateCommand string `yaml:"remediate-command"` + RemediateMsg string `yaml:"remediate-msg"` } type FileModuleCheck struct { diff --git a/pkg/checks/drupal/userrolecheck.go b/pkg/checks/drupal/userrolecheck.go index b4817da..4833a18 100644 --- a/pkg/checks/drupal/userrolecheck.go +++ b/pkg/checks/drupal/userrolecheck.go @@ -8,6 +8,7 @@ import ( "os/exec" "strings" + "github.com/salsadigitalauorg/shipshape/pkg/command" "github.com/salsadigitalauorg/shipshape/pkg/config" "github.com/salsadigitalauorg/shipshape/pkg/result" "github.com/salsadigitalauorg/shipshape/pkg/utils" @@ -55,11 +56,11 @@ func (c *UserRoleCheck) getUserIds() string { var pathErr *fs.PathError if err != nil && errors.As(err, &pathErr) { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ Value: pathErr.Path + ": " + pathErr.Err.Error()}) } else if err != nil { msg := string(err.(*exec.ExitError).Stderr) - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ Value: strings.ReplaceAll(strings.TrimSpace(msg), " \n ", "")}) } return string(userIds) @@ -70,7 +71,7 @@ func (c *UserRoleCheck) FetchData() { var err error userIds := c.getUserIds() - if c.Result.Status == result.Fail { + if len(c.Result.Breaches) > 0 { return } @@ -78,8 +79,8 @@ func (c *UserRoleCheck) FetchData() { cmd := []string{"user:information", "--uid=" + userIds, "--fields=roles", "--format=json"} c.DataMap["user-info"], err = Drush(c.DrushPath, c.Alias, cmd).Exec() if err != nil { - msg := string(err.(*exec.ExitError).Stderr) - c.AddBreach(result.ValueBreach{ + msg := command.GetMsgFromCommandError(err) + c.AddBreach(&result.ValueBreach{ Value: strings.ReplaceAll(strings.TrimSpace(msg), " \n ", "")}) } } @@ -88,7 +89,7 @@ func (c *UserRoleCheck) FetchData() { // into the userRoles for further processing. func (c *UserRoleCheck) UnmarshalDataMap() { if len(c.DataMap["user-info"]) == 0 { - c.AddBreach(result.ValueBreach{Value: "no data provided"}) + c.AddBreach(&result.ValueBreach{Value: "no data provided"}) return } @@ -96,7 +97,7 @@ func (c *UserRoleCheck) UnmarshalDataMap() { err := json.Unmarshal(c.DataMap["user-info"], &userInfoMap) var synErr *json.SyntaxError if err != nil && errors.As(err, &synErr) { - c.AddBreach(result.ValueBreach{Value: err.Error()}) + c.AddBreach(&result.ValueBreach{Value: err.Error()}) return } @@ -109,7 +110,7 @@ func (c *UserRoleCheck) UnmarshalDataMap() { // RunCheck implements the Check logic for disallowed user roles. func (c *UserRoleCheck) RunCheck() { if len(c.Roles) == 0 { - c.AddBreach(result.ValueBreach{Value: "no disallowed role provided"}) + c.AddBreach(&result.ValueBreach{Value: "no disallowed role provided"}) return } @@ -121,7 +122,7 @@ func (c *UserRoleCheck) RunCheck() { disallowed := utils.StringSlicesIntersect(roles, c.Roles) if len(disallowed) > 0 { - c.AddBreach(result.KeyValuesBreach{ + c.AddBreach(&result.KeyValuesBreach{ KeyLabel: "user", Key: fmt.Sprintf("%d", uid), ValueLabel: "disallowed roles", diff --git a/pkg/checks/drupal/userrolecheck_test.go b/pkg/checks/drupal/userrolecheck_test.go index 1d559a5..d97d63d 100644 --- a/pkg/checks/drupal/userrolecheck_test.go +++ b/pkg/checks/drupal/userrolecheck_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestInit(t *testing.T) { +func TestUserRoleCheckInit(t *testing.T) { c := UserRoleCheck{} c.Init(UserRole) assert.True(t, c.RequiresDb) @@ -47,15 +47,14 @@ func TestUserRoleMerge(t *testing.T) { }, c) } -func TestFetchData(t *testing.T) { +func TestUserRoleCheckFetchData(t *testing.T) { assert := assert.New(t) t.Run("drushNotFound", func(t *testing.T) { c := UserRoleCheck{} c.FetchData() - assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", Value: "vendor/drush/drush/drush: no such file or directory", }}, @@ -63,9 +62,9 @@ func TestFetchData(t *testing.T) { ) }) - curShellCommander := command.ShellCommander - defer func() { command.ShellCommander = curShellCommander }() t.Run("drushError", func(t *testing.T) { + curShellCommander := command.ShellCommander + defer func() { command.ShellCommander = curShellCommander }() sqlQueryFail := true command.ShellCommander = func(name string, arg ...string) command.IShellCommand { var stdout []byte @@ -83,9 +82,8 @@ func TestFetchData(t *testing.T) { } c := UserRoleCheck{} c.FetchData() - assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", Value: "unable to run drush sql query", }}, @@ -95,9 +93,8 @@ func TestFetchData(t *testing.T) { sqlQueryFail = false c = UserRoleCheck{} c.FetchData() - assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", Value: "unable to run drush command", }}, @@ -120,15 +117,14 @@ func TestFetchData(t *testing.T) { }) } -func TestUnmarshalData(t *testing.T) { +func TestUserRoleCheckUnmarshalData(t *testing.T) { assert := assert.New(t) // Empty datamap. c := UserRoleCheck{} c.UnmarshalDataMap() - assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", Value: "no data provided", }}, @@ -143,9 +139,8 @@ func TestUnmarshalData(t *testing.T) { }, } c.UnmarshalDataMap() - assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", Value: "invalid character ']' after object key:value pair", }}, @@ -166,64 +161,58 @@ func TestUnmarshalData(t *testing.T) { assert.Equal("map[int][]string{1:[]string{\"authenticated\"}}", fmt.Sprintf("%#v", userRolesVal)) } -func TestRunCheck(t *testing.T) { - assert := assert.New(t) - - // No disallowed roles provided. - c := UserRoleCheck{ - CheckBase: config.CheckBase{ - DataMap: map[string][]byte{ - "user-info": []byte(`{"1":{"roles":["authenticated"]}}`)}, - }, - } - c.UnmarshalDataMap() - c.RunCheck() - assert.Equal(result.Fail, c.Result.Status) - assert.EqualValues( - []result.Breach{result.ValueBreach{ - BreachType: "value", - Value: "no disallowed role provided", - }}, - c.Result.Breaches, - ) - - // User has disallowed roles. - c = UserRoleCheck{ - CheckBase: config.CheckBase{ - DataMap: map[string][]byte{ - "user-info": []byte(`{"1":{"roles":["authenticated","site-admin","content-admin"]}}`)}, - }, - Roles: []string{"site-admin", "content-admin"}, - } - c.UnmarshalDataMap() - c.RunCheck() - assert.Equal(result.Fail, c.Result.Status) - assert.EqualValues( - []result.Breach{result.KeyValuesBreach{ - BreachType: "key-values", - KeyLabel: "user", - Key: "1", - ValueLabel: "disallowed roles", - Values: []string{"site-admin", "content-admin"}, - }}, - c.Result.Breaches, - ) - - // User allowed to have disallowed roles. - c = UserRoleCheck{ - CheckBase: config.CheckBase{ - DataMap: map[string][]byte{ - "user-info": []byte(` +func TestUserRoleCheckRunCheck(t *testing.T) { + tt := []internal.RunCheckTest{ + { + Name: "noDisallowedRoleProvided", + Check: &UserRoleCheck{ + CheckBase: config.CheckBase{ + DataMap: map[string][]byte{ + "user-info": []byte(`{"1":{"roles":["authenticated"]}}`)}, + }, + }, + ExpectStatus: result.Fail, + ExpectFails: []result.Breach{&result.ValueBreach{ + BreachType: "value", + Value: "no disallowed role provided"}}}, + + { + Name: "userHasDisallowedRoles", + Check: &UserRoleCheck{ + CheckBase: config.CheckBase{ + DataMap: map[string][]byte{ + "user-info": []byte(`{"1":{"roles":["authenticated","site-admin","content-admin"]}}`)}, + }, + Roles: []string{"site-admin", "content-admin"}, + }, + ExpectStatus: result.Fail, + ExpectFails: []result.Breach{&result.KeyValuesBreach{ + BreachType: "key-values", + KeyLabel: "user", + Key: "1", + ValueLabel: "disallowed roles", + Values: []string{"site-admin", "content-admin"}}}}, + + { + Name: "userAllowedToHaveDisallowedRoles", + Check: &UserRoleCheck{ + CheckBase: config.CheckBase{ + DataMap: map[string][]byte{ + "user-info": []byte(` { "1":{"roles":["authenticated"]}, "2":{"roles":["authenticated","site-admin","content-admin"]} } `)}, - }, - Roles: []string{"site-admin", "content-admin"}, - AllowedUsers: []int{2}, + }, + Roles: []string{"site-admin", "content-admin"}, + AllowedUsers: []int{2}, + }, + ExpectStatus: result.Pass}, + } + + for _, tc := range tt { + tc.Check.UnmarshalDataMap() + internal.TestRunCheck(t, tc) } - c.UnmarshalDataMap() - c.RunCheck() - assert.Equal(result.Pass, c.Result.Status) } diff --git a/pkg/checks/file/filecheck.go b/pkg/checks/file/filecheck.go index 5433fb7..98b58ef 100644 --- a/pkg/checks/file/filecheck.go +++ b/pkg/checks/file/filecheck.go @@ -55,7 +55,7 @@ func (c *FileCheck) RequiresData() bool { return false } func (c *FileCheck) RunCheck() { files, err := utils.FindFiles(filepath.Join(config.ProjectDir, c.Path), c.DisallowedPattern, c.ExcludePattern, c.SkipDir) if err != nil { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: "error finding files", Value: err.Error()}) return @@ -65,7 +65,7 @@ func (c *FileCheck) RunCheck() { c.AddPass("No illegal files") return } - c.AddBreach(result.KeyValueBreach{ + c.AddBreach(&result.KeyValueBreach{ Key: fmt.Sprintf("%s - illegal files found", c.Name), Value: strings.Join(files, "\n"), }) diff --git a/pkg/checks/file/filecheck_test.go b/pkg/checks/file/filecheck_test.go index 7ddb3e6..99b08b6 100644 --- a/pkg/checks/file/filecheck_test.go +++ b/pkg/checks/file/filecheck_test.go @@ -55,10 +55,11 @@ func TestFileCheckRunCheck(t *testing.T) { c.Name = "filecheck1" c.Init(File) c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.Equal(0, len(c.Result.Passes)) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ CheckType: "file", CheckName: "filecheck1", BreachType: result.BreachTypeValue, @@ -75,11 +76,12 @@ func TestFileCheckRunCheck(t *testing.T) { c.Name = "filecheck2" c.Init(File) c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.Equal(0, len(c.Result.Passes)) assert.EqualValues( []result.Breach{ - result.KeyValueBreach{ + &result.KeyValueBreach{ CheckType: "file", CheckName: "filecheck2", BreachType: result.BreachTypeKeyValue, diff --git a/pkg/checks/json/json.go b/pkg/checks/json/json.go index afd1799..059d937 100644 --- a/pkg/checks/json/json.go +++ b/pkg/checks/json/json.go @@ -21,7 +21,7 @@ func (c *JsonCheck) UnmarshalDataMap() { var n any err := json.Unmarshal(data, &n) if err != nil { - c.AddBreach(result.ValueBreach{ValueLabel: "JSON error", Value: err.Error()}) + c.AddBreach(&result.ValueBreach{ValueLabel: "JSON error", Value: err.Error()}) return } c.Node[configName] = n @@ -36,16 +36,16 @@ func (c *JsonCheck) processData(configName string) { kvr, fails, err := CheckKeyValue(c.Node[configName], kv) switch kvr { case yaml.KeyValueError: - c.AddBreach(result.ValueBreach{Value: err.Error()}) + c.AddBreach(&result.ValueBreach{Value: err.Error()}) case yaml.KeyValueNotFound: - c.AddBreach(result.KeyValueBreach{ + c.AddBreach(&result.KeyValueBreach{ KeyLabel: "config", Key: configName, ValueLabel: "key not found", Value: kv.Key, }) case yaml.KeyValueNotEqual: - c.AddBreach(result.KeyValueBreach{ + c.AddBreach(&result.KeyValueBreach{ KeyLabel: configName, Key: kv.Key, ValueLabel: "actual", @@ -53,7 +53,7 @@ func (c *JsonCheck) processData(configName string) { Value: fails[0], }) case yaml.KeyValueDisallowedFound: - c.AddBreach(result.KeyValuesBreach{ + c.AddBreach(&result.KeyValuesBreach{ KeyLabel: "config", Key: configName, ValueLabel: fmt.Sprintf("disallowed %s", kv.Key), diff --git a/pkg/checks/json/json_test.go b/pkg/checks/json/json_test.go index 490a0b0..1ee42e9 100644 --- a/pkg/checks/json/json_test.go +++ b/pkg/checks/json/json_test.go @@ -32,11 +32,10 @@ func TestJsonCheckUnmarshalDataMap(t *testing.T) { } c.UnmarshalDataMap() - assertions.Equal(result.Fail, c.Result.Status) assertions.EqualValues(0, len(c.Result.Passes)) assertions.ElementsMatch( []result.Breach{ - result.ValueBreach{ + &result.ValueBreach{ BreachType: result.BreachTypeValue, ValueLabel: "JSON error", Value: "invalid character 'p' looking for beginning of value", diff --git a/pkg/checks/json/jsoncheck_test.go b/pkg/checks/json/jsoncheck_test.go index daa7ab4..403a7cd 100644 --- a/pkg/checks/json/jsoncheck_test.go +++ b/pkg/checks/json/jsoncheck_test.go @@ -58,31 +58,30 @@ func TestJsonCheckMerge(t *testing.T) { assertions.EqualError(err, "can only merge checks with the same name") } -func TestJsonCheckRunCheck(t *testing.T) { - assertions := assert.New(t) - - mockCheck := func() JsonCheck { - return JsonCheck{ - KeyValues: []KeyValue{ - { - KeyValue: yaml.KeyValue{ - Key: "$.license", - Value: "MIT", - }, - DisallowedValues: nil, - AllowedValues: nil, +func MockJsonCheck() JsonCheck { + return JsonCheck{ + KeyValues: []KeyValue{ + { + KeyValue: yaml.KeyValue{ + Key: "$.license", + Value: "MIT", }, + DisallowedValues: nil, + AllowedValues: nil, }, - } + }, } +} - c := mockCheck() +func TestJsonCheckFetchData(t *testing.T) { + assertions := assert.New(t) + + c := MockJsonCheck() c.FetchData() - assertions.Equal(result.Fail, c.Result.Status) assertions.Empty(c.Result.Passes) assertions.ElementsMatch( []result.Breach{ - result.ValueBreach{ + &result.ValueBreach{ BreachType: result.BreachTypeValue, ValueLabel: "- no file", Value: "no file provided", @@ -93,15 +92,14 @@ func TestJsonCheckRunCheck(t *testing.T) { // Non-existent file. config.ProjectDir = "testdata" - c = mockCheck() + c = MockJsonCheck() c.Init(Json) c.File = "non-existent.json" c.FetchData() - assertions.Equal(result.Fail, c.Result.Status) assertions.Empty(c.Result.Passes) assertions.ElementsMatch( []result.Breach{ - result.ValueBreach{ + &result.ValueBreach{ CheckType: "json", Severity: "normal", BreachType: result.BreachTypeValue, @@ -113,38 +111,22 @@ func TestJsonCheckRunCheck(t *testing.T) { ) // Non-existent file with ignore missing. - c = mockCheck() + c = MockJsonCheck() c.File = "non-existent.json" c.IgnoreMissing = &cTrue c.FetchData() - assertions.Equal(result.Pass, c.Result.Status) assertions.Empty(c.Result.Breaches) assertions.EqualValues([]string{"File testdata/non-existent.json does not exist"}, c.Result.Passes) - // Single file. - c = mockCheck() - c.File = "composer.map.json" - c.FetchData() - // Should not fail yet. - assertions.NotEqual(result.Fail, c.Result.Status) - assertions.Empty(c.Result.Breaches) - assertions.True(c.HasData(false)) - c.UnmarshalDataMap() - c.RunCheck() - assertions.Equal(result.Pass, c.Result.Status) - assertions.Empty(c.Result.Breaches) - assertions.EqualValues([]string{"[composer.map.json] '$.license' equals 'MIT'"}, c.Result.Passes) - // Bad File pattern. - c = mockCheck() + c = MockJsonCheck() c.Pattern = "*.composer.json" c.Path = "" c.FetchData() - assertions.Equal(result.Fail, c.Result.Status) assertions.Empty(c.Result.Passes) assertions.ElementsMatch( []result.Breach{ - result.ValueBreach{ + &result.ValueBreach{ BreachType: result.BreachTypeValue, ValueLabel: "error finding files in path: testdata", Value: "error parsing regexp: missing argument to repetition operator: `*`", @@ -154,14 +136,13 @@ func TestJsonCheckRunCheck(t *testing.T) { ) // File pattern with no matching files. - c = mockCheck() + c = MockJsonCheck() c.Pattern = "composer*.json" c.FetchData() - assertions.Equal(result.Fail, c.Result.Status) assertions.Empty(c.Result.Passes) assertions.ElementsMatch( []result.Breach{ - result.ValueBreach{ + &result.ValueBreach{ BreachType: result.BreachTypeValue, ValueLabel: "- no file", Value: "no matching yaml files found", @@ -171,20 +152,34 @@ func TestJsonCheckRunCheck(t *testing.T) { ) // File pattern with no matching files, ignoring missing. - c = mockCheck() + c = MockJsonCheck() c.Pattern = "composer*.json" c.IgnoreMissing = &cTrue c.FetchData() - assertions.Equal(result.Pass, c.Result.Status) assertions.Empty(c.Result.Breaches) assertions.EqualValues([]string{"no matching config files found"}, c.Result.Passes) +} + +func TestJsonCheckRunCheck(t *testing.T) { + assertions := assert.New(t) + + // Single file. + c := MockJsonCheck() + c.File = "composer.map.json" + c.FetchData() + assertions.Empty(c.Result.Breaches) + assertions.True(c.HasData(false)) + c.UnmarshalDataMap() + c.RunCheck() + assertions.Equal(result.Pass, c.Result.Status) + assertions.Empty(c.Result.Breaches) + assertions.EqualValues([]string{"[composer.map.json] '$.license' equals 'MIT'"}, c.Result.Passes) // Correct single file pattern & value. - c = mockCheck() + c = MockJsonCheck() c.Pattern = "composer.map.json" c.Path = "dir/subdir" c.FetchData() - assertions.NotEqual(result.Fail, c.Result.Status) assertions.Empty(c.Result.Breaches) c.UnmarshalDataMap() c.RunCheck() @@ -192,10 +187,9 @@ func TestJsonCheckRunCheck(t *testing.T) { assertions.Empty(c.Result.Breaches) // Recursive file lookup. - c = mockCheck() + c = MockJsonCheck() c.Pattern = ".*.*.json" c.FetchData() - assertions.NotEqual(result.Fail, c.Result.Status) assertions.Empty(c.Result.Breaches) c.UnmarshalDataMap() c.RunCheck() @@ -209,7 +203,7 @@ func TestJsonCheckRunCheck(t *testing.T) { c.Result.Passes) assertions.ElementsMatch( []result.Breach{ - result.KeyValueBreach{ + &result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, KeyLabel: "testdata/composer.array.json", Key: "$.license", @@ -217,7 +211,7 @@ func TestJsonCheckRunCheck(t *testing.T) { Value: "BSD", ExpectedValue: "MIT", }, - result.KeyValueBreach{ + &result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, KeyLabel: "testdata/dir/composer.array.json", Key: "$.license", @@ -225,7 +219,7 @@ func TestJsonCheckRunCheck(t *testing.T) { Value: "BSD", ExpectedValue: "MIT", }, - result.KeyValueBreach{ + &result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, KeyLabel: "testdata/dir/subdir/composer.array.json", Key: "$.license", @@ -252,7 +246,6 @@ func TestJsonCheckRunCheck(t *testing.T) { } c.File = "composer.map.json" c.FetchData() - assertions.NotEqual(result.Fail, c.Result.Status) assertions.Empty(c.Result.Breaches) assertions.True(c.HasData(false)) c.UnmarshalDataMap() @@ -261,7 +254,7 @@ func TestJsonCheckRunCheck(t *testing.T) { assertions.Empty(c.Result.Passes) assertions.ElementsMatch( []result.Breach{ - result.KeyValuesBreach{ + &result.KeyValuesBreach{ BreachType: result.BreachTypeKeyValues, KeyLabel: "config", Key: "composer.map.json", @@ -286,7 +279,6 @@ func TestJsonCheckRunCheck(t *testing.T) { } c.File = "composer.map.json" c.FetchData() - assertions.NotEqual(result.Fail, c.Result.Status) assertions.Empty(c.Result.Breaches) assertions.True(c.HasData(false)) c.UnmarshalDataMap() @@ -295,7 +287,7 @@ func TestJsonCheckRunCheck(t *testing.T) { assertions.Empty(c.Result.Passes) assertions.ElementsMatch( []result.Breach{ - result.KeyValuesBreach{ + &result.KeyValuesBreach{ BreachType: result.BreachTypeKeyValues, KeyLabel: "config", Key: "composer.map.json", @@ -320,7 +312,6 @@ func TestJsonCheckRunCheck(t *testing.T) { } c.File = "composer.map.json" c.FetchData() - assertions.NotEqual(result.Fail, c.Result.Status) assertions.Empty(c.Result.Breaches) assertions.True(c.HasData(false)) c.UnmarshalDataMap() @@ -329,7 +320,7 @@ func TestJsonCheckRunCheck(t *testing.T) { assertions.Empty(c.Result.Passes) assertions.ElementsMatch( []result.Breach{ - result.ValueBreach{ + &result.ValueBreach{ BreachType: result.BreachTypeValue, Value: "json: invalid path format: found invalid path character * after dot", }, @@ -349,7 +340,6 @@ func TestJsonCheckRunCheck(t *testing.T) { } c.File = "composer.map.json" c.FetchData() - assertions.NotEqual(result.Fail, c.Result.Status) assertions.Empty(c.Result.Breaches) assertions.True(c.HasData(false)) c.UnmarshalDataMap() @@ -358,7 +348,7 @@ func TestJsonCheckRunCheck(t *testing.T) { assertions.Empty(c.Result.Passes) assertions.ElementsMatch( []result.Breach{ - result.KeyValueBreach{ + &result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, KeyLabel: "config", Key: "composer.map.json", @@ -382,7 +372,6 @@ func TestJsonCheckRunCheck(t *testing.T) { } c.File = "composer.map.json" c.FetchData() - assertions.NotEqual(result.Fail, c.Result.Status) assertions.Empty(c.Result.Breaches) assertions.True(c.HasData(false)) c.UnmarshalDataMap() diff --git a/pkg/checks/phpstan/phpstan.go b/pkg/checks/phpstan/phpstan.go index bd7a0fa..0d366b5 100644 --- a/pkg/checks/phpstan/phpstan.go +++ b/pkg/checks/phpstan/phpstan.go @@ -119,11 +119,11 @@ func (c *PhpStanCheck) FetchData() { c.DataMap["phpstan"], err = command.ShellCommander(phpstanPath, args...).Output() if err != nil { if pathErr, ok := err.(*fs.PathError); ok { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: pathErr.Path, Value: pathErr.Err.Error()}) } else if len(c.DataMap["phpstan"]) == 0 { // If errors were found, exit code will be 1. - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: "Phpstan failed to run", Value: string(err.(*exec.ExitError).Stderr)}) } @@ -135,7 +135,7 @@ func (c *PhpStanCheck) FetchData() { func (c *PhpStanCheck) HasData(failCheck bool) bool { if c.DataMap == nil && len(c.Result.Passes) == 0 { if failCheck { - c.AddBreach(result.ValueBreach{Value: "no data available"}) + c.AddBreach(&result.ValueBreach{Value: "no data available"}) } return false } @@ -158,7 +158,7 @@ func (c *PhpStanCheck) UnmarshalDataMap() { c.phpstanResult = PhpStanResult{} err := json.Unmarshal(c.DataMap["phpstan"], &c.phpstanResult) if err != nil { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: "unable to parse phpstan result", Value: err.Error()}) return @@ -172,7 +172,7 @@ func (c *PhpStanCheck) UnmarshalDataMap() { // Unmarshal file errors. err = json.Unmarshal(c.phpstanResult.FilesRaw, &c.phpstanResult.Files) if err != nil { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: "unable to parse phpstan file errors", Value: err.Error()}) return @@ -193,14 +193,14 @@ func (c *PhpStanCheck) RunCheck() { errLines = append(errLines, fmt.Sprintf("line %d: %s", er.Line, er.Message)) } - c.AddBreach(result.KeyValueBreach{ + c.AddBreach(&result.KeyValueBreach{ Key: fmt.Sprintf("file contains banned functions: %s", file), Value: strings.Join(errLines, "\n"), }) } if len(c.phpstanResult.Errors) > 0 { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: "errors encountered when running phpstan", Value: strings.Join(c.phpstanResult.Errors, "\n")}) } diff --git a/pkg/checks/phpstan/phpstan_test.go b/pkg/checks/phpstan/phpstan_test.go index 35a4d7f..852dfe7 100644 --- a/pkg/checks/phpstan/phpstan_test.go +++ b/pkg/checks/phpstan/phpstan_test.go @@ -128,10 +128,8 @@ func TestFetchDataBinNotExists(t *testing.T) { Paths: []string{dir}, } c.FetchData() - - assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", ValueLabel: "Phpstan failed to run", Value: "/my/custom/path/phpstan: no such file or directory", @@ -175,9 +173,8 @@ func TestHasData(t *testing.T) { assert := assert.New(t) c := PhpStanCheck{} assert.False(c.HasData(true)) - assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", Value: "no data available", }}, @@ -227,9 +224,8 @@ func TestUnmarshalDataMap(t *testing.T) { }, } c.UnmarshalDataMap() - assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", ValueLabel: "unable to parse phpstan file errors", Value: "json: cannot unmarshal array into Go value of type " + @@ -255,6 +251,7 @@ func TestRunCheck(t *testing.T) { } c.UnmarshalDataMap() c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Pass, c.Result.Status) assert.EqualValues([]string{"no error found"}, c.Result.Passes) @@ -268,9 +265,10 @@ func TestRunCheck(t *testing.T) { } c.UnmarshalDataMap() c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.KeyValueBreach{ + []result.Breach{&result.KeyValueBreach{ BreachType: "key-value", Key: "file contains banned functions: /app/web/themes/custom/custom/test-theme/info.php", Value: "line 3: Calling curl_exec() is forbidden, please change the code", @@ -288,9 +286,10 @@ func TestRunCheck(t *testing.T) { } c.UnmarshalDataMap() c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.ValueBreach{ + []result.Breach{&result.ValueBreach{ BreachType: "value", ValueLabel: "errors encountered when running phpstan", Value: "Error found in file foo", diff --git a/pkg/checks/sca/apptypecheck.go b/pkg/checks/sca/apptypecheck.go index a3864fa..a4c257e 100644 --- a/pkg/checks/sca/apptypecheck.go +++ b/pkg/checks/sca/apptypecheck.go @@ -105,7 +105,7 @@ func (c *AppTypeCheck) RunCheck() { } } if len(disallowedFound) > 0 { - c.AddBreach(result.KeyValueBreach{ + c.AddBreach(&result.KeyValueBreach{ Key: fmt.Sprintf("[%s] contains disallowed frameworks", path), Value: "[" + strings.Join(disallowedFound, ", ") + "]", }) diff --git a/pkg/checks/sca/apptypecheck_test.go b/pkg/checks/sca/apptypecheck_test.go index d8f7d04..39ed5cc 100644 --- a/pkg/checks/sca/apptypecheck_test.go +++ b/pkg/checks/sca/apptypecheck_test.go @@ -24,9 +24,10 @@ func TestIsDrupalCheck(t *testing.T) { c.Dependencies["drupal"] = []string{"drupal/core-recommended"} c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.KeyValueBreach{ + []result.Breach{&result.KeyValueBreach{ BreachType: "key-value", Key: "[./testdata/drupal] contains disallowed frameworks", Value: "[drupal]", @@ -64,9 +65,10 @@ func TestIsWordpressCheck(t *testing.T) { c.Dirs["wordpress"] = []string{"wp-admin", "wp-content", "wp-includes"} c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.KeyValueBreach{ + []result.Breach{&result.KeyValueBreach{ BreachType: "key-value", Key: "[./testdata/wordpress] contains disallowed frameworks", Value: "[wordpress]", @@ -90,9 +92,10 @@ func TestIsSymfonyCheck(t *testing.T) { c.Dependencies["symfony"] = []string{"symfony/runtime", "symfony/symfony", "symfony/framework"} c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.KeyValueBreach{ + []result.Breach{&result.KeyValueBreach{ BreachType: "key-value", Key: "[./testdata/symfony] contains disallowed frameworks", Value: "[symfony]", @@ -116,9 +119,10 @@ func TestIsLaravelCheck(t *testing.T) { c.Dependencies["laravel"] = []string{"laravel/framework", "laravel/tinker"} c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.EqualValues( - []result.Breach{result.KeyValueBreach{ + []result.Breach{&result.KeyValueBreach{ BreachType: "key-value", Key: "[./testdata/laravel] contains disallowed frameworks", Value: "[laravel]", diff --git a/pkg/checks/yaml/yaml.go b/pkg/checks/yaml/yaml.go index 10ff471..948ba72 100644 --- a/pkg/checks/yaml/yaml.go +++ b/pkg/checks/yaml/yaml.go @@ -52,7 +52,7 @@ func (c *YamlBase) UnmarshalDataMap() { n := yaml.Node{} err := yaml.Unmarshal([]byte(data), &n) if err != nil { - c.AddBreach(result.ValueBreach{Value: err.Error()}) + c.AddBreach(&result.ValueBreach{Value: err.Error()}) return } c.NodeMap[configName] = n @@ -66,16 +66,16 @@ func (c *YamlBase) determineBreaches(configName string) { kvr, fails, err := CheckKeyValue(c.NodeMap[configName], kv) switch kvr { case KeyValueError: - c.AddBreach(result.ValueBreach{Value: err.Error()}) + c.AddBreach(&result.ValueBreach{Value: err.Error()}) case KeyValueNotFound: - c.AddBreach(result.KeyValueBreach{ + c.AddBreach(&result.KeyValueBreach{ KeyLabel: "config", Key: configName, ValueLabel: "key not found", Value: kv.Key, }) case KeyValueNotEqual: - c.AddBreach(result.KeyValueBreach{ + c.AddBreach(&result.KeyValueBreach{ KeyLabel: configName, Key: kv.Key, ValueLabel: "actual", @@ -83,7 +83,7 @@ func (c *YamlBase) determineBreaches(configName string) { Value: fails[0], }) case KeyValueDisallowedFound: - c.AddBreach(result.KeyValuesBreach{ + c.AddBreach(&result.KeyValuesBreach{ KeyLabel: "config", Key: configName, ValueLabel: fmt.Sprintf("disallowed %s", kv.Key), diff --git a/pkg/checks/yaml/yaml_test.go b/pkg/checks/yaml/yaml_test.go index d8c3177..6e22bb0 100644 --- a/pkg/checks/yaml/yaml_test.go +++ b/pkg/checks/yaml/yaml_test.go @@ -52,9 +52,8 @@ foo: }, } c.UnmarshalDataMap() - assert.Equal(result.Fail, c.Result.Status) assert.EqualValues(0, len(c.Result.Passes)) - assert.ElementsMatch([]result.Breach{result.ValueBreach{ + assert.ElementsMatch([]result.Breach{&result.ValueBreach{ BreachType: result.BreachTypeValue, Value: "yaml: line 4: found character that cannot start any token"}}, c.Result.Breaches) @@ -91,9 +90,10 @@ foo: }, } c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) - assert.ElementsMatch([]result.Breach{result.ValueBreach{ + assert.ElementsMatch([]result.Breach{&result.ValueBreach{ BreachType: result.BreachTypeValue, Value: "invalid character '&' at position 3, following \"baz\""}}, c.Result.Breaches) @@ -270,8 +270,7 @@ func TestYamlBase(t *testing.T) { c := YamlBase{} c.HasData(true) - assert.Equal(result.Fail, c.Result.Status) - assert.ElementsMatch([]result.Breach{result.ValueBreach{ + assert.ElementsMatch([]result.Breach{&result.ValueBreach{ BreachType: result.BreachTypeValue, Value: "no data available"}}, c.Result.Breaches) @@ -314,9 +313,10 @@ notification: }, } c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.EqualValues(0, len(c.Result.Passes)) - assert.ElementsMatch([]result.Breach{result.KeyValueBreach{ + assert.ElementsMatch([]result.Breach{&result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, KeyLabel: "config", Key: "data", @@ -334,9 +334,10 @@ notification: } c.UnmarshalDataMap() c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.EqualValues(0, len(c.Result.Passes)) - assert.ElementsMatch([]result.Breach{result.KeyValueBreach{ + assert.ElementsMatch([]result.Breach{&result.KeyValueBreach{ BreachType: result.BreachTypeKeyValue, KeyLabel: "data", Key: "check.interval_days", @@ -359,6 +360,7 @@ notification: } c.UnmarshalDataMap() c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Pass, c.Result.Status) assert.EqualValues(0, len(c.Result.Breaches)) assert.EqualValues( @@ -393,9 +395,10 @@ efgh: } c.UnmarshalDataMap() c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.EqualValues(0, len(c.Result.Passes)) - assert.ElementsMatch([]result.Breach{result.KeyValuesBreach{ + assert.ElementsMatch([]result.Breach{&result.KeyValuesBreach{ BreachType: result.BreachTypeKeyValues, KeyLabel: "config", Key: "data", @@ -432,9 +435,10 @@ foo: c := mockCheck() c.UnmarshalDataMap() c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.EqualValues(0, len(c.Result.Passes)) - assert.ElementsMatch([]result.Breach{result.KeyValuesBreach{ + assert.ElementsMatch([]result.Breach{&result.KeyValuesBreach{ BreachType: result.BreachTypeKeyValues, KeyLabel: "config", Key: "data", @@ -446,6 +450,7 @@ foo: c.Values[0].Disallowed = []string{"e"} c.UnmarshalDataMap() c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Pass, c.Result.Status) assert.EqualValues(0, len(c.Result.Breaches)) assert.EqualValues([]string{"[data] no disallowed 'foo'"}, c.Result.Passes) diff --git a/pkg/checks/yaml/yamlcheck.go b/pkg/checks/yaml/yamlcheck.go index 08775ca..169bce3 100644 --- a/pkg/checks/yaml/yamlcheck.go +++ b/pkg/checks/yaml/yamlcheck.go @@ -38,7 +38,7 @@ func (c *YamlCheck) readFile(fkey string, fname string) { c.AddPass(fmt.Sprintf("File %s does not exist", fname)) c.Result.Status = result.Pass } else { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: "error reading file: " + fname, Value: err.Error()}) } @@ -65,7 +65,7 @@ func (c *YamlCheck) FetchData() { c.AddPass(fmt.Sprintf("Path %s does not exist", configPath)) c.Result.Status = result.Pass } else { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: "error finding files in path: " + configPath, Value: err.Error()}) } @@ -77,7 +77,7 @@ func (c *YamlCheck) FetchData() { c.Result.Status = result.Pass return } else if len(files) == 0 { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: c.Name + "- no file", Value: "no matching yaml files found"}) return @@ -88,7 +88,7 @@ func (c *YamlCheck) FetchData() { c.readFile(fname, fname) } } else { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: c.Name + "- no file", Value: "no file provided"}) } diff --git a/pkg/checks/yaml/yamlcheck_test.go b/pkg/checks/yaml/yamlcheck_test.go index b0494ed..fae5cd0 100644 --- a/pkg/checks/yaml/yamlcheck_test.go +++ b/pkg/checks/yaml/yamlcheck_test.go @@ -5,6 +5,7 @@ import ( . "github.com/salsadigitalauorg/shipshape/pkg/checks/yaml" "github.com/salsadigitalauorg/shipshape/pkg/config" + "github.com/salsadigitalauorg/shipshape/pkg/internal" "github.com/salsadigitalauorg/shipshape/pkg/result" "github.com/stretchr/testify/assert" ) @@ -36,175 +37,204 @@ func TestYamlCheckMerge(t *testing.T) { }, c) } -func TestYamlCheck(t *testing.T) { - assert := assert.New(t) - - mockCheck := func() YamlCheck { - return YamlCheck{ - YamlBase: YamlBase{ - Values: []KeyValue{ - {Key: "check.interval_days", Value: "7"}, +func TestYamlCheckFetchData(t *testing.T) { + tt := []internal.FetchDataTest{ + { + Name: "noFile", + Check: &YamlCheck{ + YamlBase: YamlBase{ + Values: []KeyValue{ + {Key: "check.interval_days", Value: "7"}, + }, }, }, - } - } - - c := mockCheck() - c.FetchData() - assert.Equal(result.Fail, c.Result.Status) - assert.Empty(c.Result.Passes) - assert.EqualValues( - []result.Breach{ - result.ValueBreach{ - BreachType: "value", - ValueLabel: "- no file", - Value: "no file provided", + ExpectBreaches: []result.Breach{ + &result.ValueBreach{ + BreachType: "value", + CheckType: "yaml", + Severity: "normal", + ValueLabel: "- no file", + Value: "no file provided", + }, }, }, - c.Result.Breaches, - ) - // Non-existent file. - config.ProjectDir = "testdata" - c = mockCheck() - c.Init(Yaml) - c.File = "non-existent.yml" - c.FetchData() - assert.Equal(result.Fail, c.Result.Status) - assert.Empty(c.Result.Passes) - assert.EqualValues( - []result.Breach{ - result.ValueBreach{ - BreachType: "value", - CheckType: "yaml", - Severity: "normal", - ValueLabel: "error reading file: testdata/non-existent.yml", - Value: "open testdata/non-existent.yml: no such file or directory", + { + Name: "nonExistentFile", + Check: &YamlCheck{ + YamlBase: YamlBase{ + Values: []KeyValue{ + {Key: "check.interval_days", Value: "7"}, + }, + }, + File: "non-existent.yml", + }, + ExpectBreaches: []result.Breach{ + &result.ValueBreach{ + BreachType: "value", + CheckType: "yaml", + Severity: "normal", + ValueLabel: "error reading file: testdata/non-existent.yml", + Value: "open testdata/non-existent.yml: no such file or directory", + }, }, }, - c.Result.Breaches, - ) - // Non-existent file with ignore missing. - c = mockCheck() - c.File = "non-existent.yml" - c.IgnoreMissing = &cTrue - c.FetchData() - assert.Equal(result.Pass, c.Result.Status) - assert.Empty(c.Result.Breaches) - assert.EqualValues([]string{"File testdata/non-existent.yml does not exist"}, c.Result.Passes) + { + Name: "nonExistentFileIgnoreMissing", + Check: &YamlCheck{ + YamlBase: YamlBase{ + Values: []KeyValue{ + {Key: "check.interval_days", Value: "7"}, + }, + }, + File: "non-existent.yml", + IgnoreMissing: &cTrue, + }, + ExpectPasses: []string{"File testdata/non-existent.yml does not exist"}, + }, - // Single file. - c = mockCheck() - c.File = "update.settings.yml" - c.FetchData() - // Should not fail yet. - assert.NotEqual(result.Fail, c.Result.Status) - assert.Empty(c.Result.Breaches) - assert.True(c.HasData(false)) - c.UnmarshalDataMap() - c.RunCheck() - assert.Equal(result.Pass, c.Result.Status) - assert.Empty(c.Result.Breaches) - assert.EqualValues([]string{"[update.settings.yml] 'check.interval_days' equals '7'"}, c.Result.Passes) + { + Name: "singleFile", + Check: &YamlCheck{ + YamlBase: YamlBase{ + Values: []KeyValue{ + {Key: "check.interval_days", Value: "7"}, + }, + }, + File: "update.settings.yml", + }, + ExpectDataMap: map[string][]byte{ + "update.settings.yml": []byte( + `check: + interval_days: 7 +notification: + emails: + - admin@example.com +`), + }, + }, - // Bad File pattern. - c = mockCheck() - c.Pattern = "*.bar.yml" - c.Path = "" - c.FetchData() - assert.Equal(result.Fail, c.Result.Status) - assert.Empty(c.Result.Passes) - assert.EqualValues( - []result.Breach{ - result.ValueBreach{ - BreachType: "value", - ValueLabel: "error finding files in path: testdata", - Value: "error parsing regexp: missing argument to repetition operator: `*`", + { + Name: "badFilePattern", + Check: &YamlCheck{ + YamlBase: YamlBase{ + Values: []KeyValue{ + {Key: "check.interval_days", Value: "7"}, + }, + }, + Pattern: "*.bar.yml", + Path: "", + }, + ExpectBreaches: []result.Breach{ + &result.ValueBreach{ + BreachType: "value", + CheckType: "yaml", + Severity: "normal", + ValueLabel: "error finding files in path: testdata", + Value: "error parsing regexp: missing argument to repetition operator: `*`", + }, }, }, - c.Result.Breaches, - ) - // File pattern with no matching files. - c = mockCheck() - c.Pattern = "bla.*.yml" - c.FetchData() - assert.Equal(result.Fail, c.Result.Status) - assert.Empty(c.Result.Passes) - assert.EqualValues( - []result.Breach{ - result.ValueBreach{ - BreachType: "value", - ValueLabel: "- no file", - Value: "no matching yaml files found", + { + Name: "filePatternNoMatchingFile", + Check: &YamlCheck{ + YamlBase: YamlBase{ + Values: []KeyValue{ + {Key: "check.interval_days", Value: "7"}, + }, + }, + Pattern: "bla.*.yml", + }, + ExpectBreaches: []result.Breach{ + &result.ValueBreach{ + BreachType: "value", + CheckType: "yaml", + Severity: "normal", + ValueLabel: "- no file", + Value: "no matching yaml files found", + }, }, }, - c.Result.Breaches, - ) - // File pattern with no matching files, ignoring missing. - c = mockCheck() - c.Pattern = "bla.*.yml" - c.IgnoreMissing = &cTrue - c.FetchData() - assert.Equal(result.Pass, c.Result.Status) - assert.Empty(c.Result.Breaches) - assert.EqualValues([]string{"no matching config files found"}, c.Result.Passes) + { + Name: "filePatternNoMatchingFileIgnoreMissing", + Check: &YamlCheck{ + YamlBase: YamlBase{ + Values: []KeyValue{ + {Key: "check.interval_days", Value: "7"}, + }, + }, + Pattern: "bla.*.yml", + IgnoreMissing: &cTrue, + }, + ExpectPasses: []string{"no matching config files found"}, + }, - // Correct single file pattern & value. - c = mockCheck() - c.Pattern = "foo.bar.yml" - c.Path = "dir/subdir" - c.FetchData() - assert.NotEqual(result.Fail, c.Result.Status) - assert.Empty(c.Result.Breaches) - c.UnmarshalDataMap() - c.RunCheck() - assert.EqualValues([]string{"[testdata/dir/subdir/foo.bar.yml] 'check.interval_days' equals '7'"}, c.Result.Passes) - assert.Empty(c.Result.Breaches) + { + Name: "correctSingleFilePatternAndValue", + Check: &YamlCheck{ + YamlBase: YamlBase{ + Values: []KeyValue{ + {Key: "check.interval_days", Value: "7"}, + }, + }, + Pattern: "foo.bar.yml", + Path: "dir/subdir", + }, + ExpectDataMap: map[string][]byte{ + "testdata/dir/subdir/foo.bar.yml": []byte( + `check: + interval_days: 7 +`), + }, + }, - // Recursive file lookup. - c = mockCheck() - c.Pattern = ".*.bar.yml" - c.FetchData() - assert.NotEqual(result.Fail, c.Result.Status) - assert.Empty(c.Result.Breaches) - c.UnmarshalDataMap() - c.RunCheck() - assert.Equal(result.Fail, c.Result.Status) - assert.ElementsMatch( - []string{ - "[testdata/dir/foo.bar.yml] 'check.interval_days' equals '7'", - "[testdata/dir/subdir/foo.bar.yml] 'check.interval_days' equals '7'", - "[testdata/foo.bar.yml] 'check.interval_days' equals '7'"}, - c.Result.Passes) - assert.ElementsMatch( - []result.Breach{ - result.KeyValueBreach{ - BreachType: "key-value", - KeyLabel: "testdata/dir/subdir/zoom.bar.yml", - Key: "check.interval_days", - ValueLabel: "actual", - Value: "5", - ExpectedValue: "7", - }, - result.KeyValueBreach{ - BreachType: "key-value", - KeyLabel: "testdata/dir/zoom.bar.yml", - Key: "check.interval_days", - ValueLabel: "actual", - Value: "5", - ExpectedValue: "7", - }, - result.KeyValueBreach{ - BreachType: "key-value", - KeyLabel: "testdata/zoom.bar.yml", - Key: "check.interval_days", - ValueLabel: "actual", - Value: "5", - ExpectedValue: "7", + { + Name: "recursiveFileLookup", + Check: &YamlCheck{ + YamlBase: YamlBase{ + Values: []KeyValue{ + {Key: "check.interval_days", Value: "7"}, + }, + }, + Pattern: ".*.bar.yml", + }, + ExpectDataMap: map[string][]byte{ + "testdata/dir/foo.bar.yml": []byte( + `check: + interval_days: 7 +`), + "testdata/dir/subdir/foo.bar.yml": []byte( + `check: + interval_days: 7 +`), + "testdata/dir/subdir/zoom.bar.yml": []byte( + `check: + interval_days: 5 +`), + "testdata/dir/zoom.bar.yml": []byte( + `check: + interval_days: 5 +`), + "testdata/foo.bar.yml": []byte( + `check: + interval_days: 7 +`), + "testdata/zoom.bar.yml": []byte( + `check: + interval_days: 5 +`), }, }, - c.Result.Breaches) + } + + config.ProjectDir = "testdata" + for _, tc := range tt { + t.Run(tc.Name, func(innerT *testing.T) { + tc.Check.Init(Yaml) + internal.TestFetchData(innerT, tc) + }) + } } diff --git a/pkg/checks/yaml/yamllintcheck.go b/pkg/checks/yaml/yamllintcheck.go index 755a34b..b1681f4 100644 --- a/pkg/checks/yaml/yamllintcheck.go +++ b/pkg/checks/yaml/yamllintcheck.go @@ -26,11 +26,11 @@ func (c *YamlLintCheck) UnmarshalDataMap() { err := yaml.Unmarshal([]byte(data), &ifc) if err != nil { if typeErr, ok := err.(*yaml.TypeError); ok { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: "cannot decode yaml: " + f, Value: strings.Join(typeErr.Errors, "\n")}) } else { - c.AddBreach(result.ValueBreach{ + c.AddBreach(&result.ValueBreach{ ValueLabel: "yaml error: " + f, Value: err.Error()}) } diff --git a/pkg/checks/yaml/yamllintcheck_test.go b/pkg/checks/yaml/yamllintcheck_test.go index 8adc192..453bca8 100644 --- a/pkg/checks/yaml/yamllintcheck_test.go +++ b/pkg/checks/yaml/yamllintcheck_test.go @@ -29,33 +29,32 @@ func TestYamlLintMerge(t *testing.T) { }, c) } -func TestYamlLintCheck(t *testing.T) { - assert := assert.New(t) - - mockCheck := func(file string, files []string, ignoreMissing bool) YamlLintCheck { - return YamlLintCheck{ - YamlCheck: YamlCheck{ - YamlBase: YamlBase{ - CheckBase: config.CheckBase{ - Name: "Test yaml lint", - DataMap: map[string][]byte{}, - }, +func MockYamlLintCheck(file string, files []string, ignoreMissing bool) YamlLintCheck { + return YamlLintCheck{ + YamlCheck: YamlCheck{ + YamlBase: YamlBase{ + CheckBase: config.CheckBase{ + Name: "Test yaml lint", + DataMap: map[string][]byte{}, }, - File: file, - Files: files, - IgnoreMissing: &ignoreMissing, }, - } + File: file, + Files: files, + IgnoreMissing: &ignoreMissing, + }, } +} + +func TestYamlLintCheckFetchData(t *testing.T) { + assert := assert.New(t) - c := mockCheck("", []string{}, false) + c := MockYamlLintCheck("", []string{}, false) c.Init(YamlLint) c.FetchData() - assert.Equal(result.Fail, c.Result.Status) assert.Empty(c.Result.Passes) assert.ElementsMatch( []result.Breach{ - result.ValueBreach{ + &result.ValueBreach{ BreachType: "value", CheckType: "yamllint", CheckName: "Test yaml lint", @@ -67,34 +66,31 @@ func TestYamlLintCheck(t *testing.T) { c.Result.Breaches, ) - c = mockCheck("non-existent-file.yml", []string{}, true) + c = MockYamlLintCheck("non-existent-file.yml", []string{}, true) c.Init(YamlLint) c.FetchData() - assert.NotEqual(result.Fail, c.Result.Status) assert.Empty(c.Result.Breaches) assert.ElementsMatch( []string{"File testdata/non-existent-file.yml does not exist"}, c.Result.Passes, ) - c = mockCheck("", []string{"non-existent-file.yml", "yaml-invalid.yml"}, true) + c = MockYamlLintCheck("", []string{"non-existent-file.yml", "yaml-invalid.yml"}, true) c.Init(YamlLint) c.FetchData() - assert.NotEqual(result.Fail, c.Result.Status) assert.Empty(c.Result.Breaches) assert.ElementsMatch([]string{ "File testdata/non-existent-file.yml does not exist", "File testdata/yaml-invalid.yml does not exist", }, c.Result.Passes) - c = mockCheck("non-existent-file.yml", []string{}, false) + c = MockYamlLintCheck("non-existent-file.yml", []string{}, false) c.Init(YamlLint) c.FetchData() - assert.Equal(result.Fail, c.Result.Status) assert.Empty(c.Result.Passes) assert.ElementsMatch( []result.Breach{ - result.ValueBreach{ + &result.ValueBreach{ BreachType: "value", CheckType: "yamllint", CheckName: "Test yaml lint", @@ -106,14 +102,13 @@ func TestYamlLintCheck(t *testing.T) { c.Result.Breaches, ) - c = mockCheck("", []string{"non-existent-file.yml", "yamllint-invalid.yml"}, false) + c = MockYamlLintCheck("", []string{"non-existent-file.yml", "yamllint-invalid.yml"}, false) c.Init(YamlLint) c.FetchData() - assert.Equal(result.Fail, c.Result.Status) assert.Empty(c.Result.Passes) assert.ElementsMatch( []result.Breach{ - result.ValueBreach{ + &result.ValueBreach{ BreachType: "value", CheckType: "yamllint", CheckName: "Test yaml lint", @@ -124,19 +119,24 @@ func TestYamlLintCheck(t *testing.T) { }, c.Result.Breaches, ) +} + +func TestYamlLintCheckUnmarshalDataMap(t *testing.T) { + assert := assert.New(t) - c = mockCheck("", []string{}, false) + c := MockYamlLintCheck("", []string{}, false) c.Init(YamlLint) c.DataMap["yaml-invalid.yml"] = []byte(` this: is invalid this: yaml `) c.UnmarshalDataMap() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.Empty(c.Result.Passes) assert.ElementsMatch( []result.Breach{ - result.ValueBreach{ + &result.ValueBreach{ BreachType: "value", CheckType: "yamllint", CheckName: "Test yaml lint", @@ -148,13 +148,14 @@ this: yaml c.Result.Breaches, ) - c = mockCheck("", []string{}, false) + c = MockYamlLintCheck("", []string{}, false) c.Init(YamlLint) c.DataMap["yaml-valid.yml"] = []byte(` this: is valid: yaml `) c.UnmarshalDataMap() + c.Result.DetermineResultStatus(false) assert.Equal(result.Pass, c.Result.Status) assert.Empty(c.Result.Breaches) assert.ElementsMatch( @@ -170,11 +171,12 @@ foo: bar - item 1 `)} c.UnmarshalDataMap() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) assert.Empty(c.Result.Passes) assert.ElementsMatch( []result.Breach{ - result.ValueBreach{ + &result.ValueBreach{ BreachType: "value", ValueLabel: "yaml error: yaml-invalid-root.yml", Value: "yaml: line 1: did not find expected key", @@ -193,6 +195,7 @@ foo: bar foo: bar `)} c.UnmarshalDataMap() + c.Result.DetermineResultStatus(false) assert.Equal(result.Pass, c.Result.Status) assert.Empty(c.Result.Breaches) assert.ElementsMatch( diff --git a/pkg/command/command.go b/pkg/command/command.go index 045dd35..4d1155f 100644 --- a/pkg/command/command.go +++ b/pkg/command/command.go @@ -10,6 +10,8 @@ import ( "errors" "io/fs" "os/exec" + + log "github.com/sirupsen/logrus" ) // IShellCommand is an interface for running shell commands. @@ -25,7 +27,12 @@ type ExecShellCommand struct { // NewExecShellCommander returns a command instance. func NewExecShellCommander(name string, arg ...string) IShellCommand { execCmd := exec.Command(name, arg...) - return ExecShellCommand{Cmd: execCmd} + return &ExecShellCommand{Cmd: execCmd} +} + +func (c *ExecShellCommand) Output() ([]byte, error) { + log.WithField("command", c).Debug("running command") + return c.Cmd.Output() } // ShellCommander provides a wrapper around the commander to allow for better diff --git a/pkg/command/command_test.go b/pkg/command/command_test.go index c7255f6..0af93ad 100644 --- a/pkg/command/command_test.go +++ b/pkg/command/command_test.go @@ -21,7 +21,7 @@ func TestExecReplacement(t *testing.T) { t.Run("differentStruct", func(t *testing.T) { cmd := command.ShellCommander("foo", "bar") - assert.IsType(command.ExecShellCommand{}, cmd) + assert.IsType(&command.ExecShellCommand{}, cmd) curShellCommander := command.ShellCommander defer func() { command.ShellCommander = curShellCommander }() diff --git a/pkg/config/checkbase.go b/pkg/config/checkbase.go index fd83dc2..e88e412 100644 --- a/pkg/config/checkbase.go +++ b/pkg/config/checkbase.go @@ -63,7 +63,7 @@ func (c *CheckBase) FetchData() {} func (c *CheckBase) HasData(failCheck bool) bool { if c.DataMap == nil { if failCheck { - c.AddBreach(result.ValueBreach{Value: "no data available"}) + c.AddBreach(&result.ValueBreach{Value: "no data available"}) } return false } @@ -76,8 +76,7 @@ func (c *CheckBase) UnmarshalDataMap() {} // AddBreach appends a Breach to the Result and sets the Check as Fail. func (c *CheckBase) AddBreach(b result.Breach) { - c.Result.Status = result.Fail - result.BreachSetCommonValues(&b, string(c.cType), c.Name, string(c.Severity)) + b.SetCommonValues(string(c.cType), c.Name, string(c.Severity)) c.Result.Breaches = append( c.Result.Breaches, b, @@ -102,26 +101,28 @@ func (c *CheckBase) SetPerformRemediation(flag bool) { c.PerformRemediation = flag } -// AddWarning appends a Warning message to the result. -func (c *CheckBase) AddRemediation(msg string) { - c.Result.Remediations = append(c.Result.Remediations, msg) -} - // RunCheck contains the core logic for running the check, // generating the result and remediating breaches. // This is where c.Result should be populated. func (c *CheckBase) RunCheck() { - c.AddBreach(result.ValueBreach{Value: "not implemented"}) + c.AddBreach(&result.ValueBreach{Value: "not implemented"}) } -// GetResult returns a ref of the result. -func (c *CheckBase) GetResult() *result.Result { - return &c.Result +// ShouldPerformRemediation returns whether to remediate or not. +func (c *CheckBase) ShouldPerformRemediation() bool { + return c.PerformRemediation } // Remediate should implement the logic to fix the breach(es). // Any type or custom struct can be used for the breach; it just needs to be // cast to the required type before being used. -func (c *CheckBase) Remediate(breachIfc interface{}) error { - return nil +func (c *CheckBase) Remediate() { + for _, b := range c.Result.Breaches { + b.SetRemediation(result.RemediationStatusNoSupport, "") + } +} + +// GetResult returns a ref of the result. +func (c *CheckBase) GetResult() *result.Result { + return &c.Result } diff --git a/pkg/config/checkbase_test.go b/pkg/config/checkbase_test.go index b7daccf..b84e2af 100644 --- a/pkg/config/checkbase_test.go +++ b/pkg/config/checkbase_test.go @@ -12,8 +12,6 @@ import ( "github.com/stretchr/testify/assert" ) -type testCheckForCheckBaseInit struct{} - const testCheckForCheckBaseInitType CheckType = "testCheckForCheckBaseInitType" func TestCheckBaseInit(t *testing.T) { @@ -62,13 +60,25 @@ func TestHasData(t *testing.T) { c := CheckBase{Name: "foo"} assert.False(c.HasData(false)) + assert.Empty(c.Result.Breaches) + c.Result.DetermineResultStatus(false) assert.NotEqual(result.Fail, c.Result.Status) assert.False(c.HasData(true)) + assert.EqualValues([]result.Breach{ + &result.ValueBreach{ + BreachType: "value", + CheckName: "foo", + Value: "no data available", + }, + }, c.Result.Breaches) + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) c = CheckBase{Name: "foo", DataMap: map[string][]byte{"foo": []byte(`bar`)}} assert.True(c.HasData(true)) + assert.Empty(c.Result.Breaches) + c.Result.DetermineResultStatus(false) assert.NotEqual(result.Fail, c.Result.Status) } @@ -91,21 +101,21 @@ func TestAddBreach(t *testing.T) { checkType: vbCheckType, checkName: "vbCheck", severity: "high", - breach: result.ValueBreach{}, + breach: &result.ValueBreach{}, }, { name: "KeyValueBreach", checkType: kvbCheckType, checkName: "kvbCheck", severity: "low", - breach: result.KeyValueBreach{}, + breach: &result.KeyValueBreach{}, }, { name: "KeyValuesBreach", checkType: kvsbCheckType, checkName: "kvsbCheck", severity: "normal", - breach: result.KeyValuesBreach{}, + breach: &result.KeyValuesBreach{}, }, } @@ -114,9 +124,9 @@ func TestAddBreach(t *testing.T) { c := CheckBase{Name: test.checkName, Severity: test.severity} c.Init(test.checkType) c.AddBreach(test.breach) - assert.Equal(string(test.checkType), result.BreachGetCheckType(c.Result.Breaches[0])) - assert.Equal(test.checkName, result.BreachGetCheckName(c.Result.Breaches[0])) - assert.Equal(string(test.severity), result.BreachGetSeverity(c.Result.Breaches[0])) + assert.Equal(string(test.checkType), c.Result.Breaches[0].GetCheckType()) + assert.Equal(test.checkName, c.Result.Breaches[0].GetCheckName()) + assert.Equal(string(test.severity), c.Result.Breaches[0].GetSeverity()) }) } } @@ -143,9 +153,10 @@ func TestCheckBaseRunCheck(t *testing.T) { c := CheckBase{} c.FetchData() c.RunCheck() + c.Result.DetermineResultStatus(false) assert.Equal(result.Fail, c.Result.Status) - assert.ElementsMatch( - []result.Breach{result.ValueBreach{ + assert.EqualValues( + []result.Breach{&result.ValueBreach{ BreachType: result.BreachTypeValue, Value: "not implemented", }}, @@ -167,8 +178,7 @@ func TestRemediate(t *testing.T) { t.Run("notSupported", func(t *testing.T) { c := testCheckRemediationNotSupported{} - err := c.Remediate(nil) - assert.NoError(err) + c.Remediate() assert.Empty(c.Result.Passes) assert.Empty(c.Result.Breaches) }) diff --git a/pkg/config/types.go b/pkg/config/types.go index 2e8f668..507da6b 100644 --- a/pkg/config/types.go +++ b/pkg/config/types.go @@ -46,10 +46,10 @@ type Check interface { AddPass(msg string) AddWarning(msg string) SetPerformRemediation(flag bool) - AddRemediation(msg string) RunCheck() + ShouldPerformRemediation() bool + Remediate() GetResult() *result.Result - Remediate(breachIfc interface{}) error } // CheckBase provides the basic structure for all Checks. diff --git a/pkg/internal/testutils_command.go b/pkg/internal/testutils_command.go index 7a06e42..1603d97 100644 --- a/pkg/internal/testutils_command.go +++ b/pkg/internal/testutils_command.go @@ -17,11 +17,18 @@ func (sc TestShellCommand) Output() ([]byte, error) { // ShellCommanderMaker is a commander generator that can return the provided // stdout or stderr, and can also update a given variable with the generated // command. -func ShellCommanderMaker(out *string, err error, updateVar *string) func(name string, arg ...string) command.IShellCommand { +func ShellCommanderMaker(out *string, err error, generatedCommand *string) func(name string, arg ...string) command.IShellCommand { return func(name string, arg ...string) command.IShellCommand { - if updateVar != nil { - fullCmd := append([]string{name}, arg...) - *updateVar = strings.Join(fullCmd, " ") + if generatedCommand != nil { + fullCmd := name + for _, a := range arg { + // Add quotes when there are spaces. + if len(strings.Fields(a)) > 1 { + a = "'" + a + "'" + } + fullCmd += " " + a + } + *generatedCommand = fullCmd } var stdout []byte if out != nil { diff --git a/pkg/internal/testutils_fetchdata.go b/pkg/internal/testutils_fetchdata.go new file mode 100644 index 0000000..a268f94 --- /dev/null +++ b/pkg/internal/testutils_fetchdata.go @@ -0,0 +1,57 @@ +package internal + +import ( + "reflect" + "testing" + + "github.com/salsadigitalauorg/shipshape/pkg/config" + "github.com/salsadigitalauorg/shipshape/pkg/result" + "github.com/stretchr/testify/assert" +) + +// FetchDataTest can be used to create test scenarios using test tables, +// for the FetchData method using TestFetchData below. +type FetchDataTest struct { + // Name of the test. + Name string + Check config.Check + // Initialise the check before testing. + Init bool + // Func to run before running the check + PreFetch func(t *testing.T) + // Expected values after running the check. + ExpectPasses []string + ExpectBreaches []result.Breach + ExpectDataMap map[string][]byte +} + +// TestFetchData can be used to run test scenarios in test tables. +func TestFetchData(t *testing.T, ctest FetchDataTest) { + t.Helper() + assert := assert.New(t) + + if ctest.PreFetch != nil { + ctest.PreFetch(t) + } + + ctest.Check.FetchData() + + r := ctest.Check.GetResult() + + if len(ctest.ExpectPasses) > 0 { + assert.ElementsMatch(ctest.ExpectPasses, r.Passes) + } else { + assert.Empty(r.Passes) + } + + if len(ctest.ExpectBreaches) > 0 { + assert.ElementsMatch(ctest.ExpectBreaches, r.Breaches) + } else { + assert.Empty(r.Breaches) + } + + if ctest.ExpectDataMap != nil { + dataMap := reflect.ValueOf(ctest.Check).Elem().FieldByName("DataMap").Interface().(map[string][]byte) + assert.EqualValues(ctest.ExpectDataMap, dataMap) + } +} diff --git a/pkg/internal/testutils_remediate.go b/pkg/internal/testutils_remediate.go new file mode 100644 index 0000000..c9ce3bc --- /dev/null +++ b/pkg/internal/testutils_remediate.go @@ -0,0 +1,94 @@ +package internal + +import ( + "io" + "testing" + + "github.com/salsadigitalauorg/shipshape/pkg/command" + "github.com/salsadigitalauorg/shipshape/pkg/config" + "github.com/salsadigitalauorg/shipshape/pkg/result" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +// CheckTest can be used to create test scenarios, especially using test tables, +// for the RunCheck method using TestRunCheck below. +type RemediateTest struct { + // Name of the test. + Name string + Check config.Check + Breaches []result.Breach + // Func to run before running Remediate + PreRun func(t *testing.T) + // Expected values after running Remediate. + ExpectGeneratedCommand string + ExpectStatusFail bool + ExpectNoBreach bool + ExpectBreaches []result.Breach + ExpectRemediationStatus result.RemediationStatus + ExpectNoRemediations bool + ExpectRemediations []string +} + +// TestRunCheck can be used to run test scenarios in test tables. +func TestRemediate(t *testing.T, rt RemediateTest) { + t.Helper() + assert := assert.New(t) + // Hide logging output. + currLogOut := logrus.StandardLogger().Out + defer logrus.SetOutput(currLogOut) + logrus.SetOutput(io.Discard) + + if rt.PreRun != nil { + rt.PreRun(t) + } + + var generatedCommand string + if rt.ExpectGeneratedCommand != "" { + curShellCommander := command.ShellCommander + defer func() { command.ShellCommander = curShellCommander }() + command.ShellCommander = ShellCommanderMaker(nil, nil, &generatedCommand) + } + rt.Check.Remediate() + if rt.ExpectGeneratedCommand != "" { + assert.Equal(rt.ExpectGeneratedCommand, generatedCommand) + } + + r := rt.Check.GetResult() + r.DetermineResultStatus(true) + + if rt.ExpectStatusFail { + assert.Equal(result.Fail, r.Status) + } + if rt.ExpectNoBreach { + assert.Empty(r.Breaches) + } else { + assert.ElementsMatchf( + rt.ExpectBreaches, + r.Breaches, + "Expected breaches: %#v \nGot %#v", rt.ExpectBreaches, r.Breaches) + } + + assert.Equal(rt.ExpectRemediationStatus, r.RemediationStatus) + if rt.ExpectNoRemediations { + assert.NotEmpty(r.Breaches) + remediationsFound := false + for _, b := range r.Breaches { + if b.GetRemediation().Status != "" { + remediationsFound = true + break + } + } + assert.False(remediationsFound, "Expected no remediations, but found some") + } else if len(rt.ExpectRemediations) > 0 { + assert.NotEmpty(r.Breaches) + remediationMsgs := []string{} + for _, b := range r.Breaches { + remediationMsgs = append(remediationMsgs, b.GetRemediation().Messages...) + } + assert.ElementsMatchf( + rt.ExpectRemediations, + remediationMsgs, + "Expected remediations: %#v \nGot %#v", rt.ExpectRemediations, remediationMsgs) + } +} diff --git a/pkg/internal/testutils_runcheck.go b/pkg/internal/testutils_runcheck.go index 65babdd..c1db77e 100644 --- a/pkg/internal/testutils_runcheck.go +++ b/pkg/internal/testutils_runcheck.go @@ -1,10 +1,12 @@ package internal import ( + "io" "testing" "github.com/salsadigitalauorg/shipshape/pkg/config" "github.com/salsadigitalauorg/shipshape/pkg/result" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) @@ -21,19 +23,21 @@ type RunCheckTest struct { // Func to run before running the check PreRun func(t *testing.T) // Expected values after running the check. - ExpectStatus result.Status - ExpectNoPass bool - ExpectPasses []string - ExpectNoFail bool - ExpectFails []result.Breach - ExpectNoRemediations bool - ExpectRemediations []string + ExpectStatus result.Status + ExpectNoPass bool + ExpectPasses []string + ExpectNoFail bool + ExpectFails []result.Breach } // TestRunCheck can be used to run test scenarios in test tables. func TestRunCheck(t *testing.T, ctest RunCheckTest) { t.Helper() assert := assert.New(t) + // Hide logging output. + currLogOut := logrus.StandardLogger().Out + defer logrus.SetOutput(currLogOut) + logrus.SetOutput(io.Discard) c := ctest.Check @@ -47,6 +51,7 @@ func TestRunCheck(t *testing.T, ctest RunCheckTest) { c.RunCheck() r := c.GetResult() + r.DetermineResultStatus(false) if ctest.Sort { r.Sort() } @@ -70,13 +75,4 @@ func TestRunCheck(t *testing.T, ctest RunCheckTest) { r.Breaches, "Expected fails: %#v \nGot %#v", ctest.ExpectFails, r.Breaches) } - - if ctest.ExpectNoRemediations { - assert.Empty(r.Remediations) - } else { - assert.ElementsMatchf( - ctest.ExpectRemediations, - r.Remediations, - "Expected remediations: %#v \nGot %#v", ctest.ExpectRemediations, r.Remediations) - } } diff --git a/pkg/lagoon/lagoon.go b/pkg/lagoon/lagoon.go index 1ab47e5..f51403e 100644 --- a/pkg/lagoon/lagoon.go +++ b/pkg/lagoon/lagoon.go @@ -226,8 +226,8 @@ func BreachFactName(b result.Breach) string { } else if result.BreachGetValueLabel(b) != "" { name = result.BreachGetValueLabel(b) } else { - name = result.BreachGetCheckName(b) + " - " + - string(result.BreachGetCheckType(b)) + name = b.GetCheckName() + " - " + + string(b.GetCheckType()) } return name } diff --git a/pkg/lagoon/lagoon_test.go b/pkg/lagoon/lagoon_test.go index 7a28a22..f8a0011 100644 --- a/pkg/lagoon/lagoon_test.go +++ b/pkg/lagoon/lagoon_test.go @@ -205,7 +205,7 @@ func TestBreachFactNameAndValue(t *testing.T) { }{ { name: "value breach - no label", - breach: result.ValueBreach{ + breach: &result.ValueBreach{ CheckName: "illegal file", CheckType: "file", Value: "/an/illegal/file", @@ -215,7 +215,7 @@ func TestBreachFactNameAndValue(t *testing.T) { }, { name: "value breach - label", - breach: result.ValueBreach{ + breach: &result.ValueBreach{ CheckName: "illegal file", CheckType: "file", ValueLabel: "the illegal file exists", @@ -226,7 +226,7 @@ func TestBreachFactNameAndValue(t *testing.T) { }, { name: "key-value breach - with value label", - breach: result.KeyValueBreach{ + breach: &result.KeyValueBreach{ CheckName: "illegal file", CheckType: "file", Key: "illegal file found", @@ -238,7 +238,7 @@ func TestBreachFactNameAndValue(t *testing.T) { }, { name: "key-value breach - with value and key labels", - breach: result.KeyValueBreach{ + breach: &result.KeyValueBreach{ CheckName: "illegal file", CheckType: "file", KeyLabel: "illegal file found in", @@ -251,7 +251,7 @@ func TestBreachFactNameAndValue(t *testing.T) { }, { name: "value breach - with value and key labels and expected value", - breach: result.KeyValueBreach{ + breach: &result.KeyValueBreach{ CheckName: "update module status", CheckType: "module-status", KeyLabel: "disallowed module found", @@ -264,7 +264,7 @@ func TestBreachFactNameAndValue(t *testing.T) { }, { name: "key-values breach - no label", - breach: result.KeyValuesBreach{ + breach: &result.KeyValuesBreach{ CheckName: "illegal files", CheckType: "file", Values: []string{"/an/illegal/file", "/another/illegal/file"}, diff --git a/pkg/result/breach.go b/pkg/result/breach.go index 81632df..1b8bfd7 100644 --- a/pkg/result/breach.go +++ b/pkg/result/breach.go @@ -7,6 +7,13 @@ import ( // Breach provides a representation for different breach types. type Breach interface { + GetCheckName() string + GetCheckType() string + GetRemediation() *Remediation + GetSeverity() string + GetType() BreachType + SetCommonValues(checkType string, checkName string, severity string) + SetRemediation(status RemediationStatus, msg string) String() string } @@ -21,6 +28,8 @@ const ( BreachTypeKeyValues BreachType = "key-values" ) +//go:generate go run ../../cmd/gen.go breach-type --type=Value,KeyValue,KeyValues + // Simple breach with no key. // Example: // @@ -33,13 +42,14 @@ type ValueBreach struct { ValueLabel string Value string ExpectedValue string + Remediation } func (b ValueBreach) String() string { if b.ValueLabel != "" { return fmt.Sprintf("[%s] %s", b.ValueLabel, b.Value) } - return fmt.Sprintf("%s", b.Value) + return b.Value } // Breach with key and value. @@ -60,6 +70,7 @@ type KeyValueBreach struct { ValueLabel string Value string ExpectedValue string + Remediation } func (b KeyValueBreach) String() string { @@ -86,6 +97,7 @@ type KeyValuesBreach struct { Key string ValueLabel string Values []string + Remediation } func (b KeyValuesBreach) String() string { @@ -96,121 +108,55 @@ func (b KeyValuesBreach) String() string { return fmt.Sprintf("%s: %s", b.Key, "["+strings.Join(b.Values, ", ")+"]") } -func BreachSetCommonValues(bIfc *Breach, checkType string, checkName string, severity string) { - if b, ok := (*bIfc).(ValueBreach); ok { - b.BreachType = BreachTypeValue - b.CheckType = checkType - b.CheckName = checkName - b.Severity = severity - *bIfc = b - } else if b, ok := (*bIfc).(KeyValueBreach); ok { - b.BreachType = BreachTypeKeyValue - b.CheckType = checkType - b.CheckName = checkName - b.Severity = severity - *bIfc = b - } else if b, ok := (*bIfc).(KeyValuesBreach); ok { - b.BreachType = BreachTypeKeyValues - b.CheckType = checkType - b.CheckName = checkName - b.Severity = severity - *bIfc = b - } -} - -func BreachGetBreachType(bIfc Breach) BreachType { - if _, ok := bIfc.(ValueBreach); ok { - return BreachTypeValue - } else if _, ok := bIfc.(KeyValueBreach); ok { - return BreachTypeKeyValue - } else if _, ok := bIfc.(KeyValuesBreach); ok { - return BreachTypeKeyValues - } - return "" -} - -func BreachGetCheckType(bIfc Breach) string { - if b, ok := bIfc.(ValueBreach); ok { - return b.CheckType - } else if b, ok := bIfc.(KeyValueBreach); ok { - return b.CheckType - } else if b, ok := bIfc.(KeyValuesBreach); ok { - return b.CheckType - } - return "" -} - -func BreachGetCheckName(bIfc Breach) string { - if b, ok := bIfc.(ValueBreach); ok { - return b.CheckName - } else if b, ok := bIfc.(KeyValueBreach); ok { - return b.CheckName - } else if b, ok := bIfc.(KeyValuesBreach); ok { - return b.CheckName - } - return "" -} - -func BreachGetSeverity(bIfc Breach) string { - if b, ok := bIfc.(ValueBreach); ok { - return b.Severity - } else if b, ok := bIfc.(KeyValueBreach); ok { - return b.Severity - } else if b, ok := bIfc.(KeyValuesBreach); ok { - return b.Severity - } - return "" -} - func BreachGetKeyLabel(bIfc Breach) string { - if b, ok := bIfc.(KeyValueBreach); ok { + if b, ok := bIfc.(*KeyValueBreach); ok { return b.KeyLabel - } else if b, ok := bIfc.(KeyValuesBreach); ok { + } else if b, ok := bIfc.(*KeyValuesBreach); ok { return b.KeyLabel } return "" } func BreachGetKey(bIfc Breach) string { - if b, ok := bIfc.(KeyValueBreach); ok { + if b, ok := bIfc.(*KeyValueBreach); ok { return b.Key - } else if b, ok := bIfc.(KeyValuesBreach); ok { + } else if b, ok := bIfc.(*KeyValuesBreach); ok { return b.Key } return "" } func BreachGetValueLabel(bIfc Breach) string { - if b, ok := bIfc.(ValueBreach); ok { + if b, ok := bIfc.(*ValueBreach); ok { return b.ValueLabel - } else if b, ok := bIfc.(KeyValueBreach); ok { + } else if b, ok := bIfc.(*KeyValueBreach); ok { return b.ValueLabel - } else if b, ok := bIfc.(KeyValuesBreach); ok { + } else if b, ok := bIfc.(*KeyValuesBreach); ok { return b.ValueLabel } return "" } func BreachGetValue(bIfc Breach) string { - if b, ok := bIfc.(ValueBreach); ok { + if b, ok := bIfc.(*ValueBreach); ok { return b.Value - } else if b, ok := bIfc.(KeyValueBreach); ok { + } else if b, ok := bIfc.(*KeyValueBreach); ok { return b.Value } return "" } func BreachGetValues(bIfc Breach) []string { - if b, ok := bIfc.(KeyValuesBreach); ok { + if b, ok := bIfc.(*KeyValuesBreach); ok { return b.Values } return []string(nil) } func BreachGetExpectedValue(bIfc Breach) string { - if b, ok := bIfc.(ValueBreach); ok { + if b, ok := bIfc.(*ValueBreach); ok { return b.ExpectedValue - } else if b, ok := bIfc.(KeyValueBreach); ok { + } else if b, ok := bIfc.(*KeyValueBreach); ok { return b.ExpectedValue } return "" diff --git a/pkg/result/breach_test.go b/pkg/result/breach_test.go index 241c4a7..1fbbb2e 100644 --- a/pkg/result/breach_test.go +++ b/pkg/result/breach_test.go @@ -18,7 +18,7 @@ func TestBreachValueBreachStringer(t *testing.T) { }{ { name: "value-breach", - breach: ValueBreach{ + breach: &ValueBreach{ ValueLabel: "file not found", Value: "foo.ext", }, @@ -43,7 +43,7 @@ func TestBreachKeyValueBreachStringer(t *testing.T) { }{ { name: "key-value-breach-1", - breach: KeyValueBreach{ + breach: &KeyValueBreach{ KeyLabel: "config", Key: "clamav.settings", ValueLabel: "key not found", @@ -53,7 +53,7 @@ func TestBreachKeyValueBreachStringer(t *testing.T) { }, { name: "key-value-breach-2", - breach: KeyValueBreach{ + breach: &KeyValueBreach{ KeyLabel: "clamav.settings", Key: "enabled", Value: "false", @@ -80,7 +80,7 @@ func TestBreachKeyValuesBreachStringers(t *testing.T) { }{ { name: "KeyValuesBreach", - breach: KeyValuesBreach{ + breach: &KeyValuesBreach{ KeyLabel: "role", Key: "admin", ValueLabel: "disallowed permissions", @@ -99,10 +99,35 @@ func TestBreachKeyValuesBreachStringers(t *testing.T) { type bogusBreach struct{} +func (b bogusBreach) GetCheckName() string { + return "" +} + +func (b bogusBreach) GetCheckType() string { + return "" +} + +func (b bogusBreach) GetRemediation() *Remediation { + return &Remediation{} +} + +func (b bogusBreach) GetSeverity() string { + return "" +} + +func (b bogusBreach) GetType() BreachType { + return "" +} + +func (b bogusBreach) SetCommonValues(checkType string, checkName string, severity string) { +} + func (b bogusBreach) String() string { return "" } +func (b bogusBreach) SetRemediation(status RemediationStatus, msg string) {} + func TestBreachSetCommonValues(t *testing.T) { assert := assert.New(t) @@ -117,7 +142,7 @@ func TestBreachSetCommonValues(t *testing.T) { }{ { name: "ValueBreach", - breach: ValueBreach{}, + breach: &ValueBreach{}, expectedBreachType: BreachTypeValue, expectedCheckType: "ctvb", expectedCheckName: "valuebreachcheck", @@ -125,7 +150,7 @@ func TestBreachSetCommonValues(t *testing.T) { }, { name: "KeyValueBreach", - breach: KeyValueBreach{}, + breach: &KeyValueBreach{}, expectedBreachType: BreachTypeKeyValue, expectedCheckType: "ctkvb", expectedCheckName: "keyvaluebreachcheck", @@ -133,7 +158,7 @@ func TestBreachSetCommonValues(t *testing.T) { }, { name: "KeyValuesBreach", - breach: KeyValuesBreach{}, + breach: &KeyValuesBreach{}, expectedBreachType: BreachTypeKeyValues, expectedCheckType: "ctkvsb", expectedCheckName: "keyvaluesbreachcheck", @@ -141,7 +166,7 @@ func TestBreachSetCommonValues(t *testing.T) { }, { name: "BogusBreach", - breach: bogusBreach{}, + breach: &bogusBreach{}, expectedBreachType: "", expectedCheckType: "ctbb", expectedCheckName: "bogusbreachcheck", @@ -152,18 +177,18 @@ func TestBreachSetCommonValues(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - BreachSetCommonValues(&test.breach, test.expectedCheckType, test.expectedCheckName, + test.breach.SetCommonValues(test.expectedCheckType, test.expectedCheckName, test.expectedSeverity) if !test.empty { - assert.Equal(test.expectedBreachType, BreachGetBreachType(test.breach)) - assert.Equal(test.expectedCheckName, BreachGetCheckName(test.breach)) - assert.Equal(test.expectedCheckType, BreachGetCheckType(test.breach)) - assert.Equal(test.expectedSeverity, BreachGetSeverity(test.breach)) + assert.Equal(test.expectedBreachType, test.breach.GetType()) + assert.Equal(test.expectedCheckName, test.breach.GetCheckName()) + assert.Equal(test.expectedCheckType, test.breach.GetCheckType()) + assert.Equal(test.expectedSeverity, test.breach.GetSeverity()) } else { - assert.Equal(BreachType(""), BreachGetBreachType(test.breach)) - assert.Equal("", BreachGetCheckName(test.breach)) - assert.Equal("", BreachGetCheckType(test.breach)) - assert.Equal("", BreachGetSeverity(test.breach)) + assert.Equal(BreachType(""), test.breach.GetType()) + assert.Equal("", test.breach.GetCheckName()) + assert.Equal("", test.breach.GetCheckType()) + assert.Equal("", test.breach.GetSeverity()) } }) } @@ -184,7 +209,7 @@ func TestBreachGetters(t *testing.T) { }{ { name: "ValueBreach", - breach: ValueBreach{ + breach: &ValueBreach{ ValueLabel: "vbvl", Value: "vbv", ExpectedValue: "vbve", @@ -198,7 +223,7 @@ func TestBreachGetters(t *testing.T) { }, { name: "KeyValueBreach", - breach: KeyValueBreach{ + breach: &KeyValueBreach{ KeyLabel: "kvbklbl", Key: "kvbk", ValueLabel: "kvbvl", @@ -214,7 +239,7 @@ func TestBreachGetters(t *testing.T) { }, { name: "KeyValuesBreach", - breach: KeyValuesBreach{ + breach: &KeyValuesBreach{ KeyLabel: "kvsbklbl", Key: "kvsbk", ValueLabel: "kvsbvl", diff --git a/pkg/result/remediation.go b/pkg/result/remediation.go new file mode 100644 index 0000000..2effc3b --- /dev/null +++ b/pkg/result/remediation.go @@ -0,0 +1,15 @@ +package result + +type RemediationStatus string + +const ( + RemediationStatusNoSupport RemediationStatus = "no-support" + RemediationStatusSuccess RemediationStatus = "success" + RemediationStatusFailed RemediationStatus = "failed" + RemediationStatusPartial RemediationStatus = "partial" +) + +type Remediation struct { + Status RemediationStatus + Messages []string +} diff --git a/pkg/result/result.go b/pkg/result/result.go index 137a5cd..d68c2ba 100644 --- a/pkg/result/result.go +++ b/pkg/result/result.go @@ -13,21 +13,21 @@ const ( // Result provides the structure for a Check's outcome. type Result struct { - Name string `json:"name"` - Severity string `json:"severity"` - CheckType string `json:"check-type"` - Status Status `json:"status"` - Passes []string `json:"passes"` - Breaches []Breach `json:"breaches"` - Warnings []string `json:"warnings"` - Remediations []string `json:"remediations"` + Name string `json:"name"` + Severity string `json:"severity"` + CheckType string `json:"check-type"` + Passes []string `json:"passes"` + Breaches []Breach `json:"breaches"` + Warnings []string `json:"warnings"` + Status Status `json:"status"` + RemediationStatus RemediationStatus `json:"remediation-status"` } // Sort reorders the Passes & Failures in order to get consistent output. func (r *Result) Sort() { if len(r.Breaches) > 0 { sort.Slice(r.Breaches, func(i int, j int) bool { - return BreachGetCheckName(r.Breaches[i]) < BreachGetCheckName(r.Breaches[j]) + return r.Breaches[i].GetCheckName() < r.Breaches[j].GetCheckName() }) } @@ -43,3 +43,61 @@ func (r *Result) Sort() { }) } } + +// RemediationsCount returns the number of unsupported, successful, failed and +// partial for all attempted remediations. +func (r *Result) RemediationsCount() (uint32, uint32, uint32, uint32) { + unsupported := uint32(0) + successful := uint32(0) + failed := uint32(0) + partial := uint32(0) + for _, b := range r.Breaches { + switch b.GetRemediation().Status { + case RemediationStatusNoSupport: + unsupported++ + case RemediationStatusSuccess: + successful++ + case RemediationStatusFailed: + failed++ + case RemediationStatusPartial: + partial++ + } + } + return unsupported, successful, failed, partial +} + +// DetermineResultStatus determines the overall status of the result based on +// the breaches and remediation status. +func (r *Result) DetermineResultStatus(remediationPerformed bool) { + r.Sort() + + // Remediation status. + if remediationPerformed { + unsupported, success, failed, partial := r.RemediationsCount() + if partial > 0 || (success > 0 && (failed > 0 || unsupported > 0)) { + r.RemediationStatus = RemediationStatusPartial + r.Status = Fail + return + } + if unsupported > 0 && success == 0 && failed == 0 && partial == 0 { + r.RemediationStatus = RemediationStatusNoSupport + r.Status = Fail + return + } + if failed > 0 && success == 0 && unsupported == 0 && partial == 0 { + r.RemediationStatus = RemediationStatusFailed + r.Status = Fail + return + } + r.RemediationStatus = RemediationStatusSuccess + r.Status = Pass + return + } + + // Overall status. + if len(r.Breaches) > 0 { + r.Status = Fail + return + } + r.Status = Pass +} diff --git a/pkg/result/result_test.go b/pkg/result/result_test.go index 6b65759..a742385 100644 --- a/pkg/result/result_test.go +++ b/pkg/result/result_test.go @@ -14,10 +14,10 @@ func TestResultSort(t *testing.T) { r := Result{ Passes: []string{"z pass", "g pass", "a pass", "b pass"}, Breaches: []Breach{ - ValueBreach{CheckName: "x", Value: "breach 1"}, - ValueBreach{CheckName: "h", Value: "breach 2"}, - ValueBreach{CheckName: "v", Value: "breach 3"}, - ValueBreach{CheckName: "f", Value: "breach 4"}, + &ValueBreach{CheckName: "x", Value: "breach 1"}, + &ValueBreach{CheckName: "h", Value: "breach 2"}, + &ValueBreach{CheckName: "v", Value: "breach 3"}, + &ValueBreach{CheckName: "f", Value: "breach 4"}, }, Warnings: []string{"y warn", "i warn", "u warn", "c warn"}, } @@ -26,11 +26,257 @@ func TestResultSort(t *testing.T) { assert.EqualValues(Result{ Passes: []string{"a pass", "b pass", "g pass", "z pass"}, Breaches: []Breach{ - ValueBreach{CheckName: "f", Value: "breach 4"}, - ValueBreach{CheckName: "h", Value: "breach 2"}, - ValueBreach{CheckName: "v", Value: "breach 3"}, - ValueBreach{CheckName: "x", Value: "breach 1"}, + &ValueBreach{CheckName: "f", Value: "breach 4"}, + &ValueBreach{CheckName: "h", Value: "breach 2"}, + &ValueBreach{CheckName: "v", Value: "breach 3"}, + &ValueBreach{CheckName: "x", Value: "breach 1"}, }, Warnings: []string{"c warn", "i warn", "u warn", "y warn"}, }, r) } + +func TestResultRemediationsCount(t *testing.T) { + assert := assert.New(t) + + r := Result{ + Breaches: []Breach{ + &ValueBreach{CheckName: "x", Remediation: Remediation{Status: RemediationStatusNoSupport}}, + &ValueBreach{CheckName: "h", Remediation: Remediation{Status: RemediationStatusSuccess}}, + &ValueBreach{CheckName: "i", Remediation: Remediation{Status: RemediationStatusSuccess}}, + &ValueBreach{CheckName: "v", Remediation: Remediation{Status: RemediationStatusFailed}}, + &ValueBreach{CheckName: "w", Remediation: Remediation{Status: RemediationStatusFailed}}, + &ValueBreach{CheckName: "x", Remediation: Remediation{Status: RemediationStatusFailed}}, + &ValueBreach{CheckName: "f", Remediation: Remediation{Status: RemediationStatusPartial}}, + &ValueBreach{CheckName: "e", Remediation: Remediation{Status: RemediationStatusPartial}}, + &ValueBreach{CheckName: "d", Remediation: Remediation{Status: RemediationStatusPartial}}, + &ValueBreach{CheckName: "c", Remediation: Remediation{Status: RemediationStatusPartial}}, + }, + } + unsupported, successful, failed, partial := r.RemediationsCount() + assert.EqualValues(1, unsupported) + assert.EqualValues(2, successful) + assert.EqualValues(3, failed) + assert.EqualValues(4, partial) +} + +func TestResultDetermineResultStatus(t *testing.T) { + tt := []struct { + name string + remediationPerformed bool + breaches []Breach + expectedStatus Status + expectedRemediationStatus RemediationStatus + }{ + { + name: "noBreach", + remediationPerformed: false, + breaches: []Breach{}, + expectedStatus: Pass, + expectedRemediationStatus: "", + }, + { + name: "noBreachRemediation", + remediationPerformed: true, + breaches: []Breach{}, + expectedStatus: Pass, + expectedRemediationStatus: RemediationStatusSuccess, + }, + + // Single breach. + { + name: "singleBreach", + remediationPerformed: false, + breaches: []Breach{ + &ValueBreach{CheckName: "x", Value: "breach 1"}, + }, + expectedStatus: Fail, + expectedRemediationStatus: "", + }, + { + name: "singleBreachRemediationNotSupported", + remediationPerformed: true, + breaches: []Breach{ + &ValueBreach{ + CheckName: "x", + Value: "breach 1", + Remediation: Remediation{Status: RemediationStatusNoSupport}, + }, + }, + expectedStatus: Fail, + expectedRemediationStatus: RemediationStatusNoSupport, + }, + { + name: "singleBreachRemediationSuccess", + remediationPerformed: true, + breaches: []Breach{ + &ValueBreach{ + CheckName: "x", + Value: "breach 1", + Remediation: Remediation{Status: RemediationStatusSuccess}, + }, + }, + expectedStatus: Pass, + expectedRemediationStatus: RemediationStatusSuccess, + }, + { + name: "singleBreachRemediationFailed", + remediationPerformed: true, + breaches: []Breach{ + &ValueBreach{ + CheckName: "x", + Value: "breach 1", + Remediation: Remediation{Status: RemediationStatusFailed}, + }, + }, + expectedStatus: Fail, + expectedRemediationStatus: RemediationStatusFailed, + }, + { + name: "singleBreachRemediationPartial", + remediationPerformed: true, + breaches: []Breach{ + &ValueBreach{ + CheckName: "x", + Value: "breach 1", + Remediation: Remediation{Status: RemediationStatusPartial}, + }, + }, + expectedStatus: Fail, + expectedRemediationStatus: RemediationStatusPartial, + }, + + // Multiple breaches. + { + name: "multipleBreaches", + remediationPerformed: false, + breaches: []Breach{ + &ValueBreach{CheckName: "x", Value: "breach 1"}, + &ValueBreach{CheckName: "f", Value: "breach 2"}, + }, + expectedStatus: Fail, + expectedRemediationStatus: "", + }, + { + name: "multipleBreachesRemediationFailed", + remediationPerformed: true, + breaches: []Breach{ + &ValueBreach{ + CheckName: "x", + Value: "breach 1", + Remediation: Remediation{Status: RemediationStatusFailed}, + }, + &ValueBreach{ + CheckName: "f", + Value: "breach 2", + Remediation: Remediation{Status: RemediationStatusFailed}, + }, + }, + expectedStatus: Fail, + expectedRemediationStatus: RemediationStatusFailed, + }, + { + name: "multipleBreachesRemediationUnsupported", + remediationPerformed: true, + breaches: []Breach{ + &ValueBreach{ + CheckName: "x", + Value: "breach 1", + Remediation: Remediation{Status: RemediationStatusNoSupport}, + }, + &ValueBreach{ + CheckName: "f", + Value: "breach 2", + Remediation: Remediation{Status: RemediationStatusNoSupport}, + }, + }, + expectedStatus: Fail, + expectedRemediationStatus: RemediationStatusNoSupport, + }, + + // Multiple breaches with partial remediation. + { + name: "multipleBreachesRemediationPartial", + remediationPerformed: true, + breaches: []Breach{ + &ValueBreach{ + CheckName: "x", + Value: "breach 1", + Remediation: Remediation{Status: RemediationStatusSuccess}, + }, + &ValueBreach{ + CheckName: "f", + Value: "breach 2", + Remediation: Remediation{Status: RemediationStatusFailed}, + }, + }, + expectedStatus: Fail, + expectedRemediationStatus: RemediationStatusPartial, + }, + { + name: "multipleBreachesRemediationPartial", + remediationPerformed: true, + breaches: []Breach{ + &ValueBreach{ + CheckName: "x", + Value: "breach 1", + Remediation: Remediation{Status: RemediationStatusSuccess}, + }, + &ValueBreach{ + CheckName: "f", + Value: "breach 2", + Remediation: Remediation{Status: RemediationStatusFailed}, + }, + }, + expectedStatus: Fail, + expectedRemediationStatus: RemediationStatusPartial, + }, + { + name: "multipleBreachesRemediationPartial", + remediationPerformed: true, + breaches: []Breach{ + &ValueBreach{ + CheckName: "x", + Value: "breach 1", + Remediation: Remediation{Status: RemediationStatusPartial}, + }, + &ValueBreach{ + CheckName: "f", + Value: "breach 2", + Remediation: Remediation{Status: RemediationStatusPartial}, + }, + }, + expectedStatus: Fail, + expectedRemediationStatus: RemediationStatusPartial, + }, + { + name: "multipleBreachesRemediationPartial", + remediationPerformed: true, + breaches: []Breach{ + &ValueBreach{ + CheckName: "x", + Value: "breach 1", + Remediation: Remediation{Status: RemediationStatusSuccess}, + }, + &ValueBreach{ + CheckName: "f", + Value: "breach 2", + Remediation: Remediation{Status: RemediationStatusNoSupport}, + }, + }, + expectedStatus: Fail, + expectedRemediationStatus: RemediationStatusPartial, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + assert := assert.New(t) + + r := Result{Breaches: tc.breaches} + r.DetermineResultStatus(tc.remediationPerformed) + + assert.Equal(tc.expectedStatus, r.Status) + assert.Equal(tc.expectedRemediationStatus, r.RemediationStatus) + }) + } +} diff --git a/pkg/result/resultlist.go b/pkg/result/resultlist.go index 82e79bd..60f4ace 100644 --- a/pkg/result/resultlist.go +++ b/pkg/result/resultlist.go @@ -9,40 +9,28 @@ import ( // ResultList is a wrapper around a list of results, providing some useful // methods to manipulate and use it. type ResultList struct { - RemediationPerformed bool `json:"remediation-performed"` - TotalChecks uint32 `json:"total-checks"` - TotalBreaches uint32 `json:"total-breaches"` - TotalRemediations uint32 `json:"total-remediations"` - TotalUnsupportedRemediations uint32 `json:"total-unsupported-remediations"` - CheckCountByType map[string]int `json:"check-count-by-type"` - BreachCountByType map[string]int `json:"breach-count-by-type"` - BreachCountBySeverity map[string]int `json:"breach-count-by-severity"` - RemediationCountByType map[string]int `json:"remediation-count-by-type"` - Results []Result `json:"results"` + RemediationPerformed bool `json:"remediation-performed"` + TotalChecks uint32 `json:"total-checks"` + TotalBreaches uint32 `json:"total-breaches"` + RemediationTotals map[string]uint32 `json:"remediation-totals"` + CheckCountByType map[string]int `json:"check-count-by-type"` + BreachCountByType map[string]int `json:"breach-count-by-type"` + BreachCountBySeverity map[string]int `json:"breach-count-by-severity"` + Results []Result `json:"results"` } // Use locks to make map mutations concurrency-safe. var lock = sync.RWMutex{} func NewResultList(remediate bool) ResultList { - return ResultList{ - RemediationPerformed: remediate, - Results: []Result{}, - CheckCountByType: map[string]int{}, - BreachCountByType: map[string]int{}, - BreachCountBySeverity: map[string]int{}, - RemediationCountByType: map[string]int{}, + rl := ResultList{ + RemediationPerformed: remediate, + Results: []Result{}, + CheckCountByType: map[string]int{}, + BreachCountByType: map[string]int{}, + BreachCountBySeverity: map[string]int{}, } -} - -// Status calculates and returns the overall result of all check results. -func (rl *ResultList) Status() Status { - for _, r := range rl.Results { - if r.Status == Fail { - return Fail - } - } - return Pass + return rl } // IncrChecks increments the total checks count & checks count by type. @@ -64,10 +52,62 @@ func (rl *ResultList) AddResult(r Result) { atomic.AddUint32(&rl.TotalBreaches, uint32(breachesIncr)) rl.BreachCountByType[r.CheckType] = rl.BreachCountByType[r.CheckType] + breachesIncr rl.BreachCountBySeverity[r.Severity] = rl.BreachCountBySeverity[r.Severity] + breachesIncr +} + +// Status calculates and returns the overall result of all check results. +func (rl *ResultList) Status() Status { + for _, r := range rl.Results { + if r.Status == Fail { + return Fail + } + } + return Pass +} - remediationsIncr := len(r.Remediations) - atomic.AddUint32(&rl.TotalRemediations, uint32(remediationsIncr)) - rl.RemediationCountByType[r.CheckType] = rl.RemediationCountByType[r.CheckType] + remediationsIncr +// RemediationTotalsCount calculates the total number of unsupported, +// successful, failed and partial remediations across all checks. +func (rl *ResultList) RemediationTotalsCount() { + rl.RemediationTotals = map[string]uint32{ + "unsupported": 0, + "successful": 0, + "failed": 0, + "partial": 0, + } + for _, r := range rl.Results { + unsupported, successful, failed, partial := r.RemediationsCount() + rl.RemediationTotals["unsupported"] = rl.RemediationTotals["unsupported"] + unsupported + rl.RemediationTotals["successful"] = rl.RemediationTotals["successful"] + successful + rl.RemediationTotals["failed"] = rl.RemediationTotals["failed"] + failed + rl.RemediationTotals["partial"] = rl.RemediationTotals["partial"] + partial + } +} + +// RemediationStatus calculates and returns the overall result of +// remediation for all breaches. +func (rl *ResultList) RemediationStatus() RemediationStatus { + if !rl.RemediationPerformed { + return "" + } + + if rl.RemediationTotals["partial"] > 0 || + (rl.RemediationTotals["successful"] > 0 && + (rl.RemediationTotals["failed"] > 0 || + rl.RemediationTotals["unsupported"] > 0)) { + return RemediationStatusPartial + } + if rl.RemediationTotals["unsupported"] > 0 && + rl.RemediationTotals["successful"] == 0 && + rl.RemediationTotals["failed"] == 0 && + rl.RemediationTotals["partial"] == 0 { + return RemediationStatusNoSupport + } + if rl.RemediationTotals["failed"] > 0 && + rl.RemediationTotals["successful"] == 0 && + rl.RemediationTotals["unsupported"] == 0 && + rl.RemediationTotals["partial"] == 0 { + return RemediationStatusFailed + } + return RemediationStatusSuccess } // GetBreachesByCheckName fetches the list of failures by check name. @@ -93,17 +133,6 @@ func (rl *ResultList) GetBreachesBySeverity(s string) []Breach { return breaches } -// GetBreachesByCheckName fetches the list of failures by check name. -func (rl *ResultList) GetRemediationsByCheckName(cn string) []string { - var remediations []string - for _, r := range rl.Results { - if r.Name == cn { - remediations = append(remediations, r.Remediations...) - } - } - return remediations -} - // Sort reorders the results by name. func (rl *ResultList) Sort() { sort.Slice(rl.Results, func(i int, j int) bool { diff --git a/pkg/result/resultlist_test.go b/pkg/result/resultlist_test.go index c92bc16..7cf7aa7 100644 --- a/pkg/result/resultlist_test.go +++ b/pkg/result/resultlist_test.go @@ -14,40 +14,23 @@ func TestNewResultList(t *testing.T) { t.Run("emptyInit", func(t *testing.T) { rl := NewResultList(false) - assert.Equal(rl.RemediationPerformed, false) - assert.Equal(rl.Results, []Result{}) - assert.Equal(rl.CheckCountByType, map[string]int{}) - assert.Equal(rl.BreachCountByType, map[string]int{}) - assert.Equal(rl.BreachCountBySeverity, map[string]int{}) - assert.Equal(rl.RemediationCountByType, map[string]int{}) + assert.Equal(false, rl.RemediationPerformed) + assert.Equal([]Result{}, rl.Results) + assert.Equal(map[string]int{}, rl.CheckCountByType) + assert.Equal(map[string]int{}, rl.BreachCountByType) + assert.Equal(map[string]int{}, rl.BreachCountBySeverity) + assert.Nil(rl.RemediationTotals) }) t.Run("remediation", func(t *testing.T) { rl := NewResultList(true) - assert.Equal(rl.RemediationPerformed, true) - assert.Equal(rl.Results, []Result{}) - assert.Equal(rl.CheckCountByType, map[string]int{}) - assert.Equal(rl.BreachCountByType, map[string]int{}) - assert.Equal(rl.BreachCountBySeverity, map[string]int{}) - assert.Equal(rl.RemediationCountByType, map[string]int{}) + assert.Equal(true, rl.RemediationPerformed) + assert.Equal([]Result{}, rl.Results) + assert.Equal(map[string]int{}, rl.CheckCountByType) + assert.Equal(map[string]int{}, rl.BreachCountByType) + assert.Equal(map[string]int{}, rl.BreachCountBySeverity) + assert.Nil(rl.RemediationTotals) }) - -} - -func TestResultListStatus(t *testing.T) { - assert := assert.New(t) - - rl := ResultList{ - Results: []Result{ - {Status: Pass}, - {Status: Pass}, - {Status: Pass}, - }, - } - assert.Equal(Pass, rl.Status()) - - rl.Results[0].Status = Fail - assert.Equal(Fail, rl.Status()) } const testCheckType config.CheckType = "test-check" @@ -92,36 +75,35 @@ func TestResultListAddResult(t *testing.T) { BreachCountByType: map[string]int{}, BreachCountBySeverity: map[string]int{}, - TotalRemediations: 0, - RemediationCountByType: map[string]int{}, + RemediationTotals: map[string]uint32{"successful": 0}, } rl.AddResult(Result{ Severity: "high", CheckType: string(testCheckType), Breaches: []Breach{ - ValueBreach{Value: "fail1"}, - ValueBreach{Value: "fail2"}, - ValueBreach{Value: "fail3"}, - ValueBreach{Value: "fail4"}, - ValueBreach{Value: "fail5"}, + &ValueBreach{Value: "fail1", Remediation: Remediation{ + Status: RemediationStatusSuccess, + Messages: []string{"fixed1"}, + }}, + &ValueBreach{Value: "fail2"}, + &ValueBreach{Value: "fail3"}, + &ValueBreach{Value: "fail4"}, + &ValueBreach{Value: "fail5"}, }, - Remediations: []string{"fixed1"}, }) assert.Equal(5, int(rl.TotalBreaches)) assert.Equal(5, rl.BreachCountByType[string(testCheckType)]) assert.Equal(5, rl.BreachCountBySeverity["high"]) - assert.Equal(1, int(rl.TotalRemediations)) - assert.Equal(1, rl.RemediationCountByType[string(testCheckType)]) rl.AddResult(Result{ Severity: "critical", CheckType: string(testCheck2Type), Breaches: []Breach{ - ValueBreach{Value: "fail1"}, - ValueBreach{Value: "fail2"}, - ValueBreach{Value: "fail3"}, - ValueBreach{Value: "fail4"}, - ValueBreach{Value: "fail5"}, + &ValueBreach{Value: "fail1"}, + &ValueBreach{Value: "fail2"}, + &ValueBreach{Value: "fail3"}, + &ValueBreach{Value: "fail4"}, + &ValueBreach{Value: "fail5"}, }, }) assert.Equal(10, int(rl.TotalBreaches)) @@ -129,9 +111,6 @@ func TestResultListAddResult(t *testing.T) { assert.Equal(5, rl.BreachCountByType[string(testCheck2Type)]) assert.Equal(5, rl.BreachCountBySeverity["high"]) assert.Equal(5, rl.BreachCountBySeverity["critical"]) - assert.Equal(1, int(rl.TotalRemediations)) - assert.Equal(1, rl.RemediationCountByType[string(testCheckType)]) - assert.Equal(0, rl.RemediationCountByType[string(testCheck2Type)]) var wg sync.WaitGroup for i := 0; i < 100; i++ { @@ -141,13 +120,18 @@ func TestResultListAddResult(t *testing.T) { rl.AddResult(Result{ Severity: "high", CheckType: string(testCheckType), - Breaches: []Breach{ValueBreach{Value: "fail6"}}, + Breaches: []Breach{&ValueBreach{Value: "fail6"}}, }) rl.AddResult(Result{ - Severity: "critical", - CheckType: string(testCheck2Type), - Breaches: []Breach{ValueBreach{Value: "fail7"}}, - Remediations: []string{"fixed2", "fixed3"}, + Severity: "critical", + CheckType: string(testCheck2Type), + Breaches: []Breach{&ValueBreach{ + Value: "fail7", + Remediation: Remediation{ + Status: RemediationStatusSuccess, + Messages: []string{"fixed2", "fixed3"}, + }, + }}, }) }() } @@ -157,9 +141,167 @@ func TestResultListAddResult(t *testing.T) { assert.Equal(105, rl.BreachCountByType[string(testCheck2Type)]) assert.Equal(105, rl.BreachCountBySeverity["high"]) assert.Equal(105, rl.BreachCountBySeverity["critical"]) - assert.Equal(201, int(rl.TotalRemediations)) - assert.Equal(1, rl.RemediationCountByType[string(testCheckType)]) - assert.Equal(200, rl.RemediationCountByType[string(testCheck2Type)]) +} + +func TestResultListStatus(t *testing.T) { + assert := assert.New(t) + + rl := ResultList{ + Results: []Result{ + {Status: Pass}, + {Status: Pass}, + {Status: Pass}, + }, + } + assert.Equal(Pass, rl.Status()) + + rl.Results[0].Status = Fail + assert.Equal(Fail, rl.Status()) +} + +func TestResultListRemediationTotalsCount(t *testing.T) { + tt := []struct { + name string + results []Result + expected map[string]uint32 + }{ + { + name: "allSuccess", + results: []Result{ + {Breaches: []Breach{ + &ValueBreach{Remediation: Remediation{Status: RemediationStatusSuccess}}, + &ValueBreach{Remediation: Remediation{Status: RemediationStatusSuccess}}, + }}, + {Breaches: []Breach{ + &ValueBreach{Remediation: Remediation{Status: RemediationStatusSuccess}}, + }}, + }, + expected: map[string]uint32{ + "unsupported": 0, + "successful": 3, + "failed": 0, + "partial": 0, + }, + }, + { + name: "countingWorks", + results: []Result{ + {Breaches: []Breach{ + &ValueBreach{Remediation: Remediation{Status: RemediationStatusSuccess}}, + &ValueBreach{Remediation: Remediation{Status: RemediationStatusSuccess}}, + &ValueBreach{Remediation: Remediation{Status: RemediationStatusFailed}}, + }}, + {Breaches: []Breach{ + &ValueBreach{Remediation: Remediation{Status: RemediationStatusSuccess}}, + &ValueBreach{Remediation: Remediation{Status: RemediationStatusPartial}}, + }}, + {Breaches: []Breach{ + &ValueBreach{Remediation: Remediation{Status: RemediationStatusFailed}}, + &ValueBreach{Remediation: Remediation{Status: RemediationStatusNoSupport}}, + }}, + }, + expected: map[string]uint32{ + "unsupported": 1, + "successful": 3, + "failed": 2, + "partial": 1, + }, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + assert := assert.New(t) + rl := ResultList{ + Results: tc.results, + } + rl.RemediationTotalsCount() + assert.Equal(tc.expected, rl.RemediationTotals) + }) + } +} + +func TestResultListRemediationStatus(t *testing.T) { + tt := []struct { + name string + remediationPerformed bool + results []Result + expected RemediationStatus + }{ + { + name: "noRemediation", + remediationPerformed: false, + results: []Result{{Breaches: []Breach{&ValueBreach{}}}}, + expected: "", + }, + { + name: "allSuccess", + remediationPerformed: true, + results: []Result{ + { + Breaches: []Breach{&ValueBreach{Remediation: Remediation{Status: RemediationStatusSuccess}}}, + }, + { + Breaches: []Breach{&ValueBreach{Remediation: Remediation{Status: RemediationStatusSuccess}}}, + }, + { + Breaches: []Breach{&ValueBreach{Remediation: Remediation{Status: RemediationStatusSuccess}}}, + }, + }, + expected: RemediationStatusSuccess, + }, + { + name: "partial", + remediationPerformed: true, + results: []Result{ + { + Breaches: []Breach{&ValueBreach{Remediation: Remediation{Status: RemediationStatusPartial}}}, + }, + { + Breaches: []Breach{&ValueBreach{Remediation: Remediation{Status: RemediationStatusFailed}}}, + }, + }, + expected: RemediationStatusPartial, + }, + { + name: "fail", + remediationPerformed: true, + results: []Result{ + { + Breaches: []Breach{&ValueBreach{Remediation: Remediation{Status: RemediationStatusFailed}}}, + }, + { + Breaches: []Breach{&ValueBreach{Remediation: Remediation{Status: RemediationStatusFailed}}}, + }, + }, + expected: RemediationStatusFailed, + }, + { + name: "unsupported", + remediationPerformed: true, + results: []Result{ + { + Breaches: []Breach{&ValueBreach{Remediation: Remediation{Status: RemediationStatusNoSupport}}}, + }, + { + Breaches: []Breach{&ValueBreach{Remediation: Remediation{Status: RemediationStatusNoSupport}}}, + }, + }, + expected: RemediationStatusNoSupport, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + assert := assert.New(t) + rl := ResultList{ + RemediationPerformed: tc.remediationPerformed, + Results: tc.results, + } + rl.RemediationTotalsCount() + assert.Equal(tc.expected, rl.RemediationStatus()) + }) + } } func TestResultListGetBreachesByCheckName(t *testing.T) { @@ -170,29 +312,29 @@ func TestResultListGetBreachesByCheckName(t *testing.T) { { Name: "check1", Breaches: []Breach{ - ValueBreach{Value: "failure1"}, - ValueBreach{Value: "failure 2"}, + &ValueBreach{Value: "failure1"}, + &ValueBreach{Value: "failure 2"}, }, }, { Name: "check2", Breaches: []Breach{ - ValueBreach{Value: "failure3"}, - ValueBreach{Value: "failure 4"}, + &ValueBreach{Value: "failure3"}, + &ValueBreach{Value: "failure 4"}, }, }, }, } assert.EqualValues( []Breach{ - ValueBreach{Value: "failure1"}, - ValueBreach{Value: "failure 2"}, + &ValueBreach{Value: "failure1"}, + &ValueBreach{Value: "failure 2"}, }, rl.GetBreachesByCheckName("check1")) assert.EqualValues( []Breach{ - ValueBreach{Value: "failure3"}, - ValueBreach{Value: "failure 4"}, + &ValueBreach{Value: "failure3"}, + &ValueBreach{Value: "failure 4"}, }, rl.GetBreachesByCheckName("check2")) } @@ -205,56 +347,33 @@ func TestResultListGetBreachesBySeverity(t *testing.T) { { Severity: "high", Breaches: []Breach{ - ValueBreach{Value: "failure1"}, - ValueBreach{Value: "failure 2"}, + &ValueBreach{Value: "failure1"}, + &ValueBreach{Value: "failure 2"}, }, }, { Severity: "normal", Breaches: []Breach{ - ValueBreach{Value: "failure3"}, - ValueBreach{Value: "failure 4"}, + &ValueBreach{Value: "failure3"}, + &ValueBreach{Value: "failure 4"}, }, }, }, } assert.EqualValues( []Breach{ - ValueBreach{Value: "failure1"}, - ValueBreach{Value: "failure 2"}, + &ValueBreach{Value: "failure1"}, + &ValueBreach{Value: "failure 2"}, }, rl.GetBreachesBySeverity("high")) assert.EqualValues( []Breach{ - ValueBreach{Value: "failure3"}, - ValueBreach{Value: "failure 4"}, + &ValueBreach{Value: "failure3"}, + &ValueBreach{Value: "failure 4"}, }, rl.GetBreachesBySeverity("normal")) } -func TestResultListGetRemediationsByCheckName(t *testing.T) { - assert := assert.New(t) - - rl := ResultList{ - Results: []Result{ - { - Name: "check1", - Remediations: []string{"fix1", "fix 2"}, - }, - { - Name: "check2", - Remediations: []string{"fix3", "fix 4"}, - }, - }, - } - assert.EqualValues( - []string{"fix1", "fix 2"}, - rl.GetRemediationsByCheckName("check1")) - assert.EqualValues( - []string{"fix3", "fix 4"}, - rl.GetRemediationsByCheckName("check2")) -} - func TestResultListSort(t *testing.T) { assert := assert.New(t) diff --git a/pkg/shipshape/output.go b/pkg/shipshape/output.go index 5192f6e..f9350df 100644 --- a/pkg/shipshape/output.go +++ b/pkg/shipshape/output.go @@ -67,46 +67,66 @@ func SimpleDisplay(w *bufio.Writer) { printRemediations := func() { for _, r := range RunResultList.Results { - if len(r.Remediations) == 0 { + _, successful, _, _ := r.RemediationsCount() + if successful == 0 { continue } fmt.Fprintf(w, " ### %s\n", r.Name) - for _, f := range r.Remediations { - fmt.Fprintf(w, " -- %s\n", f) + for _, b := range r.Breaches { + if b.GetRemediation().Status != result.RemediationStatusSuccess { + continue + } + for _, msg := range b.GetRemediation().Messages { + fmt.Fprintf(w, " -- %s\n", msg) + } } fmt.Fprintln(w) } } - if RunResultList.Status() == result.Pass && int(RunResultList.TotalRemediations) == 0 { + if RunResultList.RemediationPerformed && RunResultList.TotalBreaches > 0 { + switch RunResultList.RemediationStatus() { + case result.RemediationStatusNoSupport: + fmt.Fprint(w, "Breaches were detected but none of them could be "+ + "fixed as remediation is not supported for them yet.\n\n") + fmt.Fprint(w, "# Non-remediated breaches\n\n") + case result.RemediationStatusFailed: + fmt.Fprint(w, "Breaches were detected but none of them could "+ + "be fixed as there were errors when trying to remediate.\n\n") + fmt.Fprint(w, "# Non-remediated breaches\n\n") + case result.RemediationStatusPartial: + fmt.Fprint(w, "Breaches were detected but not all of them could "+ + "be fixed as they are either not supported yet or there were "+ + "errors when trying to remediate.\n\n") + fmt.Fprint(w, "# Remediations\n\n") + printRemediations() + fmt.Fprint(w, "# Non-remediated breaches\n\n") + case result.RemediationStatusSuccess: + fmt.Fprintf(w, "Breaches were detected but were all fixed successfully!\n\n") + printRemediations() + w.Flush() + return + } + } else if RunResultList.Status() == result.Pass { fmt.Fprint(w, "Ship is in top shape; no breach detected!\n") w.Flush() return - } else if RunResultList.Status() == result.Pass && int(RunResultList.TotalRemediations) > 0 { - fmt.Fprintf(w, "Breaches were detected but were all fixed successfully!\n\n") - printRemediations() - w.Flush() - return } - if RunResultList.RemediationPerformed && int(RunResultList.TotalBreaches) > 0 { - fmt.Fprint(w, "Breaches were detected but not all of them could "+ - "be fixed as they are either not supported yet or there were "+ - "errors when trying to remediate.\n\n") - fmt.Fprint(w, "# Remediations\n\n") - printRemediations() - fmt.Fprint(w, "# Non-remediated breaches\n\n") - } else if !RunResultList.RemediationPerformed { + if !RunResultList.RemediationPerformed { fmt.Fprint(w, "# Breaches were detected\n\n") } for _, r := range RunResultList.Results { - if len(r.Breaches) == 0 { + if len(r.Breaches) == 0 || r.RemediationStatus == result.RemediationStatusSuccess { continue } fmt.Fprintf(w, " ### %s\n", r.Name) - for _, f := range r.Breaches { - fmt.Fprintf(w, " -- %s\n", f) + for _, b := range r.Breaches { + if b.GetRemediation().Status == result.RemediationStatusSuccess { + continue + } + fmt.Fprintf(w, " -- %s\n", b) } fmt.Fprintln(w) } diff --git a/pkg/shipshape/output_test.go b/pkg/shipshape/output_test.go index 9cc70dd..e32eb29 100644 --- a/pkg/shipshape/output_test.go +++ b/pkg/shipshape/output_test.go @@ -62,8 +62,8 @@ func TestTableDisplay(t *testing.T) { Name: "c", Status: result.Fail, Breaches: []result.Breach{ - result.ValueBreach{Value: "Fail c"}, - result.ValueBreach{Value: "Fail cb"}, + &result.ValueBreach{Value: "Fail c"}, + &result.ValueBreach{Value: "Fail cb"}, }, }, { @@ -71,8 +71,8 @@ func TestTableDisplay(t *testing.T) { Status: result.Fail, Passes: []string{"Pass d", "Pass db"}, Breaches: []result.Breach{ - result.ValueBreach{Value: "Fail c"}, - result.ValueBreach{Value: "Fail cb"}, + &result.ValueBreach{Value: "Fail c"}, + &result.ValueBreach{Value: "Fail cb"}, }, }, }, @@ -121,10 +121,9 @@ func TestSimpleDisplay(t *testing.T) { Name: "b", Status: result.Fail, Breaches: []result.Breach{ - result.ValueBreach{Value: "Fail b"}, + &result.ValueBreach{Value: "Fail b"}, }, }) - buf = bytes.Buffer{} SimpleDisplay(w) assert.Equal("# Breaches were detected\n\n ### b\n -- Fail b\n\n", buf.String()) }) @@ -135,56 +134,93 @@ func TestSimpleDisplay(t *testing.T) { w := bufio.NewWriter(&buf) RunResultList.Results = append(RunResultList.Results, result.Result{ Name: "a", Status: result.Pass}) - buf = bytes.Buffer{} SimpleDisplay(w) assert.Equal("Ship is in top shape; no breach detected!\n", buf.String()) }) t.Run("allBreachesRemediated", func(t *testing.T) { - RunResultList = result.ResultList{RemediationPerformed: true} + RunResultList = result.ResultList{ + Results: []result.Result{{ + Name: "a", + Breaches: []result.Breach{ + &result.ValueBreach{ + Remediation: result.Remediation{ + Status: result.RemediationStatusSuccess, + Messages: []string{"fixed 1"}, + }, + }, + }}}, + TotalBreaches: 1, + RemediationPerformed: true, + RemediationTotals: map[string]uint32{"successful": 1}, + } + var buf bytes.Buffer w := bufio.NewWriter(&buf) - RunResultList.TotalRemediations = 1 - RunResultList.Results = append(RunResultList.Results, result.Result{ - Name: "a", Status: result.Pass, Remediations: []string{"fixed 1"}}) - buf = bytes.Buffer{} SimpleDisplay(w) assert.Equal("Breaches were detected but were all fixed successfully!\n\n"+ " ### a\n -- fixed 1\n\n", buf.String()) }) t.Run("someBreachesRemediated", func(t *testing.T) { - RunResultList = result.ResultList{RemediationPerformed: true} + RunResultList = result.ResultList{ + Results: []result.Result{{ + Name: "a", + Breaches: []result.Breach{ + &result.ValueBreach{ + Value: "Fail a", + Remediation: result.Remediation{ + Status: result.RemediationStatusSuccess, + Messages: []string{"fixed 1"}, + }, + }, + &result.ValueBreach{ + Value: "Fail b", + Remediation: result.Remediation{ + Status: result.RemediationStatusFailed, + Messages: []string{"not fixed 1"}, + }, + }, + }}}, + TotalBreaches: 2, + RemediationPerformed: true, + RemediationTotals: map[string]uint32{"successful": 1, "failed": 1}, + } + var buf bytes.Buffer w := bufio.NewWriter(&buf) - RunResultList.TotalRemediations = 1 - RunResultList.TotalBreaches = 1 - RunResultList.Results = append(RunResultList.Results, result.Result{ - Name: "a", Status: result.Fail, Remediations: []string{"fixed 1"}}) - buf = bytes.Buffer{} SimpleDisplay(w) assert.Equal("Breaches were detected but not all of them could be "+ "fixed as they are either not supported yet or there were errors "+ "when trying to remediate.\n\n"+ "# Remediations\n\n ### a\n -- fixed 1\n\n"+ - "# Non-remediated breaches\n\n", buf.String()) + "# Non-remediated breaches\n\n ### a\n -- Fail b\n\n", buf.String()) }) t.Run("noBreachRemediated", func(t *testing.T) { - RunResultList = result.ResultList{RemediationPerformed: true} + RunResultList = result.ResultList{ + Results: []result.Result{{ + Name: "a", + Breaches: []result.Breach{ + &result.ValueBreach{ + Remediation: result.Remediation{ + Status: result.RemediationStatusFailed, + Messages: []string{"failed 1"}, + }, + }, + }}}, + TotalBreaches: 1, + RemediationPerformed: true, + RemediationTotals: map[string]uint32{"failed": 1}, + } + var buf bytes.Buffer w := bufio.NewWriter(&buf) - RunResultList.TotalBreaches = 1 - RunResultList.TotalRemediations = 0 - RunResultList.Results = append(RunResultList.Results, result.Result{ - Name: "a", Status: result.Fail}) - buf = bytes.Buffer{} SimpleDisplay(w) - assert.Equal("Breaches were detected but not all of them could be "+ - "fixed as they are either not supported yet or there were errors "+ - "when trying to remediate.\n\n"+ - "# Remediations\n\n"+ - "# Non-remediated breaches\n\n", buf.String()) + assert.Equal("Breaches were detected but none of them could be "+ + "fixed as there were errors when trying to remediate.\n\n"+ + "# Non-remediated breaches\n\n"+ + " ### a\n -- \n\n", buf.String()) }) } @@ -225,7 +261,7 @@ func TestJUnit(t *testing.T) { Name: "b", Status: result.Fail, Breaches: []result.Breach{ - result.ValueBreach{Value: "Fail b"}, + &result.ValueBreach{Value: "Fail b"}, }, }) buf = bytes.Buffer{} diff --git a/pkg/shipshape/shipshape.go b/pkg/shipshape/shipshape.go index 35f6e60..1c1f919 100644 --- a/pkg/shipshape/shipshape.go +++ b/pkg/shipshape/shipshape.go @@ -161,7 +161,7 @@ func ParseConfigData(configData [][]byte) error { return nil } -func RunChecks() result.ResultList { +func RunChecks() { log.Print("preparing concurrent check runs") var wg sync.WaitGroup for ct, checks := range RunConfig.Checks { @@ -178,7 +178,7 @@ func RunChecks() result.ResultList { } wg.Wait() RunResultList.Sort() - return RunResultList + RunResultList.RemediationTotalsCount() } func ProcessCheck(rl *result.ResultList, c config.Check) { @@ -198,7 +198,14 @@ func ProcessCheck(rl *result.ResultList, c config.Check) { if len(c.GetResult().Breaches) == 0 && len(c.GetResult().Passes) == 0 { contextLogger.Print("running check") c.RunCheck() - c.GetResult().Sort() } + if len(c.GetResult().Breaches) > 0 && c.ShouldPerformRemediation() { + contextLogger.Print("performing remediation") + c.Remediate() + } + c.GetResult().DetermineResultStatus(c.ShouldPerformRemediation()) + contextLogger. + WithFields(log.Fields{"result": c.GetResult()}). + Print("check processed") rl.AddResult(*c.GetResult()) } diff --git a/pkg/shipshape/shipshape_test.go b/pkg/shipshape/shipshape_test.go index 6a768b1..a28ec6a 100644 --- a/pkg/shipshape/shipshape_test.go +++ b/pkg/shipshape/shipshape_test.go @@ -134,6 +134,10 @@ checks: } func TestRunChecks(t *testing.T) { + currLogOut := logrus.StandardLogger().Out + defer logrus.SetOutput(currLogOut) + logrus.SetOutput(io.Discard) + assert := assert.New(t) test1stCheck := &testchecks.TestCheck1Check{} @@ -149,13 +153,14 @@ func TestRunChecks(t *testing.T) { }, } - rl := RunChecks() - assert.Equal(uint32(2), rl.TotalChecks) - assert.Equal(uint32(2), rl.TotalBreaches) + RunResultList = result.NewResultList(false) + RunChecks() + assert.Equal(uint32(2), RunResultList.TotalChecks) + assert.Equal(uint32(2), RunResultList.TotalBreaches) assert.EqualValues(map[string]int{ string(testchecks.TestCheck1): 1, string(testchecks.TestCheck2): 1, - }, rl.BreachCountByType) + }, RunResultList.BreachCountByType) assert.ElementsMatch([]result.Result{ { Name: "test1stcheck", @@ -163,7 +168,7 @@ func TestRunChecks(t *testing.T) { CheckType: "test-check-1", Status: "Fail", Passes: []string(nil), - Breaches: []result.Breach{result.ValueBreach{ + Breaches: []result.Breach{&result.ValueBreach{ BreachType: "value", CheckType: "test-check-1", CheckName: "test1stcheck", @@ -178,7 +183,7 @@ func TestRunChecks(t *testing.T) { CheckType: "test-check-2", Status: "Fail", Passes: []string(nil), - Breaches: []result.Breach{result.ValueBreach{ + Breaches: []result.Breach{&result.ValueBreach{ BreachType: "value", CheckType: "test-check-2", CheckName: "test2ndcheck", @@ -187,5 +192,5 @@ func TestRunChecks(t *testing.T) { }}, Warnings: []string(nil), }}, - rl.Results) + RunResultList.Results) }