Skip to content

Commit

Permalink
Adding a ShellCommandLocalSocketExecutorServer for the ShellExecutor …
Browse files Browse the repository at this point in the history
…to talk to the ShellMain.

PiperOrigin-RevId: 681170566
  • Loading branch information
copybara-androidxtest committed Nov 7, 2024
1 parent 35bdab8 commit d3c0e21
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ kt_android_library(
"ShellCommandExecutor.java",
"ShellCommandExecutorServer.java",
"ShellCommandFileObserverExecutorServer.kt",
"ShellCommandLocalSocketExecutorServer.kt",
"ShellExecSharedConstants.java",
"ShellMain.java",
],
Expand All @@ -72,6 +73,8 @@ kt_android_library(
deps = [
":coroutine_file_observer",
":file_observer_protocol",
":local_socket_protocol",
":local_socket_protocol_pb_java_proto_lite",
"//services/speakeasy/java/androidx/test/services/speakeasy:protocol",
"//services/speakeasy/java/androidx/test/services/speakeasy/client",
"//services/speakeasy/java/androidx/test/services/speakeasy/client:tool_connection",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
/*
* Copyright (C) 2024 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package androidx.test.services.shellexecutor

import android.net.LocalServerSocket
import android.net.LocalSocket
import android.net.LocalSocketAddress
import android.os.Process as AndroidProcess
import android.util.Log
import androidx.test.services.shellexecutor.LocalSocketProtocol.asBinderKey
import androidx.test.services.shellexecutor.LocalSocketProtocol.readRequest
import androidx.test.services.shellexecutor.LocalSocketProtocol.sendResponse
import androidx.test.services.shellexecutor.LocalSocketProtocolProto.RunCommandRequest
import java.io.IOException
import java.io.InterruptedIOException
import java.security.SecureRandom
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.async
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.runInterruptible
import kotlinx.coroutines.withTimeout

/** Server that run shell commands for a client talking over a LocalSocket. */
final class ShellCommandLocalSocketExecutorServer
@JvmOverloads
constructor(
private val scope: CoroutineScope =
CoroutineScope(Executors.newCachedThreadPool().asCoroutineDispatcher())
) {
// Use the same secret generation as SpeakEasy does.
private val secret = java.lang.Long.toHexString(SecureRandom().nextLong())
lateinit var socket: LocalServerSocket
lateinit var address: LocalSocketAddress
// Since LocalServerSocket.accept() has to be interrupted, we keep that in its own Job...
lateinit var serverJob: Job
// ...while all the child jobs are under a single SupervisorJob that we can join later.
val shellJobs = SupervisorJob()
val running = AtomicBoolean(true)

/** Returns the binder key to pass to client processes. */
fun binderKey(): String {
// The address can contain spaces, and since it gets passed through a command line, we need to
// encode it. java.net.URLEncoder is conveniently available in all SDK versions.
return address.asBinderKey(secret)
}

/** Runs a simple server. */
private suspend fun server() = coroutineScope {
while (running.get()) {
val connection =
try {
runInterruptible { socket.accept() }
} catch (x: Exception) {
// None of my tests have managed to trigger this one.
Log.e(TAG, "LocalServerSocket.accept() failed", x)
break
}
launch(scope.coroutineContext + shellJobs) { handleConnection(connection) }
}
}

/**
* Relays the output of process to connection with a series of RunCommandResponses.
*
* @param process The process to relay output from.
* @param connection The connection to relay output to.
* @return false if there was a problem, true otherwise.
*/
private suspend fun relay(process: Process, connection: LocalSocket): Boolean {
// Experiment shows that 64K is *much* faster than 4K, especially on API 21-23. Streaming 1MB
// takes 3s with 4K buffers and 2s with 64K on API 23. 22 is a bit faster (2.6s -> 1.5s),
// 21 faster still (630ms -> 545ms). Higher API levels are *much* faster (24 is 119 ms ->
// 75ms).
val buffer = ByteArray(65536)
var size: Int

// LocalSocket.isOutputShutdown() throws UnsupportedOperationException, so we can't use
// that as our loop constraint.
while (true) {
try {
size = runInterruptible { process.inputStream.read(buffer) }
if (size < 0) return true // EOF
if (size == 0) {
delay(1.milliseconds)
continue
}
} catch (x: InterruptedIOException) {
// We start getting these at API 24 when the timeout handling kicks in.
Log.i(TAG, "Interrupted while reading from ${process}: ${x.message}")
return false
} catch (x: IOException) {
Log.i(TAG, "Error reading from ${process}; did it time out?", x)
return false
}

if (!connection.sendResponse(buffer = buffer, size = size)) {
return false
}
}
}

/** Handle one connection. */
private suspend fun handleConnection(connection: LocalSocket) {
// connection.localSocketAddress is always null, so no point in logging it.

// Close the connection when done.
connection.use {
val request = connection.readRequest()

if (request.secret.compareTo(secret) != 0) {
Log.w(TAG, "Ignoring request with wrong secret: $request")
return
}

val pb = request.toProcessBuilder()
pb.redirectErrorStream(true)

val process: Process
try {
process = pb.start()
} catch (x: IOException) {
Log.e(TAG, "Failed to start process", x)
connection.sendResponse(
buffer = x.stackTraceToString().toByteArray(),
exitCode = EXIT_CODE_FAILED_TO_START,
)
return
}

// We will not be writing anything to the process' stdin.
process.outputStream.close()

// Close the process' stdout when we're done reading.
process.inputStream.use {
// Launch a coroutine to relay the process' output to the client. If it times out, kill the
// process and cancel the job. This is more coroutine-friendly than using waitFor() to
// handle timeouts.
val ioJob = scope.async { relay(process, connection) }

try {
withTimeout(request.timeout()) {
if (!ioJob.await()) {
Log.w(TAG, "Relaying ${process} output failed")
}
runInterruptible { process.waitFor() }
}
} catch (x: TimeoutCancellationException) {
Log.e(TAG, "Process ${process} timed out after ${request.timeout()}")
process.destroy()
ioJob.cancel()
connection.sendResponse(exitCode = EXIT_CODE_TIMED_OUT)
return
}

connection.sendResponse(exitCode = process.exitValue())
}
}
}

/** Starts the server. */
fun start() {
socket = LocalServerSocket("androidx.test.services ${AndroidProcess.myPid()}")
address = socket.localSocketAddress
Log.i(TAG, "Starting server on ${address.name}")

// Launch a coroutine to call socket.accept()
serverJob = scope.launch { server() }
}

/** Stops the server. */
fun stop(timeout: Duration) {
running.set(false)
// Closing the socket does not interrupt accept()...
socket.close()
runBlocking(scope.coroutineContext) {
try {
// ...so we simply cancel that job...
serverJob.cancel()
// ...and play nicely with all the shell jobs underneath.
withTimeout(timeout) {
shellJobs.complete()
shellJobs.join()
}
} catch (x: TimeoutCancellationException) {
Log.w(TAG, "Shell jobs did not stop after $timeout", x)
shellJobs.cancel()
}
}
}

private fun RunCommandRequest.timeout(): Duration =
if (timeoutMs <= 0) {
Duration.INFINITE
} else {
timeoutMs.milliseconds
}

/**
* Sets up a ProcessBuilder with information from the request; other configuration is up to the
* caller.
*/
private fun RunCommandRequest.toProcessBuilder(): ProcessBuilder {
val pb = ProcessBuilder(argvList)
val redacted = argvList.map { it.replace(secret, "(SECRET)") } // Don't log the secret!
Log.i(TAG, "Command to execute: [${redacted.joinToString("] [")}] within ${timeout()}")
if (environmentMap.isNotEmpty()) {
pb.environment().putAll(environmentMap)
val env = environmentMap.entries.map { (k, v) -> "$k=$v" }.joinToString(", ")
Log.i(TAG, "Environment: $env")
}
return pb
}

private companion object {
const val TAG = "SCLSEServer" // up to 23 characters

const val EXIT_CODE_FAILED_TO_START = -1
const val EXIT_CODE_TIMED_OUT = -2
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,23 @@ axt_android_library_test(
],
)

axt_android_library_test(
name = "ShellCommandLocalSocketExecutorServerTest",
srcs = [
"ShellCommandLocalSocketExecutorServerTest.kt",
],
deps = [
"//runner/monitor",
"//services/shellexecutor:exec_server",
"//services/shellexecutor/java/androidx/test/services/shellexecutor:local_socket_protocol",
"//services/shellexecutor/java/androidx/test/services/shellexecutor:local_socket_protocol_pb_java_proto_lite",
"@com_google_protobuf//:protobuf_javalite",
"@maven//:com_google_truth_truth",
"@maven//:junit_junit",
"@maven//:org_jetbrains_kotlinx_kotlinx_coroutines_android",
],
)

axt_android_library_test(
name = "ShellExecutorTest",
srcs = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package androidx.test.services.shellexecutor

import android.net.LocalSocket
import android.os.Build
import androidx.test.services.shellexecutor.LocalSocketProtocol.addressFromBinderKey
import androidx.test.services.shellexecutor.LocalSocketProtocol.hasExited
import androidx.test.services.shellexecutor.LocalSocketProtocol.readResponse
import androidx.test.services.shellexecutor.LocalSocketProtocol.secretFromBinderKey
import androidx.test.services.shellexecutor.LocalSocketProtocol.sendRequest
import androidx.test.services.shellexecutor.LocalSocketProtocolProto.RunCommandResponse
import com.google.common.truth.Truth.assertThat
import kotlin.time.Duration.Companion.milliseconds
import kotlinx.coroutines.runBlocking
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

@RunWith(JUnit4::class)
class ShellCommandLocalSocketExecutorServerTest {

@Test
fun success_simple() {
val responses = mutableListOf<RunCommandResponse>()
runBlocking {
val server = ShellCommandLocalSocketExecutorServer()
server.start()
val client = LocalSocket(LocalSocket.SOCKET_STREAM)
client.connect(addressFromBinderKey(server.binderKey()))
client.sendRequest(
secretFromBinderKey(server.binderKey()),
listOf("echo", "\${POTRZEBIE}"),
mapOf("POTRZEBIE" to "furshlugginer"),
1000.milliseconds,
)
do {
client.readResponse()?.let { responses.add(it) }
} while (!responses.last().hasExited())
server.stop(100.milliseconds)
}
if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.LOLLIPOP_MR1) {
// On API 21 and 22, echo only exists as a shell builtin!
assertThat(responses).hasSize(1)
assertThat(responses[0].exitCode).isEqualTo(-1)
assertThat(responses[0].buffer.toStringUtf8()).contains("Permission denied")
} else {
// On rare occasions, the output of the command will come back in two packets! So to keep
// this test from being 1% flaky:
val stdout = buildString {
for (response in responses) {
if (response.buffer.size() > 0) append(response.buffer.toStringUtf8())
}
}
assertThat(stdout).isEqualTo("\${POTRZEBIE}\n")
assertThat(responses.last().hasExited()).isTrue()
assertThat(responses.last().exitCode).isEqualTo(0)
}
}

@Test
fun success_shell_expansion() {
val responses = mutableListOf<RunCommandResponse>()
runBlocking {
val server = ShellCommandLocalSocketExecutorServer()
server.start()
val client = LocalSocket(LocalSocket.SOCKET_STREAM)
client.connect(addressFromBinderKey(server.binderKey()))
client.sendRequest(
secretFromBinderKey(server.binderKey()),
listOf("sh", "-c", "echo \${POTRZEBIE}"),
mapOf("POTRZEBIE" to "furshlugginer"),
1000.milliseconds,
)
do {
client.readResponse()?.let { responses.add(it) }
} while (!responses.last().hasExited())
server.stop(100.milliseconds)
}
val stdout = buildString {
for (response in responses) {
if (response.buffer.size() > 0) append(response.buffer.toStringUtf8())
}
}
assertThat(stdout).isEqualTo("furshlugginer\n")
assertThat(responses.last().hasExited()).isTrue()
assertThat(responses.last().exitCode).isEqualTo(0)
}

@Test
fun failure_bad_secret() {
runBlocking {
val server = ShellCommandLocalSocketExecutorServer()
server.start()
val client = LocalSocket(LocalSocket.SOCKET_STREAM)
client.connect(addressFromBinderKey(server.binderKey()))
client.sendRequest(
"potrzebie!",
listOf("sh", "-c", "echo \${POTRZEBIE}"),
mapOf("POTRZEBIE" to "furshlugginer"),
1000.milliseconds,
)
assertThat(client.inputStream.read()).isEqualTo(-1)
}
}
}

0 comments on commit d3c0e21

Please sign in to comment.