refactor: use std::map for handshake tracking (#2715)

This commit is contained in:
Charles Kerr 2022-02-26 00:03:32 -06:00 committed by GitHub
parent 70cce3abeb
commit 1598774b8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 60 deletions

View File

@ -85,19 +85,24 @@ struct tr_address
struct in_addr addr4;
} addr;
bool operator==(tr_address const& that) const
[[nodiscard]] int compare(tr_address const& that) const
{
return tr_address_compare(this, &that) == 0;
return tr_address_compare(this, &that);
}
bool operator<(tr_address const& that) const
[[nodiscard]] bool operator==(tr_address const& that) const
{
return tr_address_compare(this, &that) < 0;
return compare(that) == 0;
}
bool operator>(tr_address const& that) const
[[nodiscard]] bool operator<(tr_address const& that) const
{
return tr_address_compare(this, &that) > 0;
return compare(that) < 0;
}
[[nodiscard]] bool operator>(tr_address const& that) const
{
return compare(that) > 0;
}
};

View File

@ -11,6 +11,7 @@
#include <cstdlib> /* qsort */
#include <ctime> // time_t
#include <iterator> // std::back_inserter
#include <map>
#include <vector>
#include <event2/event.h>
@ -125,6 +126,26 @@ struct peer_atom
time_t shelf_date;
tr_peer* peer; /* will be nullptr if not connected */
tr_address addr;
[[nodiscard]] int compare(peer_atom const& that) const
{
return addr.compare(that.addr);
}
[[nodiscard]] bool operator==(peer_atom const& that) const
{
return compare(that) == 0;
}
[[nodiscard]] bool operator<(peer_atom const& that) const
{
return compare(that) < 0;
}
[[nodiscard]] bool operator>(peer_atom const& that) const
{
return compare(that) > 0;
}
};
#ifndef TR_ENABLE_ASSERTS
@ -160,7 +181,7 @@ public:
public:
tr_swarm_stats stats = {};
tr_ptrArray outgoingHandshakes = {}; /* tr_handshake */
std::map<tr_address, tr_handshake*> outgoing_handshakes;
tr_ptrArray pool = {}; /* struct peer_atom */
tr_ptrArray peers = {}; /* tr_peerMsgs */
tr_ptrArray webseeds = {}; /* tr_webseed */
@ -186,17 +207,22 @@ public:
struct tr_peerMgr
{
tr_peerMgr(tr_session* session_in)
: session{ session_in }
{
}
[[nodiscard]] auto unique_lock() const
{
return session->unique_lock();
}
tr_session* session;
tr_ptrArray incomingHandshakes; /* tr_handshake */
struct event* bandwidthTimer;
struct event* rechokeTimer;
struct event* refillUpkeepTimer;
struct event* atomTimer;
tr_session* const session;
std::map<tr_address, tr_handshake*> incoming_handshakes;
event* bandwidthTimer = nullptr;
event* rechokeTimer = nullptr;
event* refillUpkeepTimer = nullptr;
event* atomTimer = nullptr;
};
#define tordbg(t, ...) tr_logAddDeepNamed(tr_torrentName((t)->tor), __VA_ARGS__)
@ -240,30 +266,6 @@ tr_peer::~tr_peer()
***
**/
static int handshakeCompareToAddr(void const* va, void const* vb)
{
auto const* const a = static_cast<tr_handshake const*>(va);
auto const* const b = static_cast<tr_address const*>(vb);
return tr_address_compare(tr_handshakeGetAddr(a, nullptr), b);
}
static int handshakeCompare(void const* va, void const* vb)
{
auto const* const b = static_cast<tr_handshake const*>(vb);
return handshakeCompareToAddr(va, tr_handshakeGetAddr(b, nullptr));
}
static inline tr_handshake* getExistingHandshake(tr_ptrArray* handshakes, tr_address const* addr)
{
if (tr_ptrArrayEmpty(handshakes))
{
return nullptr;
}
return static_cast<tr_handshake*>(tr_ptrArrayFindSorted(handshakes, addr, handshakeCompareToAddr));
}
static int comparePeerAtomToAddress(void const* va, void const* vb)
{
auto const* const a = static_cast<struct peer_atom const*>(va);
@ -315,8 +317,8 @@ static bool peerIsInUse(tr_swarm const* cs, struct peer_atom const* atom)
auto* s = const_cast<tr_swarm*>(cs);
auto const lock = s->manager->unique_lock();
return atom->peer != nullptr || getExistingHandshake(&s->outgoingHandshakes, &atom->addr) != nullptr ||
getExistingHandshake(&s->manager->incomingHandshakes, &atom->addr) != nullptr;
return atom->peer != nullptr || s->outgoing_handshakes.count(atom->addr) != 0 ||
s->manager->incoming_handshakes.count(atom->addr) != 0;
}
static void swarmFree(tr_swarm* s)
@ -325,12 +327,11 @@ static void swarmFree(tr_swarm* s)
auto const lock = s->manager->unique_lock();
TR_ASSERT(!s->isRunning);
TR_ASSERT(tr_ptrArrayEmpty(&s->outgoingHandshakes));
TR_ASSERT(std::empty(s->outgoing_handshakes));
TR_ASSERT(tr_ptrArrayEmpty(&s->peers));
tr_ptrArrayDestruct(&s->webseeds, [](void* peer) { delete static_cast<tr_peer*>(peer); });
tr_ptrArrayDestruct(&s->pool, (PtrArrayForeachFunc)tr_free);
tr_ptrArrayDestruct(&s->outgoingHandshakes, nullptr);
tr_ptrArrayDestruct(&s->peers, nullptr);
s->stats = {};
@ -367,9 +368,7 @@ static void ensureMgrTimersExist(struct tr_peerMgr* m);
tr_peerMgr* tr_peerMgrNew(tr_session* session)
{
auto* const m = tr_new0(tr_peerMgr, 1);
m->session = session;
m->incomingHandshakes = {};
auto* const m = new tr_peerMgr{ session };
ensureMgrTimersExist(m);
return m;
}
@ -399,14 +398,12 @@ void tr_peerMgrFree(tr_peerMgr* manager)
/* free the handshakes. Abort invokes handshakeDoneCB(), which removes
* the item from manager->handshakes, so this is a little roundabout... */
while (!tr_ptrArrayEmpty(&manager->incomingHandshakes))
while (!std::empty(manager->incoming_handshakes))
{
tr_handshakeAbort(static_cast<tr_handshake*>(tr_ptrArrayNth(&manager->incomingHandshakes, 0)));
tr_handshakeAbort(std::begin(manager->incoming_handshakes)->second);
}
tr_ptrArrayDestruct(&manager->incomingHandshakes, nullptr);
tr_free(manager);
delete manager;
}
/***
@ -989,20 +986,20 @@ static bool on_handshake_done(tr_handshake_result const& result)
auto const hash = tr_peerIoGetTorrentHash(result.io);
tr_swarm* const s = hash ? getExistingSwarm(manager, *hash) : nullptr;
auto port = tr_port{};
auto const* const addr = tr_peerIoGetAddress(result.io, &port);
if (tr_peerIoIsIncoming(result.io))
{
tr_ptrArrayRemoveSortedPointer(&manager->incomingHandshakes, result.handshake, handshakeCompare);
manager->incoming_handshakes.erase(*addr);
}
else if (s != nullptr)
{
tr_ptrArrayRemoveSortedPointer(&s->outgoingHandshakes, result.handshake, handshakeCompare);
s->outgoing_handshakes.erase(*addr);
}
auto const lock = manager->unique_lock();
auto port = tr_port{};
tr_address const* const addr = tr_peerIoGetAddress(result.io, &port);
if (!ok || s == nullptr || !s->isRunning)
{
if (s != nullptr)
@ -1093,7 +1090,7 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const* addr, tr_port
tr_logAddDebug("Banned IP address \"%s\" tried to connect to us", tr_address_to_string(addr));
tr_netClosePeerSocket(session, socket);
}
else if (getExistingHandshake(&manager->incomingHandshakes, addr) != nullptr)
else if (manager->incoming_handshakes.count(*addr) > 0)
{
tr_netClosePeerSocket(session, socket);
}
@ -1104,7 +1101,7 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const* addr, tr_port
tr_peerIoUnref(io); /* balanced by the implicit ref in tr_peerIoNewIncoming() */
tr_ptrArrayInsertSorted(&manager->incomingHandshakes, handshake, handshakeCompare);
manager->incoming_handshakes.insert({ *addr, handshake });
}
}
@ -1432,10 +1429,10 @@ static void stopSwarm(tr_swarm* swarm)
removeAllPeers(swarm);
/* disconnect the handshakes. handshakeAbort calls handshakeDoneCB(),
* which removes the handshake from t->outgoingHandshakes... */
while (!tr_ptrArrayEmpty(&swarm->outgoingHandshakes))
* which removes the handshake from t->outgoing_handshakes... */
while (!std::empty(swarm->outgoing_handshakes))
{
tr_handshakeAbort(static_cast<tr_handshake*>(tr_ptrArrayNth(&swarm->outgoingHandshakes, 0)));
tr_handshakeAbort(std::begin(swarm->outgoing_handshakes)->second);
}
}
@ -3032,7 +3029,7 @@ static void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, struct peer_atom* a
tr_peerIoUnref(io); /* balanced by the initial ref in tr_peerIoNewOutgoing() */
tr_ptrArrayInsertSorted(&s->outgoingHandshakes, handshake, handshakeCompare);
s->outgoing_handshakes.insert({ atom->addr, handshake });
}
atom->lastConnectionAttemptAt = now;