From 4834438c80d17a8b9ad769e477fa0f902b2e99b2 Mon Sep 17 00:00:00 2001 From: Nick Silverman Date: Tue, 10 Sep 2024 15:15:59 -0400 Subject: [PATCH] clean exiting by establishing return pattern instead of exit pattern and move human input to goroutine to unblock main routine --- src/datachannel/streaming.go | 26 ++++++-- src/message/messageparser.go | 14 ++-- .../portsession/basicportforwarding.go | 66 ++++++++----------- .../session/portsession/muxportforwarding.go | 34 ++++++---- .../portsession/standardstreamforwarding.go | 19 +++++- src/sessionmanagerplugin/session/session.go | 22 +++---- .../session/sessionhandler.go | 6 +- .../sessionutil/sessionutil_windows.go | 2 +- .../session/shellsession/shellsession_unix.go | 33 +++++----- .../shellsession/shellsession_windows.go | 50 +++++++++----- 10 files changed, 160 insertions(+), 112 deletions(-) diff --git a/src/datachannel/streaming.go b/src/datachannel/streaming.go index 3ff4c59f..3b2221b1 100644 --- a/src/datachannel/streaming.go +++ b/src/datachannel/streaming.go @@ -62,6 +62,8 @@ type IDataChannel interface { RegisterOutputStreamHandler(handler OutputStreamDataMessageHandler, isSessionSpecificHandler bool) DeregisterOutputStreamHandler(handler OutputStreamDataMessageHandler) IsSessionTypeSet() chan bool + EndSession() error + IsSessionEnded() bool IsStreamMessageResendTimeout() chan bool GetSessionType() string SetSessionType(sessionType string) @@ -106,6 +108,8 @@ type DataChannel struct { isSessionTypeSet chan bool sessionProperties interface{} + isSessionEnded bool + // Used to detect if resending a streaming message reaches timeout isStreamMessageResendTimeout chan bool @@ -187,6 +191,7 @@ func (dataChannel *DataChannel) Initialize(log log.T, clientId string, sessionId dataChannel.wsChannel = &communicator.WebSocketChannel{} dataChannel.encryptionEnabled = false dataChannel.isSessionTypeSet = make(chan bool, 1) + dataChannel.isSessionEnded = false dataChannel.isStreamMessageResendTimeout = make(chan bool, 1) dataChannel.sessionType = "" dataChannel.IsAwsCliUpgradeNeeded = isAwsCliUpgradeNeeded @@ -199,7 +204,7 @@ func (dataChannel *DataChannel) SetWebsocket(log log.T, channelUrl string, chann // FinalizeHandshake sends the token for service to acknowledge the connection. func (dataChannel *DataChannel) FinalizeDataChannelHandshake(log log.T, tokenValue string) (err error) { - uuid.SwitchFormat(uuid.CleanHyphen) + uuid.SwitchFormat(uuid.FormatCanonical) uid := uuid.NewV4().String() log.Infof("Sending token through data channel %s to acknowledge connection", dataChannel.wsChannel.GetStreamUrl()) @@ -772,7 +777,7 @@ func (dataChannel *DataChannel) HandleAcknowledgeMessage( } // handleChannelClosedMessage exits the shell -func (dataChannel DataChannel) HandleChannelClosedMessage(log log.T, stopHandler Stop, sessionId string, outputMessage message.ClientMessage) { +func (dataChannel *DataChannel) HandleChannelClosedMessage(log log.T, stopHandler Stop, sessionId string, outputMessage message.ClientMessage) { var ( channelClosedMessage message.ChannelClosed err error @@ -787,6 +792,8 @@ func (dataChannel DataChannel) HandleChannelClosedMessage(log log.T, stopHandler } else { fmt.Fprintf(os.Stdout, "\n\nSessionId: %s : %s\n\n", sessionId, channelClosedMessage.Output) } + dataChannel.EndSession() + dataChannel.Close(log) stopHandler() } @@ -849,7 +856,7 @@ func (dataChannel *DataChannel) CalculateRetransmissionTimeout(log log.T, stream func (dataChannel *DataChannel) ProcessKMSEncryptionHandshakeAction(log log.T, actionParams json.RawMessage) (err error) { if dataChannel.IsAwsCliUpgradeNeeded { - return errors.New("Installed version of CLI does not support Session Manager encryption feature. Please upgrade to the latest version of your CLI (e.g., AWS CLI).") + return errors.New("installed version of CLI does not support Session Manager encryption feature. Please upgrade to the latest version of your CLI (e.g., AWS CLI)") } kmsEncRequest := message.KMSEncryptionRequest{} json.Unmarshal(actionParams, &kmsEncRequest) @@ -881,7 +888,7 @@ func (dataChannel *DataChannel) ProcessSessionTypeHandshakeAction(actionParams j dataChannel.sessionProperties = sessTypeReq.Properties return nil default: - return errors.New(fmt.Sprintf("Unknown session type %s", sessTypeReq.SessionType)) + return fmt.Errorf("Unknown session type %s", sessTypeReq.SessionType) } } @@ -890,6 +897,17 @@ func (dataChannel *DataChannel) IsSessionTypeSet() chan bool { return dataChannel.isSessionTypeSet } +// IsSessionEnded check if session has ended +func (dataChannel *DataChannel) IsSessionEnded() bool { + return dataChannel.isSessionEnded +} + +// IsSessionEnded check if session has ended +func (dataChannel *DataChannel) EndSession() error { + dataChannel.isSessionEnded = true + return nil +} + // IsStreamMessageResendTimeout checks if resending a streaming message reaches timeout func (dataChannel *DataChannel) IsStreamMessageResendTimeout() chan bool { return dataChannel.isStreamMessageResendTimeout diff --git a/src/message/messageparser.go b/src/message/messageparser.go index bcab244b..64f871c1 100644 --- a/src/message/messageparser.go +++ b/src/message/messageparser.go @@ -167,31 +167,31 @@ func getUuid(log log.T, byteArray []byte, offset int) (result uuid.UUID, err err byteArrayLength := len(byteArray) if offset > byteArrayLength-1 || offset+16-1 > byteArrayLength-1 || offset < 0 { log.Error("getUuid failed: Offset is invalid.") - return nil, errors.New("Offset is outside the byte array.") + return uuid.Nil.UUID(), errors.New("Offset is outside the byte array.") } leastSignificantLong, err := getLong(log, byteArray, offset) if err != nil { log.Error("getUuid failed: failed to get uuid LSBs Long value.") - return nil, errors.New("Failed to get uuid LSBs long value.") + return uuid.Nil.UUID(), errors.New("Failed to get uuid LSBs long value.") } leastSignificantBytes, err := longToBytes(log, leastSignificantLong) if err != nil { log.Error("getUuid failed: failed to get uuid LSBs bytes value.") - return nil, errors.New("Failed to get uuid LSBs bytes value.") + return uuid.Nil.UUID(), errors.New("Failed to get uuid LSBs bytes value.") } mostSignificantLong, err := getLong(log, byteArray, offset+8) if err != nil { log.Error("getUuid failed: failed to get uuid MSBs Long value.") - return nil, errors.New("Failed to get uuid MSBs long value.") + return uuid.Nil.UUID(), errors.New("Failed to get uuid MSBs long value.") } mostSignificantBytes, err := longToBytes(log, mostSignificantLong) if err != nil { log.Error("getUuid failed: failed to get uuid MSBs bytes value.") - return nil, errors.New("Failed to get uuid MSBs bytes value.") + return uuid.Nil.UUID(), errors.New("Failed to get uuid MSBs bytes value.") } uuidBytes := append(mostSignificantBytes, leastSignificantBytes...) @@ -414,7 +414,7 @@ func putBytes(log log.T, byteArray []byte, offsetStart int, offsetEnd int, input // putUuid puts the 128 bit uuid to an array of bytes starting from the offset. func putUuid(log log.T, byteArray []byte, offset int, input uuid.UUID) (err error) { - if input == nil { + if uuid.IsNil(input) { log.Error("putUuid failed: input is null.") return errors.New("putUuid failed: input is null.") } @@ -494,7 +494,7 @@ func SerializeClientMessageWithAcknowledgeContent(log log.T, acknowledgeContent return } - uuid.SwitchFormat(uuid.CleanHyphen) + uuid.SwitchFormat(uuid.FormatCanonical) messageId := uuid.NewV4() clientMessage := ClientMessage{ MessageType: AcknowledgeMessage, diff --git a/src/sessionmanagerplugin/session/portsession/basicportforwarding.go b/src/sessionmanagerplugin/session/portsession/basicportforwarding.go index 65f1057c..eed9e11d 100644 --- a/src/sessionmanagerplugin/session/portsession/basicportforwarding.go +++ b/src/sessionmanagerplugin/session/portsession/basicportforwarding.go @@ -34,23 +34,13 @@ import ( // accepts one client connection at a time type BasicPortForwarding struct { port IPortSession - stream *net.Conn - listener *net.Listener + stream net.Conn + listener net.Listener sessionId string portParameters PortParameters session session.Session } -// getNewListener returns a new listener to given address and type like tcp, unix etc. -var getNewListener = func(listenerType string, listenerAddress string) (listener net.Listener, err error) { - return net.Listen(listenerType, listenerAddress) -} - -// acceptConnection returns connection to the listener -var acceptConnection = func(log log.T, listener net.Listener) (tcpConn net.Conn, err error) { - return listener.Accept() -} - // IsStreamNotSet checks if stream is not set func (p *BasicPortForwarding) IsStreamNotSet() (status bool) { return p.stream == nil @@ -58,10 +48,11 @@ func (p *BasicPortForwarding) IsStreamNotSet() (status bool) { // Stop closes the stream func (p *BasicPortForwarding) Stop() { + p.listener.Close() if p.stream != nil { - (*p.stream).Close() + p.stream.Close() } - os.Exit(0) + return } // InitializeStreams establishes connection and initializes the stream @@ -77,7 +68,7 @@ func (p *BasicPortForwarding) InitializeStreams(log log.T, agentVersion string) func (p *BasicPortForwarding) ReadStream(log log.T) (err error) { msg := make([]byte, config.StreamDataPayloadSize) for { - numBytes, err := (*p.stream).Read(msg) + numBytes, err := p.stream.Read(msg) if err != nil { log.Debugf("Reading from port %s failed with error: %v. Close this connection, listen and accept new one.", p.portParameters.PortNumber, err) @@ -108,7 +99,7 @@ func (p *BasicPortForwarding) ReadStream(log log.T) (err error) { // WriteStream writes data to stream func (p *BasicPortForwarding) WriteStream(outputMessage message.ClientMessage) error { - _, err := (*p.stream).Write(outputMessage.Payload) + _, err := p.stream.Write(outputMessage.Payload) return err } @@ -120,41 +111,40 @@ func (p *BasicPortForwarding) startLocalConn(log log.T) (err error) { localPortNumber = "0" } - var listener net.Listener - if listener, err = p.startLocalListener(log, localPortNumber); err != nil { + if err = p.startLocalListener(log, localPortNumber); err != nil { log.Errorf("Unable to open tcp connection to port. %v", err) return err } - var tcpConn net.Conn - if tcpConn, err = acceptConnection(log, listener); err != nil { - log.Errorf("Failed to accept connection with error. %v", err) - return err + if p.stream, err = p.listener.Accept(); err != nil { + if p.session.DataChannel.IsSessionEnded() == false { + log.Errorf("Failed to accept connection with error. %v", err) + return err + } + } + if p.session.DataChannel.IsSessionEnded() == false { + log.Infof("Connection accepted for session %s.", p.sessionId) + fmt.Printf("Connection accepted for session %s.\n", p.sessionId) } - log.Infof("Connection accepted for session %s.", p.sessionId) - fmt.Printf("Connection accepted for session %s.\n", p.sessionId) - - p.listener = &listener - p.stream = &tcpConn return } // startLocalListener starts a local listener to given address -func (p *BasicPortForwarding) startLocalListener(log log.T, portNumber string) (listener net.Listener, err error) { +func (p *BasicPortForwarding) startLocalListener(log log.T, portNumber string) (err error) { var displayMessage string switch p.portParameters.LocalConnectionType { case "unix": - if listener, err = getNewListener(p.portParameters.LocalConnectionType, p.portParameters.LocalUnixSocket); err != nil { + if p.listener, err = net.Listen(p.portParameters.LocalConnectionType, p.portParameters.LocalUnixSocket); err != nil { return } displayMessage = fmt.Sprintf("Unix socket %s opened for sessionId %s.", p.portParameters.LocalUnixSocket, p.sessionId) default: - if listener, err = getNewListener("tcp", "localhost:"+portNumber); err != nil { + if p.listener, err = net.Listen("tcp", "localhost:"+portNumber); err != nil { return } // get port number the TCP listener opened - p.portParameters.LocalPortNumber = strconv.Itoa(listener.Addr().(*net.TCPAddr).Port) + p.portParameters.LocalPortNumber = strconv.Itoa(p.listener.Addr().(*net.TCPAddr).Port) displayMessage = fmt.Sprintf("Port %s opened for sessionId %s.", p.portParameters.LocalPortNumber, p.sessionId) } @@ -171,29 +161,31 @@ func (p *BasicPortForwarding) handleControlSignals(log log.T) { <-c fmt.Println("Terminate signal received, exiting.") + p.session.DataChannel.EndSession() if version.DoesAgentSupportTerminateSessionFlag(log, p.session.DataChannel.GetAgentVersion()) { if err := p.session.DataChannel.SendFlag(log, message.TerminateSession); err != nil { log.Errorf("Failed to send TerminateSession flag: %v", err) } fmt.Fprintf(os.Stdout, "\n\nExiting session with sessionId: %s.\n\n", p.sessionId) - p.Stop() } else { p.session.TerminateSession(log) } + p.Stop() }() } // reconnect closes existing connection, listens to new connection and accept it func (p *BasicPortForwarding) reconnect(log log.T) (err error) { // close existing connection as it is in a state from which data cannot be read - (*p.stream).Close() + p.stream.Close() // wait for new connection on listener and accept it - var conn net.Conn - if conn, err = acceptConnection(log, *p.listener); err != nil { - return log.Errorf("Failed to accept connection with error. %v", err) + if p.stream, err = p.listener.Accept(); err != nil { + if p.session.DataChannel.IsSessionEnded() == false { + log.Errorf("Failed to accept connection with error. %v", err) + return err + } } - p.stream = &conn return } diff --git a/src/sessionmanagerplugin/session/portsession/muxportforwarding.go b/src/sessionmanagerplugin/session/portsession/muxportforwarding.go index 85fe0ce6..28a20571 100644 --- a/src/sessionmanagerplugin/session/portsession/muxportforwarding.go +++ b/src/sessionmanagerplugin/session/portsession/muxportforwarding.go @@ -41,8 +41,9 @@ import ( // MuxClient contains smux client session and corresponding network connection type MuxClient struct { - conn net.Conn - session *smux.Session + conn net.Conn + localListener net.Listener + session *smux.Session } // MgsConn contains local server and corresponding connection to smux client @@ -71,6 +72,7 @@ func (c *MgsConn) close() { func (c *MuxClient) close() { c.session.Close() c.conn.Close() + c.localListener.Close() } // IsStreamNotSet checks if stream is not set @@ -87,7 +89,7 @@ func (p *MuxPortForwarding) Stop() { p.muxClient.close() } p.cleanUp() - os.Exit(0) + return } // InitializeStreams initializes i/o streams @@ -116,6 +118,16 @@ func (p *MuxPortForwarding) ReadStream(log log.T) (err error) { return p.handleClientConnections(log, ctx) }) + g.Go(func() error { + for { + time.Sleep(50 * time.Millisecond) + if p.session.DataChannel.IsSessionEnded() == true { + p.Stop() + return nil + } + } + }) + return g.Wait() } @@ -169,13 +181,13 @@ func (p *MuxPortForwarding) initialize(log log.T, agentVersion string) (err erro } else { smuxConfig := smux.DefaultConfig() if version.DoesAgentSupportDisableSmuxKeepAlive(log, agentVersion) { - // Disable smux KeepAlive or else it breaks Session Manager idle timeout. smuxConfig.KeepAliveDisabled = true } if muxSession, err := smux.Client(muxConn, smuxConfig); err != nil { return err } else { - p.muxClient = &MuxClient{muxConn, muxSession} + var localListener net.Listener + p.muxClient = &MuxClient{muxConn, localListener, muxSession} } } return nil @@ -195,7 +207,6 @@ func (p *MuxPortForwarding) handleControlSignals(log log.T) { if err := p.session.DataChannel.SendFlag(log, message.TerminateSession); err != nil { log.Errorf("Failed to send TerminateSession flag: %v", err) } - fmt.Fprintf(os.Stdout, "\n\nExiting session with sessionId: %s.\n\n", p.sessionId) p.Stop() }() } @@ -228,12 +239,11 @@ func (p *MuxPortForwarding) transferDataToServer(log log.T, ctx context.Context) // handleClientConnections sets up network server on local ssm port to accept connections from clients (browser/terminal) func (p *MuxPortForwarding) handleClientConnections(log log.T, ctx context.Context) (err error) { var ( - listener net.Listener displayMsg string ) if p.portParameters.LocalConnectionType == "unix" { - if listener, err = net.Listen(p.portParameters.LocalConnectionType, p.portParameters.LocalUnixSocket); err != nil { + if p.muxClient.localListener, err = net.Listen(p.portParameters.LocalConnectionType, p.portParameters.LocalUnixSocket); err != nil { return err } displayMsg = fmt.Sprintf("Unix socket %s opened for sessionId %s.", p.portParameters.LocalUnixSocket, p.sessionId) @@ -242,14 +252,14 @@ func (p *MuxPortForwarding) handleClientConnections(log log.T, ctx context.Conte if p.portParameters.LocalPortNumber == "" { localPortNumber = "0" } - if listener, err = net.Listen("tcp", "localhost:"+localPortNumber); err != nil { + if p.muxClient.localListener, err = net.Listen("tcp", "localhost:"+localPortNumber); err != nil { return err } - p.portParameters.LocalPortNumber = strconv.Itoa(listener.Addr().(*net.TCPAddr).Port) + p.portParameters.LocalPortNumber = strconv.Itoa(p.muxClient.localListener.Addr().(*net.TCPAddr).Port) displayMsg = fmt.Sprintf("Port %s opened for sessionId %s.", p.portParameters.LocalPortNumber, p.sessionId) } - defer listener.Close() + defer p.muxClient.localListener.Close() log.Infof(displayMsg) fmt.Printf(displayMsg) @@ -263,7 +273,7 @@ func (p *MuxPortForwarding) handleClientConnections(log log.T, ctx context.Conte case <-ctx.Done(): return ctx.Err() default: - if conn, err := listener.Accept(); err != nil { + if conn, err := p.muxClient.localListener.Accept(); err != nil { log.Errorf("Error while accepting connection: %v", err) } else { log.Infof("Connection accepted from %s\n for session [%s]", conn.RemoteAddr(), p.sessionId) diff --git a/src/sessionmanagerplugin/session/portsession/standardstreamforwarding.go b/src/sessionmanagerplugin/session/portsession/standardstreamforwarding.go index 2154904a..d551f2d4 100644 --- a/src/sessionmanagerplugin/session/portsession/standardstreamforwarding.go +++ b/src/sessionmanagerplugin/session/portsession/standardstreamforwarding.go @@ -15,14 +15,17 @@ package portsession import ( + "fmt" "io" "os" + "os/signal" "time" "github.com/aws/session-manager-plugin/src/config" "github.com/aws/session-manager-plugin/src/log" "github.com/aws/session-manager-plugin/src/message" "github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session" + "github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session/sessionutil" ) type StandardStreamForwarding struct { @@ -42,16 +45,30 @@ func (p *StandardStreamForwarding) IsStreamNotSet() (status bool) { func (p *StandardStreamForwarding) Stop() { p.inputStream.Close() p.outputStream.Close() - os.Exit(0) + return } // InitializeStreams initializes the streams with its file descriptors func (p *StandardStreamForwarding) InitializeStreams(log log.T, agentVersion string) (err error) { + p.handleControlSignals(log) p.inputStream = os.Stdin p.outputStream = os.Stdout return } +// handleControlSignals handles terminate signals +func (p *StandardStreamForwarding) handleControlSignals(log log.T) { + c := make(chan os.Signal, 1) + signal.Notify(c, sessionutil.ControlSignals...) + go func() { + <-c + fmt.Println("Terminate signal received, exiting.") + + p.session.DataChannel.EndSession() + p.Stop() + }() +} + // ReadStream reads data from the input stream func (p *StandardStreamForwarding) ReadStream(log log.T) (err error) { msg := make([]byte, config.StreamDataPayloadSize) diff --git a/src/sessionmanagerplugin/session/session.go b/src/sessionmanagerplugin/session/session.go index 6f4d154b..3b223aba 100644 --- a/src/sessionmanagerplugin/session/session.go +++ b/src/sessionmanagerplugin/session/session.go @@ -20,7 +20,6 @@ import ( "fmt" "io" "os" - "strings" "time" "github.com/aws/session-manager-plugin/src/config" @@ -140,7 +139,7 @@ func ValidateInputAndStartSession(args []string, out io.Writer) { target string ) log := log.Logger(true, "session-manager-plugin") - uuid.SwitchFormat(uuid.CleanHyphen) + uuid.SwitchFormat(uuid.FormatCanonical) if len(args) == 1 { fmt.Fprint(out, "\nThe Session Manager plugin was installed successfully. "+ @@ -163,14 +162,7 @@ func ValidateInputAndStartSession(args []string, out io.Writer) { for argsIndex := 1; argsIndex < len(args); argsIndex++ { switch argsIndex { case 1: - if strings.HasPrefix(args[1], "AWS_SSM_START_SESSION_RESPONSE") == true { - response = []byte(os.Getenv(args[1])) - if err = os.Unsetenv(args[1]); err != nil { - log.Errorf("Failed to remove temporary session env parameter: %v", err) - } - } else { - response = []byte(args[1]) - } + response = []byte(args[1]) case 2: region = args[2] case 3: @@ -211,8 +203,10 @@ func ValidateInputAndStartSession(args []string, out io.Writer) { } if err = startSession(&session, log); err != nil { - log.Errorf("Cannot perform start session: %v", err) - fmt.Fprintf(out, "Cannot perform start session: %v\n", err) + if session.DataChannel.IsSessionEnded() == false { + log.Errorf("Cannot perform start session: %v", err) + fmt.Fprintf(out, "Cannot perform start session: %v\n", err) + } return } } @@ -239,7 +233,9 @@ func (s *Session) Execute(log log.T) (err error) { s.SessionType = s.DataChannel.GetSessionType() s.SessionProperties = s.DataChannel.GetSessionProperties() if err = setSessionHandlersWithSessionType(s, log); err != nil { - log.Errorf("Session ending with error: %v", err) + if s.DataChannel.IsSessionEnded() == false { + log.Errorf("Session ending with error: %v", err) + } return } } diff --git a/src/sessionmanagerplugin/session/sessionhandler.go b/src/sessionmanagerplugin/session/sessionhandler.go index cab3cfa7..2b72d8ea 100644 --- a/src/sessionmanagerplugin/session/sessionhandler.go +++ b/src/sessionmanagerplugin/session/sessionhandler.go @@ -87,9 +87,7 @@ func (s *Session) ProcessFirstMessage(log log.T, outputMessage message.ClientMes } // Stop will end the session -func (s *Session) Stop() { - os.Exit(0) -} +func (s *Session) Stop() {} // GetResumeSessionParams calls ResumeSession API and gets tokenvalue for reconnecting func (s *Session) GetResumeSessionParams(log log.T) (string, error) { @@ -130,7 +128,7 @@ func (s *Session) ResumeSessionHandler(log log.T) (err error) { } else if s.TokenValue == "" { log.Debugf("Session: %s timed out", s.SessionId) fmt.Fprintf(os.Stdout, "Session: %s timed out.\n", s.SessionId) - os.Exit(0) + return } s.DataChannel.GetWsChannel().SetChannelToken(s.TokenValue) err = s.DataChannel.Reconnect(log) diff --git a/src/sessionmanagerplugin/session/sessionutil/sessionutil_windows.go b/src/sessionmanagerplugin/session/sessionutil/sessionutil_windows.go index 95dc2775..68516d40 100644 --- a/src/sessionmanagerplugin/session/sessionutil/sessionutil_windows.go +++ b/src/sessionmanagerplugin/session/sessionutil/sessionutil_windows.go @@ -72,7 +72,7 @@ func (d *DisplayMode) DisplayMessage(log log.T, message message.ClientMessage) { if err = windows.WriteFile(d.handle, message.Payload, done, nil); err != nil { log.Errorf("error occurred while writing to file: %v", err) fmt.Fprintf(os.Stdout, "\nError getting the output. %s\n", err.Error()) - os.Exit(0) + return } } diff --git a/src/sessionmanagerplugin/session/shellsession/shellsession_unix.go b/src/sessionmanagerplugin/session/shellsession/shellsession_unix.go index 783c2195..679038a8 100644 --- a/src/sessionmanagerplugin/session/shellsession/shellsession_unix.go +++ b/src/sessionmanagerplugin/session/shellsession/shellsession_unix.go @@ -55,7 +55,6 @@ func setState(state *bytes.Buffer) error { func (s *ShellSession) Stop() { setState(&s.originalSttyState) setState(bytes.NewBufferString("echo")) // for linux and ubuntu - os.Exit(0) } // handleKeyboardInput handles input entered by customer on terminal @@ -64,23 +63,27 @@ func (s *ShellSession) handleKeyboardInput(log log.T) (err error) { stdinBytesLen int ) - //handle double echo and disable input buffering s.disableEchoAndInputBuffering() - - stdinBytes := make([]byte, StdinBufferLimit) - reader := bufio.NewReader(os.Stdin) - for { - if stdinBytesLen, err = reader.Read(stdinBytes); err != nil { - log.Errorf("Unable read from Stdin: %v", err) - break + ch := make(chan []byte) + go func(ch chan []byte) { + reader := bufio.NewReader(os.Stdin) + for { + stdinBytes := make([]byte, StdinBufferLimit) + stdinBytesLen, _ = reader.Read(stdinBytes) + ch <- stdinBytes } + }(ch) - if err = s.Session.DataChannel.SendInputDataMessage(log, message.Output, stdinBytes[:stdinBytesLen]); err != nil { - log.Errorf("Failed to send UTF8 char: %v", err) - break + for { + select { + case <-time.After(time.Second): + if s.Session.DataChannel.IsSessionEnded() { + return + } + case stdinBytes := <-ch: + if err = s.Session.DataChannel.SendInputDataMessage(log, message.Output, stdinBytes[:stdinBytesLen]); err != nil { + return + } } - // sleep to limit the rate of data transfer - time.Sleep(time.Millisecond) } - return } diff --git a/src/sessionmanagerplugin/session/shellsession/shellsession_windows.go b/src/sessionmanagerplugin/session/shellsession/shellsession_windows.go index 71a3538c..398832e0 100644 --- a/src/sessionmanagerplugin/session/shellsession/shellsession_windows.go +++ b/src/sessionmanagerplugin/session/shellsession/shellsession_windows.go @@ -18,7 +18,6 @@ package shellsession import ( - "os" "time" "github.com/aws/session-manager-plugin/src/log" @@ -55,7 +54,7 @@ var specialKeysInputMap = map[keyboard.Key][]byte{ // stop restores the terminal settings and exits func (s *ShellSession) Stop() { - os.Exit(0) + keyboard.Close() } // handleKeyboardInput handles input entered by customer on terminal @@ -64,35 +63,50 @@ func (s *ShellSession) handleKeyboardInput(log log.T) (err error) { character rune //character input from keyboard key keyboard.Key //special keys like arrows and function keys ) - if err = keyboard.Open(); err != nil { - log.Errorf("Failed to load Keyboard: %v", err) - return - } - defer keyboard.Close() - for { - if character, key, err = keyboard.GetKey(); err != nil { - log.Errorf("Failed to get the key stroke: %v", err) + charCH := make(chan rune) + keyCH := make(chan keyboard.Key) + go func(charCH chan rune, keyCH chan keyboard.Key) { + if err = keyboard.Open(); err != nil { + log.Errorf("Failed to load Keyboard: %v", err) return } - if character != 0 { - charBytes := []byte(string(character)) + for { + if character, key, err = keyboard.GetKey(); err != nil { + log.Errorf("Failed to get the key stroke: %v", err) + return + } + if character != 0 { + charCH <- character + } else if key != 0 { + keyCH <- key + } + } + }(charCH, keyCH) + + for { + select { + case <-time.After(time.Second): + if s.Session.DataChannel.IsSessionEnded() == true { + s.Stop() + return + } + case charStr := <-charCH: + charBytes := []byte(string(charStr)) if err = s.Session.DataChannel.SendInputDataMessage(log, message.Output, charBytes); err != nil { log.Errorf("Failed to send UTF8 char: %v", err) - break + return } - } else if key != 0 { - keyBytes := []byte(string(key)) + case keyStr := <-keyCH: + keyBytes := []byte(string(keyStr)) if byteValue, ok := specialKeysInputMap[key]; ok { keyBytes = byteValue } if err = s.Session.DataChannel.SendInputDataMessage(log, message.Output, keyBytes); err != nil { log.Errorf("Failed to send UTF8 char: %v", err) - break + return } } - // sleep to limit the rate of transfer - time.Sleep(time.Millisecond) } return }