refactor: tr_handshake lifecycle (#4358)

This commit is contained in:
Charles Kerr 2022-12-13 11:59:21 -06:00 committed by GitHub
parent 9e0b42a61d
commit 2f6315b649
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 561 additions and 590 deletions

File diff suppressed because it is too large Load Diff

View File

@ -9,17 +9,18 @@
#error only libtransmission should #include this header.
#endif
#include <chrono>
#include <cstddef> // for size_t
#include <optional>
#include <functional>
#include <memory>
#include <optional>
#include "transmission.h"
#include "net.h" // tr_address
#include "peer-mse.h" // tr_message_stream_encryption::DH
/** @addtogroup peers Peers
@{ */
#include "peer-io.h"
#include "timer.h"
namespace libtransmission
{
@ -30,64 +31,265 @@ class tr_peerIo;
/** @brief opaque struct holding handshake state information.
freed when the handshake is completed. */
struct tr_handshake;
struct tr_handshake_result
{
struct tr_handshake* handshake;
std::shared_ptr<tr_peerIo> io;
bool readAnythingFromPeer;
bool isConnected;
void* userData;
std::optional<tr_peer_id_t> peer_id;
};
class tr_handshake_mediator
class tr_handshake
{
public:
struct torrent_info
using DH = tr_message_stream_encryption::DH;
struct Result
{
tr_sha1_digest_t info_hash;
tr_peer_id_t client_peer_id;
tr_torrent_id_t id;
bool is_done;
std::shared_ptr<tr_peerIo> io;
std::optional<tr_peer_id_t> peer_id;
bool read_anything_from_peer;
bool is_connected;
};
virtual ~tr_handshake_mediator() = default;
using DoneFunc = std::function<bool(Result const&)>;
[[nodiscard]] virtual std::optional<torrent_info> torrentInfo(tr_sha1_digest_t const& info_hash) const = 0;
[[nodiscard]] virtual std::optional<torrent_info> torrentInfoFromObfuscated(tr_sha1_digest_t const& info_hash) const = 0;
[[nodiscard]] virtual libtransmission::TimerMaker& timerMaker() = 0;
[[nodiscard]] virtual bool allowsDHT() const = 0;
[[nodiscard]] virtual bool allowsTCP() const = 0;
[[nodiscard]] virtual bool isPeerKnownSeed(tr_torrent_id_t tor_id, tr_address addr) const = 0;
[[nodiscard]] virtual size_t pad(void* setme, size_t max_bytes) const = 0;
[[nodiscard]] virtual tr_message_stream_encryption::DH::private_key_bigend_t privateKey() const
class Mediator
{
return tr_message_stream_encryption::DH::randomPrivateKey();
public:
struct TorrentInfo
{
tr_sha1_digest_t info_hash;
tr_peer_id_t client_peer_id;
tr_torrent_id_t id;
bool is_done;
};
virtual ~Mediator() = default;
[[nodiscard]] virtual std::optional<TorrentInfo> torrent_info(tr_sha1_digest_t const& info_hash) const = 0;
[[nodiscard]] virtual std::optional<TorrentInfo> torrent_info_from_obfuscated(
tr_sha1_digest_t const& info_hash) const = 0;
[[nodiscard]] virtual libtransmission::TimerMaker& timer_maker() = 0;
[[nodiscard]] virtual bool allows_dht() const = 0;
[[nodiscard]] virtual bool allows_tcp() const = 0;
[[nodiscard]] virtual bool is_peer_known_seed(tr_torrent_id_t tor_id, tr_address const& addr) const = 0;
[[nodiscard]] virtual size_t pad(void* setme, size_t max_bytes) const = 0;
[[nodiscard]] virtual DH::private_key_bigend_t private_key() const
{
return DH::randomPrivateKey();
}
virtual void set_utp_failed(tr_sha1_digest_t const& info_hash, tr_address const&) = 0;
};
tr_handshake(Mediator* mediator, std::shared_ptr<tr_peerIo> peer_io, tr_encryption_mode mode_in, DoneFunc done_func);
private:
enum class State
{
// incoming
AwaitingHandshake,
AwaitingPeerId,
AwaitingYa,
AwaitingPadA,
AwaitingCryptoProvide,
AwaitingPadC,
AwaitingIa,
AwaitingPayloadStream,
// outgoing
AwaitingYb,
AwaitingVc,
AwaitingCryptoSelect,
AwaitingPadD
};
bool build_handshake_message(tr_peerIo* io, uint8_t* buf) const;
ReadState read_crypto_provide(tr_peerIo* peer_io);
ReadState read_crypto_select(tr_peerIo* peer_io);
ReadState read_handshake(tr_peerIo* peer_io);
ReadState read_ia(tr_peerIo* peer_io);
ReadState read_pad_a(tr_peerIo* peer_io);
ReadState read_pad_c(tr_peerIo* peer_io);
ReadState read_pad_d(tr_peerIo* peer_io);
ReadState read_payload_stream(tr_peerIo* peer_io);
ReadState read_peer_id(tr_peerIo* peer_io);
ReadState read_vc(tr_peerIo* peer_io);
ReadState read_ya(tr_peerIo* peer_io);
ReadState read_yb(tr_peerIo* peer_io);
void send_ya(tr_peerIo* io);
enum class ParseResult
{
Ok,
EncryptionWrong,
BadTorrent,
PeerIsSelf,
};
ParseResult parse_handshake(tr_peerIo* peer_io);
static ReadState can_read(tr_peerIo* peer_io, void* vhandshake, size_t* piece);
static void on_error(tr_peerIo* io, short what, void* vhandshake);
void set_peer_id(tr_peer_id_t const& id) noexcept
{
peer_id_ = id;
}
virtual void setUTPFailed(tr_sha1_digest_t const& info_hash, tr_address) = 0;
void set_have_read_anything_from_peer(bool val) noexcept
{
have_read_anything_from_peer_ = val;
}
ReadState done(bool is_connected)
{
peer_io_->clearCallbacks();
return fire_done(is_connected) ? READ_LATER : READ_ERR;
}
[[nodiscard]] auto is_incoming() const noexcept
{
return peer_io_->isIncoming();
}
[[nodiscard]] auto display_name() const
{
return peer_io_->display_name();
}
void set_utp_failed(tr_sha1_digest_t const& info_hash, tr_address const& addr)
{
mediator_->set_utp_failed(info_hash, addr);
}
[[nodiscard]] constexpr auto state() const noexcept
{
return state_;
}
[[nodiscard]] constexpr auto is_state(State state) const noexcept
{
return state_ == state;
}
constexpr void set_state(State state)
{
state_ = state;
}
[[nodiscard]] constexpr std::string_view state_string() const
{
return state_string(state_);
}
[[nodiscard]] constexpr uint32_t crypto_provide() const
{
uint32_t provide = 0;
switch (encryption_mode_)
{
case TR_ENCRYPTION_REQUIRED:
case TR_ENCRYPTION_PREFERRED:
provide |= CryptoProvideCrypto;
break;
case TR_CLEAR_PREFERRED:
provide |= CryptoProvideCrypto | CryptoProvidePlaintext;
break;
}
return provide;
}
bool fire_done(bool is_connected)
{
if (!done_func_)
{
return false;
}
auto cb = DoneFunc{};
std::swap(cb, done_func_);
auto peer_io = std::shared_ptr<tr_peerIo>{};
std::swap(peer_io, peer_io_);
bool const success = (cb)(Result{ std::move(peer_io), peer_id_, have_read_anything_from_peer_, is_connected });
return success;
}
static auto constexpr HandshakeTimeoutSec = std::chrono::seconds{ 30 };
[[nodiscard]] static constexpr std::string_view state_string(State state)
{
using State = tr_handshake::State;
switch (state)
{
case State::AwaitingHandshake:
return "awaiting handshake";
case State::AwaitingPeerId:
return "awaiting peer id";
case State::AwaitingYa:
return "awaiting ya";
case State::AwaitingPadA:
return "awaiting pad a";
case State::AwaitingCryptoProvide:
return "awaiting crypto provide";
case State::AwaitingPadC:
return "awaiting pad c";
case State::AwaitingIa:
return "awaiting ia";
case State::AwaitingPayloadStream:
return "awaiting payload stream";
// outgoing
case State::AwaitingYb:
return "awaiting yb";
case State::AwaitingVc:
return "awaiting vc";
case State::AwaitingCryptoSelect:
return "awaiting crypto select";
case State::AwaitingPadD:
return "awaiting pad d";
}
}
template<size_t PadMax>
void send_public_key_and_pad(tr_peerIo* io)
{
auto const public_key = dh_.publicKey();
auto outbuf = std::array<std::byte, std::size(public_key) + PadMax>{};
auto const data = std::data(outbuf);
auto walk = data;
walk = std::copy(std::begin(public_key), std::end(public_key), walk);
walk += mediator_->pad(walk, PadMax);
io->writeBytes(data, walk - data, false);
}
static auto constexpr CryptoProvidePlaintext = int{ 1 };
static auto constexpr CryptoProvideCrypto = int{ 2 };
DH dh_ = {};
DoneFunc done_func_;
std::optional<tr_peer_id_t> peer_id_;
std::shared_ptr<tr_peerIo> peer_io_;
std::unique_ptr<libtransmission::Timer> timeout_timer_;
Mediator* mediator_ = nullptr;
State state_ = State::AwaitingHandshake;
tr_encryption_mode encryption_mode_;
uint32_t crypto_select_ = {};
uint32_t crypto_provide_ = {};
uint16_t pad_c_len_ = {};
uint16_t pad_d_len_ = {};
uint16_t ia_len_ = {};
bool have_read_anything_from_peer_ = false;
bool have_sent_bittorrent_handshake_ = false;
};
/* returns true on success, false on error */
using tr_handshake_done_func = bool (*)(tr_handshake_result const& result);
/** @brief create a new handshake */
tr_handshake* tr_handshakeNew(
tr_handshake_mediator& mediator,
std::shared_ptr<tr_peerIo> io,
tr_encryption_mode encryption_mode,
tr_handshake_done_func done_func,
void* done_func_user_data);
void tr_handshakeAbort(tr_handshake* handshake);
/** @} */

View File

@ -12,6 +12,7 @@
#include <cstdint>
#include <ctime> // time_t
#include <deque>
#include <map>
#include <iterator> // std::back_inserter
#include <memory>
#include <numeric> // std::accumulate
@ -66,17 +67,17 @@ static auto constexpr CancelHistorySec = int{ 60 };
static bool tr_peerMgrPeerIsSeed(tr_torrent const* tor, tr_address const& addr);
class tr_handshake_mediator_impl final : public tr_handshake_mediator
class HandshakeMediator final : public tr_handshake::Mediator
{
private:
[[nodiscard]] static std::optional<torrent_info> torrentInfo(tr_torrent* tor)
[[nodiscard]] static std::optional<TorrentInfo> torrent_info(tr_torrent* tor)
{
if (tor == nullptr)
{
return {};
}
auto info = torrent_info{};
auto info = TorrentInfo{};
info.info_hash = tor->infoHash();
info.client_peer_id = tr_torrentGetPeerId(tor);
info.id = tor->id();
@ -85,33 +86,33 @@ private:
}
public:
explicit tr_handshake_mediator_impl(tr_session& session) noexcept
explicit HandshakeMediator(tr_session& session) noexcept
: session_{ session }
{
}
[[nodiscard]] std::optional<torrent_info> torrentInfo(tr_sha1_digest_t const& info_hash) const override
[[nodiscard]] std::optional<TorrentInfo> torrent_info(tr_sha1_digest_t const& info_hash) const override
{
return torrentInfo(session_.torrents().get(info_hash));
return torrent_info(session_.torrents().get(info_hash));
}
[[nodiscard]] std::optional<torrent_info> torrentInfoFromObfuscated(
[[nodiscard]] std::optional<TorrentInfo> torrent_info_from_obfuscated(
tr_sha1_digest_t const& obfuscated_info_hash) const override
{
return torrentInfo(tr_torrentFindFromObfuscatedHash(&session_, obfuscated_info_hash));
return torrent_info(tr_torrentFindFromObfuscatedHash(&session_, obfuscated_info_hash));
}
[[nodiscard]] bool allowsDHT() const override
[[nodiscard]] bool allows_dht() const override
{
return session_.allowsDHT();
}
[[nodiscard]] bool allowsTCP() const override
[[nodiscard]] bool allows_tcp() const override
{
return session_.allowsTCP();
}
void setUTPFailed(tr_sha1_digest_t const& info_hash, tr_address addr) override
void set_utp_failed(tr_sha1_digest_t const& info_hash, tr_address const& addr) override
{
if (auto* const tor = session_.torrents().get(info_hash); tor != nullptr)
{
@ -119,13 +120,13 @@ public:
}
}
[[nodiscard]] bool isPeerKnownSeed(tr_torrent_id_t tor_id, tr_address addr) const override
[[nodiscard]] bool is_peer_known_seed(tr_torrent_id_t tor_id, tr_address const& addr) const override
{
auto const* const tor = session_.torrents().get(tor_id);
return tor != nullptr && tr_peerMgrPeerIsSeed(tor, addr);
}
[[nodiscard]] libtransmission::TimerMaker& timerMaker() override
[[nodiscard]] libtransmission::TimerMaker& timer_maker() override
{
return session_.timerMaker();
}
@ -296,58 +297,7 @@ private:
static auto constexpr MinimumReconnectIntervalSecs = int{ 5 };
};
// a container for keeping track of tr_handshakes
class Handshakes
{
public:
void add(tr_address const& address, tr_handshake* handshake)
{
TR_ASSERT(!contains(address));
handshakes_.emplace_back(address, handshake);
}
[[nodiscard]] bool contains(tr_address const& address) const noexcept
{
return std::any_of(
std::begin(handshakes_),
std::end(handshakes_),
[&address](auto const& pair) { return pair.first == address; });
}
void erase(tr_address const& address)
{
for (auto iter = std::begin(handshakes_), end = std::end(handshakes_); iter != end; ++iter)
{
if (iter->first == address)
{
handshakes_.erase(iter);
return;
}
}
}
[[nodiscard]] auto empty() const noexcept
{
return std::empty(handshakes_);
}
void abortAll()
{
// make a tmp copy so that calls to tr_handshakeAbort() won't
// be able to invalidate its loop iteration
auto tmp = handshakes_;
for (auto& [addr, handshake] : tmp)
{
tr_handshakeAbort(handshake);
}
handshakes_ = {};
}
private:
std::vector<std::pair<tr_address, tr_handshake*>> handshakes_;
};
using Handshakes = std::map<tr_address, tr_handshake>;
#define tr_logAddDebugSwarm(swarm, msg) tr_logAddDebugTor((swarm)->tor, msg)
#define tr_logAddTraceSwarm(swarm, msg) tr_logAddTraceTor((swarm)->tor, msg)
@ -414,7 +364,7 @@ public:
is_running = false;
removeAllPeers();
outgoing_handshakes.abortAll();
outgoing_handshakes.clear();
}
void removePeer(tr_peer* peer)
@ -586,7 +536,7 @@ struct tr_peerMgr
~tr_peerMgr()
{
auto const lock = unique_lock();
incoming_handshakes.abortAll();
incoming_handshakes.clear();
}
void rechokeSoon() noexcept
@ -603,7 +553,7 @@ struct tr_peerMgr
tr_session* const session;
Handshakes incoming_handshakes;
tr_handshake_mediator_impl handshake_mediator_;
HandshakeMediator handshake_mediator_;
private:
void rechokePulseMarshall()
@ -684,8 +634,8 @@ static struct peer_atom* getExistingAtom(tr_swarm const* cswarm, tr_address cons
static bool peerIsInUse(tr_swarm const* swarm, struct peer_atom const* atom)
{
return atom->is_connected || swarm->outgoing_handshakes.contains(atom->addr) ||
swarm->manager->incoming_handshakes.contains(atom->addr);
return atom->is_connected || swarm->outgoing_handshakes.count(atom->addr) != 0U ||
swarm->manager->incoming_handshakes.count(atom->addr) != 0U;
}
static void swarmFree(tr_swarm* s)
@ -1133,13 +1083,12 @@ static void createBitTorrentPeer(tr_torrent* tor, std::shared_ptr<tr_peerIo> io,
}
/* FIXME: this is kind of a mess. */
static bool on_handshake_done(tr_handshake_result const& result)
static bool on_handshake_done(tr_peerMgr* manager, tr_handshake::Result const& result)
{
TR_ASSERT(result.io != nullptr);
bool const ok = result.isConnected;
bool const ok = result.is_connected;
bool success = false;
auto* manager = static_cast<tr_peerMgr*>(result.userData);
tr_swarm* const s = getExistingSwarm(manager, result.io->torrentHash());
@ -1166,7 +1115,7 @@ static bool on_handshake_done(tr_handshake_result const& result)
{
++atom->num_fails;
if (!result.readAnythingFromPeer)
if (!result.read_anything_from_peer)
{
tr_logAddTraceSwarm(
s,
@ -1244,21 +1193,19 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_peer_socket&& socket)
tr_logAddTrace(fmt::format("Banned IP address '{}' tried to connect to us", socket.display_name()));
socket.close(session);
}
else if (manager->incoming_handshakes.contains(socket.address()))
else if (manager->incoming_handshakes.count(socket.address()) != 0U)
{
socket.close(session);
}
else /* we don't have a connection to them yet... */
{
auto address = socket.address();
manager->incoming_handshakes.add(
manager->incoming_handshakes.try_emplace(
address,
tr_handshakeNew(
manager->handshake_mediator_,
tr_peerIo::newIncoming(session, &session->top_bandwidth_, std::move(socket)),
session->encryptionMode(),
on_handshake_done,
manager));
&manager->handshake_mediator_,
tr_peerIo::newIncoming(session, &session->top_bandwidth_, std::move(socket)),
session->encryptionMode(),
[manager](tr_handshake::Result const& result) { return on_handshake_done(manager, result); });
}
}
@ -2813,7 +2760,7 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom)
s,
fmt::format("Starting an OUTGOING {} connection with {}", utp ? " µTP" : "TCP", atom.display_name()));
auto io = tr_peerIo::newOutgoing(
auto peer_io = tr_peerIo::newOutgoing(
mgr->session,
&mgr->session->top_bandwidth_,
atom.addr,
@ -2822,7 +2769,7 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom)
s->tor->completeness == TR_SEED,
utp);
if (io == nullptr)
if (!peer_io)
{
tr_logAddTraceSwarm(s, fmt::format("peerIo not created; marking peer {} as unreachable", atom.display_name()));
atom.flags2 |= MyflagUnreachable;
@ -2830,13 +2777,12 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom)
}
else
{
auto* const handshake = tr_handshakeNew(
mgr->handshake_mediator_,
std::move(io),
s->outgoing_handshakes.try_emplace(
atom.addr,
&mgr->handshake_mediator_,
peer_io,
mgr->session->encryptionMode(),
on_handshake_done,
mgr);
s->outgoing_handshakes.add(atom.addr, handshake);
[mgr](tr_handshake::Result const& result) { return on_handshake_done(mgr, result); });
}
atom.lastConnectionAttemptAt = now;

View File

@ -35,7 +35,7 @@ auto constexpr MaxWaitMsec = int{ 5000 };
class HandshakeTest : public SessionTest
{
public:
class MediatorMock final : public tr_handshake_mediator
class MediatorMock final : public tr_handshake::Mediator
{
public:
explicit MediatorMock(tr_session* session)
@ -43,7 +43,7 @@ public:
{
}
[[nodiscard]] std::optional<torrent_info> torrentInfo(tr_sha1_digest_t const& info_hash) const override
[[nodiscard]] std::optional<TorrentInfo> torrent_info(tr_sha1_digest_t const& info_hash) const override
{
if (auto const iter = torrents.find(info_hash); iter != std::end(torrents))
{
@ -53,7 +53,7 @@ public:
return {};
}
[[nodiscard]] std::optional<torrent_info> torrentInfoFromObfuscated(tr_sha1_digest_t const& obfuscated) const override
[[nodiscard]] std::optional<TorrentInfo> torrent_info_from_obfuscated(tr_sha1_digest_t const& obfuscated) const override
{
for (auto const& [info_hash, info] : torrents)
{
@ -66,22 +66,22 @@ public:
return {};
}
[[nodiscard]] libtransmission::TimerMaker& timerMaker() override
[[nodiscard]] libtransmission::TimerMaker& timer_maker() override
{
return session_->timerMaker();
}
[[nodiscard]] bool allowsDHT() const override
[[nodiscard]] bool allows_dht() const override
{
return false;
}
[[nodiscard]] bool allowsTCP() const override
[[nodiscard]] bool allows_tcp() const override
{
return true;
}
[[nodiscard]] bool isPeerKnownSeed(tr_torrent_id_t /*tor_id*/, tr_address /*addr*/) const override
[[nodiscard]] bool is_peer_known_seed(tr_torrent_id_t /*tor_id*/, tr_address const& /*addr*/) const override
{
return false;
}
@ -94,12 +94,12 @@ public:
return len;
}
[[nodiscard]] tr_message_stream_encryption::DH::private_key_bigend_t privateKey() const override
[[nodiscard]] tr_message_stream_encryption::DH::private_key_bigend_t private_key() const override
{
return private_key_;
}
void setUTPFailed(tr_sha1_digest_t const& /*info_hash*/, tr_address /*addr*/) override
void set_utp_failed(tr_sha1_digest_t const& /*info_hash*/, tr_address const& /*addr*/) override
{
}
@ -111,7 +111,7 @@ public:
}
tr_session* const session_;
std::map<tr_sha1_digest_t, torrent_info> torrents;
std::map<tr_sha1_digest_t, TorrentInfo> torrents;
tr_message_stream_encryption::DH::private_key_bigend_t private_key_ = {};
};
@ -145,11 +145,11 @@ public:
tr_address const DefaultPeerAddr = *tr_address::from_string("127.0.0.1"sv);
tr_port const DefaultPeerPort = tr_port::fromHost(8080);
tr_handshake_mediator::torrent_info const TorrentWeAreSeeding{ tr_sha1::digest("abcde"sv),
tr_handshake::Mediator::TorrentInfo const TorrentWeAreSeeding{ tr_sha1::digest("abcde"sv),
tr_peerIdInit(),
tr_torrent_id_t{ 100 },
true /*is_done*/ };
tr_handshake_mediator::torrent_info const UbuntuTorrent{ *tr_sha1_from_string("2c6b6858d61da9543d4231a71db4b1c9264b0685"sv),
tr_handshake::Mediator::TorrentInfo const UbuntuTorrent{ *tr_sha1_from_string("2c6b6858d61da9543d4231a71db4b1c9264b0685"sv),
tr_peerIdInit(),
tr_torrent_id_t{ 101 },
false /*is_done*/ };
@ -158,25 +158,27 @@ public:
{
auto sockpair = std::array<evutil_socket_t, 2>{ -1, -1 };
EXPECT_EQ(0, evutil_socketpair(LOCAL_SOCKETPAIR_AF, SOCK_STREAM, 0, std::data(sockpair))) << tr_strerror(errno);
auto io = tr_peerIo::newIncoming(
session,
&session->top_bandwidth_,
tr_peer_socket(session, DefaultPeerAddr, DefaultPeerPort, sockpair[0]));
return std::make_pair(io, sockpair[1]);
return std::make_pair(
tr_peerIo::newIncoming(
session,
&session->top_bandwidth_,
tr_peer_socket(session, DefaultPeerAddr, DefaultPeerPort, sockpair[0])),
sockpair[1]);
}
auto createOutgoingIo(tr_session* session, tr_sha1_digest_t const& info_hash)
{
auto sockpair = std::array<evutil_socket_t, 2>{ -1, -1 };
EXPECT_EQ(0, evutil_socketpair(LOCAL_SOCKETPAIR_AF, SOCK_STREAM, 0, std::data(sockpair))) << tr_strerror(errno);
auto io = tr_peerIo::create(
session,
&session->top_bandwidth_,
&info_hash,
false /*is_incoming*/,
false /*is_seed*/,
tr_peer_socket(session, DefaultPeerAddr, DefaultPeerPort, sockpair[0]));
return std::make_pair(io, sockpair[1]);
return std::make_pair(
tr_peerIo::create(
session,
&session->top_bandwidth_,
&info_hash,
false /*is_incoming*/,
false /*is_seed*/,
tr_peer_socket(session, DefaultPeerAddr, DefaultPeerPort, sockpair[0])),
sockpair[1]);
}
static constexpr auto makePeerId(std::string_view sv)
@ -198,22 +200,22 @@ public:
}
static auto runHandshake(
tr_handshake_mediator& mediator,
std::shared_ptr<tr_peerIo> io,
tr_handshake::Mediator* mediator,
std::shared_ptr<tr_peerIo> const& peer_io,
tr_encryption_mode encryption_mode = TR_CLEAR_PREFERRED)
{
auto result = std::optional<tr_handshake_result>{};
auto result = std::optional<tr_handshake::Result>{};
static auto const DoneCallback = [](auto const& resin)
{
*static_cast<std::optional<tr_handshake_result>*>(resin.userData) = resin;
return true;
};
tr_handshakeNew(mediator, std::move(io), encryption_mode, DoneCallback, &result);
auto handshake = tr_handshake{ mediator,
peer_io,
encryption_mode,
[&result](auto const& resin)
{
result = resin;
return true;
} };
waitFor([&result]() { return result.has_value(); }, MaxWaitMsec);
return result;
}
};
@ -239,12 +241,12 @@ TEST_F(HandshakeTest, incomingPlaintext)
sendToClient(sock, TorrentWeAreSeeding.info_hash);
sendToClient(sock, peer_id);
auto const res = runHandshake(mediator, io);
auto const res = runHandshake(&mediator, io);
// check the results
EXPECT_TRUE(res);
EXPECT_TRUE(res->isConnected);
EXPECT_TRUE(res->readAnythingFromPeer);
EXPECT_TRUE(res->is_connected);
EXPECT_TRUE(res->read_anything_from_peer);
EXPECT_EQ(io, res->io);
EXPECT_TRUE(res->peer_id);
EXPECT_EQ(peer_id, res->peer_id);
@ -266,12 +268,12 @@ TEST_F(HandshakeTest, incomingPlaintextUnknownInfoHash)
sendToClient(sock, tr_sha1::digest("some other torrent unknown to us"sv));
sendToClient(sock, makeRandomPeerId());
auto const res = runHandshake(mediator, io);
auto const res = runHandshake(&mediator, io);
// check the results
EXPECT_TRUE(res);
EXPECT_FALSE(res->isConnected);
EXPECT_TRUE(res->readAnythingFromPeer);
EXPECT_FALSE(res->is_connected);
EXPECT_TRUE(res->read_anything_from_peer);
EXPECT_EQ(io, res->io);
EXPECT_FALSE(res->peer_id);
EXPECT_EQ(tr_sha1_digest_t{}, io->torrentHash());
@ -291,12 +293,12 @@ TEST_F(HandshakeTest, outgoingPlaintext)
sendToClient(sock, UbuntuTorrent.info_hash);
sendToClient(sock, peer_id);
auto const res = runHandshake(mediator, io);
auto const res = runHandshake(&mediator, io);
// check the results
EXPECT_TRUE(res);
EXPECT_TRUE(res->isConnected);
EXPECT_TRUE(res->readAnythingFromPeer);
EXPECT_TRUE(res->is_connected);
EXPECT_TRUE(res->read_anything_from_peer);
EXPECT_EQ(io, res->io);
EXPECT_TRUE(res->peer_id);
EXPECT_EQ(peer_id, res->peer_id);
@ -329,12 +331,12 @@ TEST_F(HandshakeTest, incomingEncrypted)
"VGwrTPstEPu3V5lmzjtMGVLaL5EErlpJ93Xrz+ea6EIQEUZA+D4jKaV/to9NVi"
"04/1W1A2PHgg+I9puac/i9BsFPcjdQeoVtU73lNCbTDQgTieyjDWmwo="sv);
auto const res = runHandshake(mediator, io);
auto const res = runHandshake(&mediator, io);
// check the results
EXPECT_TRUE(res);
EXPECT_TRUE(res->isConnected);
EXPECT_TRUE(res->readAnythingFromPeer);
EXPECT_TRUE(res->is_connected);
EXPECT_TRUE(res->read_anything_from_peer);
EXPECT_EQ(io, res->io);
EXPECT_TRUE(res->peer_id);
EXPECT_EQ(ExpectedPeerId, res->peer_id);
@ -366,12 +368,12 @@ TEST_F(HandshakeTest, incomingEncryptedUnknownInfoHash)
"VGwrTPstEPu3V5lmzjtMGVLaL5EErlpJ93Xrz+ea6EIQEUZA+D4jKaV/to9NVi"
"04/1W1A2PHgg+I9puac/i9BsFPcjdQeoVtU73lNCbTDQgTieyjDWmwo="sv);
auto const res = runHandshake(mediator, io);
auto const res = runHandshake(&mediator, io);
// check the results
EXPECT_TRUE(res);
EXPECT_FALSE(res->isConnected);
EXPECT_TRUE(res->readAnythingFromPeer);
EXPECT_FALSE(res->is_connected);
EXPECT_TRUE(res->read_anything_from_peer);
EXPECT_EQ(tr_sha1_digest_t{}, io->torrentHash());
evutil_closesocket(sock);
@ -405,12 +407,12 @@ TEST_F(HandshakeTest, outgoingEncrypted)
"3+o/RdiKQJAsGxMIU08scBc5VOmrAmjeYrLNpFnpXVuavH5if7490zMCu3DEn"
"G9hpbYbiX95T+EUcRbM6pSCvr3Twq1Q="sv);
auto const res = runHandshake(mediator, io, TR_ENCRYPTION_PREFERRED);
auto const res = runHandshake(&mediator, io, TR_ENCRYPTION_PREFERRED);
// check the results
EXPECT_TRUE(res);
EXPECT_TRUE(res->isConnected);
EXPECT_TRUE(res->readAnythingFromPeer);
EXPECT_TRUE(res->is_connected);
EXPECT_TRUE(res->read_anything_from_peer);
EXPECT_EQ(io, res->io);
EXPECT_TRUE(res->peer_id);
EXPECT_EQ(ExpectedPeerId, res->peer_id);