Skip to content

Commit

Permalink
Merge pull request #35 from Clever/sigterm-handling
Browse files Browse the repository at this point in the history
SIGTERM handling
  • Loading branch information
meliaj authored Dec 28, 2018
2 parents 18864ae + 5273854 commit 02e0208
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.2
0.3.3
5 changes: 4 additions & 1 deletion cmd/sfncli/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ type TaskRunner struct {
receivedSigterm bool
sigtermGracePeriod time.Duration
workDirectory string
ctxCancel context.CancelFunc
}

// NewTaskRunner instantiates a new TaskRunner
func NewTaskRunner(cmd string, sfnapi sfniface.SFNAPI, taskToken string, workDirectory string) TaskRunner {
func NewTaskRunner(cmd string, sfnapi sfniface.SFNAPI, taskToken string, workDirectory string, cancelFunc context.CancelFunc) TaskRunner {
return TaskRunner{
sfnapi: sfnapi,
taskToken: taskToken,
Expand All @@ -50,6 +51,7 @@ func NewTaskRunner(cmd string, sfnapi sfniface.SFNAPI, taskToken string, workDir
// set the default grace period to something slightly lower than the default
// docker stop grace period in ECS (30s)
sigtermGracePeriod: 25 * time.Second,
ctxCancel: cancelFunc,
}
}

Expand Down Expand Up @@ -188,6 +190,7 @@ func (t *TaskRunner) handleSignals(ctx context.Context) {
go func(pidtokill int) {
time.Sleep(t.sigtermGracePeriod)
signalProcess(pidtokill, os.Signal(syscall.SIGKILL))
t.ctxCancel()
}(pid)
}
signalProcess(pid, sigReceived)
Expand Down
30 changes: 15 additions & 15 deletions cmd/sfncli/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestTaskFailureTaskInputNotJSON(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
err := taskRunner.Process(testCtx, cmdArgs, taskInput)
require.Equal(t, err, expectedError)

Expand All @@ -93,7 +93,7 @@ func TestTaskOutputEmptyStringAsJSON(t *testing.T) {
TaskToken: aws.String(mockTaskToken),
Output: aws.String(`{"_EXECUTION_NAME":"fake-WFM-uuid"}`),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
err := taskRunner.Process(testCtx, cmdArgs, taskInput)
require.NoError(t, err)

Expand All @@ -115,7 +115,7 @@ func TestTaskFailureCommandNotFound(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
err := taskRunner.Process(testCtx, cmdArgs, emptyTaskInput)
require.Equal(t, err, expectedError)
}
Expand All @@ -136,7 +136,7 @@ func TestTaskFailureCommandKilled(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
go func() {
time.Sleep(2 * time.Second)
taskRunner.execCmd.Process.Signal(syscall.SIGKILL)
Expand All @@ -161,7 +161,7 @@ func TestTaskFailureCommandExitedNonzero(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
err := taskRunner.Process(testCtx, cmdArgs, emptyTaskInput)
require.Equal(t, err, expectedError)
}
Expand All @@ -182,7 +182,7 @@ func TestTaskFailureCustomErrorName(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
err := taskRunner.Process(testCtx, cmdArgs, emptyTaskInput)
require.Equal(t, err, expectedError)
}
Expand All @@ -203,7 +203,7 @@ func TestTaskFailureTaskOutputNotJSON(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
err := taskRunner.Process(testCtx, cmdArgs, emptyTaskInput)
require.Equal(t, err, expectedError)
}
Expand All @@ -224,7 +224,7 @@ func TestTaskFailureCommandTerminated(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
go func() {
time.Sleep(1 * time.Second)
process, _ := os.FindProcess(os.Getpid())
Expand All @@ -249,7 +249,7 @@ func TestTaskFailureCommandTerminated(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
go func() {
time.Sleep(1 * time.Second)
process, _ := os.FindProcess(os.Getpid())
Expand All @@ -274,7 +274,7 @@ func TestTaskFailureCommandTerminated(t *testing.T) {
Error: aws.String(expectedError.ErrorName()),
TaskToken: aws.String(mockTaskToken),
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
// lower the grace period so this test doesn't take forever
taskRunner.sigtermGracePeriod = 5 * time.Second
go func() {
Expand All @@ -300,7 +300,7 @@ func TestTaskSuccessSignalForwarded(t *testing.T) {
TaskToken: aws.String(mockTaskToken),
})
defer controller.Finish()
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
go func() {
time.Sleep(1 * time.Second)
process, _ := os.FindProcess(os.Getpid())
Expand All @@ -322,7 +322,7 @@ func TestTaskSuccessOutputIsLastLineOfStdout(t *testing.T) {
TaskToken: aws.String(mockTaskToken),
})
defer controller.Finish()
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
require.Nil(t, taskRunner.Process(testCtx, cmdArgs, emptyTaskInput))
}

Expand All @@ -341,7 +341,7 @@ func TestTaskWorkDirectorySetup(t *testing.T) {
taskToken: mockTaskToken,
expectedPrefix: "/tmp",
}) // returns the result of WORK_DIR
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "/tmp")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "/tmp", testCtxCancel)
err := taskRunner.Process(testCtx, cmdArgs, taskInput)
require.NoError(t, err)
}
Expand All @@ -361,7 +361,7 @@ func TestTaskWorkDirectoryUnsetByDefault(t *testing.T) {
TaskToken: aws.String(mockTaskToken),
Output: aws.String(`{"_EXECUTION_NAME":"fake-WFM-uuid","work_dir":""}`), // returns the result of WORK_DIR
})
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "", testCtxCancel)
err := taskRunner.Process(testCtx, cmdArgs, taskInput)
require.NoError(t, err)
}
Expand All @@ -385,7 +385,7 @@ func TestTaskWorkDirectoryCleaned(t *testing.T) {

os.MkdirAll("/tmp/test", os.ModeDir|0777) // base path is created by cmd/sfncli/sfncli.go
defer os.RemoveAll("/tmp/test")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "/tmp/test")
taskRunner := NewTaskRunner(path.Join(testScriptsDir, cmd), mockSFN, mockTaskToken, "/tmp/test", testCtxCancel)
err := taskRunner.Process(testCtx, cmdArgs, taskInput)
require.NoError(t, err)
if _, err := os.Stat(dirMatcher.foundWorkdir); os.IsExist(err) {
Expand Down
5 changes: 3 additions & 2 deletions cmd/sfncli/sfncli.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ func main() {
// Anything spawned on behalf of the task should use this context.
var taskCtx context.Context
var taskCtxCancel context.CancelFunc
taskCtx, taskCtxCancel = context.WithCancel(mainCtx)
// context.Background() to disconnect this from the mainCtx cancellation
taskCtx, taskCtxCancel = context.WithCancel(context.Background())

// Begin sending heartbeats
go func() {
Expand All @@ -176,7 +177,7 @@ func main() {

// Run the command. Treat unprocessed args (flag.Args()) as additional args to
// send to the command on every invocation of the command
taskRunner := NewTaskRunner(*cmd, sfnapi, token, *workDirectory)
taskRunner := NewTaskRunner(*cmd, sfnapi, token, *workDirectory, taskCtxCancel)
err = taskRunner.Process(taskCtx, flag.Args(), input)
if err != nil {
log.ErrorD("task-process-error", logger.M{"error": err.Error()})
Expand Down

0 comments on commit 02e0208

Please sign in to comment.