Skip to content

Commit

Permalink
fix: handle goroutine errors via errgroup
Browse files Browse the repository at this point in the history
  • Loading branch information
smlx committed Jul 6, 2023
1 parent aa0a7a9 commit 23304a1
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 34 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ require (
go.uber.org/zap v1.24.0
golang.org/x/crypto v0.11.0
golang.org/x/oauth2 v0.10.0
golang.org/x/sync v0.3.0
k8s.io/api v0.27.3
k8s.io/apimachinery v0.27.3
k8s.io/client-go v0.27.3
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
75 changes: 41 additions & 34 deletions internal/k8s/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"context"
"fmt"
"io"
"sync"

"golang.org/x/sync/errgroup"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
Expand All @@ -19,16 +19,16 @@ var (
// defaultTailLines is the number of log lines to tail by default if no number
// is specified
defaultTailLines int64 = 32
// maxTailLines is the maximum number of log lines to tail.
// maxTailLines is the maximum number of log lines to tail
maxTailLines int64 = 1024
// limitBytes defines the maximum number of bytes of logs returned from a single container
// limitBytes defines the maximum number of bytes of logs returned from a
// single container
limitBytes int64 = 1 * 1024 * 1024 // 1MiB
)

// linewiseCopy reads strings separated by \n from logStream, and writes them
// with the \n stripped to the logs channel.
func linewiseCopy(wg *sync.WaitGroup, logs chan<- string, logStream io.ReadCloser) {
defer wg.Done()
func linewiseCopy(logs chan<- string, logStream io.ReadCloser) {
defer logStream.Close()
s := bufio.NewScanner(logStream)
for s.Scan() {
Expand All @@ -38,10 +38,9 @@ func linewiseCopy(wg *sync.WaitGroup, logs chan<- string, logStream io.ReadClose

// readLogs reads logs from the given pod, writing them back to the given logs
// channel in a linewise manner.
func (c *Client) readLogs(ctx context.Context, wg *sync.WaitGroup,
func (c *Client) readLogs(ctx context.Context, eg *errgroup.Group,
p *corev1.Pod, container string, follow bool, tailLines int64,
logs chan<- string) {
defer wg.Done()
logs chan<- string) error {
var containers []string
if container != "" {
containers = append(containers, container)
Expand All @@ -52,28 +51,31 @@ func (c *Client) readLogs(ctx context.Context, wg *sync.WaitGroup,
}
for _, name := range containers {
// read logs for a single container
req := c.clientset.CoreV1().Pods(p.Namespace).GetLogs(p.Name, &corev1.PodLogOptions{
Container: name,
Follow: follow,
Timestamps: true,
TailLines: &tailLines,
LimitBytes: &limitBytes,
})
req := c.clientset.CoreV1().Pods(p.Namespace).GetLogs(p.Name,
&corev1.PodLogOptions{
Container: name,
Follow: follow,
Timestamps: true,
TailLines: &tailLines,
LimitBytes: &limitBytes,
})
logStream, err := req.Stream(ctx)
if err != nil {
return // TODO: handle the err somehow?
return fmt.Errorf("couldn't stream logs: %v", err)
}
wg.Add(1)
go linewiseCopy(wg, logs, logStream)
eg.Go(func() error {
linewiseCopy(logs, logStream)
return nil
})
}
return nil
}

// watchForNewPods sets up a k8s watch on pods in the given deployment and
// starts reading logs from any new pods that are created.
func (c *Client) watchForNewPods(ctx context.Context, wg *sync.WaitGroup,
func (c *Client) watchForNewPods(ctx context.Context, eg *errgroup.Group,
namespace, deployment, container string, follow bool, tailLines int64,
logs chan<- string) error {
defer wg.Done()
// set up the watch for new pods
watchFunc := func(options metav1.ListOptions) (watch.Interface, error) {
d, err := c.clientset.AppsV1().Deployments(namespace).Get(ctx, deployment,
Expand All @@ -100,8 +102,9 @@ func (c *Client) watchForNewPods(ctx context.Context, wg *sync.WaitGroup,
}
switch event.Type {
case watch.Added:
wg.Add(1)
go c.readLogs(ctx, wg, item, container, follow, tailLines, logs)
eg.Go(func() error {
return c.readLogs(ctx, eg, item, container, follow, tailLines, logs)
})
default:
// no-op
}
Expand All @@ -116,9 +119,6 @@ func (c *Client) Logs(ctx context.Context, namespace, deployment,
// initialise a buffered channel for the worker goroutines to write to, and
// for this function to read log lines from
logs := make(chan string, 4)
// pass each goroutine a waitgroup to signal when the goroutine completes
var wg sync.WaitGroup
defer wg.Wait()
// wrap the parent context so we can cancel from this function when returning
// an error
childCtx, cancel := context.WithCancel(ctx)
Expand All @@ -129,9 +129,10 @@ func (c *Client) Logs(ctx context.Context, namespace, deployment,
if err != nil {
return fmt.Errorf("couldn't get deployment: %v", err)
}
pods, err := c.clientset.CoreV1().Pods(namespace).List(childCtx, metav1.ListOptions{
LabelSelector: labels.FormatLabels(d.Spec.Selector.MatchLabels),
})
pods, err := c.clientset.CoreV1().Pods(namespace).List(childCtx,
metav1.ListOptions{
LabelSelector: labels.FormatLabels(d.Spec.Selector.MatchLabels),
})
if err != nil {
return fmt.Errorf("couldn't get pods: %v", err)
}
Expand All @@ -144,15 +145,21 @@ func (c *Client) Logs(ctx context.Context, namespace, deployment,
if tailLines > maxTailLines {
tailLines = maxTailLines
}
// put goroutines in an errgroup.Group to handle errors
var eg errgroup.Group
// start a goroutine which watches for new pods in the deployment and adds
// them to the list of streams to read.
wg.Add(1)
go c.watchForNewPods(childCtx, &wg, namespace, deployment, container, follow,
tailLines, logs)
eg.Go(func() error {
return c.watchForNewPods(childCtx, &eg, namespace, deployment, container,
follow, tailLines, logs)
})
// start a goroutine for each pod which reads logs into the channel
for i := range pods.Items {
wg.Add(1)
go c.readLogs(childCtx, &wg, &pods.Items[i], container, follow, tailLines, logs)
pod := pods.Items[i] // avoid copying loop var
eg.Go(func() error {
return c.readLogs(childCtx, &eg, &pod, container, follow, tailLines,
logs)
})
}
// start reading off the channel and writing lines back to stdio
for {
Expand All @@ -164,7 +171,7 @@ func (c *Client) Logs(ctx context.Context, namespace, deployment,
// ctx.Done() shortly anyway.
_, _ = fmt.Fprintln(stdio, msg)
case <-childCtx.Done():
return nil
return eg.Wait()
}
}
}

0 comments on commit 23304a1

Please sign in to comment.