refactor: tr_peerIo ref, unref (#3735)

* refactor: replace manual peerIo refcounting with std::shared_ptr
This commit is contained in:
Charles Kerr 2022-08-30 12:38:30 -05:00 committed by GitHub
parent 7c014e3256
commit b7ea4d9f04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 280 additions and 361 deletions

View File

@ -139,7 +139,7 @@ void tr_bandwidth::allocateBandwidth(
tr_priority_t parent_priority,
tr_direction dir,
unsigned int period_msec,
std::vector<tr_peerIo*>& peer_pool)
std::vector<std::shared_ptr<tr_peerIo>>& peer_pool)
{
tr_priority_t const priority = std::max(parent_priority, this->priority_);
@ -151,10 +151,10 @@ void tr_bandwidth::allocateBandwidth(
}
/* add this bandwidth's peer, if any, to the peer pool */
if (this->peer_ != nullptr)
if (auto shared = this->peer_.lock(); shared)
{
this->peer_->priority = priority;
peer_pool.push_back(this->peer_);
shared->priority = priority;
peer_pool.push_back(std::move(shared));
}
// traverse & repeat for the subtree
@ -199,33 +199,34 @@ void tr_bandwidth::allocate(tr_direction dir, unsigned int period_msec)
{
TR_ASSERT(tr_isDirection(dir));
// keep these peers alive for the scope of this function
auto refs = std::vector<std::shared_ptr<tr_peerIo>>{};
auto high = std::vector<tr_peerIo*>{};
auto low = std::vector<tr_peerIo*>{};
auto normal = std::vector<tr_peerIo*>{};
auto tmp = std::vector<tr_peerIo*>{};
/* allocateBandwidth () is a helper function with two purposes:
* 1. allocate bandwidth to b and its subtree
* 2. accumulate an array of all the peerIos from b and its subtree. */
this->allocateBandwidth(TR_PRI_LOW, dir, period_msec, tmp);
this->allocateBandwidth(TR_PRI_LOW, dir, period_msec, refs);
for (auto* io : tmp)
for (auto& io : refs)
{
tr_peerIoRef(io);
io->flushOutgoingProtocolMsgs();
switch (io->priority)
{
case TR_PRI_HIGH:
high.push_back(io);
high.push_back(io.get());
[[fallthrough]];
case TR_PRI_NORMAL:
normal.push_back(io);
normal.push_back(io.get());
[[fallthrough]];
default:
low.push_back(io);
low.push_back(io.get());
}
}
@ -241,15 +242,10 @@ void tr_bandwidth::allocate(tr_direction dir, unsigned int period_msec)
* enable on-demand IO for peers with bandwidth left to burn.
* This on-demand IO is enabled until (1) the peer runs out of bandwidth,
* or (2) the next tr_bandwidth::allocate () call, when we start over again. */
for (auto* io : tmp)
for (auto& io : refs)
{
io->setEnabled(dir, io->hasBandwidthLeft(dir));
}
for (auto* io : tmp)
{
tr_peerIoUnref(io);
}
}
/***

View File

@ -12,6 +12,7 @@
#include <array>
#include <cstddef> // size_t
#include <cstdint> // uint64_t
#include <memory>
#include <vector>
#include "transmission.h"
@ -98,12 +99,10 @@ public:
tr_bandwidth(tr_bandwidth&&) = delete;
tr_bandwidth(tr_bandwidth&) = delete;
/**
* @brief Sets new peer, nullptr is allowed.
*/
constexpr void setPeer(tr_peerIo* peer) noexcept
// @brief Sets the peer. nullptr is allowed.
void setPeer(std::weak_ptr<tr_peerIo> peer) noexcept
{
this->peer_ = peer;
this->peer_ = std::move(peer);
}
/**
@ -258,12 +257,12 @@ private:
tr_priority_t parent_priority,
tr_direction dir,
unsigned int period_msec,
std::vector<tr_peerIo*>& peer_pool);
std::vector<std::shared_ptr<tr_peerIo>>& peer_pool);
mutable std::array<Band, 2> band_ = {};
std::vector<tr_bandwidth*> children_;
tr_bandwidth* parent_ = nullptr;
tr_peerIo* peer_ = nullptr;
std::weak_ptr<tr_peerIo> peer_;
tr_priority_t priority_ = 0;
};

View File

@ -21,7 +21,6 @@
#include "torrent.h"
#include "torrents.h"
#include "tr-assert.h"
#include "trevent.h"
#include "utils.h" // tr_time(), tr_formatter
Cache::Key Cache::makeKey(tr_torrent const* torrent, tr_block_info::Location loc) noexcept

View File

@ -114,8 +114,12 @@ enum handshake_state_t
struct tr_handshake
{
tr_handshake(std::shared_ptr<tr_handshake_mediator> mediator_in, tr_encryption_mode encryption_mode_in)
tr_handshake(
std::shared_ptr<tr_handshake_mediator> mediator_in,
std::shared_ptr<tr_peerIo> io_in,
tr_encryption_mode encryption_mode_in)
: mediator{ std::move(mediator_in) }
, io{ std::move(io_in) }
, dh{ mediator->privateKey() }
, encryption_mode{ encryption_mode_in }
{
@ -125,16 +129,9 @@ struct tr_handshake
tr_handshake(tr_handshake const&) = delete;
tr_handshake& operator=(tr_handshake&&) = delete;
tr_handshake& operator=(tr_handshake const&) = delete;
~tr_handshake() = default;
~tr_handshake()
{
if (io != nullptr)
{
tr_peerIoUnref(io); /* balanced by the ref in tr_handshakeNew */
}
}
[[nodiscard]] auto constexpr isIncoming() const noexcept
[[nodiscard]] auto isIncoming() const noexcept
{
return io->isIncoming();
}
@ -143,7 +140,7 @@ struct tr_handshake
bool haveReadAnythingFromPeer = false;
bool haveSentBitTorrentHandshake = false;
tr_peerIo* io = nullptr;
std::shared_ptr<tr_peerIo> const io;
DH dh = {};
handshake_state_t state = AWAITING_HANDSHAKE;
tr_encryption_mode encryption_mode;
@ -1133,20 +1130,18 @@ static void gotError(tr_peerIo* io, short what, void* vhandshake)
tr_handshake* tr_handshakeNew(
std::shared_ptr<tr_handshake_mediator> mediator,
tr_peerIo* io,
std::shared_ptr<tr_peerIo> io,
tr_encryption_mode encryption_mode,
tr_handshake_done_func done_func,
void* done_func_user_data)
{
auto* const handshake = new tr_handshake{ std::move(mediator), encryption_mode };
handshake->io = io;
auto* const handshake = new tr_handshake{ std::move(mediator), std::move(io), encryption_mode };
handshake->done_func = done_func;
handshake->done_func_user_data = done_func_user_data;
handshake->timeout_timer = handshake->mediator->createTimer();
handshake->timeout_timer->setCallback([handshake]() { tr_handshakeAbort(handshake); });
handshake->timeout_timer->startSingleShot(HandshakeTimeoutSec);
tr_peerIoRef(io); /* balanced by the unref in ~tr_handshake() */
handshake->io->setCallbacks(canRead, nullptr, gotError, handshake);
if (handshake->isIncoming())
@ -1169,13 +1164,3 @@ tr_handshake* tr_handshakeNew(
return handshake;
}
tr_peerIo* tr_handshakeStealIO(tr_handshake* handshake)
{
TR_ASSERT(handshake != nullptr);
TR_ASSERT(handshake->io != nullptr);
tr_peerIo* io = handshake->io;
handshake->io = nullptr;
return io;
}

View File

@ -31,7 +31,7 @@ struct tr_handshake;
struct tr_handshake_result
{
struct tr_handshake* handshake;
tr_peerIo* io;
std::shared_ptr<tr_peerIo> io;
bool readAnythingFromPeer;
bool isConnected;
void* userData;
@ -77,13 +77,11 @@ using tr_handshake_done_func = bool (*)(tr_handshake_result const& result);
/** @brief create a new handshake */
tr_handshake* tr_handshakeNew(
std::shared_ptr<tr_handshake_mediator> mediator,
tr_peerIo* io,
std::shared_ptr<tr_peerIo> io,
tr_encryption_mode encryption_mode,
tr_handshake_done_func done_func,
void* done_func_user_data);
void tr_handshakeAbort(tr_handshake* handshake);
tr_peerIo* tr_handshakeStealIO(tr_handshake* handshake);
/** @} */

View File

@ -26,7 +26,6 @@
#include "peer-io.h"
#include "tr-assert.h"
#include "tr-utp.h"
#include "trevent.h" /* tr_runInEventThread() */
#include "utils.h"
#ifdef _WIN32
@ -113,12 +112,11 @@ static void didWriteWrapper(tr_peerIo* io, unsigned int bytes_transferred)
}
}
static void canReadWrapper(tr_peerIo* io)
static void canReadWrapper(tr_peerIo* io_in)
{
auto const io = io_in->shared_from_this();
tr_logAddTraceIo(io, "canRead");
tr_peerIoRef(io);
tr_session const* const session = io->session;
/* try to consume the input buffer */
@ -134,7 +132,7 @@ static void canReadWrapper(tr_peerIo* io)
{
size_t piece = 0;
size_t const oldLen = evbuffer_get_length(io->inbuf.get());
int const ret = io->canRead(io, io->userData, &piece);
int const ret = io->canRead(io.get(), io->userData, &piece);
size_t const used = oldLen - evbuffer_get_length(io->inbuf.get());
unsigned int const overhead = guessPacketOverhead(used);
@ -175,12 +173,8 @@ static void canReadWrapper(tr_peerIo* io)
err = true;
break;
}
TR_ASSERT(tr_isPeerIo(io));
}
}
tr_peerIoUnref(io);
}
static void event_read_cb(evutil_socket_t fd, short /*event*/, void* vio)
@ -491,7 +485,7 @@ static uint64 utp_callback(utp_callback_arguments* args)
#endif /* #ifdef WITH_UTP */
tr_peerIo* tr_peerIo::create(
std::shared_ptr<tr_peerIo> tr_peerIo::create(
tr_session* session,
tr_bandwidth* parent,
tr_address const* addr,
@ -519,7 +513,9 @@ tr_peerIo* tr_peerIo::create(
maybeSetCongestionAlgorithm(socket.handle.tcp, session->peerCongestionAlgorithm());
}
auto* io = new tr_peerIo{ session, torrent_hash, is_incoming, *addr, port, is_seed, current_time, parent };
auto io = std::shared_ptr<tr_peerIo>{
new tr_peerIo{ session, torrent_hash, is_incoming, *addr, port, is_seed, current_time, parent }
};
io->socket = socket;
io->bandwidth().setPeer(io);
tr_logAddTraceIo(io, fmt::format("bandwidth is {}; its parent is {}", fmt::ptr(&io->bandwidth()), fmt::ptr(parent)));
@ -528,15 +524,15 @@ tr_peerIo* tr_peerIo::create(
{
case TR_PEER_SOCKET_TYPE_TCP:
tr_logAddTraceIo(io, fmt::format("socket (tcp) is {}", socket.handle.tcp));
io->event_read = event_new(session->eventBase(), socket.handle.tcp, EV_READ, event_read_cb, io);
io->event_write = event_new(session->eventBase(), socket.handle.tcp, EV_WRITE, event_write_cb, io);
io->event_read = event_new(session->eventBase(), socket.handle.tcp, EV_READ, event_read_cb, io.get());
io->event_write = event_new(session->eventBase(), socket.handle.tcp, EV_WRITE, event_write_cb, io.get());
break;
#ifdef WITH_UTP
case TR_PEER_SOCKET_TYPE_UTP:
tr_logAddTraceIo(io, fmt::format("socket (utp) is {}", fmt::ptr(socket.handle.utp)));
utp_set_userdata(socket.handle.utp, io);
utp_set_userdata(socket.handle.utp, io.get());
break;
#endif
@ -563,7 +559,7 @@ void tr_peerIo::utpInit([[maybe_unused]] struct_utp_context* ctx)
#endif
}
tr_peerIo* tr_peerIo::newIncoming(
std::shared_ptr<tr_peerIo> tr_peerIo::newIncoming(
tr_session* session,
tr_bandwidth* parent,
tr_address const* addr,
@ -577,7 +573,7 @@ tr_peerIo* tr_peerIo::newIncoming(
return tr_peerIo::create(session, parent, addr, port, current_time, nullptr, true, false, socket);
}
tr_peerIo* tr_peerIo::newOutgoing(
std::shared_ptr<tr_peerIo> tr_peerIo::newOutgoing(
tr_session* session,
tr_bandwidth* parent,
tr_address const* addr,
@ -620,7 +616,6 @@ tr_peerIo* tr_peerIo::newOutgoing(
static void event_enable(tr_peerIo* io, short event)
{
TR_ASSERT(tr_amInEventThread(io->session));
TR_ASSERT(io->session != nullptr);
TR_ASSERT(io->session->events != nullptr);
@ -659,8 +654,6 @@ static void event_enable(tr_peerIo* io, short event)
static void event_disable(tr_peerIo* io, short event)
{
TR_ASSERT(tr_amInEventThread(io->session));
TR_ASSERT(io->session != nullptr);
TR_ASSERT(io->session->events != nullptr);
bool const need_events = io->socket.type == TR_PEER_SOCKET_TYPE_TCP;
@ -699,8 +692,6 @@ static void event_disable(tr_peerIo* io, short event)
void tr_peerIo::setEnabled(tr_direction dir, bool is_enabled)
{
TR_ASSERT(tr_isDirection(dir));
TR_ASSERT(tr_amInEventThread(session));
TR_ASSERT(session->events != nullptr);
short const event = dir == TR_UP ? EV_WRITE : EV_READ;
@ -757,55 +748,18 @@ static void io_close_socket(tr_peerIo* io)
}
}
static void io_dtor(tr_peerIo* const io)
tr_peerIo::~tr_peerIo()
{
TR_ASSERT(tr_isPeerIo(io));
TR_ASSERT(tr_amInEventThread(io->session));
TR_ASSERT(io->session->events != nullptr);
auto const lock = session->unique_lock();
TR_ASSERT(session->events != nullptr);
tr_logAddTraceIo(io, "in tr_peerIo destructor");
event_disable(io, EV_READ | EV_WRITE);
io_close_socket(io);
this->canRead = nullptr;
this->didWrite = nullptr;
this->gotError = nullptr;
io->magic_number = ~0;
delete io;
}
static void tr_peerIoFree(tr_peerIo* io)
{
if (io != nullptr)
{
tr_logAddTraceIo(io, "in tr_peerIoFree");
io->canRead = nullptr;
io->didWrite = nullptr;
io->gotError = nullptr;
tr_runInEventThread(io->session, io_dtor, io);
}
}
void tr_peerIoRefImpl(char const* file, int line, tr_peerIo* io)
{
TR_ASSERT(tr_isPeerIo(io));
tr_logAddTraceIo(
io,
fmt::format("{}:{} incrementing the IO's refcount from {} to {}", file, line, io->refCount, io->refCount + 1));
++io->refCount;
}
void tr_peerIoUnrefImpl(char const* file, int line, tr_peerIo* io)
{
TR_ASSERT(tr_isPeerIo(io));
tr_logAddTraceIo(
io,
fmt::format("{}:{} decrementing the IO's refcount from {} to {}", file, line, io->refCount, io->refCount - 1));
if (--io->refCount == 0)
{
tr_peerIoFree(io);
}
tr_logAddTraceIo(this, "in tr_peerIo destructor");
event_disable(this, EV_READ | EV_WRITE);
io_close_socket(this);
}
std::string tr_peerIo::addrStr() const

View File

@ -66,14 +66,23 @@ struct evbuffer_deleter
using tr_evbuffer_ptr = std::unique_ptr<evbuffer, evbuffer_deleter>;
class tr_peerIo
namespace libtransmission::test
{
class HandshakeTest;
} // namespace libtransmission::test
class tr_peerIo final : public std::enable_shared_from_this<tr_peerIo>
{
using DH = tr_message_stream_encryption::DH;
using Filter = tr_message_stream_encryption::Filter;
public:
~tr_peerIo();
// TODO: 8 constructor args is too many; maybe a builder object?
static tr_peerIo* newOutgoing(
static std::shared_ptr<tr_peerIo> newOutgoing(
tr_session* session,
tr_bandwidth* parent,
struct tr_address const* addr,
@ -83,7 +92,7 @@ public:
bool is_seed,
bool utp);
static tr_peerIo* newIncoming(
static std::shared_ptr<tr_peerIo> newIncoming(
tr_session* session,
tr_bandwidth* parent,
struct tr_address const* addr,
@ -91,19 +100,6 @@ public:
time_t current_time,
struct tr_peer_socket const socket);
// this is only public for testing purposes.
// production code should use newOutgoing() or newIncoming()
static tr_peerIo* create(
tr_session* session,
tr_bandwidth* parent,
tr_address const* addr,
tr_port port,
time_t current_time,
tr_sha1_digest_t const* torrent_hash,
bool is_incoming,
bool is_seed,
struct tr_peer_socket const socket);
void clear();
void readBytes(void* bytes, size_t byte_count);
@ -218,7 +214,7 @@ public:
bandwidth_.setParent(parent);
}
[[nodiscard]] constexpr auto isIncoming() noexcept
[[nodiscard]] constexpr auto isIncoming() const noexcept
{
return is_incoming_;
}
@ -235,12 +231,6 @@ public:
void setCallbacks(tr_can_read_cb readcb, tr_did_write_cb writecb, tr_net_error_cb errcb, void* user_data);
// TODO(ckerr): yikes, unlike other class' magic_numbers it looks
// like this one isn't being used just for assertions, but also in
// didWriteWrapper() to see if the tr_peerIo got freed during the
// notify-consumed events. Fix this before removing this field.
int magic_number = PEER_IO_MAGIC_NUMBER;
struct tr_peer_socket socket = {};
tr_session* const session;
@ -260,9 +250,6 @@ public:
struct event* event_read = nullptr;
struct event* event_write = nullptr;
// TODO: use std::shared_ptr instead of manual refcounting?
int refCount = 1;
short int pendingEvents = 0;
tr_priority_t priority = TR_PRI_NORMAL;
@ -297,6 +284,21 @@ public:
static void utpInit(struct_utp_context* ctx);
private:
friend class libtransmission::test::HandshakeTest;
// this is only public for testing purposes.
// production code should use newOutgoing() or newIncoming()
static std::shared_ptr<tr_peerIo> create(
tr_session* session,
tr_bandwidth* parent,
tr_address const* addr,
tr_port port,
time_t current_time,
tr_sha1_digest_t const* torrent_hash,
bool is_incoming,
bool is_seed,
struct tr_peer_socket const socket);
tr_peerIo(
tr_session* session_in,
tr_sha1_digest_t const* torrent_hash,
@ -347,24 +349,11 @@ private:
bool fast_extension_supported_ = false;
};
void tr_peerIoRefImpl(char const* file, int line, tr_peerIo* io);
#define tr_peerIoRef(io) tr_peerIoRefImpl(__FILE__, __LINE__, (io))
void tr_peerIoUnrefImpl(char const* file, int line, tr_peerIo* io);
#define tr_peerIoUnref(io) tr_peerIoUnrefImpl(__FILE__, __LINE__, (io))
constexpr bool tr_isPeerIo(tr_peerIo const* io)
{
return io != nullptr && io->magic_number == PEER_IO_MAGIC_NUMBER && io->refCount >= 0 &&
tr_address_is_valid(&io->address());
return io != nullptr && tr_address_is_valid(&io->address());
}
/**
***
**/
void evbuffer_add_uint8(struct evbuffer* outbuf, uint8_t addme);
void evbuffer_add_uint16(struct evbuffer* outbuf, uint16_t hs);
void evbuffer_add_uint32(struct evbuffer* outbuf, uint32_t hl);
@ -373,5 +362,3 @@ void evbuffer_add_uint64(struct evbuffer* outbuf, uint64_t hll);
void evbuffer_add_hton_16(struct evbuffer* buf, uint16_t val);
void evbuffer_add_hton_32(struct evbuffer* buf, uint32_t val);
void evbuffer_add_hton_64(struct evbuffer* buf, uint64_t val);
/* @} */

View File

@ -1118,7 +1118,7 @@ static struct peer_atom* ensureAtomExists(
return tor->max_connected_peers;
}
static void createBitTorrentPeer(tr_torrent* tor, tr_peerIo* io, struct peer_atom* atom, tr_quark client)
static void createBitTorrentPeer(tr_torrent* tor, std::shared_ptr<tr_peerIo> io, struct peer_atom* atom, tr_quark client)
{
TR_ASSERT(atom != nullptr);
TR_ASSERT(tr_isTorrent(tor));
@ -1126,7 +1126,7 @@ static void createBitTorrentPeer(tr_torrent* tor, tr_peerIo* io, struct peer_ato
tr_swarm* swarm = tor->swarm;
auto* peer = tr_peerMsgsNew(tor, atom, io, peerCallbackFunc, swarm);
auto* peer = tr_peerMsgsNew(tor, atom, std::move(io), peerCallbackFunc, swarm);
peer->client = client;
atom->is_connected = true;
@ -1232,10 +1232,8 @@ static bool on_handshake_done(tr_handshake_result const& result)
client = tr_quark_new(std::data(buf));
}
/* this steals its refcount too, which is balanced by our unref in peerDelete() */
tr_peerIo* stolen = tr_handshakeStealIO(result.handshake);
stolen->setParent(&s->tor->bandwidth_);
createBitTorrentPeer(s->tor, stolen, atom, client);
result.io->setParent(&s->tor->bandwidth_);
createBitTorrentPeer(s->tor, result.io, atom, client);
success = true;
}
@ -1263,11 +1261,8 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const* addr, tr_port
else /* we don't have a connection to them yet... */
{
auto mediator = std::make_shared<tr_handshake_mediator_impl>(*session);
tr_peerIo* const io = tr_peerIo::newIncoming(session, &session->top_bandwidth_, addr, port, tr_time(), socket);
tr_handshake* const handshake = tr_handshakeNew(mediator, io, session->encryptionMode(), on_handshake_done, manager);
tr_peerIoUnref(io); /* balanced by the implicit ref in tr_peerIo::NewIncoming() */
auto io = tr_peerIo::newIncoming(session, &session->top_bandwidth_, addr, port, tr_time(), socket);
auto* const handshake = tr_handshakeNew(mediator, std::move(io), session->encryptionMode(), on_handshake_done, manager);
manager->incoming_handshakes.add(*addr, handshake);
}
}
@ -2836,7 +2831,7 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom)
tr_logAddTraceSwarm(s, fmt::format("Starting an OUTGOING {} connection with {}", utp ? " µTP" : "TCP", atom.readable()));
tr_peerIo* const io = tr_peerIo::newOutgoing(
auto io = tr_peerIo::newOutgoing(
mgr->session,
&mgr->session->top_bandwidth_,
&atom.addr,
@ -2855,12 +2850,12 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom)
else
{
auto mediator = std::make_shared<tr_handshake_mediator_impl>(*mgr->session);
tr_handshake* handshake = tr_handshakeNew(mediator, io, mgr->session->encryptionMode(), on_handshake_done, mgr);
TR_ASSERT(io->torrentHash());
tr_peerIoUnref(io); /* balanced by the initial ref in tr_peerIo::newOutgoing() */
auto* const handshake = tr_handshakeNew(
mediator,
std::move(io),
mgr->session->encryptionMode(),
on_handshake_done,
mgr);
s->outgoing_handshakes.add(atom.addr, handshake);
}

View File

@ -248,12 +248,17 @@ static void updateDesiredRequestCount(tr_peerMsgsImpl* msgs);
class tr_peerMsgsImpl final : public tr_peerMsgs
{
public:
tr_peerMsgsImpl(tr_torrent* torrent_in, peer_atom* atom_in, tr_peerIo* io_in, tr_peer_callback callback, void* callbackData)
tr_peerMsgsImpl(
tr_torrent* torrent_in,
peer_atom* atom_in,
std::shared_ptr<tr_peerIo> io_in,
tr_peer_callback callback,
void* callbackData)
: tr_peerMsgs{ torrent_in, atom_in }
, outMessagesBatchPeriod{ LowPriorityIntervalSecs }
, torrent{ torrent_in }
, outMessages{ evbuffer_new() }
, io{ io_in }
, io{ std::move(io_in) }
, have_{ torrent_in->pieceCount() }
, callback_{ callback }
, callbackData_{ callbackData }
@ -300,10 +305,9 @@ public:
set_active(TR_UP, false);
set_active(TR_DOWN, false);
if (this->io != nullptr)
if (this->io)
{
this->io->clear();
tr_peerIoUnref(this->io); /* balanced by the ref in handshakeDoneCB() */
}
evbuffer_free(this->outMessages);
@ -816,7 +820,7 @@ public:
evbuffer* const outMessages; /* all the non-piece messages */
tr_peerIo* const io;
std::shared_ptr<tr_peerIo> const io;
struct QueuedPeerRequest : public peer_request
{
@ -864,9 +868,14 @@ private:
static auto constexpr SendPexInterval = 90s;
};
tr_peerMsgs* tr_peerMsgsNew(tr_torrent* torrent, peer_atom* atom, tr_peerIo* io, tr_peer_callback callback, void* callback_data)
tr_peerMsgs* tr_peerMsgsNew(
tr_torrent* torrent,
peer_atom* atom,
std::shared_ptr<tr_peerIo> io,
tr_peer_callback callback,
void* callback_data)
{
return new tr_peerMsgsImpl(torrent, atom, io, callback, callback_data);
return new tr_peerMsgsImpl(torrent, atom, std::move(io), callback, callback_data);
}
/**
@ -2377,7 +2386,7 @@ static void peerPulse(void* vmsgs)
auto* msgs = static_cast<tr_peerMsgsImpl*>(vmsgs);
time_t const now = tr_time();
if (tr_isPeerIo(msgs->io))
if (msgs->io)
{
updateDesiredRequestCount(msgs);
updateBlockRequests(msgs);

View File

@ -12,6 +12,7 @@
#include <cstdint> // int8_t
#include <cstddef> // size_t
#include <ctime> // time_t
#include <memory>
#include <utility>
#include "bitfield.h"
@ -77,7 +78,7 @@ protected:
tr_peerMsgs* tr_peerMsgsNew(
tr_torrent* torrent,
peer_atom* atom,
tr_peerIo* io,
std::shared_ptr<tr_peerIo> io,
tr_peer_callback callback,
void* callback_data);

View File

@ -33,197 +33,199 @@ namespace test
auto constexpr MaxWaitMsec = int{ 5000 };
using HandshakeTest = SessionTest;
class MediatorMock final : public tr_handshake_mediator
class HandshakeTest : public SessionTest
{
public:
explicit MediatorMock(tr_session* session)
: session_{ session }
class MediatorMock final : public tr_handshake_mediator
{
}
virtual ~MediatorMock() = default;
[[nodiscard]] std::optional<torrent_info> torrentInfo(tr_sha1_digest_t const& info_hash) const override
{
if (auto const iter = torrents.find(info_hash); iter != std::end(torrents))
public:
explicit MediatorMock(tr_session* session)
: session_{ session }
{
return iter->second;
}
return {};
}
virtual ~MediatorMock() = default;
[[nodiscard]] std::optional<torrent_info> torrentInfoFromObfuscated(tr_sha1_digest_t const& obfuscated) const override
{
for (auto const& [info_hash, info] : torrents)
[[nodiscard]] std::optional<torrent_info> torrentInfo(tr_sha1_digest_t const& info_hash) const override
{
if (obfuscated == tr_sha1::digest("req2"sv, info.info_hash))
if (auto const iter = torrents.find(info_hash); iter != std::end(torrents))
{
return info;
return iter->second;
}
return {};
}
return {};
}
[[nodiscard]] std::optional<torrent_info> torrentInfoFromObfuscated(tr_sha1_digest_t const& obfuscated) const override
{
for (auto const& [info_hash, info] : torrents)
{
if (obfuscated == tr_sha1::digest("req2"sv, info.info_hash))
{
return info;
}
}
[[nodiscard]] std::unique_ptr<libtransmission::Timer> createTimer() override
{
return session_->timerMaker().create();
}
return {};
}
[[nodiscard]] bool isDHTEnabled() const override
{
return false;
}
[[nodiscard]] std::unique_ptr<libtransmission::Timer> createTimer() override
{
return session_->timerMaker().create();
}
[[nodiscard]] bool allowsTCP() const override
{
return true;
}
[[nodiscard]] bool isDHTEnabled() const override
{
return false;
}
[[nodiscard]] bool isPeerKnownSeed(tr_torrent_id_t /*tor_id*/, tr_address /*addr*/) const override
{
return false;
}
[[nodiscard]] bool allowsTCP() const override
{
return true;
}
[[nodiscard]] size_t pad(void* setme, [[maybe_unused]] size_t maxlen) const override
{
TR_ASSERT(maxlen > 10);
auto const len = size_t{ 10 };
std::fill_n(static_cast<char*>(setme), 10, ' ');
return len;
}
[[nodiscard]] bool isPeerKnownSeed(tr_torrent_id_t /*tor_id*/, tr_address /*addr*/) const override
{
return false;
}
[[nodiscard]] tr_message_stream_encryption::DH::private_key_bigend_t privateKey() const override
{
return private_key_;
}
[[nodiscard]] size_t pad(void* setme, [[maybe_unused]] size_t maxlen) const override
{
TR_ASSERT(maxlen > 10);
auto const len = size_t{ 10 };
std::fill_n(static_cast<char*>(setme), 10, ' ');
return len;
}
void setUTPFailed(tr_sha1_digest_t const& /*info_hash*/, tr_address /*addr*/) override
{
}
[[nodiscard]] tr_message_stream_encryption::DH::private_key_bigend_t privateKey() const override
{
return private_key_;
}
void setPrivateKeyFromBase64(std::string_view b64)
{
auto const str = tr_base64_decode(b64);
assert(std::size(str) == std::size(private_key_));
std::copy_n(reinterpret_cast<std::byte const*>(std::data(str)), std::size(str), std::begin(private_key_));
}
void setUTPFailed(tr_sha1_digest_t const& /*info_hash*/, tr_address /*addr*/) override
{
}
tr_session* const session_;
std::map<tr_sha1_digest_t, torrent_info> torrents;
tr_message_stream_encryption::DH::private_key_bigend_t private_key_ = {};
};
void setPrivateKeyFromBase64(std::string_view b64)
{
auto const str = tr_base64_decode(b64);
assert(std::size(str) == std::size(private_key_));
std::copy_n(reinterpret_cast<std::byte const*>(std::data(str)), std::size(str), std::begin(private_key_));
}
template<typename Span>
void sendToClient(evutil_socket_t sock, Span const& data)
{
auto const* walk = std::data(data);
static_assert(sizeof(*walk) == 1);
size_t len = std::size(data);
while (len > 0)
{
#if defined(_WIN32)
auto const n = send(sock, reinterpret_cast<char const*>(walk), len, 0);
#else
auto const n = write(sock, walk, len);
#endif
assert(n >= 0);
len -= n;
walk += n;
}
}
void sendB64ToClient(evutil_socket_t sock, std::string_view b64)
{
sendToClient(sock, tr_base64_decode(b64));
}
auto constexpr ReservedBytesNoExtensions = std::array<uint8_t, 8>{ 0, 0, 0, 0, 0, 0, 0, 0 };
auto constexpr PlaintextProtocolName = "\023BitTorrent protocol"sv;
auto const DefaultPeerAddr = *tr_address::fromString("127.0.0.1"sv);
auto const DefaultPeerPort = tr_port::fromHost(8080);
auto const TorrentWeAreSeeding = tr_handshake_mediator::torrent_info{ tr_sha1::digest("abcde"sv),
tr_peerIdInit(),
tr_torrent_id_t{ 100 },
true /*is_done*/ };
auto const UbuntuTorrent = tr_handshake_mediator::torrent_info{ *tr_sha1_from_string(
"2c6b6858d61da9543d4231a71db4b1c9264b0685"sv),
tr_peerIdInit(),
tr_torrent_id_t{ 101 },
false /*is_done*/ };
auto createIncomingIo(tr_session* session)
{
auto sockpair = std::array<evutil_socket_t, 2>{ -1, -1 };
EXPECT_EQ(0, evutil_socketpair(LOCAL_SOCKETPAIR_AF, SOCK_STREAM, 0, std::data(sockpair))) << tr_strerror(errno);
auto const now = tr_time();
auto const peer_socket = tr_peer_socket_tcp_create(sockpair[0]);
auto* const
io = tr_peerIo::newIncoming(session, &session->top_bandwidth_, &DefaultPeerAddr, DefaultPeerPort, now, peer_socket);
return std::make_pair(io, sockpair[1]);
}
auto createOutgoingIo(tr_session* session, tr_sha1_digest_t const& info_hash)
{
auto sockpair = std::array<evutil_socket_t, 2>{ -1, -1 };
EXPECT_EQ(0, evutil_socketpair(LOCAL_SOCKETPAIR_AF, SOCK_STREAM, 0, std::data(sockpair))) << tr_strerror(errno);
auto const now = tr_time();
auto const peer_socket = tr_peer_socket_tcp_create(sockpair[0]);
auto* const io = tr_peerIo::create(
session,
&session->top_bandwidth_,
&DefaultPeerAddr,
DefaultPeerPort,
now,
&info_hash,
false /*is_incoming*/,
false /*is_seed*/,
peer_socket);
return std::make_pair(io, sockpair[1]);
}
constexpr auto makePeerId(std::string_view sv)
{
auto peer_id = tr_peer_id_t{};
for (size_t i = 0, n = std::size(sv); i < n; ++i)
{
peer_id[i] = sv[i];
}
return peer_id;
}
auto makeRandomPeerId()
{
auto peer_id = tr_peer_id_t{};
tr_rand_buffer(std::data(peer_id), std::size(peer_id));
auto const peer_id_prefix = "-UW110Q-"sv;
std::copy(std::begin(peer_id_prefix), std::end(peer_id_prefix), std::begin(peer_id));
return peer_id;
}
auto runHandshake(
std::shared_ptr<tr_handshake_mediator> mediator,
tr_peerIo* io,
tr_encryption_mode encryption_mode = TR_CLEAR_PREFERRED)
{
auto result = std::optional<tr_handshake_result>{};
static auto const DoneCallback = [](auto const& resin)
{
*static_cast<std::optional<tr_handshake_result>*>(resin.userData) = resin;
return true;
tr_session* const session_;
std::map<tr_sha1_digest_t, torrent_info> torrents;
tr_message_stream_encryption::DH::private_key_bigend_t private_key_ = {};
};
tr_handshakeNew(std::move(mediator), io, encryption_mode, DoneCallback, &result);
template<typename Span>
void sendToClient(evutil_socket_t sock, Span const& data)
{
auto const* walk = std::data(data);
static_assert(sizeof(*walk) == 1);
size_t len = std::size(data);
waitFor([&result]() { return result.has_value(); }, MaxWaitMsec);
while (len > 0)
{
#if defined(_WIN32)
auto const n = send(sock, reinterpret_cast<char const*>(walk), len, 0);
#else
auto const n = write(sock, walk, len);
#endif
assert(n >= 0);
len -= n;
walk += n;
}
}
return result;
}
void sendB64ToClient(evutil_socket_t sock, std::string_view b64)
{
sendToClient(sock, tr_base64_decode(b64));
}
static auto constexpr ReservedBytesNoExtensions = std::array<uint8_t, 8>{ 0, 0, 0, 0, 0, 0, 0, 0 };
static auto constexpr PlaintextProtocolName = "\023BitTorrent protocol"sv;
tr_address const DefaultPeerAddr = *tr_address::fromString("127.0.0.1"sv);
tr_port const DefaultPeerPort = tr_port::fromHost(8080);
tr_handshake_mediator::torrent_info const TorrentWeAreSeeding{ tr_sha1::digest("abcde"sv),
tr_peerIdInit(),
tr_torrent_id_t{ 100 },
true /*is_done*/ };
tr_handshake_mediator::torrent_info const UbuntuTorrent{ *tr_sha1_from_string("2c6b6858d61da9543d4231a71db4b1c9264b0685"sv),
tr_peerIdInit(),
tr_torrent_id_t{ 101 },
false /*is_done*/ };
auto createIncomingIo(tr_session* session)
{
auto sockpair = std::array<evutil_socket_t, 2>{ -1, -1 };
EXPECT_EQ(0, evutil_socketpair(LOCAL_SOCKETPAIR_AF, SOCK_STREAM, 0, std::data(sockpair))) << tr_strerror(errno);
auto const now = tr_time();
auto const peer_socket = tr_peer_socket_tcp_create(sockpair[0]);
auto
io = tr_peerIo::newIncoming(session, &session->top_bandwidth_, &DefaultPeerAddr, DefaultPeerPort, now, peer_socket);
return std::make_pair(io, sockpair[1]);
}
auto createOutgoingIo(tr_session* session, tr_sha1_digest_t const& info_hash)
{
auto sockpair = std::array<evutil_socket_t, 2>{ -1, -1 };
EXPECT_EQ(0, evutil_socketpair(LOCAL_SOCKETPAIR_AF, SOCK_STREAM, 0, std::data(sockpair))) << tr_strerror(errno);
auto const now = tr_time();
auto const peer_socket = tr_peer_socket_tcp_create(sockpair[0]);
auto io = tr_peerIo::create(
session,
&session->top_bandwidth_,
&DefaultPeerAddr,
DefaultPeerPort,
now,
&info_hash,
false /*is_incoming*/,
false /*is_seed*/,
peer_socket);
return std::make_pair(io, sockpair[1]);
}
static constexpr auto makePeerId(std::string_view sv)
{
auto peer_id = tr_peer_id_t{};
for (size_t i = 0, n = std::size(sv); i < n; ++i)
{
peer_id[i] = sv[i];
}
return peer_id;
}
static auto makeRandomPeerId()
{
auto peer_id = tr_peer_id_t{};
tr_rand_buffer(std::data(peer_id), std::size(peer_id));
auto const peer_id_prefix = "-UW110Q-"sv;
std::copy(std::begin(peer_id_prefix), std::end(peer_id_prefix), std::begin(peer_id));
return peer_id;
}
static auto runHandshake(
std::shared_ptr<tr_handshake_mediator> mediator,
std::shared_ptr<tr_peerIo> io,
tr_encryption_mode encryption_mode = TR_CLEAR_PREFERRED)
{
auto result = std::optional<tr_handshake_result>{};
static auto const DoneCallback = [](auto const& resin)
{
*static_cast<std::optional<tr_handshake_result>*>(resin.userData) = resin;
return true;
};
tr_handshakeNew(std::move(mediator), std::move(io), encryption_mode, DoneCallback, &result);
waitFor([&result]() { return result.has_value(); }, MaxWaitMsec);
return result;
}
};
TEST_F(HandshakeTest, incomingPlaintext)
{
@ -257,7 +259,6 @@ TEST_F(HandshakeTest, incomingPlaintext)
EXPECT_TRUE(io->torrentHash());
EXPECT_EQ(TorrentWeAreSeeding.info_hash, *io->torrentHash());
tr_peerIoUnref(io);
evutil_closesocket(sock);
}
@ -284,7 +285,6 @@ TEST_F(HandshakeTest, incomingPlaintextUnknownInfoHash)
EXPECT_FALSE(res->peer_id);
EXPECT_FALSE(io->torrentHash());
tr_peerIoUnref(io);
evutil_closesocket(sock);
}
@ -313,7 +313,6 @@ TEST_F(HandshakeTest, outgoingPlaintext)
EXPECT_EQ(UbuntuTorrent.info_hash, *io->torrentHash());
EXPECT_EQ(tr_sha1_to_string(UbuntuTorrent.info_hash), tr_sha1_to_string(*io->torrentHash()));
tr_peerIoUnref(io);
evutil_closesocket(sock);
}
@ -353,7 +352,6 @@ TEST_F(HandshakeTest, incomingEncrypted)
EXPECT_EQ(UbuntuTorrent.info_hash, *io->torrentHash());
EXPECT_EQ(tr_sha1_to_string(UbuntuTorrent.info_hash), tr_sha1_to_string(*io->torrentHash()));
tr_peerIoUnref(io);
evutil_closesocket(sock);
}
@ -387,7 +385,6 @@ TEST_F(HandshakeTest, incomingEncryptedUnknownInfoHash)
EXPECT_TRUE(res->readAnythingFromPeer);
EXPECT_FALSE(io->torrentHash());
tr_peerIoUnref(io);
evutil_closesocket(sock);
}
@ -432,7 +429,6 @@ TEST_F(HandshakeTest, outgoingEncrypted)
EXPECT_EQ(UbuntuTorrent.info_hash, *io->torrentHash());
EXPECT_EQ(tr_sha1_to_string(UbuntuTorrent.info_hash), tr_sha1_to_string(*io->torrentHash()));
tr_peerIoUnref(io);
evutil_closesocket(sock);
}