Skip to content

Commit

Permalink
Merge pull request #3275 from canonical/wait-for-ssh-user-auth
Browse files Browse the repository at this point in the history
[ssh] wait for ssh user authentication before returning
  • Loading branch information
luis4a0 authored Oct 28, 2023
2 parents ef40bb5 + 34da2b6 commit 9740dfd
Show file tree
Hide file tree
Showing 27 changed files with 147 additions and 100 deletions.
9 changes: 4 additions & 5 deletions include/multipass/ssh/ssh_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ class SSHKeyProvider;
class SSHSession
{
public:
SSHSession(const std::string& host, int port, const std::chrono::milliseconds timeout = std::chrono::seconds(1));
SSHSession(const std::string& host, int port, const std::string& ssh_username, const SSHKeyProvider& key_provider,
SSHSession(const std::string& host,
int port,
const std::string& ssh_username,
const SSHKeyProvider& key_provider,
const std::chrono::milliseconds timeout = std::chrono::seconds(20));

SSHProcess exec(const std::string& cmd);
Expand All @@ -42,9 +44,6 @@ class SSHSession
operator ssh_session() const;

private:
SSHSession(const std::string& host, int port, const std::string& ssh_username, const SSHKeyProvider* key_provider);
SSHSession(const std::string& host, int port, const std::string& ssh_username, const SSHKeyProvider* key_provider,
const std::chrono::milliseconds timeout = std::chrono::seconds(20));
void set_option(ssh_options_e type, const void* value);
std::unique_ptr<ssh_session_struct, void (*)(ssh_session)> session;
};
Expand Down
7 changes: 5 additions & 2 deletions include/multipass/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,11 @@ std::string match_line_for(const std::string& output, const std::string& matcher

// virtual machine helpers
bool is_running(const VirtualMachine::State& state);
void wait_until_ssh_up(VirtualMachine* virtual_machine, std::chrono::milliseconds timeout,
std::function<void()> const& ensure_vm_is_running = []() {});
void wait_until_ssh_up(
VirtualMachine* virtual_machine,
std::chrono::milliseconds timeout,
const SSHKeyProvider& key_provider,
std::function<void()> const& ensure_vm_is_running = []() {});
std::string run_in_ssh_session(SSHSession& session, const std::string& cmd);

// yaml helpers
Expand Down
4 changes: 2 additions & 2 deletions include/multipass/virtual_machine.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ class VirtualMachine : private DisabledCopyMove
};
virtual std::string ssh_hostname(std::chrono::milliseconds timeout) = 0;
virtual std::string ssh_username() = 0;
virtual std::string management_ipv4() = 0;
virtual std::string management_ipv4(const SSHKeyProvider& key_provider) = 0;
virtual std::vector<std::string> get_all_ipv4(const SSHKeyProvider& key_provider) = 0;
virtual std::string ipv6() = 0;
virtual void wait_until_ssh_up(std::chrono::milliseconds timeout) = 0;
virtual void wait_until_ssh_up(std::chrono::milliseconds timeout, const SSHKeyProvider& key_provider) = 0;
virtual void ensure_vm_is_running() = 0;
virtual void update_state() = 0;
virtual void update_cpus(int num_cores) = 0;
Expand Down
6 changes: 3 additions & 3 deletions src/daemon/daemon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1629,7 +1629,7 @@ try // clang-format on
mpu::run_in_ssh_session(session, "df -t ext4 -t vfat --total -B1 --output=size | tail -n 1"));
info->set_cpu_count(mpu::run_in_ssh_session(session, "nproc"));

std::string management_ip = vm.management_ipv4();
std::string management_ip = vm.management_ipv4(*config->ssh_key_provider);
auto all_ipv4 = vm.get_all_ipv4(*config->ssh_key_provider);

if (is_ipv4_valid(management_ip))
Expand Down Expand Up @@ -1710,7 +1710,7 @@ try // clang-format on

