Skip to content

Commit

Permalink
add a rety option to the step command (#1585)
Browse files Browse the repository at this point in the history
* add a retry option to the step command
  • Loading branch information
lewismarshall authored Jan 13, 2025
1 parent dd29b2d commit c327b00
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 30 deletions.
103 changes: 73 additions & 30 deletions cmd/step/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"errors"
"fmt"
"io"
"math/rand"
"os"
"os/exec"
"path/filepath"
Expand Down Expand Up @@ -71,14 +72,26 @@ func main() {
flags.StringSliceVarP(&step.UploadFile, "upload", "u", []string{}, "Upload file as a kubernetes secret")
flags.StringVar(&step.WaitFile, "wait-on", "", "The path to a file to indicate this step can be run")
flags.StringSliceVarP(&step.Commands, "command", "c", []string{}, "Command to execute")

flags.IntVar(&step.RetryAttempts, "retry-attempts", 0, "Number of times to retry the commands")
flags.DurationVar(&step.RetryMinBackoff, "retry-min-backoff", 0, "Minimum duration to wait between retry attempts")
flags.DurationVar(&step.RetryMaxJitter, "retry-max-jitter", 2*time.Second, "Maximum random jitter to add to backoff time")
if err := cmd.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "[Error] %s\n", err)

os.Exit(1)
}
}

// calculateBackoff returns a duration that includes the minimum backoff plus a random jitter
func calculateBackoff(minBackoff, maxJitter time.Duration) time.Duration {
if maxJitter <= 0 {
return minBackoff
}
//nolint:gosec // math/rand is acceptable here as jitter value is not used for security purposes
jitter := time.Duration(rand.Int63n(int64(maxJitter)))
return minBackoff + jitter
}

// Run is called to implement the action
func Run(ctx context.Context, step Step) error {
if err := step.IsValid(); err != nil {
Expand Down Expand Up @@ -128,49 +141,79 @@ func Run(ctx context.Context, step Step) error {
}

for i, command := range step.Commands {
//nolint:gosec
cmd := exec.CommandContext(ctx, step.Shell, "-c", command)
cmd.Env = os.Environ()
attempt := 0
var lastErr error

for attempt <= step.RetryAttempts {
if attempt > 0 {
backoff := calculateBackoff(step.RetryMinBackoff, step.RetryMaxJitter)
log.WithFields(log.Fields{
"attempt": attempt,
"command-index": i,
"backoff": backoff,
}).Info("retrying command")

select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff):
}
}

logger := log.WithField("command", i)
//nolint:gosec
cmd := exec.CommandContext(ctx, step.Shell, "-c", command)
cmd.Env = os.Environ()

stdout, err := cmd.StdoutPipe()
if err != nil {
logger.WithError(err).Error("failed to acquire stdout pipe on command")
logger := log.WithFields(log.Fields{
"command-index": i,
"attempt": attempt,
})

return err
}
stderr, err := cmd.StderrPipe()
if err != nil {
logger.WithError(err).Error("failed to acquire stderr pipe on command")
stdout, err := cmd.StdoutPipe()
if err != nil {
logger.WithError(err).Error("failed to acquire stdout pipe on command")
return err
}
stderr, err := cmd.StderrPipe()
if err != nil {
logger.WithError(err).Error("failed to acquire stderr pipe on command")
return err
}

return err
}
//nolint:errcheck
go io.Copy(os.Stdout, stdout)
//nolint:errcheck
go io.Copy(os.Stdout, stderr)

//nolint:errcheck
go io.Copy(os.Stdout, stdout)
//nolint:errcheck
go io.Copy(os.Stdout, stderr)
if err := cmd.Start(); err != nil {
logger.WithError(err).Error("failed to execute the command")
lastErr = err
attempt++
continue
}

if err := cmd.Start(); err != nil {
logger.WithError(err).Error("failed to execute the command")
// @step: wait for the command to finish
if err := cmd.Wait(); err != nil {
logger.WithError(err).Error("command execution failed")
lastErr = err
attempt++
continue
}

return err
// Command succeeded, break the retry loop
lastErr = nil
break
}

// @step: wait for the command to finish
if err := cmd.Wait(); err != nil {
logger.WithError(err).Error("failed to execute command successfully")

// If we exhausted all retries and still have an error
if lastErr != nil {
if step.ErrorFile != "" {
if err := utils.TouchFile(step.ErrorFile); err != nil {
logger.WithError(err).WithField("file", step.ErrorFile).Error("failed to create error file")

log.WithError(err).WithField("file", step.ErrorFile).Error("failed to create error file")
return err
}
}

return err
return fmt.Errorf("command failed after %d attempts: %w", attempt, lastErr)
}
}
log.Info("successfully executed the step")
Expand Down
15 changes: 15 additions & 0 deletions cmd/step/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ import (
type Step struct {
// Commands is the commands and arguments to run
Commands []string
// RetryMinBackoff is the minimum backoff time to retry ANY of the commands
RetryMinBackoff time.Duration
// RetryMaxJitter is the maximum random jitter to add to the backoff time
RetryMaxJitter time.Duration
// RetryAttempts is the number of times to retry the commands before giving up
RetryAttempts int
// Comment adds a banner to the stage
Comment string
// ErrorFile is the path to a file which is created when the command failed
Expand Down Expand Up @@ -57,6 +63,15 @@ func (s Step) IsValid() error {
case s.Timeout < 0:
return errors.New("timeout must be greater than 0")

case s.RetryAttempts < 0:
return errors.New("retry attempts must be greater than or equal to 0")

case s.RetryAttempts > 0 && s.RetryMinBackoff < 0:
return errors.New("minimum retry backoff must be greater than or equal to 0")

case s.RetryAttempts > 0 && s.RetryMaxJitter < 0:
return errors.New("maximum jitter must be greater than or equal to 0")

case len(s.UploadFile) > 0 && s.Namespace == "":
return errors.New("namespace must be specified when uploading files")

Expand Down

0 comments on commit c327b00

Please sign in to comment.