Skip to content

Commit

Permalink
kill process tree using syscall on windows & cleanup (#80)
Browse files Browse the repository at this point in the history
* kill process tree using syscall on windows & cleanup

* use job api

* add error check for cmd.Start
  • Loading branch information
uubulb authored Oct 28, 2024
1 parent 134c8c5 commit 069fc30
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 43 deletions.
24 changes: 12 additions & 12 deletions cmd/agent/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"bytes"
"context"
"crypto/tls"
"errors"
Expand All @@ -11,7 +12,6 @@ import (
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
Expand Down Expand Up @@ -622,21 +622,22 @@ func handleCommandTask(task *pb.Task, result *pb.TaskResult) {
return
}
startedAt := time.Now()
var cmd *exec.Cmd
var endCh = make(chan struct{})
endCh := make(chan struct{})
pg, err := processgroup.NewProcessExitGroup()
if err != nil {
// 进程组创建失败,直接退出
result.Data = err.Error()
return
}
timeout := time.NewTimer(time.Hour * 2)
if util.IsWindows() {
cmd = exec.Command("cmd", "/c", task.GetData()) // #nosec
} else {
cmd = exec.Command("sh", "-c", task.GetData()) // #nosec
}
cmd := processgroup.NewCommand(task.GetData())
var b bytes.Buffer
cmd.Stdout = &b
cmd.Env = os.Environ()
if err = cmd.Start(); err != nil {
result.Data = err.Error()
return
}
pg.AddProcess(cmd)
go func() {
select {
Expand All @@ -648,12 +649,11 @@ func handleCommandTask(task *pb.Task, result *pb.TaskResult) {
timeout.Stop()
}
}()
output, err := cmd.Output()
if err != nil {
result.Data += fmt.Sprintf("%s\n%s", string(output), err.Error())
if err = cmd.Wait(); err != nil {
result.Data += fmt.Sprintf("%s\n%s", b.String(), err.Error())
} else {
close(endCh)
result.Data = string(output)
result.Data = b.String()
result.Successful = true
}
pg.Dispose()
Expand Down
42 changes: 21 additions & 21 deletions pkg/processgroup/process_group.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//go:build !windows
// +build !windows

package processgroup

Expand All @@ -17,38 +16,39 @@ func NewProcessExitGroup() (ProcessExitGroup, error) {
return ProcessExitGroup{}, nil
}

func (g *ProcessExitGroup) killChildProcess(c *exec.Cmd) error {
pgid, err := syscall.Getpgid(c.Process.Pid)
if err != nil {
// Fall-back on error. Kill the main process only.
c.Process.Kill()
}
// Kill the whole process group.
syscall.Kill(-pgid, syscall.SIGTERM)
return c.Wait()
func NewCommand(arg string) *exec.Cmd {
cmd := exec.Command("sh", "-c", arg)
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
return cmd
}

func (g *ProcessExitGroup) Dispose() []error {
var errors []error
mutex := new(sync.Mutex)
wg := new(sync.WaitGroup)
func (g *ProcessExitGroup) Dispose() error {
var wg sync.WaitGroup
wg.Add(len(g.cmds))

for _, c := range g.cmds {
go func(c *exec.Cmd) {
defer wg.Done()
if err := g.killChildProcess(c); err != nil {
mutex.Lock()
defer mutex.Unlock()
errors = append(errors, err)
}
killChildProcess(c)
}(c)
}

wg.Wait()
return errors
return nil
}

func (g *ProcessExitGroup) AddProcess(cmd *exec.Cmd) error {
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
g.cmds = append(g.cmds, cmd)
return nil
}

func killChildProcess(c *exec.Cmd) {
pgid, err := syscall.Getpgid(c.Process.Pid)
if err != nil {
// Fall-back on error. Kill the main process only.
c.Process.Kill()
}
// Kill the whole process group.
syscall.Kill(-pgid, syscall.SIGTERM)
c.Wait()
}
72 changes: 63 additions & 9 deletions pkg/processgroup/process_group_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,80 @@ package processgroup
import (
"fmt"
"os/exec"
"unsafe"

"golang.org/x/sys/windows"
)

type ProcessExitGroup struct {
cmds []*exec.Cmd
cmds []*exec.Cmd
jobHandle windows.Handle
procs []windows.Handle
}

func NewProcessExitGroup() (ProcessExitGroup, error) {
return ProcessExitGroup{}, nil
func NewProcessExitGroup() (*ProcessExitGroup, error) {
job, err := windows.CreateJobObject(nil, nil)
if err != nil {
return nil, err
}

info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{
BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{
LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
},
}

_, err = windows.SetInformationJobObject(
job,
windows.JobObjectExtendedLimitInformation,
uintptr(unsafe.Pointer(&info)),
uint32(unsafe.Sizeof(info)))

return &ProcessExitGroup{jobHandle: job}, nil
}

func (g *ProcessExitGroup) Dispose() error {
for _, c := range g.cmds {
if err := exec.Command("taskkill", "/F", "/T", "/PID", fmt.Sprint(c.Process.Pid)).Run(); err != nil {
return err
}
func NewCommand(args string) *exec.Cmd {
cmd := exec.Command("cmd")
cmd.SysProcAttr = &windows.SysProcAttr{
CmdLine: fmt.Sprintf("/c %s", args),
CreationFlags: windows.CREATE_NEW_PROCESS_GROUP,
}
return nil
return cmd
}

func (g *ProcessExitGroup) AddProcess(cmd *exec.Cmd) error {
proc, err := windows.OpenProcess(windows.PROCESS_TERMINATE|windows.PROCESS_SET_QUOTA|windows.PROCESS_SET_INFORMATION, false, uint32(cmd.Process.Pid))
if err != nil {
return err
}

g.procs = append(g.procs, proc)
g.cmds = append(g.cmds, cmd)

return windows.AssignProcessToJobObject(g.jobHandle, proc)
}

func (g *ProcessExitGroup) Dispose() error {
defer func() {
windows.CloseHandle(g.jobHandle)
for _, proc := range g.procs {
windows.CloseHandle(proc)
}
}()

if err := windows.TerminateJobObject(g.jobHandle, 1); err != nil {
// Fall-back on error. Kill the main process only.
for _, cmd := range g.cmds {
cmd.Process.Kill()
}
return err
}

// wait for job to be terminated
status, err := windows.WaitForSingleObject(g.jobHandle, windows.INFINITE)
if status != windows.WAIT_OBJECT_0 {
return err
}

return nil
}
1 change: 0 additions & 1 deletion pkg/pty/pty.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//go:build !windows
// +build !windows

package pty

Expand Down

0 comments on commit 069fc30

Please sign in to comment.