refactor: decouple tr_handshake (#3435)

Add a shim between the handshake code and the rest of the codebase to 
improve decoupling so that a followup PR can add handshake unit tests. 

The handshake code no longer directly relies on tr_torrent, tr_session, 
tr_dht, or tr_peerMgr.
This commit is contained in:
Charles Kerr 2022-07-11 18:29:48 -05:00 committed by GitHub
parent 0c9ca9ac30
commit ba26e79afe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 164 additions and 73 deletions

View File

@ -15,16 +15,12 @@
#include <fmt/format.h>
#include "transmission.h"
#include "clients.h"
#include "crypto-utils.h"
#include "handshake.h"
#include "log.h"
#include "peer-io.h"
#include "peer-mgr.h"
#include "session.h"
#include "torrent.h"
#include "tr-assert.h"
#include "tr-dht.h"
#include "utils.h"
using namespace std::literals;
@ -108,16 +104,32 @@ enum handshake_state_t
struct tr_handshake
{
tr_handshake(std::shared_ptr<tr_handshake_mediator> mediator_in)
: mediator{ std::move(mediator_in) }
{
}
~tr_handshake()
{
if (io != nullptr)
{
tr_peerIoUnref(io); /* balanced by the ref in tr_handshakeNew */
}
event_free(timeout_timer);
}
[[nodiscard]] auto constexpr isIncoming() const noexcept
{
return io->isIncoming();
}
std::shared_ptr<tr_handshake_mediator> const mediator;
bool haveReadAnythingFromPeer;
bool haveSentBitTorrentHandshake;
tr_peerIo* io;
tr_crypto* crypto;
tr_session* session;
handshake_state_t state;
tr_encryption_mode encryptionMode;
uint16_t pad_c_len;
@ -173,37 +185,34 @@ static void setReadState(tr_handshake* handshake, handshake_state_t state)
static bool buildHandshakeMessage(tr_handshake* handshake, uint8_t* buf)
{
auto const torrent_hash = handshake->crypto->torrentHash();
auto* const tor = torrent_hash ? handshake->session->torrents().get(*torrent_hash) : nullptr;
bool const success = tor != nullptr;
if (success)
auto const info_hash = handshake->crypto->torrentHash();
auto const info = info_hash ? handshake->mediator->torrentInfo(*info_hash) : std::nullopt;
if (!info)
{
uint8_t* walk = buf;
walk = std::copy_n(HANDSHAKE_NAME, HANDSHAKE_NAME_LEN, walk);
memset(walk, 0, HANDSHAKE_FLAGS_LEN);
HANDSHAKE_SET_LTEP(walk);
HANDSHAKE_SET_FASTEXT(walk);
/* Note that this doesn't depend on whether the torrent is private.
* We don't accept DHT peers for a private torrent,
* but we participate in the DHT regardless. */
if (tr_dhtEnabled(handshake->session))
{
HANDSHAKE_SET_DHT(walk);
}
walk += HANDSHAKE_FLAGS_LEN;
walk = std::copy_n(reinterpret_cast<char const*>(std::data(tor->infoHash())), std::size(tor->infoHash()), walk);
auto const& peer_id = tr_torrentGetPeerId(tor);
std::copy_n(std::data(peer_id), std::size(peer_id), walk);
TR_ASSERT(walk + std::size(peer_id) - buf == HANDSHAKE_SIZE);
return false;
}
return success;
uint8_t* walk = buf;
walk = std::copy_n(HANDSHAKE_NAME, HANDSHAKE_NAME_LEN, walk);
memset(walk, 0, HANDSHAKE_FLAGS_LEN);
HANDSHAKE_SET_LTEP(walk);
HANDSHAKE_SET_FASTEXT(walk);
/* Note that this doesn't depend on whether the torrent is private.
* We don't accept DHT peers for a private torrent,
* but we participate in the DHT regardless. */
if (handshake->mediator->isDHTEnabled())
{
HANDSHAKE_SET_DHT(walk);
}
walk += HANDSHAKE_FLAGS_LEN;
walk = std::copy_n(reinterpret_cast<char const*>(std::data(*info_hash)), std::size(*info_hash), walk);
walk = std::copy(std::begin(info->client_peer_id), std::end(info->client_peer_id), walk);
TR_ASSERT(walk - buf == HANDSHAKE_SIZE);
return true;
}
static ReadState tr_handshakeDone(tr_handshake* handshake, bool isConnected);
@ -257,7 +266,7 @@ static handshake_parse_err_t parseHandshake(tr_handshake* handshake, struct evbu
auto const peer_id_sv = std::string_view{ std::data(peer_id), std::size(peer_id) };
tr_logAddTraceHand(handshake, fmt::format("peer-id is '{}'", peer_id_sv));
if (auto* const tor = handshake->session->torrents().get(hash); peer_id == tr_torrentGetPeerId(tor))
if (auto const info = handshake->mediator->torrentInfo(hash); info && info->client_peer_id == peer_id)
{
tr_logAddTraceHand(handshake, "streuth! we've connected to ourselves.");
return HANDSHAKE_PEER_IS_SELF;
@ -651,7 +660,7 @@ static ReadState readHandshake(tr_handshake* handshake, struct evbuffer* inbuf)
if (handshake->isIncoming())
{
if (!handshake->session->torrents().contains(hash))
if (!handshake->mediator->torrentInfo(hash))
{
tr_logAddTraceHand(handshake, "peer is trying to connect to us for a torrent we don't have.");
return tr_handshakeDone(handshake, false);
@ -708,8 +717,8 @@ static ReadState readPeerId(tr_handshake* handshake, struct evbuffer* inbuf)
// if we've somehow connected to ourselves, don't keep the connection
auto const hash = handshake->io->torrentHash();
auto* const tor = hash ? handshake->session->torrents().get(*hash) : nullptr;
bool const connected_to_self = tor != nullptr && peer_id == tr_torrentGetPeerId(tor);
auto const info = hash ? handshake->mediator->torrentInfo(*hash) : std::nullopt;
auto const connected_to_self = info && info->client_peer_id == peer_id;
return tr_handshakeDone(handshake, !connected_to_self);
}
@ -818,16 +827,14 @@ static ReadState readCryptoProvide(tr_handshake* handshake, struct evbuffer* inb
obfuscated_hash[i] = req2[i] ^ (*req3)[i];
}
if (auto const* const tor = tr_torrentFindFromObfuscatedHash(handshake->session, obfuscated_hash); tor != nullptr)
if (auto const info = handshake->mediator->torrentInfoFromObfuscated(obfuscated_hash); info)
{
bool const clientIsSeed = tor->isDone();
bool const peerIsSeed = tr_peerMgrPeerIsSeed(tor, handshake->io->address());
tr_logAddTraceHand(
handshake,
fmt::format("got INCOMING connection's encrypted handshake for torrent [{}]", tor->name()));
handshake->io->setTorrentHash(tor->infoHash());
bool const client_is_seed = info->is_done;
bool const peer_is_seed = handshake->mediator->isPeerKnownSeed(info->id, handshake->io->address());
tr_logAddTraceHand(handshake, fmt::format("got INCOMING connection's encrypted handshake for torrent [{}]", info->id));
handshake->io->setTorrentHash(info->info_hash);
if (clientIsSeed && peerIsSeed)
if (client_is_seed && peer_is_seed)
{
tr_logAddTraceHand(handshake, "another seed tried to reconnect to us!");
return tr_handshakeDone(handshake, false);
@ -1102,24 +1109,13 @@ static bool fireDoneFunc(tr_handshake* handshake, bool isConnected)
return success;
}
static void tr_handshakeFree(tr_handshake* handshake)
{
if (handshake->io != nullptr)
{
tr_peerIoUnref(handshake->io); /* balanced by the ref in tr_handshakeNew */
}
event_free(handshake->timeout_timer);
tr_free(handshake);
}
static ReadState tr_handshakeDone(tr_handshake* handshake, bool isOK)
{
tr_logAddTraceHand(handshake, isOK ? "handshakeDone: connected" : "handshakeDone: aborting");
tr_peerIoSetIOFuncs(handshake->io, nullptr, nullptr, nullptr, nullptr);
bool const success = fireDoneFunc(handshake, isOK);
tr_handshakeFree(handshake);
delete handshake;
return success ? READ_LATER : READ_ERR;
}
@ -1138,15 +1134,15 @@ static void gotError(tr_peerIo* io, short what, void* vhandshake)
if (io->socket.type == TR_PEER_SOCKET_TYPE_UTP && !io->isIncoming() && handshake->state == AWAITING_YB)
{
/* This peer probably doesn't speak uTP. */
// the peer probably doesn't speak uTP.
auto const hash = io->torrentHash();
auto* const tor = hash ? handshake->session->torrents().get(*hash) : nullptr;
auto const info = hash ? handshake->mediator->torrentInfo(*hash) : std::nullopt;
/* Don't mark a peer as non-uTP unless it's really a connect failure. */
if ((errcode == ETIMEDOUT || errcode == ECONNREFUSED) && tr_isTorrent(tor))
if ((errcode == ETIMEDOUT || errcode == ECONNREFUSED) && info)
{
tr_peerMgrSetUtpFailed(tor, io->address(), true);
handshake->mediator->setUTPFailed(*hash, io->address());
}
if (tr_peerIoReconnect(handshake->io) == 0)
@ -1192,24 +1188,22 @@ static void handshakeTimeout(evutil_socket_t /*s*/, short /*type*/, void* handsh
}
tr_handshake* tr_handshakeNew(
std::shared_ptr<tr_handshake_mediator> mediator,
tr_peerIo* io,
tr_encryption_mode encryptionMode,
tr_handshake_done_func done_func,
void* done_func_user_data)
{
tr_session* session = tr_peerIoGetSession(io);
auto* const handshake = tr_new0(tr_handshake, 1);
auto* const handshake = new tr_handshake{ std::move(mediator) };
handshake->io = io;
handshake->crypto = tr_peerIoGetCrypto(io);
handshake->encryptionMode = encryptionMode;
handshake->done_func = done_func;
handshake->done_func_user_data = done_func_user_data;
handshake->session = session;
handshake->timeout_timer = evtimer_new(session->event_base, handshakeTimeout, handshake);
handshake->timeout_timer = evtimer_new(handshake->mediator->eventBase(), handshakeTimeout, handshake);
tr_timerAdd(*handshake->timeout_timer, HANDSHAKE_TIMEOUT_SEC, 0);
tr_peerIoRef(io); /* balanced by the unref in tr_handshakeFree */
tr_peerIoRef(io); /* balanced by the unref in ~tr_handshake() */
tr_peerIoSetIOFuncs(handshake->io, canRead, nullptr, gotError, handshake);
tr_peerIoSetEncryption(io, PEER_ENCRYPTION_NONE);

View File

@ -10,9 +10,12 @@
#endif
#include <optional>
#include <memory>
#include "transmission.h"
#include "net.h" // tr_address
/** @addtogroup peers Peers
@{ */
@ -21,6 +24,7 @@ class tr_peerIo;
/** @brief opaque struct holding hanshake state information.
freed when the handshake is completed. */
struct tr_handshake;
struct event_base;
struct tr_handshake_result
{
@ -32,11 +36,36 @@ struct tr_handshake_result
std::optional<tr_peer_id_t> peer_id;
};
class tr_handshake_mediator
{
public:
struct torrent_info
{
tr_sha1_digest_t info_hash;
tr_peer_id_t client_peer_id;
tr_torrent_id_t id;
bool is_done;
};
[[nodiscard]] virtual std::optional<torrent_info> torrentInfo(tr_sha1_digest_t const& info_hash) const = 0;
[[nodiscard]] virtual std::optional<torrent_info> torrentInfoFromObfuscated(tr_sha1_digest_t const& info_hash) const = 0;
[[nodiscard]] virtual event_base* eventBase() const = 0;
[[nodiscard]] virtual bool isDHTEnabled() const = 0;
[[nodiscard]] virtual bool isPeerKnownSeed(tr_torrent_id_t tor_id, tr_address addr) const = 0;
virtual void setUTPFailed(tr_sha1_digest_t const& info_hash, tr_address) = 0;
};
/* returns true on success, false on error */
using tr_handshake_done_func = bool (*)(tr_handshake_result const& result);
/** @brief create a new handshake */
tr_handshake* tr_handshakeNew(
std::shared_ptr<tr_handshake_mediator> mediator,
tr_peerIo* io,
tr_encryption_mode encryption_mode,
tr_handshake_done_func when_done,

View File

@ -45,6 +45,7 @@
#include "stats.h" /* tr_statsAddUploaded, tr_statsAddDownloaded */
#include "torrent.h"
#include "tr-assert.h"
#include "tr-dht.h"
#include "tr-utp.h"
#include "utils.h"
#include "webseed.h"
@ -63,6 +64,73 @@ static auto constexpr CancelHistorySec = int{ 60 };
***
**/
static bool tr_peerMgrPeerIsSeed(tr_torrent const* tor, tr_address const& addr);
class tr_handshake_mediator_impl final : public tr_handshake_mediator
{
private:
[[nodiscard]] static std::optional<torrent_info> torrentInfo(tr_torrent* tor)
{
if (tor == nullptr)
{
return {};
}
auto info = torrent_info{};
info.info_hash = tor->infoHash();
info.client_peer_id = tr_torrentGetPeerId(tor);
info.id = tor->id();
info.is_done = tor->isDone();
return info;
}
public:
tr_handshake_mediator_impl(tr_session& session)
: session_{ session }
{
}
virtual ~tr_handshake_mediator_impl() = default;
[[nodiscard]] std::optional<torrent_info> torrentInfo(tr_sha1_digest_t const& info_hash) const override
{
return torrentInfo(session_.torrents().get(info_hash));
}
[[nodiscard]] std::optional<torrent_info> torrentInfoFromObfuscated(
tr_sha1_digest_t const& obfuscated_info_hash) const override
{
return torrentInfo(tr_torrentFindFromObfuscatedHash(&session_, obfuscated_info_hash));
}
[[nodiscard]] bool isDHTEnabled() const override
{
return tr_dhtEnabled(&session_);
}
void setUTPFailed(tr_sha1_digest_t const& info_hash, tr_address addr) override
{
if (auto* const tor = session_.torrents().get(info_hash); tor != nullptr)
{
tr_peerMgrSetUtpFailed(tor, addr, true);
}
}
[[nodiscard]] bool isPeerKnownSeed(tr_torrent_id_t tor_id, tr_address addr) const override
{
auto* const tor = session_.torrents().get(tor_id);
return tor != nullptr && tr_peerMgrPeerIsSeed(tor, addr);
}
[[nodiscard]] event_base* eventBase() const override
{
return session_.event_base;
}
private:
tr_session& session_;
};
/**
* Peer information that should be kept even before we've connected and
* after we've disconnected. These are kept in a pool of peer_atoms to decide
@ -683,7 +751,7 @@ static void atomSetSeed(tr_swarm* swarm, peer_atom& atom)
swarm->markAllSeedsFlagDirty();
}
bool tr_peerMgrPeerIsSeed(tr_torrent const* tor, tr_address const& addr)
static bool tr_peerMgrPeerIsSeed(tr_torrent const* tor, tr_address const& addr)
{
if (auto const* atom = getExistingAtom(tor->swarm, addr); atom != nullptr)
{
@ -1193,8 +1261,9 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const* addr, tr_port
}
else /* we don't have a connection to them yet... */
{
auto mediator = std::make_shared<tr_handshake_mediator_impl>(*session);
tr_peerIo* const io = tr_peerIoNewIncoming(session, &session->top_bandwidth_, addr, port, tr_time(), socket);
tr_handshake* const handshake = tr_handshakeNew(io, session->encryptionMode, on_handshake_done, manager);
tr_handshake* const handshake = tr_handshakeNew(mediator, io, session->encryptionMode, on_handshake_done, manager);
tr_peerIoUnref(io); /* balanced by the implicit ref in tr_peerIoNewIncoming() */
@ -2769,7 +2838,8 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom)
}
else
{
tr_handshake* handshake = tr_handshakeNew(io, mgr->session->encryptionMode, on_handshake_done, mgr);
auto mediator = std::make_shared<tr_handshake_mediator_impl>(*mgr->session);
tr_handshake* handshake = tr_handshakeNew(mediator, io, mgr->session->encryptionMode, on_handshake_done, mgr);
TR_ASSERT(io->torrentHash());

View File

@ -110,8 +110,6 @@ tr_peerMgr* tr_peerMgrNew(tr_session* session);
void tr_peerMgrFree(tr_peerMgr* manager);
bool tr_peerMgrPeerIsSeed(tr_torrent const* tor, tr_address const& addr);
void tr_peerMgrSetUtpSupported(tr_torrent* tor, tr_address const& addr);
void tr_peerMgrSetUtpFailed(tr_torrent* tor, tr_address const& addr, bool failed);