diff --git a/services/shellexecutor/java/androidx/test/services/shellexecutor/BUILD b/services/shellexecutor/java/androidx/test/services/shellexecutor/BUILD index 897bc547a..91bf1da4d 100644 --- a/services/shellexecutor/java/androidx/test/services/shellexecutor/BUILD +++ b/services/shellexecutor/java/androidx/test/services/shellexecutor/BUILD @@ -64,6 +64,7 @@ kt_android_library( "ShellCommandExecutor.java", "ShellCommandExecutorServer.java", "ShellCommandFileObserverExecutorServer.kt", + "ShellCommandLocalSocketExecutorServer.kt", "ShellExecSharedConstants.java", "ShellMain.java", ], @@ -72,6 +73,8 @@ kt_android_library( deps = [ ":coroutine_file_observer", ":file_observer_protocol", + ":local_socket_protocol", + ":local_socket_protocol_pb_java_proto_lite", "//services/speakeasy/java/androidx/test/services/speakeasy:protocol", "//services/speakeasy/java/androidx/test/services/speakeasy/client", "//services/speakeasy/java/androidx/test/services/speakeasy/client:tool_connection", diff --git a/services/shellexecutor/java/androidx/test/services/shellexecutor/ShellCommandLocalSocketExecutorServer.kt b/services/shellexecutor/java/androidx/test/services/shellexecutor/ShellCommandLocalSocketExecutorServer.kt new file mode 100644 index 000000000..775f81d89 --- /dev/null +++ b/services/shellexecutor/java/androidx/test/services/shellexecutor/ShellCommandLocalSocketExecutorServer.kt @@ -0,0 +1,245 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package androidx.test.services.shellexecutor + +import android.net.LocalServerSocket +import android.net.LocalSocket +import android.net.LocalSocketAddress +import android.os.Process as AndroidProcess +import android.util.Log +import androidx.test.services.shellexecutor.LocalSocketProtocol.asBinderKey +import androidx.test.services.shellexecutor.LocalSocketProtocol.readRequest +import androidx.test.services.shellexecutor.LocalSocketProtocol.sendResponse +import androidx.test.services.shellexecutor.LocalSocketProtocolProto.RunCommandRequest +import java.io.IOException +import java.io.InterruptedIOException +import java.security.SecureRandom +import java.util.concurrent.Executors +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.asCoroutineDispatcher +import kotlinx.coroutines.async +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.runInterruptible +import kotlinx.coroutines.withTimeout + +/** Server that run shell commands for a client talking over a LocalSocket. */ +final class ShellCommandLocalSocketExecutorServer +@JvmOverloads +constructor( + private val scope: CoroutineScope = + CoroutineScope(Executors.newCachedThreadPool().asCoroutineDispatcher()) +) { + // Use the same secret generation as SpeakEasy does. + private val secret = java.lang.Long.toHexString(SecureRandom().nextLong()) + lateinit var socket: LocalServerSocket + lateinit var address: LocalSocketAddress + // Since LocalServerSocket.accept() has to be interrupted, we keep that in its own Job... + lateinit var serverJob: Job + // ...while all the child jobs are under a single SupervisorJob that we can join later. + val shellJobs = SupervisorJob() + val running = AtomicBoolean(true) + + /** Returns the binder key to pass to client processes. */ + fun binderKey(): String { + // The address can contain spaces, and since it gets passed through a command line, we need to + // encode it. java.net.URLEncoder is conveniently available in all SDK versions. + return address.asBinderKey(secret) + } + + /** Runs a simple server. */ + private suspend fun server() = coroutineScope { + while (running.get()) { + val connection = + try { + runInterruptible { socket.accept() } + } catch (x: Exception) { + // None of my tests have managed to trigger this one. + Log.e(TAG, "LocalServerSocket.accept() failed", x) + break + } + launch(scope.coroutineContext + shellJobs) { handleConnection(connection) } + } + } + + /** + * Relays the output of process to connection with a series of RunCommandResponses. + * + * @param process The process to relay output from. + * @param connection The connection to relay output to. + * @return false if there was a problem, true otherwise. + */ + private suspend fun relay(process: Process, connection: LocalSocket): Boolean { + // Experiment shows that 64K is *much* faster than 4K, especially on API 21-23. Streaming 1MB + // takes 3s with 4K buffers and 2s with 64K on API 23. 22 is a bit faster (2.6s -> 1.5s), + // 21 faster still (630ms -> 545ms). Higher API levels are *much* faster (24 is 119 ms -> + // 75ms). + val buffer = ByteArray(65536) + var size: Int + + // LocalSocket.isOutputShutdown() throws UnsupportedOperationException, so we can't use + // that as our loop constraint. + while (true) { + try { + size = runInterruptible { process.inputStream.read(buffer) } + if (size < 0) return true // EOF + if (size == 0) { + delay(1.milliseconds) + continue + } + } catch (x: InterruptedIOException) { + // We start getting these at API 24 when the timeout handling kicks in. + Log.i(TAG, "Interrupted while reading from ${process}: ${x.message}") + return false + } catch (x: IOException) { + Log.i(TAG, "Error reading from ${process}; did it time out?", x) + return false + } + + if (!connection.sendResponse(buffer = buffer, size = size)) { + return false + } + } + } + + /** Handle one connection. */ + private suspend fun handleConnection(connection: LocalSocket) { + // connection.localSocketAddress is always null, so no point in logging it. + + // Close the connection when done. + connection.use { + val request = connection.readRequest() + + if (request.secret.compareTo(secret) != 0) { + Log.w(TAG, "Ignoring request with wrong secret: $request") + return + } + + val pb = request.toProcessBuilder() + pb.redirectErrorStream(true) + + val process: Process + try { + process = pb.start() + } catch (x: IOException) { + Log.e(TAG, "Failed to start process", x) + connection.sendResponse( + buffer = x.stackTraceToString().toByteArray(), + exitCode = EXIT_CODE_FAILED_TO_START, + ) + return + } + + // We will not be writing anything to the process' stdin. + process.outputStream.close() + + // Close the process' stdout when we're done reading. + process.inputStream.use { + // Launch a coroutine to relay the process' output to the client. If it times out, kill the + // process and cancel the job. This is more coroutine-friendly than using waitFor() to + // handle timeouts. + val ioJob = scope.async { relay(process, connection) } + + try { + withTimeout(request.timeout()) { + if (!ioJob.await()) { + Log.w(TAG, "Relaying ${process} output failed") + } + runInterruptible { process.waitFor() } + } + } catch (x: TimeoutCancellationException) { + Log.e(TAG, "Process ${process} timed out after ${request.timeout()}") + process.destroy() + ioJob.cancel() + connection.sendResponse(exitCode = EXIT_CODE_TIMED_OUT) + return + } + + connection.sendResponse(exitCode = process.exitValue()) + } + } + } + + /** Starts the server. */ + fun start() { + socket = LocalServerSocket("androidx.test.services ${AndroidProcess.myPid()}") + address = socket.localSocketAddress + Log.i(TAG, "Starting server on ${address.name}") + + // Launch a coroutine to call socket.accept() + serverJob = scope.launch { server() } + } + + /** Stops the server. */ + fun stop(timeout: Duration) { + running.set(false) + // Closing the socket does not interrupt accept()... + socket.close() + runBlocking(scope.coroutineContext) { + try { + // ...so we simply cancel that job... + serverJob.cancel() + // ...and play nicely with all the shell jobs underneath. + withTimeout(timeout) { + shellJobs.complete() + shellJobs.join() + } + } catch (x: TimeoutCancellationException) { + Log.w(TAG, "Shell jobs did not stop after $timeout", x) + shellJobs.cancel() + } + } + } + + private fun RunCommandRequest.timeout(): Duration = + if (timeoutMs <= 0) { + Duration.INFINITE + } else { + timeoutMs.milliseconds + } + + /** + * Sets up a ProcessBuilder with information from the request; other configuration is up to the + * caller. + */ + private fun RunCommandRequest.toProcessBuilder(): ProcessBuilder { + val pb = ProcessBuilder(argvList) + val redacted = argvList.map { it.replace(secret, "(SECRET)") } // Don't log the secret! + Log.i(TAG, "Command to execute: [${redacted.joinToString("] [")}] within ${timeout()}") + if (environmentMap.isNotEmpty()) { + pb.environment().putAll(environmentMap) + val env = environmentMap.entries.map { (k, v) -> "$k=$v" }.joinToString(", ") + Log.i(TAG, "Environment: $env") + } + return pb + } + + private companion object { + const val TAG = "SCLSEServer" // up to 23 characters + + const val EXIT_CODE_FAILED_TO_START = -1 + const val EXIT_CODE_TIMED_OUT = -2 + } +} diff --git a/services/shellexecutor/javatests/androidx/test/services/shellexecutor/BUILD b/services/shellexecutor/javatests/androidx/test/services/shellexecutor/BUILD index 0b69391ab..ff90f5063 100644 --- a/services/shellexecutor/javatests/androidx/test/services/shellexecutor/BUILD +++ b/services/shellexecutor/javatests/androidx/test/services/shellexecutor/BUILD @@ -109,6 +109,23 @@ axt_android_library_test( ], ) +axt_android_library_test( + name = "ShellCommandLocalSocketExecutorServerTest", + srcs = [ + "ShellCommandLocalSocketExecutorServerTest.kt", + ], + deps = [ + "//runner/monitor", + "//services/shellexecutor:exec_server", + "//services/shellexecutor/java/androidx/test/services/shellexecutor:local_socket_protocol", + "//services/shellexecutor/java/androidx/test/services/shellexecutor:local_socket_protocol_pb_java_proto_lite", + "@com_google_protobuf//:protobuf_javalite", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_jetbrains_kotlinx_kotlinx_coroutines_android", + ], +) + axt_android_library_test( name = "ShellExecutorTest", srcs = [ diff --git a/services/shellexecutor/javatests/androidx/test/services/shellexecutor/ShellCommandLocalSocketExecutorServerTest.kt b/services/shellexecutor/javatests/androidx/test/services/shellexecutor/ShellCommandLocalSocketExecutorServerTest.kt new file mode 100644 index 000000000..230ea9820 --- /dev/null +++ b/services/shellexecutor/javatests/androidx/test/services/shellexecutor/ShellCommandLocalSocketExecutorServerTest.kt @@ -0,0 +1,104 @@ +package androidx.test.services.shellexecutor + +import android.net.LocalSocket +import android.os.Build +import androidx.test.services.shellexecutor.LocalSocketProtocol.addressFromBinderKey +import androidx.test.services.shellexecutor.LocalSocketProtocol.hasExited +import androidx.test.services.shellexecutor.LocalSocketProtocol.readResponse +import androidx.test.services.shellexecutor.LocalSocketProtocol.secretFromBinderKey +import androidx.test.services.shellexecutor.LocalSocketProtocol.sendRequest +import androidx.test.services.shellexecutor.LocalSocketProtocolProto.RunCommandResponse +import com.google.common.truth.Truth.assertThat +import kotlin.time.Duration.Companion.milliseconds +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 + +@RunWith(JUnit4::class) +class ShellCommandLocalSocketExecutorServerTest { + + @Test + fun success_simple() { + val responses = mutableListOf() + runBlocking { + val server = ShellCommandLocalSocketExecutorServer() + server.start() + val client = LocalSocket(LocalSocket.SOCKET_STREAM) + client.connect(addressFromBinderKey(server.binderKey())) + client.sendRequest( + secretFromBinderKey(server.binderKey()), + listOf("echo", "\${POTRZEBIE}"), + mapOf("POTRZEBIE" to "furshlugginer"), + 1000.milliseconds, + ) + do { + client.readResponse()?.let { responses.add(it) } + } while (!responses.last().hasExited()) + server.stop(100.milliseconds) + } + if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.LOLLIPOP_MR1) { + // On API 21 and 22, echo only exists as a shell builtin! + assertThat(responses).hasSize(1) + assertThat(responses[0].exitCode).isEqualTo(-1) + assertThat(responses[0].buffer.toStringUtf8()).contains("Permission denied") + } else { + // On rare occasions, the output of the command will come back in two packets! So to keep + // this test from being 1% flaky: + val stdout = buildString { + for (response in responses) { + if (response.buffer.size() > 0) append(response.buffer.toStringUtf8()) + } + } + assertThat(stdout).isEqualTo("\${POTRZEBIE}\n") + assertThat(responses.last().hasExited()).isTrue() + assertThat(responses.last().exitCode).isEqualTo(0) + } + } + + @Test + fun success_shell_expansion() { + val responses = mutableListOf() + runBlocking { + val server = ShellCommandLocalSocketExecutorServer() + server.start() + val client = LocalSocket(LocalSocket.SOCKET_STREAM) + client.connect(addressFromBinderKey(server.binderKey())) + client.sendRequest( + secretFromBinderKey(server.binderKey()), + listOf("sh", "-c", "echo \${POTRZEBIE}"), + mapOf("POTRZEBIE" to "furshlugginer"), + 1000.milliseconds, + ) + do { + client.readResponse()?.let { responses.add(it) } + } while (!responses.last().hasExited()) + server.stop(100.milliseconds) + } + val stdout = buildString { + for (response in responses) { + if (response.buffer.size() > 0) append(response.buffer.toStringUtf8()) + } + } + assertThat(stdout).isEqualTo("furshlugginer\n") + assertThat(responses.last().hasExited()).isTrue() + assertThat(responses.last().exitCode).isEqualTo(0) + } + + @Test + fun failure_bad_secret() { + runBlocking { + val server = ShellCommandLocalSocketExecutorServer() + server.start() + val client = LocalSocket(LocalSocket.SOCKET_STREAM) + client.connect(addressFromBinderKey(server.binderKey())) + client.sendRequest( + "potrzebie!", + listOf("sh", "-c", "echo \${POTRZEBIE}"), + mapOf("POTRZEBIE" to "furshlugginer"), + 1000.milliseconds, + ) + assertThat(client.inputStream.read()).isEqualTo(-1) + } + } +}