refactor: use std::map for handshake tracking (#2715)

This commit is contained in:
Charles Kerr 2022-02-26 00:03:32 -06:00 committed by GitHub
parent 70cce3abeb
commit 1598774b8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 60 deletions

View File

@ -85,19 +85,24 @@ struct tr_address
struct in_addr addr4; struct in_addr addr4;
} addr; } addr;
bool operator==(tr_address const& that) const [[nodiscard]] int compare(tr_address const& that) const
{ {
return tr_address_compare(this, &that) == 0; return tr_address_compare(this, &that);
} }
bool operator<(tr_address const& that) const [[nodiscard]] bool operator==(tr_address const& that) const
{ {
return tr_address_compare(this, &that) < 0; return compare(that) == 0;
} }
bool operator>(tr_address const& that) const [[nodiscard]] bool operator<(tr_address const& that) const
{ {
return tr_address_compare(this, &that) > 0; return compare(that) < 0;
}
[[nodiscard]] bool operator>(tr_address const& that) const
{
return compare(that) > 0;
} }
}; };

View File

@ -11,6 +11,7 @@
#include <cstdlib> /* qsort */ #include <cstdlib> /* qsort */
#include <ctime> // time_t #include <ctime> // time_t
#include <iterator> // std::back_inserter #include <iterator> // std::back_inserter
#include <map>
#include <vector> #include <vector>
#include <event2/event.h> #include <event2/event.h>
@ -125,6 +126,26 @@ struct peer_atom
time_t shelf_date; time_t shelf_date;
tr_peer* peer; /* will be nullptr if not connected */ tr_peer* peer; /* will be nullptr if not connected */
tr_address addr; tr_address addr;
[[nodiscard]] int compare(peer_atom const& that) const
{
return addr.compare(that.addr);
}
[[nodiscard]] bool operator==(peer_atom const& that) const
{
return compare(that) == 0;
}
[[nodiscard]] bool operator<(peer_atom const& that) const
{
return compare(that) < 0;
}
[[nodiscard]] bool operator>(peer_atom const& that) const
{
return compare(that) > 0;
}
}; };
#ifndef TR_ENABLE_ASSERTS #ifndef TR_ENABLE_ASSERTS
@ -160,7 +181,7 @@ public:
public: public:
tr_swarm_stats stats = {}; tr_swarm_stats stats = {};
tr_ptrArray outgoingHandshakes = {}; /* tr_handshake */ std::map<tr_address, tr_handshake*> outgoing_handshakes;
tr_ptrArray pool = {}; /* struct peer_atom */ tr_ptrArray pool = {}; /* struct peer_atom */
tr_ptrArray peers = {}; /* tr_peerMsgs */ tr_ptrArray peers = {}; /* tr_peerMsgs */
tr_ptrArray webseeds = {}; /* tr_webseed */ tr_ptrArray webseeds = {}; /* tr_webseed */
@ -186,17 +207,22 @@ public:
struct tr_peerMgr struct tr_peerMgr
{ {
tr_peerMgr(tr_session* session_in)
: session{ session_in }
{
}
[[nodiscard]] auto unique_lock() const [[nodiscard]] auto unique_lock() const
{ {
return session->unique_lock(); return session->unique_lock();
} }
tr_session* session; tr_session* const session;
tr_ptrArray incomingHandshakes; /* tr_handshake */ std::map<tr_address, tr_handshake*> incoming_handshakes;
struct event* bandwidthTimer; event* bandwidthTimer = nullptr;
struct event* rechokeTimer; event* rechokeTimer = nullptr;
struct event* refillUpkeepTimer; event* refillUpkeepTimer = nullptr;
struct event* atomTimer; event* atomTimer = nullptr;
}; };
#define tordbg(t, ...) tr_logAddDeepNamed(tr_torrentName((t)->tor), __VA_ARGS__) #define tordbg(t, ...) tr_logAddDeepNamed(tr_torrentName((t)->tor), __VA_ARGS__)
@ -240,30 +266,6 @@ tr_peer::~tr_peer()
*** ***
**/ **/
static int handshakeCompareToAddr(void const* va, void const* vb)
{
auto const* const a = static_cast<tr_handshake const*>(va);
auto const* const b = static_cast<tr_address const*>(vb);
return tr_address_compare(tr_handshakeGetAddr(a, nullptr), b);
}
static int handshakeCompare(void const* va, void const* vb)
{
auto const* const b = static_cast<tr_handshake const*>(vb);
return handshakeCompareToAddr(va, tr_handshakeGetAddr(b, nullptr));
}
static inline tr_handshake* getExistingHandshake(tr_ptrArray* handshakes, tr_address const* addr)
{
if (tr_ptrArrayEmpty(handshakes))
{
return nullptr;
}
return static_cast<tr_handshake*>(tr_ptrArrayFindSorted(handshakes, addr, handshakeCompareToAddr));
}
static int comparePeerAtomToAddress(void const* va, void const* vb) static int comparePeerAtomToAddress(void const* va, void const* vb)
{ {
auto const* const a = static_cast<struct peer_atom const*>(va); auto const* const a = static_cast<struct peer_atom const*>(va);
@ -315,8 +317,8 @@ static bool peerIsInUse(tr_swarm const* cs, struct peer_atom const* atom)
auto* s = const_cast<tr_swarm*>(cs); auto* s = const_cast<tr_swarm*>(cs);
auto const lock = s->manager->unique_lock(); auto const lock = s->manager->unique_lock();
return atom->peer != nullptr || getExistingHandshake(&s->outgoingHandshakes, &atom->addr) != nullptr || return atom->peer != nullptr || s->outgoing_handshakes.count(atom->addr) != 0 ||
getExistingHandshake(&s->manager->incomingHandshakes, &atom->addr) != nullptr; s->manager->incoming_handshakes.count(atom->addr) != 0;
} }
static void swarmFree(tr_swarm* s) static void swarmFree(tr_swarm* s)
@ -325,12 +327,11 @@ static void swarmFree(tr_swarm* s)
auto const lock = s->manager->unique_lock(); auto const lock = s->manager->unique_lock();
TR_ASSERT(!s->isRunning); TR_ASSERT(!s->isRunning);
TR_ASSERT(tr_ptrArrayEmpty(&s->outgoingHandshakes)); TR_ASSERT(std::empty(s->outgoing_handshakes));
TR_ASSERT(tr_ptrArrayEmpty(&s->peers)); TR_ASSERT(tr_ptrArrayEmpty(&s->peers));
tr_ptrArrayDestruct(&s->webseeds, [](void* peer) { delete static_cast<tr_peer*>(peer); }); tr_ptrArrayDestruct(&s->webseeds, [](void* peer) { delete static_cast<tr_peer*>(peer); });
tr_ptrArrayDestruct(&s->pool, (PtrArrayForeachFunc)tr_free); tr_ptrArrayDestruct(&s->pool, (PtrArrayForeachFunc)tr_free);
tr_ptrArrayDestruct(&s->outgoingHandshakes, nullptr);
tr_ptrArrayDestruct(&s->peers, nullptr); tr_ptrArrayDestruct(&s->peers, nullptr);
s->stats = {}; s->stats = {};
@ -367,9 +368,7 @@ static void ensureMgrTimersExist(struct tr_peerMgr* m);
tr_peerMgr* tr_peerMgrNew(tr_session* session) tr_peerMgr* tr_peerMgrNew(tr_session* session)
{ {
auto* const m = tr_new0(tr_peerMgr, 1); auto* const m = new tr_peerMgr{ session };
m->session = session;
m->incomingHandshakes = {};
ensureMgrTimersExist(m); ensureMgrTimersExist(m);
return m; return m;
} }
@ -399,14 +398,12 @@ void tr_peerMgrFree(tr_peerMgr* manager)
/* free the handshakes. Abort invokes handshakeDoneCB(), which removes /* free the handshakes. Abort invokes handshakeDoneCB(), which removes
* the item from manager->handshakes, so this is a little roundabout... */ * the item from manager->handshakes, so this is a little roundabout... */
while (!tr_ptrArrayEmpty(&manager->incomingHandshakes)) while (!std::empty(manager->incoming_handshakes))
{ {
tr_handshakeAbort(static_cast<tr_handshake*>(tr_ptrArrayNth(&manager->incomingHandshakes, 0))); tr_handshakeAbort(std::begin(manager->incoming_handshakes)->second);
} }
tr_ptrArrayDestruct(&manager->incomingHandshakes, nullptr); delete manager;
tr_free(manager);
} }
/*** /***
@ -989,20 +986,20 @@ static bool on_handshake_done(tr_handshake_result const& result)
auto const hash = tr_peerIoGetTorrentHash(result.io); auto const hash = tr_peerIoGetTorrentHash(result.io);
tr_swarm* const s = hash ? getExistingSwarm(manager, *hash) : nullptr; tr_swarm* const s = hash ? getExistingSwarm(manager, *hash) : nullptr;
auto port = tr_port{};
auto const* const addr = tr_peerIoGetAddress(result.io, &port);
if (tr_peerIoIsIncoming(result.io)) if (tr_peerIoIsIncoming(result.io))
{ {
tr_ptrArrayRemoveSortedPointer(&manager->incomingHandshakes, result.handshake, handshakeCompare); manager->incoming_handshakes.erase(*addr);
} }
else if (s != nullptr) else if (s != nullptr)
{ {
tr_ptrArrayRemoveSortedPointer(&s->outgoingHandshakes, result.handshake, handshakeCompare); s->outgoing_handshakes.erase(*addr);
} }
auto const lock = manager->unique_lock(); auto const lock = manager->unique_lock();
auto port = tr_port{};
tr_address const* const addr = tr_peerIoGetAddress(result.io, &port);
if (!ok || s == nullptr || !s->isRunning) if (!ok || s == nullptr || !s->isRunning)
{ {
if (s != nullptr) if (s != nullptr)
@ -1093,7 +1090,7 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const* addr, tr_port
tr_logAddDebug("Banned IP address \"%s\" tried to connect to us", tr_address_to_string(addr)); tr_logAddDebug("Banned IP address \"%s\" tried to connect to us", tr_address_to_string(addr));
tr_netClosePeerSocket(session, socket); tr_netClosePeerSocket(session, socket);
} }
else if (getExistingHandshake(&manager->incomingHandshakes, addr) != nullptr) else if (manager->incoming_handshakes.count(*addr) > 0)
{ {
tr_netClosePeerSocket(session, socket); tr_netClosePeerSocket(session, socket);
} }
@ -1104,7 +1101,7 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const* addr, tr_port
tr_peerIoUnref(io); /* balanced by the implicit ref in tr_peerIoNewIncoming() */ tr_peerIoUnref(io); /* balanced by the implicit ref in tr_peerIoNewIncoming() */
tr_ptrArrayInsertSorted(&manager->incomingHandshakes, handshake, handshakeCompare); manager->incoming_handshakes.insert({ *addr, handshake });
} }
} }
@ -1432,10 +1429,10 @@ static void stopSwarm(tr_swarm* swarm)
removeAllPeers(swarm); removeAllPeers(swarm);
/* disconnect the handshakes. handshakeAbort calls handshakeDoneCB(), /* disconnect the handshakes. handshakeAbort calls handshakeDoneCB(),
* which removes the handshake from t->outgoingHandshakes... */ * which removes the handshake from t->outgoing_handshakes... */
while (!tr_ptrArrayEmpty(&swarm->outgoingHandshakes)) while (!std::empty(swarm->outgoing_handshakes))
{ {
tr_handshakeAbort(static_cast<tr_handshake*>(tr_ptrArrayNth(&swarm->outgoingHandshakes, 0))); tr_handshakeAbort(std::begin(swarm->outgoing_handshakes)->second);
} }
} }
@ -3032,7 +3029,7 @@ static void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, struct peer_atom* a
tr_peerIoUnref(io); /* balanced by the initial ref in tr_peerIoNewOutgoing() */ tr_peerIoUnref(io); /* balanced by the initial ref in tr_peerIoNewOutgoing() */
tr_ptrArrayInsertSorted(&s->outgoingHandshakes, handshake, handshakeCompare); s->outgoing_handshakes.insert({ atom->addr, handshake });
} }
atom->lastConnectionAttemptAt = now; atom->lastConnectionAttemptAt = now;