diff --git a/cmd/ssh-portal/main.go b/cmd/ssh-portal/main.go index ea5737f7..121df927 100644 --- a/cmd/ssh-portal/main.go +++ b/cmd/ssh-portal/main.go @@ -1,3 +1,4 @@ +// Package main implements the ssh-portal executable. package main import ( diff --git a/cmd/ssh-portal/serve.go b/cmd/ssh-portal/serve.go index fc716544..4384a377 100644 --- a/cmd/ssh-portal/serve.go +++ b/cmd/ssh-portal/serve.go @@ -16,11 +16,12 @@ import ( // ServeCmd represents the serve command. type ServeCmd struct { - NATSServer string `kong:"required,env='NATS_URL',help='NATS server URL (nats://... or tls://...)'"` - SSHServerPort uint `kong:"default='2222',env='SSH_SERVER_PORT',help='Port the SSH server will listen on for SSH client connections'"` - HostKeyECDSA string `kong:"env='HOST_KEY_ECDSA',help='PEM encoded ECDSA host key'"` - HostKeyED25519 string `kong:"env='HOST_KEY_ED25519',help='PEM encoded Ed25519 host key'"` - HostKeyRSA string `kong:"env='HOST_KEY_RSA',help='PEM encoded RSA host key'"` + NATSServer string `kong:"required,env='NATS_URL',help='NATS server URL (nats://... or tls://...)'"` + SSHServerPort uint `kong:"default='2222',env='SSH_SERVER_PORT',help='Port the SSH server will listen on for SSH client connections'"` + HostKeyECDSA string `kong:"env='HOST_KEY_ECDSA',help='PEM encoded ECDSA host key'"` + HostKeyED25519 string `kong:"env='HOST_KEY_ED25519',help='PEM encoded Ed25519 host key'"` + HostKeyRSA string `kong:"env='HOST_KEY_RSA',help='PEM encoded RSA host key'"` + LogAccessEnabled bool `kong:"env='LOG_ACCESS_ENABLED',help='Allow any user who can SSH into a pod to also access its logs.'"` } // Run the serve command to handle SSH connection requests. @@ -72,5 +73,5 @@ func (cmd *ServeCmd) Run(log *zap.Logger) error { } } // start serving SSH connection requests - return sshserver.Serve(ctx, log, nc, l, c, hostkeys) + return sshserver.Serve(ctx, log, nc, l, c, hostkeys, cmd.LogAccessEnabled) } diff --git a/go.mod b/go.mod index 7a5bd4b6..765f25fe 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,9 @@ require ( go.opentelemetry.io/otel v1.21.0 go.uber.org/zap v1.26.0 golang.org/x/crypto v0.16.0 + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 golang.org/x/oauth2 v0.15.0 + golang.org/x/sync v0.3.0 k8s.io/api v0.28.4 k8s.io/apimachinery v0.28.4 k8s.io/client-go v0.28.4 @@ -65,7 +67,6 @@ require ( go.opentelemetry.io/otel/metric v1.21.0 // indirect go.opentelemetry.io/otel/trace v1.21.0 // indirect go.uber.org/multierr v1.10.0 // indirect - golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 // indirect golang.org/x/net v0.19.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/term v0.15.0 // indirect diff --git a/go.sum b/go.sum index c174bfc2..93d61986 100644 --- a/go.sum +++ b/go.sum @@ -181,6 +181,8 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/k8s/client.go b/internal/k8s/client.go index 769e984b..bc656b3e 100644 --- a/internal/k8s/client.go +++ b/internal/k8s/client.go @@ -3,6 +3,7 @@ package k8s import ( + "sync" "time" "k8s.io/client-go/kubernetes" @@ -14,10 +15,15 @@ const ( timeout = 90 * time.Second ) +// timeoutSeconds defines the common timeout for k8s API operations in the type +// required by metav1.ListOptions. +var timeoutSeconds = int64(timeout / time.Second) + // Client is a k8s client. type Client struct { - config *rest.Config - clientset *kubernetes.Clientset + config *rest.Config + clientset *kubernetes.Clientset + logStreamIDs sync.Map } // NewClient creates a new kubernetes API client. diff --git a/internal/k8s/finddeployment.go b/internal/k8s/finddeployment.go index 7eb3a966..dcdb6ef2 100644 --- a/internal/k8s/finddeployment.go +++ b/internal/k8s/finddeployment.go @@ -13,7 +13,8 @@ func (c *Client) FindDeployment(ctx context.Context, namespace, service string) (string, error) { deployments, err := c.clientset.AppsV1().Deployments(namespace). List(ctx, metav1.ListOptions{ - LabelSelector: fmt.Sprintf("lagoon.sh/service=%s", service), + LabelSelector: fmt.Sprintf("lagoon.sh/service=%s", service), + TimeoutSeconds: &timeoutSeconds, }) if err != nil { return "", fmt.Errorf("couldn't list deployments: %v", err) diff --git a/internal/k8s/logs.go b/internal/k8s/logs.go new file mode 100644 index 00000000..1b20ae23 --- /dev/null +++ b/internal/k8s/logs.go @@ -0,0 +1,294 @@ +package k8s + +import ( + "bufio" + "context" + "fmt" + "io" + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/exp/slices" + "golang.org/x/sync/errgroup" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/client-go/informers" + "k8s.io/client-go/tools/cache" +) + +var ( + // defaultTailLines is the number of log lines to tail by default if no number + // is specified + defaultTailLines int64 = 32 + // maxTailLines is the maximum number of log lines to tail + maxTailLines int64 = 1024 + // limitBytes defines the maximum number of bytes of logs returned from a + // single container + limitBytes int64 = 1 * 1024 * 1024 // 1MiB +) + +// linewiseCopy reads strings separated by \n from logStream, and writes them +// with the given prefix and \n stripped to the logs channel. It returns when +// ctx is cancelled or the logStream closes. +func linewiseCopy(ctx context.Context, prefix string, logs chan<- string, + logStream io.ReadCloser) { + defer logStream.Close() + s := bufio.NewScanner(logStream) + for s.Scan() { + select { + case logs <- fmt.Sprintf("%s %s", prefix, s.Text()): + case <-ctx.Done(): + return + } + } +} + +// readLogs reads logs from the given pod, writing them back to the logs +// channel in a linewise manner. A goroutine is started via egSend to tail logs +// for each container. requestID is used to de-duplicate simultaneous logs +// requests associated with a single call to the higher-level Logs() function. +// +// readLogs returns immediately, and relies on ctx cancellation to ensure the +// goroutines it starts are cleaned up. +func (c *Client) readLogs(ctx context.Context, requestID string, + egSend *errgroup.Group, p *corev1.Pod, containerName string, follow bool, + tailLines int64, logs chan<- string) error { + var cStatuses []corev1.ContainerStatus + // if containerName is not specified, send logs for all containers + if containerName == "" { + cStatuses = p.Status.ContainerStatuses + } else { + for _, cStatus := range p.Status.ContainerStatuses { + if containerName == cStatus.Name { + cStatuses = append(cStatuses, cStatus) + break + } + } + if len(cStatuses) == 0 { + return fmt.Errorf("couldn't find container: %s", containerName) + } + } + for _, cStatus := range cStatuses { + // skip setting up another log stream if container is already being logged + _, exists := c.logStreamIDs.LoadOrStore(requestID+cStatus.ContainerID, true) + if exists { + continue + } + // set up stream for a single container + req := c.clientset.CoreV1().Pods(p.Namespace).GetLogs(p.Name, + &corev1.PodLogOptions{ + Container: cStatus.Name, + Follow: follow, + Timestamps: true, + TailLines: &tailLines, + LimitBytes: &limitBytes, + }) + logStream, err := req.Stream(ctx) + if err != nil { + return fmt.Errorf("couldn't stream logs: %v", err) + } + // copy loop vars so they can be referenced in the closure + cName := cStatus.Name + cID := cStatus.ContainerID + egSend.Go(func() error { + defer c.logStreamIDs.Delete(cID) + linewiseCopy(ctx, fmt.Sprintf("[pod/%s/%s]", p.Name, cName), logs, + logStream) + // When a pod is terminating, the k8s API sometimes sends an event + // showing a healthy pod _after_ an existing logStream for the same pod + // has closed. This happens occasionally on scale-down of a deployment. + // When this occurs there is a race where linewiseCopy() returns, then + // the "healthy" event comes in and linewiseCopy() is called again, only + // to return immediately. This can result in duplicated log lines being + // returned on the logs channel. + // To hack around this behaviour, pause here before exiting. This means + // that the container ID is retained in c.logStreamIDs for a brief period + // after logs stop streaming, which causes "healthy pod" events from the + // k8s API to be ignored for that period and thereby avoiding duplicate + // log lines being returned to the caller. + time.Sleep(time.Second) + return nil + }) + } + return nil +} + +// podEventHandler receives pod objects from the podInformer and, if they are +// in a ready state, starts streaming logs from them. +func (c *Client) podEventHandler(ctx context.Context, + cancel context.CancelFunc, requestID string, egSend *errgroup.Group, + container string, follow bool, tailLines int64, logs chan<- string, obj any) { + // panic if obj is not a pod, since we specifically use a pod informer + pod := obj.(*corev1.Pod) + if !slices.ContainsFunc(pod.Status.Conditions, + func(cond corev1.PodCondition) bool { + return cond.Type == corev1.ContainersReady && + cond.Status == corev1.ConditionTrue + }) { + return // pod not ready + } + egSend.Go(func() error { + readLogsErr := c.readLogs(ctx, requestID, egSend, pod, container, follow, + tailLines, logs) + if readLogsErr != nil { + cancel() + return fmt.Errorf("couldn't read logs on new pod: %v", readLogsErr) + } + return nil + }) +} + +// newPodInformer sets up a k8s informer on pods in the given deployment, and +// returns the informer in an inert state. The informer is configured with +// event handlers to read logs from pods in the deployment, writing log lines +// back to the logs channel. It transparently handles the deployment scaling up +// and down (e.g. pods being added / deleted / restarted). +// +// When the caller calls Run() on the returned informer, it will start watching +// for events and sending to the logs channel. +func (c *Client) newPodInformer(ctx context.Context, + cancel context.CancelFunc, requestID string, egSend *errgroup.Group, + namespace, deployment, container string, follow bool, tailLines int64, + logs chan<- string) (cache.SharedIndexInformer, error) { + // get the deployment + d, err := c.clientset.AppsV1().Deployments(namespace).Get(ctx, deployment, + metav1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("couldn't get deployment: %v", err) + } + // configure the informer factory, filtering on deployment selector labels + factory := informers.NewSharedInformerFactoryWithOptions(c.clientset, + time.Hour, informers.WithNamespace(namespace), + informers.WithTweakListOptions(func(opts *metav1.ListOptions) { + opts.LabelSelector = labels.SelectorFromSet( + d.Spec.Selector.MatchLabels).String() + })) + // construct the informer + podInformer := factory.Core().V1().Pods().Informer() + _, err = podInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + // AddFunc handles events for new and existing pods. Since new pods are not + // in a ready state when initially added, it doesn't start log streaming + // for those. + AddFunc: func(obj any) { + c.podEventHandler(ctx, cancel, requestID, egSend, container, follow, + tailLines, logs, obj) + }, + // UpdateFunc handles events for pod state changes. When new pods are added + // (e.g. deployment is scaled up) it repeatedly receives events until the + // pod is in its final healthy state. For that reason, the + // podEventHandler() inspects the pod state before initiating log + // streaming. + UpdateFunc: func(_, obj any) { + c.podEventHandler(ctx, cancel, requestID, egSend, container, follow, + tailLines, logs, obj) + }, + }) + if err != nil { + return nil, fmt.Errorf("couldn't add event handlers to informer: %v", err) + } + return podInformer, nil +} + +// Logs takes a target namespace, deployment, and stdio stream, and writes the +// log output of the pods of of the deployment to the stdio stream. If +// container is specified, only logs of this container within the deployment +// are returned. +// +// This function exits on one of the following events: +// +// 1. It finishes sending the logs of the pods. This only occurs if +// follow=false. +// 2. ctx is cancelled (signalling that the SSH channel was closed). +// 3. An unrecoverable error occurs. +func (c *Client) Logs(ctx context.Context, + namespace, deployment, container string, follow bool, tailLines int64, + stdio io.ReadWriter) error { + // Wrap the context so we can cancel subroutines of this function on error. + childCtx, cancel := context.WithCancel(ctx) + defer cancel() + // Generate a requestID value to uniquely distinguish between multiple calls + // to this function. This requestID is used in readLogs() to distinguish + // entries in c.logStreamIDs. + requestID := uuid.New().String() + // clamp tailLines + if tailLines < 1 { + tailLines = defaultTailLines + } + if tailLines > maxTailLines { + tailLines = maxTailLines + } + // put sending goroutines in an errgroup.Group to handle errors, and + // receiving goroutines in a waitgroup (since they have no errors) + var egSend errgroup.Group + var wgRecv sync.WaitGroup + // initialise a buffered channel for the worker goroutines to write to, and + // for this function to read log lines from + logs := make(chan string, 4) + // start a goroutine reading from the logs channel and writing back to stdio + wgRecv.Add(1) + go func() { + defer wgRecv.Done() + for { + select { + case msg := <-logs: + // ignore errors writing to stdio. this may happen if the client + // disconnects after reading off the channel but before the log can be + // written. there's nothing we can do in this case and we'll select + // ctx.Done() shortly anyway. + _, _ = fmt.Fprintln(stdio, msg) + case <-childCtx.Done(): + return // context done - client went away or error within Logs() + } + } + }() + if follow { + // If following the logs, start a goroutine which watches for new (and + // existing) pods in the deployment and starts streaming logs from them. + egSend.Go(func() error { + podInformer, err := c.newPodInformer(childCtx, cancel, requestID, + &egSend, namespace, deployment, container, follow, tailLines, logs) + if err != nil { + return fmt.Errorf("couldn't construct new pod informer: %v", err) + } + podInformer.Run(childCtx.Done()) + return nil + }) + } else { + // If not following the logs, avoid constructing an informer. Instead just + // read the logs from all existing pods. + d, err := c.clientset.AppsV1().Deployments(namespace).Get(childCtx, + deployment, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("couldn't get deployment: %v", err) + } + pods, err := c.clientset.CoreV1().Pods(namespace).List(childCtx, + metav1.ListOptions{ + LabelSelector: labels.FormatLabels(d.Spec.Selector.MatchLabels), + }) + if err != nil { + return fmt.Errorf("couldn't get pods: %v", err) + } + if len(pods.Items) == 0 { + return fmt.Errorf("no pods for deployment %s", deployment) + } + for i := range pods.Items { + pod := pods.Items[i] // copy loop var so it can be referenced in the closure + egSend.Go(func() error { + readLogsErr := c.readLogs(childCtx, requestID, &egSend, &pod, + container, follow, tailLines, logs) + if readLogsErr != nil { + return fmt.Errorf("couldn't read logs on existing pods: %v", readLogsErr) + } + return nil + }) + } + } + // Wait for the writes to finish, then close the logs channel, wait for the + // read goroutine to exit, and return any sendErr. + sendErr := egSend.Wait() + cancel() + wgRecv.Wait() + return sendErr +} diff --git a/internal/sshserver/connectionparams.go b/internal/sshserver/connectionparams.go index 31ab47ae..2fb42e18 100644 --- a/internal/sshserver/connectionparams.go +++ b/internal/sshserver/connectionparams.go @@ -1,51 +1,132 @@ package sshserver -import "regexp" +import ( + "errors" + "regexp" + "strconv" + "strings" +) + +var ( + serviceRegex = regexp.MustCompile(`^service=(.+)`) + containerRegex = regexp.MustCompile(`^container=(.+)`) + logsRegex = regexp.MustCompile(`^logs=(.+)`) + tailLinesRegex = regexp.MustCompile(`^tailLines=(\d+)$`) +) var ( - serviceRegex = regexp.MustCompile(`service=(.+)`) - containerRegex = regexp.MustCompile(`container=(.+)`) + // ErrCmdArgsAfterLogs is returned when command arguments are found after + // the logs=... argument. + ErrCmdArgsAfterLogs = errors.New("command arguments after logs argument") + // ErrInvalidLogsValue is returned when the value of the logs=... + // argument is an invalid value. + ErrInvalidLogsValue = errors.New("invalid logs argument value") + // ErrNoServiceForLogs is returned when logs=... is specified, but + // service=... is not. + ErrNoServiceForLogs = errors.New("missing service argument for logs argument") ) // parseConnectionParams takes the raw SSH command, and parses out any -// leading service=... and container=... arguments. It returns: -// * If a service=... argument is given, the value of that argument. If no such -// argument is given, it falls back to a default of "cli". -// * If a container=... argument is given, the value of that argument. If no -// such argument is given, it returns an empty string. -// * The remaining arguments with any leading service= or container= arguments -// removed. +// leading service=..., container=..., and logs=... arguments. It returns: +// - If a service=... argument is given, the value of that argument. +// If no such argument is given, it falls back to a default of "cli". +// - If a container=... argument is given, the value of that argument. +// If no such argument is given, it returns an empty string. +// - If a logs=... argument is given, the value of that argument. +// If no such argument is given, it returns an empty string. +// - The remaining arguments as a slice of strings, with any leading +// service=, container=, or logs= arguments removed. // // Notes about the logic implemented here: -// * container=... may not be specified without service=... -// * service=... must be given as the first argument to be recognised. -// * If not given in the expected order or with empty values, these arguments -// will be interpreted as regular command-line arguments. +// - service=... must be given as the first argument to be recognised. +// - It is an error to specify container=... without service=... +// - If logs=... is given, it must be the final argument. +// - If not given in the expected order or with empty values, these +// parameters may be interpreted as regular command-line arguments. // // In manpage syntax: // -// [service=... [container=...]] CMD... -// -func parseConnectionParams(args []string) (string, string, []string) { +// [service=... [container=...]] CMD... +// service=... [container=...] logs=... +func parseConnectionParams(args []string) (string, string, string, []string) { // exit early if we have no args if len(args) == 0 { - return "cli", "", args + return "cli", "", "", nil } // check for service argument serviceMatches := serviceRegex.FindStringSubmatch(args[0]) if len(serviceMatches) == 0 { - return "cli", "", args + // no service= match, so assume cli and return all args + return "cli", "", "", args } service := serviceMatches[1] // exit early if we are out of arguments - if len(args) < 2 { - return service, "", args[1:] + if len(args) == 1 { + return service, "", "", nil } - // check for container argument + // check for container and/or logs argument containerMatches := containerRegex.FindStringSubmatch(args[1]) if len(containerMatches) == 0 { - return service, "", args[1:] + // no container= match, so check for logs= + logsMatches := logsRegex.FindStringSubmatch(args[1]) + if len(logsMatches) == 0 { + // no container= or logs= match, so just return the args + return service, "", "", args[1:] + } + // found logs=, so return it along with the remaining args + // (which should be empty) + return service, "", logsMatches[1], args[2:] } container := containerMatches[1] - return service, container, args[2:] + // exit early if we are out of arguments + if len(args) == 2 { + return service, container, "", nil + } + // container= matched, so check for logs= + logsMatches := logsRegex.FindStringSubmatch(args[2]) + if len(logsMatches) == 0 { + // no logs= match, so just return the remaining args + return service, container, "", args[2:] + } + // container= and logs= matched, so return both + return service, container, logsMatches[1], args[3:] +} + +// parseLogsArg checks that: +// - logs value is one or both of "follow" and "tailLines=n" arguments, comma +// separated. +// - n is a positive integer. +// - if logs is valid, service is not empty. +// - if logs is valid, cmd is empty. +// +// It returns the follow and tailLines values, and an error if one occurs (or +// nil otherwise). +// +// Note that if multiple tailLines= values are specified, the last one will be +// the value used. +func parseLogsArg(service, logs string, cmd []string) (bool, int64, error) { + if len(cmd) != 0 { + return false, 0, ErrCmdArgsAfterLogs + } + if service == "" { + return false, 0, ErrNoServiceForLogs + } + var follow bool + var tailLines int64 + var err error + for _, arg := range strings.Split(logs, ",") { + matches := tailLinesRegex.FindStringSubmatch(arg) + switch { + case arg == "follow": + follow = true + case len(matches) == 2: + tailLines, err = strconv.ParseInt(matches[1], 10, 64) + if err != nil { + return false, 0, ErrInvalidLogsValue + } + default: + return false, 0, ErrInvalidLogsValue + } + } + return follow, tailLines, nil } diff --git a/internal/sshserver/connectionparams_test.go b/internal/sshserver/connectionparams_test.go index a0832d9c..81931a9c 100644 --- a/internal/sshserver/connectionparams_test.go +++ b/internal/sshserver/connectionparams_test.go @@ -1,6 +1,7 @@ package sshserver_test import ( + "errors" "reflect" "testing" @@ -10,6 +11,7 @@ import ( type parsedParams struct { service string container string + logs string args []string } @@ -23,22 +25,25 @@ func TestParseConnectionParams(t *testing.T) { expect: parsedParams{ service: "cli", container: "", + logs: "", args: []string{"drush", "do", "something"}, }, }, - "service arg": { + "service params": { input: []string{"service=mongo", "drush", "do", "something"}, expect: parsedParams{ service: "mongo", container: "", + logs: "", args: []string{"drush", "do", "something"}, }, }, - "service and container args": { + "service and container params": { input: []string{"service=nginx", "container=php", "drush", "do", "something"}, expect: parsedParams{ service: "nginx", container: "php", + logs: "", args: []string{"drush", "do", "something"}, }, }, @@ -47,22 +52,193 @@ func TestParseConnectionParams(t *testing.T) { expect: parsedParams{ service: "cli", container: "", + logs: "", args: []string{"container=php", "service=nginx", "drush", "do", "something"}, }, }, + "service and logs params": { + input: []string{"service=nginx", "logs=follow", "drush do something"}, + expect: parsedParams{ + service: "nginx", + container: "", + logs: "follow", + args: []string{"drush do something"}, + }, + }, + "service, container and logs params": { + input: []string{"service=nginx", "container=php", "logs=follow", "drush do something"}, + expect: parsedParams{ + service: "nginx", + container: "php", + logs: "follow", + args: []string{"drush do something"}, + }, + }, + "service, container and logs params (wrong order)": { + input: []string{"service=nginx", "logs=follow", "container=php", "drush do something"}, + expect: parsedParams{ + service: "nginx", + container: "", + logs: "follow", + args: []string{"container=php", "drush do something"}, + }, + }, + "service and logs params (invalid logs value)": { + input: []string{"service=nginx", "logs=php", "drush", "do", "something"}, + expect: parsedParams{ + service: "nginx", + container: "", + logs: "php", + args: []string{"drush", "do", "something"}, + }, + }, } for name, tc := range testCases { t.Run(name, func(tt *testing.T) { - service, container, args := sshserver.ParseConnectionParams(tc.input) + service, container, logs, args := sshserver.ParseConnectionParams(tc.input) if tc.expect.service != service { tt.Fatalf("service: expected %v, got %v", tc.expect.service, service) } if tc.expect.container != container { tt.Fatalf("container: expected %v, got %v", tc.expect.container, container) } + if tc.expect.logs != logs { + tt.Fatalf("logs: expected %v, got %v", tc.expect.logs, logs) + } if !reflect.DeepEqual(tc.expect.args, args) { tt.Fatalf("args: expected %v, got %v", tc.expect.args, args) } }) } } + +func TestValidateConnectionParams(t *testing.T) { + type result struct { + follow bool + tailLines int64 + err error + } + var testCases = map[string]struct { + input parsedParams + expect result + }{ + "follow": { + input: parsedParams{ + service: "nginx-php", + logs: "follow", + }, + expect: result{ + follow: true, + }, + }, + "tail": { + input: parsedParams{ + service: "nginx-php", + logs: "tailLines=201", + }, + expect: result{ + tailLines: 201, + }, + }, + "follow and tail": { + input: parsedParams{ + service: "nginx-php", + logs: "follow,tailLines=10", + }, + expect: result{ + follow: true, + tailLines: 10, + }, + }, + "tail and follow": { + input: parsedParams{ + service: "nginx-php", + logs: "tailLines=100,follow", + }, + expect: result{ + follow: true, + tailLines: 100, + }, + }, + "multiple tail and follow": { + input: parsedParams{ + service: "nginx-php", + logs: "tailLines=100,follow,tailLines=11", + }, + expect: result{ + follow: true, + tailLines: 11, + }, + }, + "invalid tail value": { + input: parsedParams{ + service: "nginx-php", + logs: "tailLines=10f", + }, + expect: result{ + err: sshserver.ErrInvalidLogsValue, + }, + }, + "garbage prefix in logs arg": { + input: parsedParams{ + service: "nginx-php", + logs: "fallow,tailLines=10", + }, + expect: result{ + err: sshserver.ErrInvalidLogsValue, + }, + }, + "garbage infix in logs arg": { + input: parsedParams{ + service: "nginx-php", + logs: "follow,nofollow,tailLines=10f", + }, + expect: result{ + err: sshserver.ErrInvalidLogsValue, + }, + }, + "garbage suffix in logs arg": { + input: parsedParams{ + service: "nginx-php", + logs: "follow,tailLines=10,nofollow", + }, + expect: result{ + err: sshserver.ErrInvalidLogsValue, + }, + }, + "arguments after logs and invalid logs value": { + input: parsedParams{ + service: "cli", + logs: "php", + args: []string{"drush", "do", "something"}, + }, + expect: result{ + err: sshserver.ErrCmdArgsAfterLogs, + }, + }, + "invalid logs value": { + input: parsedParams{ + service: "cli", + logs: "php", + }, + expect: result{ + err: sshserver.ErrInvalidLogsValue, + }, + }, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + follow, tailLines, err := sshserver.ParseLogsArg( + tc.input.service, tc.input.logs, tc.input.args) + if !errors.Is(err, tc.expect.err) { + tt.Fatalf("expected %v, got %v", tc.expect.err, err) + } + if follow != tc.expect.follow { + tt.Fatalf("expected %v, got %v", tc.expect.follow, follow) + } + if tailLines != tc.expect.tailLines { + tt.Fatalf("expected %v, got %v", tc.expect.tailLines, tailLines) + } + }) + } +} diff --git a/internal/sshserver/helper_test.go b/internal/sshserver/helper_test.go index 471b2a4f..598faa70 100644 --- a/internal/sshserver/helper_test.go +++ b/internal/sshserver/helper_test.go @@ -2,6 +2,11 @@ package sshserver // ParseConnectionParams exposes the private parseConnectionParams for testing // only. -func ParseConnectionParams(args []string) (string, string, []string) { +func ParseConnectionParams(args []string) (string, string, string, []string) { return parseConnectionParams(args) } + +// ParseLogsArg exposes the private parseLogsArg for testing only. +func ParseLogsArg(service, logs string, args []string) (bool, int64, error) { + return parseLogsArg(service, logs, args) +} diff --git a/internal/sshserver/serve.go b/internal/sshserver/serve.go index d182d823..52422da9 100644 --- a/internal/sshserver/serve.go +++ b/internal/sshserver/serve.go @@ -32,11 +32,11 @@ func disableSHA1Kex(ctx ssh.Context) *gossh.ServerConfig { // Serve contains the main ssh session logic func Serve(ctx context.Context, log *zap.Logger, nc *nats.EncodedConn, - l net.Listener, c *k8s.Client, hostKeys [][]byte) error { + l net.Listener, c *k8s.Client, hostKeys [][]byte, logAccessEnabled bool) error { srv := ssh.Server{ - Handler: sessionHandler(log, c, false), + Handler: sessionHandler(log, c, false, logAccessEnabled), SubsystemHandlers: map[string]ssh.SubsystemHandler{ - "sftp": ssh.SubsystemHandler(sessionHandler(log, c, true)), + "sftp": ssh.SubsystemHandler(sessionHandler(log, c, true, logAccessEnabled)), }, PublicKeyHandler: pubKeyAuth(log, nc, c), ServerConfigCallback: disableSHA1Kex, diff --git a/internal/sshserver/sessionhandler.go b/internal/sshserver/sessionhandler.go index 33741bb6..7c16857c 100644 --- a/internal/sshserver/sessionhandler.go +++ b/internal/sshserver/sessionhandler.go @@ -1,8 +1,10 @@ package sshserver import ( + "context" "fmt" "strings" + "time" "github.com/gliderlabs/ssh" "github.com/prometheus/client_golang/prometheus" @@ -44,20 +46,19 @@ func getSSHIntent(sftp bool, cmd []string) []string { // handler is that the command is set to sftp-server. This implies that the // target container must have a sftp-server binary installed for sftp to work. // There is no support for a built-in sftp server. -func sessionHandler(log *zap.Logger, c *k8s.Client, sftp bool) ssh.Handler { +func sessionHandler(log *zap.Logger, c *k8s.Client, + sftp, logAccessEnabled bool) ssh.Handler { return func(s ssh.Session) { sessionTotal.Inc() ctx := s.Context() sid := ctx.SessionID() - // start the command - log.Debug("starting command exec", + log.Debug("starting session", zap.String("sessionID", sid), zap.Strings("rawCommand", s.Command()), zap.String("subsystem", s.Subsystem()), ) // parse the command line arguments to extract any service or container args - service, container, rawCmd := parseConnectionParams(s.Command()) - cmd := getSSHIntent(sftp, rawCmd) + service, container, logs, rawCmd := parseConnectionParams(s.Command()) // validate the service and container if err := k8s.ValidateLabelValue(service); err != nil { log.Debug("invalid service name", @@ -103,8 +104,6 @@ func sessionHandler(log *zap.Logger, c *k8s.Client, sftp bool) ssh.Handler { } return } - // check if a pty was requested, and get the window size channel - _, winch, pty := s.Pty() // extract info passed through the context by the authhandler eid, ok := ctx.Value(environmentIDKey).(int) if !ok { @@ -126,6 +125,71 @@ func sessionHandler(log *zap.Logger, c *k8s.Client, sftp bool) ssh.Handler { if !ok { log.Warn("couldn't extract SSH key fingerprint from session context") } + if len(logs) != 0 { + if !logAccessEnabled { + log.Debug("logs access is not enabled", + zap.String("logsArgument", logs), + zap.String("sessionID", sid)) + _, err = fmt.Fprintf(s.Stderr(), "error executing command. SID: %s\r\n", + sid) + if err != nil { + log.Warn("couldn't send error to client", + zap.String("sessionID", sid), + zap.Error(err)) + } + // Send a non-zero exit code to the client on internal logs error. + // OpenSSH uses 255 for this, 254 is an exec failure, so use 253 to + // differentiate this error. + if err = s.Exit(253); err != nil { + log.Warn("couldn't send exit code to client", + zap.String("sessionID", sid), + zap.Error(err)) + } + return + } + follow, tailLines, err := parseLogsArg(service, logs, rawCmd) + if err != nil { + log.Debug("couldn't parse logs argument", + zap.String("logsArgument", logs), + zap.String("sessionID", sid), + zap.Error(err)) + _, err = fmt.Fprintf(s.Stderr(), "error executing command. SID: %s\r\n", + sid) + if err != nil { + log.Warn("couldn't send error to client", + zap.String("sessionID", sid), + zap.Error(err)) + } + // Send a non-zero exit code to the client on internal logs error. + // OpenSSH uses 255 for this, 254 is an exec failure, so use 253 to + // differentiate this error. + if err = s.Exit(253); err != nil { + log.Warn("couldn't send exit code to client", + zap.String("sessionID", sid), + zap.Error(err)) + } + return + } + log.Info("sending logs to SSH client", + zap.Int("environmentID", eid), + zap.Int("projectID", pid), + zap.String("SSHFingerprint", fingerprint), + zap.String("container", container), + zap.String("deployment", deployment), + zap.String("environmentName", ename), + zap.String("namespace", s.User()), + zap.String("projectName", pname), + zap.String("sessionID", sid), + zap.Bool("follow", follow), + zap.Int64("tailLines", tailLines), + ) + doLogs(ctx, log, s, deployment, container, follow, tailLines, c, sid) + return + } + // handle sftp and sh fallback + cmd := getSSHIntent(sftp, rawCmd) + // check if a pty was requested, and get the window size channel + _, winch, pty := s.Pty() log.Info("executing SSH command", zap.Bool("pty", pty), zap.Int("environmentID", eid), @@ -139,39 +203,108 @@ func sessionHandler(log *zap.Logger, c *k8s.Client, sftp bool) ssh.Handler { zap.String("sessionID", sid), zap.Strings("command", cmd), ) - err = c.Exec(ctx, s.User(), deployment, container, cmd, s, - s.Stderr(), pty, winch) + doExec(ctx, log, s, deployment, container, cmd, c, pty, winch, sid) + } +} + +// startClientKeepalive sends a keepalive request to the client via the channel +// embedded in ssh.Session at a regular interval. If the client fails to +// respond, the channel is closed, and cancel is called. +func startClientKeepalive(ctx context.Context, cancel context.CancelFunc, + log *zap.Logger, s ssh.Session) { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + // https://github.com/openssh/openssh-portable/blob/ + // edc2ef4e418e514c99701451fae4428ec04ce538/serverloop.c#L127-L158 + _, err := s.SendRequest("keepalive@openssh.com", true, nil) + if err != nil { + log.Debug("client closed connection", zap.Error(err)) + _ = s.Close() + cancel() + return + } + case <-ctx.Done(): + return + } + } +} + +func doLogs(ctx ssh.Context, log *zap.Logger, s ssh.Session, deployment, + container string, follow bool, tailLines int64, c *k8s.Client, sid string) { + // Wrap the ssh.Context so we can cancel goroutines started from this + // function without affecting the SSH session. + childCtx, cancel := context.WithCancel(ctx) + defer cancel() + // In a multiplexed connection (multiple SSH channels to the single TCP + // connection), if the client disconnects from the channel the session + // context will not be cancelled (because the TCP connection is still up), + // and k8s.Logs() will hang. + // + // To work around this problem, start a goroutine to send a regular keepalive + // ping to the client. If the keepalive fails, close the channel and cancel + // the childCtx. + go startClientKeepalive(childCtx, cancel, log, s) + err := c.Logs(childCtx, s.User(), deployment, container, follow, tailLines, s) + if err != nil { + log.Warn("couldn't send logs", + zap.String("sessionID", sid), + zap.Error(err)) + _, err = fmt.Fprintf(s.Stderr(), "error executing command. SID: %s\r\n", + sid) if err != nil { - if exitErr, ok := err.(exec.ExitError); ok { - log.Debug("couldn't execute command", + log.Warn("couldn't send error to client", + zap.String("sessionID", sid), + zap.Error(err)) + } + // Send a non-zero exit code to the client on internal logs error. + // OpenSSH uses 255 for this, 254 is an exec failure, so use 253 to + // differentiate this error. + if err = s.Exit(253); err != nil { + log.Warn("couldn't send exit code to client", + zap.String("sessionID", sid), + zap.Error(err)) + } + } + log.Debug("finished command logs", zap.String("sessionID", sid)) +} + +func doExec(ctx ssh.Context, log *zap.Logger, s ssh.Session, deployment, + container string, cmd []string, c *k8s.Client, pty bool, + winch <-chan ssh.Window, sid string) { + err := c.Exec(ctx, s.User(), deployment, container, cmd, s, + s.Stderr(), pty, winch) + if err != nil { + if exitErr, ok := err.(exec.ExitError); ok { + log.Debug("couldn't execute command", + zap.String("sessionID", sid), + zap.Error(err)) + if err = s.Exit(exitErr.ExitStatus()); err != nil { + log.Warn("couldn't send exit code to client", zap.String("sessionID", sid), zap.Error(err)) - if err = s.Exit(exitErr.ExitStatus()); err != nil { - log.Warn("couldn't send exit code to client", - zap.String("sessionID", sid), - zap.Error(err)) - } - } else { - log.Warn("couldn't execute command", + } + } else { + log.Warn("couldn't execute command", + zap.String("sessionID", sid), + zap.Error(err)) + _, err = fmt.Fprintf(s.Stderr(), "error executing command. SID: %s\r\n", + sid) + if err != nil { + log.Warn("couldn't send error to client", + zap.String("sessionID", sid), + zap.Error(err)) + } + // Send a non-zero exit code to the client on internal exec error. + // OpenSSH uses 255 for this, so use 254 to differentiate the error. + if err = s.Exit(254); err != nil { + log.Warn("couldn't send exit code to client", zap.String("sessionID", sid), zap.Error(err)) - _, err = fmt.Fprintf(s.Stderr(), "error executing command. SID: %s\r\n", - sid) - if err != nil { - log.Warn("couldn't send error to client", - zap.String("sessionID", sid), - zap.Error(err)) - } - // Send a non-zero exit code to the client on internal exec error. - // OpenSSH uses 255 for this, so use 254 to differentiate the error. - if err = s.Exit(254); err != nil { - log.Warn("couldn't send exit code to client", - zap.String("sessionID", sid), - zap.Error(err)) - } } } - log.Debug("finished command exec", - zap.String("sessionID", sid)) } + log.Debug("finished command exec", zap.String("sessionID", sid)) }