Skip to content

Commit

Permalink
Use separate counters for pre- and post-patch reboots. (#788)
Browse files Browse the repository at this point in the history
`RebootCount` was used only for forcing post-patch reboots via
`RebootConfig = always`, but incremented for every reboot. As a result,
the forced post-patch reboot was skipped if pre-patch reboot was
required. This changes it to count pre-patch reboots separately.

In addition, it adds a guard against the endless patch reboot loops. The
limit was selected through common sense.

Closes: #755
  • Loading branch information
zoltak-g authored Feb 6, 2025
1 parent bbf3baa commit 88b1cc9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 13 deletions.
34 changes: 22 additions & 12 deletions agentendpoint/patch_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ func systemRebootRequired(ctx context.Context) (bool, error) {
type patchStep string

const (
prePatch = "PrePatch"
patching = "Patching"
postPatch = "PostPatch"
prePatch = "PrePatch"
patching = "Patching"
postPatch = "PostPatch"
totalRebootCountLimit = 5
)

type patchTask struct {
Expand All @@ -45,11 +46,12 @@ type patchTask struct {
lastProgressState map[agentendpointpb.ApplyPatchesTaskProgress_State]time.Time
state *taskState

TaskID string
Task *applyPatchesTask
StartedAt time.Time `json:",omitempty"`
PatchStep patchStep `json:",omitempty"`
RebootCount int
TaskID string
Task *applyPatchesTask
StartedAt time.Time `json:",omitempty"`
PatchStep patchStep `json:",omitempty"`
PrePatchRebootCount int
PostPatchRebootCount int

// TODO: add Attempts and track number of retries with backoff, jitter, etc.
}
Expand Down Expand Up @@ -154,8 +156,6 @@ func (r *patchTask) reportContinuingState(ctx context.Context, patchState agente
return r.saveState()
}

// TODO: Add MaxRebootCount so we don't loop endlessly.

func (r *patchTask) prePatchReboot(ctx context.Context) error {
return r.rebootIfNeeded(ctx, true)
}
Expand All @@ -167,7 +167,7 @@ func (r *patchTask) postPatchReboot(ctx context.Context) error {
func (r *patchTask) rebootIfNeeded(ctx context.Context, prePatch bool) error {
var reboot bool
var err error
if r.Task.GetPatchConfig().GetRebootConfig() == agentendpointpb.PatchConfig_ALWAYS && !prePatch && r.RebootCount == 0 {
if r.Task.GetPatchConfig().GetRebootConfig() == agentendpointpb.PatchConfig_ALWAYS && !prePatch && r.PostPatchRebootCount == 0 {
reboot = true
clog.Infof(ctx, "PatchConfig RebootConfig set to %s.", agentendpointpb.PatchConfig_ALWAYS)
} else {
Expand All @@ -177,6 +177,11 @@ func (r *patchTask) rebootIfNeeded(ctx context.Context, prePatch bool) error {
}
if reboot {
clog.Infof(ctx, "System indicates a reboot is required.")
totalRebootCount := r.PrePatchRebootCount + r.PostPatchRebootCount
if totalRebootCount >= totalRebootCountLimit {
clog.Infof(ctx, "Detected abnormal number of reboots for a single patch task (%d). Not rebooting to prevent a possible boot loop", totalRebootCount)
return nil
}
} else {
clog.Infof(ctx, "System indicates a reboot is not required.")
}
Expand All @@ -200,7 +205,12 @@ func (r *patchTask) rebootIfNeeded(ctx context.Context, prePatch bool) error {
return nil
}

r.RebootCount++
if prePatch {
r.PrePatchRebootCount++
} else {
r.PostPatchRebootCount++
}

if err := r.saveState(); err != nil {
return fmt.Errorf("error saving state: %v", err)
}
Expand Down
18 changes: 17 additions & 1 deletion agentendpoint/task_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
)

var (
testPatchTaskStateString = "{\"PatchTask\":{\"TaskID\":\"foo\",\"Task\":{\"patchConfig\":{\"apt\":{\"type\":\"DIST\",\"excludes\":[\"foo\",\"bar\"],\"exclusivePackages\":[\"foo\",\"bar\"]},\"windowsUpdate\":{\"classifications\":[\"CRITICAL\",\"SECURITY\"],\"excludes\":[\"foo\",\"bar\"],\"exclusivePatches\":[\"foo\",\"bar\"]}}},\"StartedAt\":\"0001-01-01T00:00:00Z\",\"RebootCount\":0},\"Labels\":{\"foo\":\"bar\"}}"
testPatchTaskStateString = "{\"PatchTask\":{\"TaskID\":\"foo\",\"Task\":{\"patchConfig\":{\"apt\":{\"type\":\"DIST\",\"excludes\":[\"foo\",\"bar\"],\"exclusivePackages\":[\"foo\",\"bar\"]},\"windowsUpdate\":{\"classifications\":[\"CRITICAL\",\"SECURITY\"],\"excludes\":[\"foo\",\"bar\"],\"exclusivePatches\":[\"foo\",\"bar\"]}}},\"StartedAt\":\"0001-01-01T00:00:00Z\",\"PrePatchRebootCount\":2,\"PostPatchRebootCount\":1},\"Labels\":{\"foo\":\"bar\"}}"
testPatchTaskState = &taskState{
Labels: map[string]string{"foo": "bar"},
PatchTask: &patchTask{
Expand All @@ -41,6 +41,8 @@ var (
},
},
},
PrePatchRebootCount: 2,
PostPatchRebootCount: 1,
},
}
)
Expand Down Expand Up @@ -83,6 +85,20 @@ func TestLoadState(t *testing.T) {
false,
testPatchTaskState,
},
{
"IgnoresOldRebootFieldName",
[]byte("{\"PatchTask\":{\"Task\":{},\"RebootCount\":1}}"),
false,
&taskState{
PatchTask: &patchTask{
Task: &applyPatchesTask{
&agentendpointpb.ApplyPatchesTask{},
},
PrePatchRebootCount: 0,
PostPatchRebootCount: 0,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down

0 comments on commit 88b1cc9

Please sign in to comment.