diff --git a/td/mtproto/Handshake.cpp b/td/mtproto/Handshake.cpp index 1ec70aef2b8d..23b32e758f11 100644 --- a/td/mtproto/Handshake.cpp +++ b/td/mtproto/Handshake.cpp @@ -47,12 +47,20 @@ static Result fetch_result(Slice message, bool check_end } AuthKeyHandshake::AuthKeyHandshake(int32 dc_id, int32 expires_in) - : mode_(expires_in == 0 ? Mode::Main : Mode::Temp), dc_id_(dc_id), expires_in_(expires_in) { + : mode_(expires_in == 0 ? Mode::Main : Mode::Temp) + , dc_id_(dc_id) + , expires_in_(expires_in) + , timeout_at_(Time::now() + 1e9) { +} + +void AuthKeyHandshake::set_timeout_in(double timeout_in) { + timeout_at_ = Time::now() + timeout_in; } void AuthKeyHandshake::clear() { last_query_ = BufferSlice(); state_ = Start; + timeout_at_ = Time::now() + 1e9; } bool AuthKeyHandshake::is_ready_for_finish() const { @@ -294,6 +302,10 @@ Status AuthKeyHandshake::on_start(Callback *connection) { Status AuthKeyHandshake::on_message(Slice message, Callback *connection, AuthKeyHandshakeContext *context) { Status status = [&] { + if (Time::now() >= timeout_at_) { + return Status::Error("Handshake timeout expired"); + } + switch (state_) { case ResPQ: return on_res_pq(message, connection, context->get_public_rsa_key_interface()); diff --git a/td/mtproto/Handshake.h b/td/mtproto/Handshake.h index 14077b546652..435350d9416d 100644 --- a/td/mtproto/Handshake.h +++ b/td/mtproto/Handshake.h @@ -45,6 +45,8 @@ class AuthKeyHandshake { AuthKeyHandshake(int32 dc_id, int32 expires_in); + void set_timeout_in(double timeout_in); + bool is_ready_for_finish() const; void on_finish(); @@ -80,6 +82,8 @@ class AuthKeyHandshake { int32 expires_in_ = 0; double expires_at_ = 0; + double timeout_at_ = 0; + AuthKey auth_key_; double server_time_diff_ = 0; uint64 server_salt_ = 0; diff --git a/td/mtproto/HandshakeActor.cpp b/td/mtproto/HandshakeActor.cpp index 3bb95d1e2a34..1f9c14b12056 100644 --- a/td/mtproto/HandshakeActor.cpp +++ b/td/mtproto/HandshakeActor.cpp @@ -34,6 +34,7 @@ void HandshakeActor::close() { void HandshakeActor::start_up() { Scheduler::subscribe(connection_->get_poll_info().extract_pollable_fd(this)); set_timeout_in(timeout_); + handshake_->set_timeout_in(timeout_); yield(); }