diff --git a/shutdown/shutdown.go b/shutdown/shutdown.go index d2bd3ca7..872802a7 100644 --- a/shutdown/shutdown.go +++ b/shutdown/shutdown.go @@ -1,24 +1,49 @@ package shutdown import ( + "container/heap" "os" "os/signal" "sync" + "time" "github.com/flanksource/commons/logger" ) -var shutdownHooks []func() +// Some helper priority levels +const ( + PriorityIngress = 100 + PriorityJobs = 500 + PriorityCritical = 1000 +) + +type shutdownHook func() + +var ( + registryLock sync.Mutex + shutdownTaskRegistry ShutdownTasks +) + +func init() { + heap.Init(&shutdownTaskRegistry) +} var Shutdown = sync.OnceFunc(func() { - if len(shutdownHooks) == 0 { - return - } - logger.Infof("Shutting down") - for _, fn := range shutdownHooks { - fn() + logger.Infof("begin shutdown") + + for len(shutdownTaskRegistry) > 0 { + _task := heap.Pop(&shutdownTaskRegistry) + if _task == nil { + break + } + + task := _task.(ShutdownTask) + logger.Infof("shutting down: %s", task.Label) + + s := time.Now() + task.Hook() + logger.Infof("shutdown %s completed in %v", task.Label, time.Since(s)) } - shutdownHooks = []func(){} }) func ShutdownAndExit(code int, msg string) { @@ -27,8 +52,23 @@ func ShutdownAndExit(code int, msg string) { os.Exit(code) } -func AddHook(fn func()) { - shutdownHooks = append(shutdownHooks, fn) +// Add a hook with the least priority. +// Least priority hooks are run first. +// +// Prefer AddHookWithPriority() +func AddHook(fn shutdownHook) { + registryLock.Lock() + heap.Push(&shutdownTaskRegistry, ShutdownTask{Hook: fn, Priority: 0}) + registryLock.Unlock() +} + +// AddHookWithPriority adds a hook with a priority level. +// +// Execution order goes from lowest to highest priority numbers. +func AddHookWithPriority(label string, priority int, fn shutdownHook) { + registryLock.Lock() + heap.Push(&shutdownTaskRegistry, ShutdownTask{Label: label, Hook: fn, Priority: priority}) + registryLock.Unlock() } func WaitForSignal() { @@ -41,3 +81,34 @@ func WaitForSignal() { Shutdown() }() } + +type ShutdownTask struct { + Hook shutdownHook + Label string + Priority int +} + +// ShutdownTasks implements heap.Interface +type ShutdownTasks []ShutdownTask + +func (st ShutdownTasks) Len() int { return len(st) } + +func (st ShutdownTasks) Less(i, j int) bool { + return st[i].Priority < st[j].Priority +} + +func (st ShutdownTasks) Swap(i, j int) { + st[i], st[j] = st[j], st[i] +} + +func (st *ShutdownTasks) Push(x interface{}) { + *st = append(*st, x.(ShutdownTask)) +} + +func (st *ShutdownTasks) Pop() interface{} { + old := *st + n := len(old) + item := old[n-1] + *st = old[0 : n-1] + return item +} diff --git a/shutdown/shutdown_test.go b/shutdown/shutdown_test.go new file mode 100644 index 00000000..30c8cda0 --- /dev/null +++ b/shutdown/shutdown_test.go @@ -0,0 +1,25 @@ +package shutdown + +import "testing" + +func TestShutdownPriority(t *testing.T) { + var lastClosed int + + add := func(label string, priority int) { + AddHookWithPriority(label, priority, func() { + if lastClosed > priority { + t.Fatalf("something higher priority (%d) was closed earlier than (%d)", lastClosed, priority) + } else { + lastClosed = priority + } + }) + } + + add("database", PriorityCritical) + add("gRPC", PriorityIngress) + add("checkJob", PriorityJobs) + add("echo", PriorityIngress) + add("topologyJob", PriorityJobs) + + Shutdown() +}