diff --git a/libtransmission/handshake.cc b/libtransmission/handshake.cc index 183ec70fd..ea5679226 100644 --- a/libtransmission/handshake.cc +++ b/libtransmission/handshake.cc @@ -15,16 +15,12 @@ #include #include "transmission.h" + #include "clients.h" -#include "crypto-utils.h" #include "handshake.h" #include "log.h" #include "peer-io.h" -#include "peer-mgr.h" -#include "session.h" -#include "torrent.h" #include "tr-assert.h" -#include "tr-dht.h" #include "utils.h" using namespace std::literals; @@ -108,16 +104,32 @@ enum handshake_state_t struct tr_handshake { + tr_handshake(std::shared_ptr mediator_in) + : mediator{ std::move(mediator_in) } + { + } + + ~tr_handshake() + { + if (io != nullptr) + { + tr_peerIoUnref(io); /* balanced by the ref in tr_handshakeNew */ + } + + event_free(timeout_timer); + } + [[nodiscard]] auto constexpr isIncoming() const noexcept { return io->isIncoming(); } + std::shared_ptr const mediator; + bool haveReadAnythingFromPeer; bool haveSentBitTorrentHandshake; tr_peerIo* io; tr_crypto* crypto; - tr_session* session; handshake_state_t state; tr_encryption_mode encryptionMode; uint16_t pad_c_len; @@ -173,37 +185,34 @@ static void setReadState(tr_handshake* handshake, handshake_state_t state) static bool buildHandshakeMessage(tr_handshake* handshake, uint8_t* buf) { - auto const torrent_hash = handshake->crypto->torrentHash(); - auto* const tor = torrent_hash ? handshake->session->torrents().get(*torrent_hash) : nullptr; - bool const success = tor != nullptr; - - if (success) + auto const info_hash = handshake->crypto->torrentHash(); + auto const info = info_hash ? handshake->mediator->torrentInfo(*info_hash) : std::nullopt; + if (!info) { - uint8_t* walk = buf; - - walk = std::copy_n(HANDSHAKE_NAME, HANDSHAKE_NAME_LEN, walk); - - memset(walk, 0, HANDSHAKE_FLAGS_LEN); - HANDSHAKE_SET_LTEP(walk); - HANDSHAKE_SET_FASTEXT(walk); - /* Note that this doesn't depend on whether the torrent is private. - * We don't accept DHT peers for a private torrent, - * but we participate in the DHT regardless. */ - if (tr_dhtEnabled(handshake->session)) - { - HANDSHAKE_SET_DHT(walk); - } - walk += HANDSHAKE_FLAGS_LEN; - - walk = std::copy_n(reinterpret_cast(std::data(tor->infoHash())), std::size(tor->infoHash()), walk); - - auto const& peer_id = tr_torrentGetPeerId(tor); - std::copy_n(std::data(peer_id), std::size(peer_id), walk); - - TR_ASSERT(walk + std::size(peer_id) - buf == HANDSHAKE_SIZE); + return false; } - return success; + uint8_t* walk = buf; + + walk = std::copy_n(HANDSHAKE_NAME, HANDSHAKE_NAME_LEN, walk); + + memset(walk, 0, HANDSHAKE_FLAGS_LEN); + HANDSHAKE_SET_LTEP(walk); + HANDSHAKE_SET_FASTEXT(walk); + /* Note that this doesn't depend on whether the torrent is private. + * We don't accept DHT peers for a private torrent, + * but we participate in the DHT regardless. */ + if (handshake->mediator->isDHTEnabled()) + { + HANDSHAKE_SET_DHT(walk); + } + walk += HANDSHAKE_FLAGS_LEN; + + walk = std::copy_n(reinterpret_cast(std::data(*info_hash)), std::size(*info_hash), walk); + walk = std::copy(std::begin(info->client_peer_id), std::end(info->client_peer_id), walk); + + TR_ASSERT(walk - buf == HANDSHAKE_SIZE); + return true; } static ReadState tr_handshakeDone(tr_handshake* handshake, bool isConnected); @@ -257,7 +266,7 @@ static handshake_parse_err_t parseHandshake(tr_handshake* handshake, struct evbu auto const peer_id_sv = std::string_view{ std::data(peer_id), std::size(peer_id) }; tr_logAddTraceHand(handshake, fmt::format("peer-id is '{}'", peer_id_sv)); - if (auto* const tor = handshake->session->torrents().get(hash); peer_id == tr_torrentGetPeerId(tor)) + if (auto const info = handshake->mediator->torrentInfo(hash); info && info->client_peer_id == peer_id) { tr_logAddTraceHand(handshake, "streuth! we've connected to ourselves."); return HANDSHAKE_PEER_IS_SELF; @@ -651,7 +660,7 @@ static ReadState readHandshake(tr_handshake* handshake, struct evbuffer* inbuf) if (handshake->isIncoming()) { - if (!handshake->session->torrents().contains(hash)) + if (!handshake->mediator->torrentInfo(hash)) { tr_logAddTraceHand(handshake, "peer is trying to connect to us for a torrent we don't have."); return tr_handshakeDone(handshake, false); @@ -708,8 +717,8 @@ static ReadState readPeerId(tr_handshake* handshake, struct evbuffer* inbuf) // if we've somehow connected to ourselves, don't keep the connection auto const hash = handshake->io->torrentHash(); - auto* const tor = hash ? handshake->session->torrents().get(*hash) : nullptr; - bool const connected_to_self = tor != nullptr && peer_id == tr_torrentGetPeerId(tor); + auto const info = hash ? handshake->mediator->torrentInfo(*hash) : std::nullopt; + auto const connected_to_self = info && info->client_peer_id == peer_id; return tr_handshakeDone(handshake, !connected_to_self); } @@ -818,16 +827,14 @@ static ReadState readCryptoProvide(tr_handshake* handshake, struct evbuffer* inb obfuscated_hash[i] = req2[i] ^ (*req3)[i]; } - if (auto const* const tor = tr_torrentFindFromObfuscatedHash(handshake->session, obfuscated_hash); tor != nullptr) + if (auto const info = handshake->mediator->torrentInfoFromObfuscated(obfuscated_hash); info) { - bool const clientIsSeed = tor->isDone(); - bool const peerIsSeed = tr_peerMgrPeerIsSeed(tor, handshake->io->address()); - tr_logAddTraceHand( - handshake, - fmt::format("got INCOMING connection's encrypted handshake for torrent [{}]", tor->name())); - handshake->io->setTorrentHash(tor->infoHash()); + bool const client_is_seed = info->is_done; + bool const peer_is_seed = handshake->mediator->isPeerKnownSeed(info->id, handshake->io->address()); + tr_logAddTraceHand(handshake, fmt::format("got INCOMING connection's encrypted handshake for torrent [{}]", info->id)); + handshake->io->setTorrentHash(info->info_hash); - if (clientIsSeed && peerIsSeed) + if (client_is_seed && peer_is_seed) { tr_logAddTraceHand(handshake, "another seed tried to reconnect to us!"); return tr_handshakeDone(handshake, false); @@ -1102,24 +1109,13 @@ static bool fireDoneFunc(tr_handshake* handshake, bool isConnected) return success; } -static void tr_handshakeFree(tr_handshake* handshake) -{ - if (handshake->io != nullptr) - { - tr_peerIoUnref(handshake->io); /* balanced by the ref in tr_handshakeNew */ - } - - event_free(handshake->timeout_timer); - tr_free(handshake); -} - static ReadState tr_handshakeDone(tr_handshake* handshake, bool isOK) { tr_logAddTraceHand(handshake, isOK ? "handshakeDone: connected" : "handshakeDone: aborting"); tr_peerIoSetIOFuncs(handshake->io, nullptr, nullptr, nullptr, nullptr); bool const success = fireDoneFunc(handshake, isOK); - tr_handshakeFree(handshake); + delete handshake; return success ? READ_LATER : READ_ERR; } @@ -1138,15 +1134,15 @@ static void gotError(tr_peerIo* io, short what, void* vhandshake) if (io->socket.type == TR_PEER_SOCKET_TYPE_UTP && !io->isIncoming() && handshake->state == AWAITING_YB) { - /* This peer probably doesn't speak uTP. */ + // the peer probably doesn't speak uTP. auto const hash = io->torrentHash(); - auto* const tor = hash ? handshake->session->torrents().get(*hash) : nullptr; + auto const info = hash ? handshake->mediator->torrentInfo(*hash) : std::nullopt; /* Don't mark a peer as non-uTP unless it's really a connect failure. */ - if ((errcode == ETIMEDOUT || errcode == ECONNREFUSED) && tr_isTorrent(tor)) + if ((errcode == ETIMEDOUT || errcode == ECONNREFUSED) && info) { - tr_peerMgrSetUtpFailed(tor, io->address(), true); + handshake->mediator->setUTPFailed(*hash, io->address()); } if (tr_peerIoReconnect(handshake->io) == 0) @@ -1192,24 +1188,22 @@ static void handshakeTimeout(evutil_socket_t /*s*/, short /*type*/, void* handsh } tr_handshake* tr_handshakeNew( + std::shared_ptr mediator, tr_peerIo* io, tr_encryption_mode encryptionMode, tr_handshake_done_func done_func, void* done_func_user_data) { - tr_session* session = tr_peerIoGetSession(io); - - auto* const handshake = tr_new0(tr_handshake, 1); + auto* const handshake = new tr_handshake{ std::move(mediator) }; handshake->io = io; handshake->crypto = tr_peerIoGetCrypto(io); handshake->encryptionMode = encryptionMode; handshake->done_func = done_func; handshake->done_func_user_data = done_func_user_data; - handshake->session = session; - handshake->timeout_timer = evtimer_new(session->event_base, handshakeTimeout, handshake); + handshake->timeout_timer = evtimer_new(handshake->mediator->eventBase(), handshakeTimeout, handshake); tr_timerAdd(*handshake->timeout_timer, HANDSHAKE_TIMEOUT_SEC, 0); - tr_peerIoRef(io); /* balanced by the unref in tr_handshakeFree */ + tr_peerIoRef(io); /* balanced by the unref in ~tr_handshake() */ tr_peerIoSetIOFuncs(handshake->io, canRead, nullptr, gotError, handshake); tr_peerIoSetEncryption(io, PEER_ENCRYPTION_NONE); diff --git a/libtransmission/handshake.h b/libtransmission/handshake.h index 57479f34b..d80aed9b5 100644 --- a/libtransmission/handshake.h +++ b/libtransmission/handshake.h @@ -10,9 +10,12 @@ #endif #include +#include #include "transmission.h" +#include "net.h" // tr_address + /** @addtogroup peers Peers @{ */ @@ -21,6 +24,7 @@ class tr_peerIo; /** @brief opaque struct holding hanshake state information. freed when the handshake is completed. */ struct tr_handshake; +struct event_base; struct tr_handshake_result { @@ -32,11 +36,36 @@ struct tr_handshake_result std::optional peer_id; }; +class tr_handshake_mediator +{ +public: + struct torrent_info + { + tr_sha1_digest_t info_hash; + tr_peer_id_t client_peer_id; + tr_torrent_id_t id; + bool is_done; + }; + + [[nodiscard]] virtual std::optional torrentInfo(tr_sha1_digest_t const& info_hash) const = 0; + + [[nodiscard]] virtual std::optional torrentInfoFromObfuscated(tr_sha1_digest_t const& info_hash) const = 0; + + [[nodiscard]] virtual event_base* eventBase() const = 0; + + [[nodiscard]] virtual bool isDHTEnabled() const = 0; + + [[nodiscard]] virtual bool isPeerKnownSeed(tr_torrent_id_t tor_id, tr_address addr) const = 0; + + virtual void setUTPFailed(tr_sha1_digest_t const& info_hash, tr_address) = 0; +}; + /* returns true on success, false on error */ using tr_handshake_done_func = bool (*)(tr_handshake_result const& result); /** @brief create a new handshake */ tr_handshake* tr_handshakeNew( + std::shared_ptr mediator, tr_peerIo* io, tr_encryption_mode encryption_mode, tr_handshake_done_func when_done, diff --git a/libtransmission/peer-mgr.cc b/libtransmission/peer-mgr.cc index 8b18be6e5..123fa37cd 100644 --- a/libtransmission/peer-mgr.cc +++ b/libtransmission/peer-mgr.cc @@ -45,6 +45,7 @@ #include "stats.h" /* tr_statsAddUploaded, tr_statsAddDownloaded */ #include "torrent.h" #include "tr-assert.h" +#include "tr-dht.h" #include "tr-utp.h" #include "utils.h" #include "webseed.h" @@ -63,6 +64,73 @@ static auto constexpr CancelHistorySec = int{ 60 }; *** **/ +static bool tr_peerMgrPeerIsSeed(tr_torrent const* tor, tr_address const& addr); + +class tr_handshake_mediator_impl final : public tr_handshake_mediator +{ +private: + [[nodiscard]] static std::optional torrentInfo(tr_torrent* tor) + { + if (tor == nullptr) + { + return {}; + } + + auto info = torrent_info{}; + info.info_hash = tor->infoHash(); + info.client_peer_id = tr_torrentGetPeerId(tor); + info.id = tor->id(); + info.is_done = tor->isDone(); + return info; + } + +public: + tr_handshake_mediator_impl(tr_session& session) + : session_{ session } + { + } + + virtual ~tr_handshake_mediator_impl() = default; + + [[nodiscard]] std::optional torrentInfo(tr_sha1_digest_t const& info_hash) const override + { + return torrentInfo(session_.torrents().get(info_hash)); + } + + [[nodiscard]] std::optional torrentInfoFromObfuscated( + tr_sha1_digest_t const& obfuscated_info_hash) const override + { + return torrentInfo(tr_torrentFindFromObfuscatedHash(&session_, obfuscated_info_hash)); + } + + [[nodiscard]] bool isDHTEnabled() const override + { + return tr_dhtEnabled(&session_); + } + + void setUTPFailed(tr_sha1_digest_t const& info_hash, tr_address addr) override + { + if (auto* const tor = session_.torrents().get(info_hash); tor != nullptr) + { + tr_peerMgrSetUtpFailed(tor, addr, true); + } + } + + [[nodiscard]] bool isPeerKnownSeed(tr_torrent_id_t tor_id, tr_address addr) const override + { + auto* const tor = session_.torrents().get(tor_id); + return tor != nullptr && tr_peerMgrPeerIsSeed(tor, addr); + } + + [[nodiscard]] event_base* eventBase() const override + { + return session_.event_base; + } + +private: + tr_session& session_; +}; + /** * Peer information that should be kept even before we've connected and * after we've disconnected. These are kept in a pool of peer_atoms to decide @@ -683,7 +751,7 @@ static void atomSetSeed(tr_swarm* swarm, peer_atom& atom) swarm->markAllSeedsFlagDirty(); } -bool tr_peerMgrPeerIsSeed(tr_torrent const* tor, tr_address const& addr) +static bool tr_peerMgrPeerIsSeed(tr_torrent const* tor, tr_address const& addr) { if (auto const* atom = getExistingAtom(tor->swarm, addr); atom != nullptr) { @@ -1193,8 +1261,9 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const* addr, tr_port } else /* we don't have a connection to them yet... */ { + auto mediator = std::make_shared(*session); tr_peerIo* const io = tr_peerIoNewIncoming(session, &session->top_bandwidth_, addr, port, tr_time(), socket); - tr_handshake* const handshake = tr_handshakeNew(io, session->encryptionMode, on_handshake_done, manager); + tr_handshake* const handshake = tr_handshakeNew(mediator, io, session->encryptionMode, on_handshake_done, manager); tr_peerIoUnref(io); /* balanced by the implicit ref in tr_peerIoNewIncoming() */ @@ -2769,7 +2838,8 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom) } else { - tr_handshake* handshake = tr_handshakeNew(io, mgr->session->encryptionMode, on_handshake_done, mgr); + auto mediator = std::make_shared(*mgr->session); + tr_handshake* handshake = tr_handshakeNew(mediator, io, mgr->session->encryptionMode, on_handshake_done, mgr); TR_ASSERT(io->torrentHash()); diff --git a/libtransmission/peer-mgr.h b/libtransmission/peer-mgr.h index 3be2c7070..08650651d 100644 --- a/libtransmission/peer-mgr.h +++ b/libtransmission/peer-mgr.h @@ -110,8 +110,6 @@ tr_peerMgr* tr_peerMgrNew(tr_session* session); void tr_peerMgrFree(tr_peerMgr* manager); -bool tr_peerMgrPeerIsSeed(tr_torrent const* tor, tr_address const& addr); - void tr_peerMgrSetUtpSupported(tr_torrent* tor, tr_address const& addr); void tr_peerMgrSetUtpFailed(tr_torrent* tor, tr_address const& addr, bool failed);