Skip to content

Commit

Permalink
Merge pull request #84 from Clever/sigterm-on-timeout
Browse files Browse the repository at this point in the history
sigterm + sigkill on task context finishing before command is done
  • Loading branch information
Sayan- authored Nov 20, 2019
2 parents b634ad0 + 1d588e6 commit 46319fd
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 29 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.6
0.3.7
32 changes: 24 additions & 8 deletions cmd/sfncli/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type TaskRunner struct {
}

// NewTaskRunner instantiates a new TaskRunner
func NewTaskRunner(cmd string, sfnapi sfniface.SFNAPI, taskToken string, workDirectory string, cancelFunc context.CancelFunc) TaskRunner {
func NewTaskRunner(cmd string, sfnapi sfniface.SFNAPI, taskToken string, workDirectory string) TaskRunner {
return TaskRunner{
sfnapi: sfnapi,
taskToken: taskToken,
Expand All @@ -51,7 +51,6 @@ func NewTaskRunner(cmd string, sfnapi sfniface.SFNAPI, taskToken string, workDir
// set the default grace period to something slightly lower than the default
// docker stop grace period in ECS (30s)
sigtermGracePeriod: 25 * time.Second,
ctxCancel: cancelFunc,
}
}

Expand Down Expand Up @@ -83,7 +82,10 @@ func (t *TaskRunner) Process(ctx context.Context, args []string, input string) e

args = append(args, string(marshaledInput))

t.execCmd = exec.CommandContext(ctx, t.cmd, args...)
// don't use exec.CommandContext, since we want to do graceful
// sigterm + (grace period) + sigkill on the context finishing
// CommandContext does sigkill immediately.
t.execCmd = exec.Command(t.cmd, args...)
t.execCmd.Env = append(os.Environ(), "_EXECUTION_NAME="+executionName)

tmpDir := ""
Expand Down Expand Up @@ -176,6 +178,14 @@ func (t *TaskRunner) handleSignals(ctx context.Context) {
for {
select {
case <-ctx.Done():
// if the context has ended, but the command is still running,
// initiate graceful shutdown with a much shorter grace period,
// since most likely this is a case of SFN timing out the
// activity. This means there is likely another activity
// out there beginning work on the same input.
if t.execCmd.Process != nil && t.execCmd.ProcessState == nil {
sigTermAndThenKill(t.execCmd.Process.Pid, 5*time.Second)
}
return
case sigReceived := <-sigChan:
if t.execCmd.Process == nil {
Expand All @@ -187,11 +197,8 @@ func (t *TaskRunner) handleSignals(ctx context.Context) {
// - after a grace period send SIGKILL to the command if it's still running
if sigReceived == syscall.SIGTERM {
t.receivedSigterm = true
go func(pidtokill int) {
time.Sleep(t.sigtermGracePeriod)
signalProcess(pidtokill, os.Signal(syscall.SIGKILL))
t.ctxCancel()
}(pid)
sigTermAndThenKill(pid, t.sigtermGracePeriod)
return
}
signalProcess(pid, sigReceived)
}
Expand All @@ -206,6 +213,15 @@ func signalProcess(pid int, signal os.Signal) {
proc.Signal(signal)
}

// sigTermAndThenKill is a docker-stop like shutdown process:
// - send sigterm
// - after a grace period send SIGKILL if the command is still running
func sigTermAndThenKill(pid int, gracePeriod time.Duration) {
signalProcess(pid, os.Signal(syscall.SIGTERM))
time.Sleep(gracePeriod)
signalProcess(pid, os.Signal(syscall.SIGKILL))
}

func parseCustomErrorFromStdout(stdout string) (TaskFailureCustom, error) {
var customError TaskFailureCustom
err := json.Unmarshal([]byte(taskOutputFromStdout(stdout)), &customError)
Expand Down
30 changes: 15 additions & 15 deletions cmd/sfncli/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestTaskFailureTaskInputNotJSON(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
err := taskRunner.Process(testCtx, cmdArgs, taskInput)
require.Equal(t, err, expectedError)

Expand All @@ -93,7 +93,7 @@ func TestTaskOutputEmptyStringAsJSON(t *testing.T) {
TaskToken: aws.String(mockTaskToken),
Output: aws.String(`{"_EXECUTION_NAME":"fake-WFM-uuid"}`),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
err := taskRunner.Process(testCtx, cmdArgs, taskInput)
require.NoError(t, err)

Expand All @@ -115,7 +115,7 @@ func TestTaskFailureCommandNotFound(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
err := taskRunner.Process(testCtx, cmdArgs, emptyTaskInput)
require.Equal(t, err, expectedError)
}
Expand All @@ -136,7 +136,7 @@ func TestTaskFailureCommandKilled(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
go func() {
time.Sleep(2 * time.Second)
taskRunner.execCmd.Process.Signal(syscall.SIGKILL)
Expand All @@ -161,7 +161,7 @@ func TestTaskFailureCommandExitedNonzero(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
err := taskRunner.Process(testCtx, cmdArgs, emptyTaskInput)
require.Equal(t, err, expectedError)
}
Expand All @@ -182,7 +182,7 @@ func TestTaskFailureCustomErrorName(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
err := taskRunner.Process(testCtx, cmdArgs, emptyTaskInput)
require.Equal(t, err, expectedError)
}
Expand All @@ -203,7 +203,7 @@ func TestTaskFailureTaskOutputNotJSON(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
err := taskRunner.Process(testCtx, cmdArgs, emptyTaskInput)
require.Equal(t, err, expectedError)
}
Expand All @@ -224,7 +224,7 @@ func TestTaskFailureCommandTerminated(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
go func() {
time.Sleep(1 * time.Second)
process, _ := os.FindProcess(os.Getpid())
Expand All @@ -249,7 +249,7 @@ func TestTaskFailureCommandTerminated(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
go func() {
time.Sleep(1 * time.Second)
process, _ := os.FindProcess(os.Getpid())
Expand All @@ -274,7 +274,7 @@ func TestTaskFailureCommandTerminated(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
// lower the grace period so this test doesn't take forever
taskRunner.sigtermGracePeriod = 5 * time.Second
go func() {
Expand All @@ -300,7 +300,7 @@ func TestTaskSuccessSignalForwarded(t *testing.T) {
TaskToken: aws.String(mockTaskToken),
})
defer controller.Finish()
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
go func() {
time.Sleep(1 * time.Second)
process, _ := os.FindProcess(os.Getpid())
Expand All @@ -322,7 +322,7 @@ func TestTaskSuccessOutputIsLastLineOfStdout(t *testing.T) {
TaskToken: aws.String(mockTaskToken),
})
defer controller.Finish()
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
require.Nil(t, taskRunner.Process(testCtx, cmdArgs, emptyTaskInput))
}

Expand All @@ -341,7 +341,7 @@ func TestTaskWorkDirectorySetup(t *testing.T) {
taskToken: mockTaskToken,
expectedPrefix: "/tmp",
}) // returns the result of WORK_DIR
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "/tmp", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "/tmp")
err := taskRunner.Process(testCtx, cmdArgs, taskInput)
require.NoError(t, err)
}
Expand All @@ -361,7 +361,7 @@ func TestTaskWorkDirectoryUnsetByDefault(t *testing.T) {
TaskToken: aws.String(mockTaskToken),
Output: aws.String(`{"_EXECUTION_NAME":"fake-WFM-uuid","work_dir":""}`), // returns the result of WORK_DIR
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
err := taskRunner.Process(testCtx, cmdArgs, taskInput)
require.NoError(t, err)
}
Expand All @@ -385,7 +385,7 @@ func TestTaskWorkDirectoryCleaned(t *testing.T) {

os.MkdirAll("/tmp/test", os.ModeDir|0777) // base path is created by cmd/sfncli/sfncli.go
defer os.RemoveAll("/tmp/test")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "/tmp/test", testCtxCancel)
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "/tmp/test")
err := taskRunner.Process(testCtx, cmdArgs, taskInput)
require.NoError(t, err)
if _, err := os.Stat(dirMatcher.foundWorkdir); os.IsExist(err) {
Expand Down
7 changes: 2 additions & 5 deletions cmd/sfncli/sfncli.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,7 @@ func main() {

// Create a context for this task. We'll cancel this context on errors.
// Anything spawned on behalf of the task should use this context.
var taskCtx context.Context
var taskCtxCancel context.CancelFunc
// context.Background() to disconnect this from the mainCtx cancellation
taskCtx, taskCtxCancel = context.WithCancel(context.Background())
taskCtx, taskCtxCancel := context.WithCancel(mainCtx)

// Begin sending heartbeats
go func() {
Expand All @@ -197,7 +194,7 @@ func main() {

// Run the command. Treat unprocessed args (flag.Args()) as additional args to
// send to the command on every invocation of the command
taskRunner := NewTaskRunner(*cmd, sfnapi, token, *workDirectory, taskCtxCancel)
taskRunner := NewTaskRunner(*cmd, sfnapi, token, *workDirectory)
err = taskRunner.Process(taskCtx, flag.Args(), input)
if err != nil {
log.ErrorD("task-process-error", logger.M{"error": err.Error()})
Expand Down

0 comments on commit 46319fd

Please sign in to comment.