if (request->request_ipv4() && mp::utils::is_running(present_state))
{
std::string management_ip = vm->management_ipv4();
std::string management_ip = vm->management_ipv4(*config->ssh_key_provider);
auto all_ipv4 = vm->get_all_ipv4(*config->ssh_key_provider);

if (is_ipv4_valid(management_ip))
Expand Down Expand Up @@ -2888,7 +2888,7 @@ mp::Daemon::async_wait_for_ssh_and_start_mounts_for(const std::string& name, con
{
auto it = operative_instances.find(name);
auto vm = it->second;
vm->wait_until_ssh_up(timeout);
vm->wait_until_ssh_up(timeout, *config->ssh_key_provider);

if (std::is_same<Reply, LaunchReply>::value)
{
Expand Down
38 changes: 24 additions & 14 deletions src/platform/backends/libvirt/libvirt_virtual_machine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,22 @@ void update_max_and_property(virDomainPtr domain_ptr, Updater* fun_ptr, Integer
flags &= ~max_flag;
} while (!twice++); // first set the maximum, then actual
}

std::string management_ipv4_impl(std::optional<mp::IPAddress>& management_ip,
const std::string& mac_addr,
const mp::LibvirtWrapper::UPtr& libvirt_wrapper)
{
if (!management_ip)
{
auto result = instance_ip_for(mac_addr, libvirt_wrapper);
if (result)
management_ip.emplace(result.value());
else
return "UNKNOWN";
}

return management_ip.value().as_string();
}
} // namespace

mp::LibVirtVirtualMachine::LibVirtVirtualMachine(const mp::VirtualMachineDescription& desc,
Expand Down Expand Up @@ -440,28 +456,22 @@ std::string mp::LibVirtVirtualMachine::ssh_username()
return username;
}

std::string mp::LibVirtVirtualMachine::management_ipv4()
std::string mp::LibVirtVirtualMachine::management_ipv4(const SSHKeyProvider& /* not used on this backend */)
{
if (!management_ip)
{
auto result = instance_ip_for(mac_addr, libvirt_wrapper);
if (result)
management_ip.emplace(result.value());
else
return "UNKNOWN";
}

return management_ip.value().as_string();
return management_ipv4_impl(management_ip, mac_addr, libvirt_wrapper);
}

std::string mp::LibVirtVirtualMachine::ipv6()
{
return {};
}

void mp::LibVirtVirtualMachine::wait_until_ssh_up(std::chrono::milliseconds timeout)
void mp::LibVirtVirtualMachine::wait_until_ssh_up(std::chrono::milliseconds timeout, const SSHKeyProvider& key_provider)
{
mp::utils::wait_until_ssh_up(this, timeout, std::bind(&LibVirtVirtualMachine::ensure_vm_is_running, this));
mp::utils::wait_until_ssh_up(this,
timeout,
key_provider,
std::bind(&LibVirtVirtualMachine::ensure_vm_is_running, this));
}

void mp::LibVirtVirtualMachine::update_state()
Expand All @@ -481,7 +491,7 @@ mp::LibVirtVirtualMachine::DomainUPtr mp::LibVirtVirtualMachine::initialize_doma
if (mac_addr.empty())
mac_addr = instance_mac_addr_for(domain.get(), libvirt_wrapper);

management_ipv4(); // To set ip
management_ipv4_impl(management_ip, mac_addr, libvirt_wrapper); // To set the IP.
state = refresh_instance_state_for_domain(domain.get(), state, libvirt_wrapper);

return domain;
Expand Down
4 changes: 2 additions & 2 deletions src/platform/backends/libvirt/libvirt_virtual_machine.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class LibVirtVirtualMachine final : public BaseVirtualMachine
int ssh_port() override;
std::string ssh_hostname(std::chrono::milliseconds timeout) override;
std::string ssh_username() override;
std::string management_ipv4() override;
std::string management_ipv4(const SSHKeyProvider& key_provider) override;
std::string ipv6() override;
void wait_until_ssh_up(std::chrono::milliseconds timeout) override;
void wait_until_ssh_up(std::chrono::milliseconds timeout, const SSHKeyProvider& key_provider) override;
void ensure_vm_is_running() override;
void update_state() override;
void update_cpus(int num_cores) override;
Expand Down
6 changes: 3 additions & 3 deletions src/platform/backends/lxd/lxd_virtual_machine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ std::string mp::LXDVirtualMachine::ssh_username()
return username;
}

std::string mp::LXDVirtualMachine::management_ipv4()
std::string mp::LXDVirtualMachine::management_ipv4(const SSHKeyProvider& /* unused on this backend */)
{
if (!management_ip)
{
Expand All @@ -385,9 +385,9 @@ std::string mp::LXDVirtualMachine::ipv6()
return {};
}

void mp::LXDVirtualMachine::wait_until_ssh_up(std::chrono::milliseconds timeout)
void mp::LXDVirtualMachine::wait_until_ssh_up(std::chrono::milliseconds timeout, const SSHKeyProvider& key_provider)
{
mpu::wait_until_ssh_up(this, timeout, [this] { ensure_vm_is_running(); });
mpu::wait_until_ssh_up(this, timeout, key_provider, [this] { ensure_vm_is_running(); });
}

const QUrl mp::LXDVirtualMachine::url()
Expand Down
4 changes: 2 additions & 2 deletions src/platform/backends/lxd/lxd_virtual_machine.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ class LXDVirtualMachine : public BaseVirtualMachine
int ssh_port() override;
std::string ssh_hostname(std::chrono::milliseconds timeout) override;
std::string ssh_username() override;
std::string management_ipv4() override;
std::string management_ipv4(const SSHKeyProvider& key_provider) override;
std::string ipv6() override;
void ensure_vm_is_running() override;
void ensure_vm_is_running(const std::chrono::milliseconds& timeout);
void wait_until_ssh_up(std::chrono::milliseconds timeout) override;
void wait_until_ssh_up(std::chrono::milliseconds timeout, const SSHKeyProvider& key_provider) override;
void update_state() override;
void update_cpus(int num_cores) override;
void resize_memory(const MemorySize& new_size) override;
Expand Down
9 changes: 6 additions & 3 deletions src/platform/backends/qemu/qemu_virtual_machine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ std::string mp::QemuVirtualMachine::ssh_username()
return username;
}

std::string mp::QemuVirtualMachine::management_ipv4()
std::string mp::QemuVirtualMachine::management_ipv4(const SSHKeyProvider& /* unused on this backend */)
{
if (!management_ip)
{
Expand All @@ -496,9 +496,12 @@ std::string mp::QemuVirtualMachine::ipv6()
return {};
}

void mp::QemuVirtualMachine::wait_until_ssh_up(std::chrono::milliseconds timeout)
void mp::QemuVirtualMachine::wait_until_ssh_up(std::chrono::milliseconds timeout, const SSHKeyProvider& key_provider)
{
mp::utils::wait_until_ssh_up(this, timeout, std::bind(&QemuVirtualMachine::ensure_vm_is_running, this));
mp::utils::wait_until_ssh_up(this,
timeout,
key_provider,
std::bind(&QemuVirtualMachine::ensure_vm_is_running, this));

if (is_starting_from_suspend)
{
Expand Down
4 changes: 2 additions & 2 deletions src/platform/backends/qemu/qemu_virtual_machine.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ class QemuVirtualMachine : public QObject, public BaseVirtualMachine
int ssh_port() override;
std::string ssh_hostname(std::chrono::milliseconds timeout) override;
std::string ssh_username() override;
std::string management_ipv4() override;
std::string management_ipv4(const SSHKeyProvider& key_provider) override;
std::string ipv6() override;
void ensure_vm_is_running() override;
void wait_until_ssh_up(std::chrono::milliseconds timeout) override;
void wait_until_ssh_up(std::chrono::milliseconds timeout, const SSHKeyProvider& key_provider) override;
void update_state() override;
void update_cpus(int num_cores) override;
void resize_memory(const MemorySize& new_size) override;
Expand Down
27 changes: 10 additions & 17 deletions src/ssh/ssh_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
namespace mp = multipass;
namespace mpl = multipass::logging;

mp::SSHSession::SSHSession(const std::string& host, int port, const std::string& username,
const SSHKeyProvider* key_provider, const std::chrono::milliseconds timeout)
mp::SSHSession::SSHSession(const std::string& host,
int port,
const std::string& username,
const SSHKeyProvider& key_provider,
const std::chrono::milliseconds timeout)
: session{ssh_new(), ssh_free}
{
if (session == nullptr)
Expand All @@ -53,22 +56,12 @@ mp::SSHSession::SSHSession(const std::string& host, int port, const std::string&
set_option(SSH_OPTIONS_SSH_DIR, ssh_dir.c_str());

SSH::throw_on_error(session, "ssh connection failed", ssh_connect);
if (key_provider)
{
SSH::throw_on_error(session, "ssh failed to authenticate", ssh_userauth_publickey, nullptr,
key_provider->private_key());
}
}

mp::SSHSession::SSHSession(const std::string& host, int port, const std::string& username,
const SSHKeyProvider& key_provider, const std::chrono::milliseconds timeout)
: SSHSession(host, port, username, &key_provider, timeout)
{
}

mp::SSHSession::SSHSession(const std::string& host, int port, const std::chrono::milliseconds timeout)
: SSHSession(host, port, "ubuntu", nullptr, timeout)
{
SSH::throw_on_error(session,
"ssh failed to authenticate",
ssh_userauth_publickey,
nullptr,
key_provider.private_key());
}

mp::SSHProcess mp::SSHSession::exec(const std::string& cmd)
Expand Down
11 changes: 8 additions & 3 deletions src/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,17 +320,22 @@ bool mp::utils::valid_mac_address(const std::string& mac)
return match.hasMatch();
}

void mp::utils::wait_until_ssh_up(VirtualMachine* virtual_machine, std::chrono::milliseconds timeout,
void mp::utils::wait_until_ssh_up(VirtualMachine* virtual_machine,
std::chrono::milliseconds timeout,
const mp::SSHKeyProvider& key_provider,
std::function<void()> const& ensure_vm_is_running)
{
static constexpr auto wait_step = 1s;
mpl::log(mpl::Level::debug, virtual_machine->vm_name, "Waiting for SSH to be up");

auto action = [virtual_machine, &ensure_vm_is_running] {
auto action = [virtual_machine, &key_provider, &ensure_vm_is_running] {
ensure_vm_is_running();
try
{
mp::SSHSession session{virtual_machine->ssh_hostname(wait_step), virtual_machine->ssh_port()};
mp::SSHSession session{virtual_machine->ssh_hostname(wait_step),
virtual_machine->ssh_port(),
virtual_machine->ssh_username(),
key_provider};

std::lock_guard<decltype(virtual_machine->state_mutex)> lock{virtual_machine->state_mutex};
virtual_machine->state = VirtualMachine::State::running;
Expand Down
6 changes: 4 additions & 2 deletions tests/libvirt/test_libvirt_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct LibVirtBackend : public Test
"pied-piper-valley",
"",
{},
"",
"ubuntu",
{dummy_image.name(), "", "", "", "", {}},
dummy_cloud_init_iso.name(),
{},
Expand Down Expand Up @@ -130,7 +130,9 @@ TEST_F(LibVirtBackend, creates_in_suspended_state_with_managed_save)

TEST_F(LibVirtBackend, machine_sends_monitoring_events)
{
const mpt::StubSSHKeyProvider key_provider;
REPLACE(ssh_connect, [](auto...) { return SSH_OK; });
REPLACE(ssh_userauth_publickey, [](auto...) { return SSH_AUTH_SUCCESS; });

mp::LibVirtVirtualMachineFactory backend{data_dir.path(), fake_libvirt_path};
backend.libvirt_wrapper->virNetworkGetDHCPLeases = [](auto, auto, auto leases, auto) {
Expand All @@ -154,7 +156,7 @@ TEST_F(LibVirtBackend, machine_sends_monitoring_events)
return 0;
};

machine->wait_until_ssh_up(2min);
machine->wait_until_ssh_up(2min, key_provider);

EXPECT_CALL(mock_monitor, on_shutdown());
machine->shutdown();
Expand Down
5 changes: 3 additions & 2 deletions tests/lxd/test_lxd_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "tests/mock_logger.h"
#include "tests/mock_platform.h"
#include "tests/mock_status_monitor.h"
#include "tests/stub_ssh_key_provider.h"
#include "tests/stub_status_monitor.h"
#include "tests/stub_url_downloader.h"
#include "tests/temp_dir.h"
Expand Down Expand Up @@ -1099,7 +1100,7 @@ TEST_P(LXDNetworkInfoSuite, returns_expected_network_info)
mp::LXDVirtualMachine machine{default_description, stub_monitor, mock_network_access_manager.get(), base_url,
bridge_name, default_storage_pool};

EXPECT_EQ(machine.management_ipv4(), "10.217.27.168");
EXPECT_EQ(machine.management_ipv4(mpt::StubSSHKeyProvider()), "10.217.27.168");
EXPECT_TRUE(machine.ipv6().empty());
EXPECT_EQ(machine.ssh_username(), default_description.ssh_username);
EXPECT_EQ(machine.ssh_port(), 22);
Expand Down Expand Up @@ -1182,7 +1183,7 @@ TEST_F(LXDBackend, no_ip_address_returns_unknown)
mp::LXDVirtualMachine machine{default_description, stub_monitor, mock_network_access_manager.get(), base_url,
bridge_name, default_storage_pool};

EXPECT_EQ(machine.management_ipv4(), "UNKNOWN");
EXPECT_EQ(machine.management_ipv4(mpt::StubSSHKeyProvider()), "UNKNOWN");
}

TEST_F(LXDBackend, lxd_request_timeout_aborts_and_throws)
Expand Down
6 changes: 3 additions & 3 deletions tests/mock_virtual_machine.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct MockVirtualMachineT : public T
ON_CALL(*this, ssh_hostname()).WillByDefault(Return("localhost"));
ON_CALL(*this, ssh_hostname(_)).WillByDefault(Return("localhost"));
ON_CALL(*this, ssh_username()).WillByDefault(Return("ubuntu"));
ON_CALL(*this, management_ipv4()).WillByDefault(Return("0.0.0.0"));
ON_CALL(*this, management_ipv4(_)).WillByDefault(Return("0.0.0.0"));
ON_CALL(*this, get_all_ipv4(_)).WillByDefault(Return(std::vector<std::string>{"192.168.2.123"}));
ON_CALL(*this, ipv6()).WillByDefault(Return("::/0"));
}
Expand All @@ -55,11 +55,11 @@ struct MockVirtualMachineT : public T
MOCK_METHOD(std::string, ssh_hostname, (), (override));
MOCK_METHOD(std::string, ssh_hostname, (std::chrono::milliseconds), (override));
MOCK_METHOD(std::string, ssh_username, (), (override));
MOCK_METHOD(std::string, management_ipv4, (), (override));
MOCK_METHOD(std::string, management_ipv4, (const SSHKeyProvider&), (override));
MOCK_METHOD(std::vector<std::string>, get_all_ipv4, (const SSHKeyProvider&), (override));
MOCK_METHOD(std::string, ipv6, (), (override));
MOCK_METHOD(void, ensure_vm_is_running, (), (override));
MOCK_METHOD(void, wait_until_ssh_up, (std::chrono::milliseconds), (override));
MOCK_METHOD(void, wait_until_ssh_up, (std::chrono::milliseconds, const SSHKeyProvider&), (override));
MOCK_METHOD(void, update_state, (), (override));
MOCK_METHOD(void, update_cpus, (int num_cores), (override));
MOCK_METHOD(void, resize_memory, (const MemorySize& new_size), (override));
Expand Down
4 changes: 2 additions & 2 deletions tests/qemu/test_qemu_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ TEST_F(QemuBackend, gets_management_ip)
machine.start();
machine.state = mp::VirtualMachine::State::running;

EXPECT_EQ(machine.management_ipv4(), expected_ip);
EXPECT_EQ(machine.management_ipv4(mpt::StubSSHKeyProvider()), expected_ip);
}

TEST_F(QemuBackend, fails_to_get_management_ip_if_dnsmasq_does_not_return_an_ip)
Expand All @@ -655,7 +655,7 @@ TEST_F(QemuBackend, fails_to_get_management_ip_if_dnsmasq_does_not_return_an_ip)
machine.start();
machine.state = mp::VirtualMachine::State::running;

EXPECT_EQ(machine.management_ipv4(), "UNKNOWN");
EXPECT_EQ(machine.management_ipv4(mpt::StubSSHKeyProvider()), "UNKNOWN");
}

TEST_F(QemuBackend, ssh_hostname_timeout_throws_and_sets_unknown_state)
Expand Down
Loading

0 comments on commit 9740dfd

Please sign in to comment.