diff --git a/docs/src/changelog.md b/docs/src/changelog.md index 74c20cd..e5a0794 100644 --- a/docs/src/changelog.md +++ b/docs/src/changelog.md @@ -7,6 +7,14 @@ CurrentModule = LibSSH This documents notable changes in LibSSH.jl. The format is based on [Keep a Changelog](https://keepachangelog.com). +## Unreleased + +### Added + +- Implemented [`Base.readchomp(::Cmd)`](@ref) for remote commands ([#12]). +- Add support for passing environment variables to remote commands with + [`Base.run(::Cmd)`](@ref) ([#12]). + ## [v0.5.0] - 2024-08-10 ### Added diff --git a/docs/src/sessions_and_channels.md b/docs/src/sessions_and_channels.md index eeabae4..34c27dc 100644 --- a/docs/src/sessions_and_channels.md +++ b/docs/src/sessions_and_channels.md @@ -114,6 +114,7 @@ Base.success(::SshProcess) Base.run(::Cmd, ::Session) Base.read(::Cmd, ::Session) Base.read(::Cmd, ::Session, ::Type{String}) +Base.readchomp(::Cmd, ::Session) Base.success(::Cmd, ::Session) ``` diff --git a/src/channel.jl b/src/channel.jl index ca0a21b..3d2201c 100644 --- a/src/channel.jl +++ b/src/channel.jl @@ -484,6 +484,25 @@ function _exec_command(process::SshProcess) throw(LibSSHException("Failed to open a session channel: $(ret)")) end + # Set environment variables + if !isnothing(process.cmd.env) + for env_var in process.cmd.env + # We explicitly convert the SubString's returned from split() to + # String's so that they're each separate and null-terminated in + # memory, otherwise the entire 'name=value' string would be sent + # when we send `name`. + name, value = String.(split(env_var, "=")) + ret = _session_trywait(session) do + lib.ssh_channel_request_env(sshchan.ptr, name, value) + end + + if ret != SSH_OK + err = get_error(session) + throw(LibSSHException("Error from lib.ssh_channel_request_env(), could not set environment variable: '$(env_var)'")) + end + end + end + # Make the request ret = _session_trywait(session) do GC.@preserve cmd_str begin @@ -521,8 +540,6 @@ Run a command on the remote host over an SSH session. Things that aren't supported compared to `run()`: - Pipelined commands (use a regular pipe like `foo | bar` instead). - Setting the directory to execute the command in. -- Setting environment variables (support is possible, it just hasn't been - implemented yet). # Throws - [`SshProcessFailedException`](@ref): if the command fails and `ignorestatus()` @@ -625,6 +642,13 @@ Base.read(cmd::Cmd, session::Session, ::Type{String}) = String(read(cmd, session """ $(TYPEDSIGNATURES) +`readchomp()` for remote commands. +""" +Base.readchomp(cmd::Cmd, session::Session) = chomp(read(cmd, session, String)) + +""" +$(TYPEDSIGNATURES) + Check the command succeeded. """ Base.success(cmd::Cmd, session::Session) = success(run(cmd, session; print_out=false)) diff --git a/src/server.jl b/src/server.jl index 31fc135..0f856e6 100644 --- a/src/server.jl +++ b/src/server.jl @@ -542,6 +542,9 @@ end function on_channel_env_request(session, sshchan, name, value, client)::Bool _add_log_event!(client, :channel_env_request, (name, value)) + + client.env[name] = value + return true end @@ -567,7 +570,7 @@ function on_channel_exec_request(session, sshchan, command, client)::Bool end owning_sshchan = popat!(client.unclaimed_channels, idx) - push!(client.channel_operations, CommandExecutor(command, owning_sshchan)) + push!(client.channel_operations, CommandExecutor(command, owning_sshchan, client.env)) return true end @@ -679,6 +682,8 @@ end unclaimed_channels::Vector{ssh.SshChannel} = ssh.SshChannel[] channel_operations::Vector{Any} = [] + env::Dict{String, String} = Dict{String, String}() + task::Union{Task, Nothing} = nothing log_timeline::Vector = [] log_lock::ReentrantLock = ReentrantLock() @@ -933,8 +938,9 @@ function exec_command(executor) cmd_stderr = IOBuffer() # Start the process and wait for it - proc = run(pipeline(ignorestatus(`sh -c $(executor.command)`); stdout=cmd_stdout, stderr=cmd_stderr); - wait=false) + cmd_str = join(Base.shell_split(executor.command), " ") + cmd = setenv(ignorestatus(`sh -c $(cmd_str)`), executor.env) + proc = run(pipeline(cmd; stdout=cmd_stdout, stderr=cmd_stderr); wait=false) executor.process = proc notify(executor._started_event) wait(proc) @@ -956,18 +962,19 @@ end @kwdef mutable struct CommandExecutor command::String sshchan::ssh.SshChannel + env::Dict{String, String} task::Union{Task, Nothing} = nothing process::Union{Base.Process, Nothing} = nothing _started_event::Base.Event = Base.Event() end -function CommandExecutor(command::String, sshchan::ssh.SshChannel) +function CommandExecutor(command::String, sshchan::ssh.SshChannel, env) if !sshchan.owning throw(ArgumentError("The passed SshChannel is non-owning, CommandExecutor requires an owning SshChannel")) end - executor = CommandExecutor(; command, sshchan) + executor = CommandExecutor(; command, sshchan, env) executor.task = Threads.@spawn try exec_command(executor) diff --git a/test/LibSSHTests.jl b/test/LibSSHTests.jl index db624ce..011a3a1 100644 --- a/test/LibSSHTests.jl +++ b/test/LibSSHTests.jl @@ -91,13 +91,29 @@ end # https://github.com/JuliaLang/julia/issues/39282 # Also note that we set `-F none` to disabling reading user config files. openssh_cmd = OpenSSH_jll.ssh() - ssh_cmd(cmd::Cmd) = ignorestatus(Cmd(`sshpass -p bar $(openssh_cmd.exec) -F none -o NoHostAuthenticationForLocalhost=yes $cmd`; env=openssh_cmd.env)) + ssh_cmd(cmd::Cmd) = ignorestatus(Cmd(`sshpass -p bar $(openssh_cmd.exec) -F none -o NoHostAuthenticationForLocalhost=yes -p 2222 $cmd`; env=openssh_cmd.env)) + + @testset "Command execution" begin + demo_server = DemoServer(2222; password="bar") do + # Test exitcodes + @test run(ssh_cmd(`foo@localhost exit 0`)).exitcode == 0 + @test run(ssh_cmd(`foo@localhost exit 42`)).exitcode == 42 + + # Test passing environment variables + cmd_out = IOBuffer() + cmd = ssh_cmd(`foo@localhost -o SendEnv=foo echo \$foo`) + cmd = addenv(cmd, "foo" => "bar") + cmd_result = run(pipeline(cmd; stdout=cmd_out)) + + @test strip(String(take!(cmd_out))) == "bar" + end + end @testset "Password authentication and session channels" begin # More complicated test, where we run a command and check the output demo_server = DemoServer(2222; password="bar") do cmd_out = IOBuffer() - cmd = ssh_cmd(`-p 2222 foo@localhost whoami`) + cmd = ssh_cmd(`foo@localhost whoami`) cmd_result = run(pipeline(cmd; stdout=cmd_out)) @test cmd_result.exitcode == 0 @@ -117,7 +133,7 @@ end # Make sure that it can handle errors too DemoServer(2222; password="bar") do - cmd = ssh_cmd(`-p 2222 foo@localhost exit 42`) + cmd = ssh_cmd(`foo@localhost exit 42`) cmd_result = run(pipeline(ignorestatus(cmd))) @test cmd_result.exitcode == 42 end @@ -135,7 +151,7 @@ end tmpfile = joinpath(tmpdir, "foo") # Start a client and wait for it - cmd = ssh_cmd(`-p 2222 -L 8080:localhost:9090 foo@localhost "touch $tmpfile; while [ -f $tmpfile ]; do sleep 0.1; done"`) + cmd = ssh_cmd(`-L 8080:localhost:9090 foo@localhost "touch $tmpfile; while [ -f $tmpfile ]; do sleep 0.1; done"`) ssh_process = run(cmd; wait=false) if timedwait(() -> isfile(tmpfile), 5) == :timed_out error("Timeout waiting for sentinel file $tmpfile to be created") @@ -180,8 +196,8 @@ end @testset "Multiple connections" begin demo_server = DemoServer(2222; password="bar") do - run(ssh_cmd(`-p 2222 foo@localhost exit 0`)) - run(ssh_cmd(`-p 2222 foo@localhost exit 0`)) + run(ssh_cmd(`foo@localhost exit 0`)) + run(ssh_cmd(`foo@localhost exit 0`)) end @test length(demo_server.clients) == 2 end @@ -408,7 +424,7 @@ end end end - @testset "Executing commands" begin + @testset "Command execution" begin demo_server_with_session(2222) do session # Smoke test process = run(`whoami`, session; print_out=false) @@ -421,11 +437,15 @@ end @test !isempty(String(process.out)) # Test Base methods - @test read(`echo foo`, session, String) == "foo\n" + @test readchomp(`echo foo`, session) == "foo" @test success(`whoami`, session) # Check that commands with quotes are properly escaped - @test read(`echo 'foo bar'`, session, String) == "foo bar\n" + @test readchomp(`echo 'foo bar'`, session) == "foo bar" + + # Test setting environment variables + cmd = setenv(`echo \$foo`, "foo" => "bar") + @test readchomp(cmd, session) == "bar" end end