refactor: peer-socket pt 2 (#4326)

* refactor: tr_netOpenPeerSocket() now takes a tr_address reference

* refactor: disable copy assignment, copy constructor

* refactor: move log statements to peer_socket constructor
This commit is contained in:
Charles Kerr 2022-12-06 10:28:28 -06:00 committed by GitHub
parent 9a5d9a0ba2
commit 22a3a5db25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 132 additions and 111 deletions

View File

@ -256,22 +256,22 @@ static tr_socket_t createSocket(tr_session* session, int domain, int type)
return sockfd;
}
struct tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const* addr, tr_port port, bool client_is_seed)
tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr, tr_port port, bool client_is_seed)
{
TR_ASSERT(tr_address_is_valid(addr));
TR_ASSERT(tr_address_is_valid(&addr));
if (!session->allowsTCP())
{
return {};
}
if (!tr_address_is_valid_for_peers(addr, port))
if (!tr_address_is_valid_for_peers(&addr, port))
{
return {};
}
static auto constexpr Domains = std::array<int, NUM_TR_AF_INET_TYPES>{ AF_INET, AF_INET6 };
auto const s = createSocket(session, Domains[addr->type], SOCK_STREAM);
auto const s = createSocket(session, Domains[addr.type], SOCK_STREAM);
if (s == TR_BAD_SOCKET)
{
return {};
@ -288,10 +288,10 @@ struct tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const
}
}
auto const [sock, addrlen] = addr->toSockaddr(port);
auto const [sock, addrlen] = addr.toSockaddr(port);
// set source address
auto const [source_addr, is_default_addr] = session->publicAddress(addr->type);
auto const [source_addr, is_default_addr] = session->publicAddress(addr.type);
auto const [source_sock, sourcelen] = source_addr.toSockaddr({});
if (bind(s, reinterpret_cast<sockaddr const*>(&source_sock), sourcelen) == -1)
@ -313,12 +313,12 @@ struct tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const
#endif
sockerrno != EINPROGRESS)
{
if (auto const tmperrno = sockerrno; (tmperrno != ENETUNREACH && tmperrno != EHOSTUNREACH) || addr->isIPv4())
if (auto const tmperrno = sockerrno; (tmperrno != ENETUNREACH && tmperrno != EHOSTUNREACH) || addr.isIPv4())
{
tr_logAddWarn(fmt::format(
_("Couldn't connect socket {socket} to {address}:{port}: {error} ({error_code})"),
fmt::arg("socket", s),
fmt::arg("address", addr->readable()),
fmt::arg("address", addr.readable()),
fmt::arg("port", port.host()),
fmt::arg("error", tr_net_strerror(tmperrno)),
fmt::arg("error_code", tmperrno)));
@ -328,31 +328,27 @@ struct tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const
}
else
{
ret = tr_peer_socket{ *addr, port, s };
ret = tr_peer_socket{ session, addr, port, s };
}
tr_logAddTrace(fmt::format("New OUTGOING connection {} ({})", s, addr->readable(port)));
tr_logAddTrace(fmt::format("New OUTGOING connection {} ({})", s, addr.readable(port)));
return ret;
}
struct tr_peer_socket tr_netOpenPeerUTPSocket(
tr_session* session,
tr_address const* addr,
tr_port port,
bool /*client_is_seed*/)
tr_peer_socket tr_netOpenPeerUTPSocket(tr_session* session, tr_address const& addr, tr_port port, bool /*client_is_seed*/)
{
auto ret = tr_peer_socket{};
if (session->utp_context != nullptr && tr_address_is_valid_for_peers(addr, port))
if (session->utp_context != nullptr && tr_address_is_valid_for_peers(&addr, port))
{
auto const [ss, sslen] = addr->toSockaddr(port);
auto const [ss, sslen] = addr.toSockaddr(port);
if (auto* const sock = utp_create_socket(session->utp_context); sock != nullptr)
{
if (utp_connect(sock, reinterpret_cast<sockaddr const*>(&ss), sslen) != -1)
{
ret = tr_peer_socket{ *addr, port, sock };
ret = tr_peer_socket{ addr, port, sock };
}
else
{

View File

@ -308,14 +308,6 @@ static void event_write_cb(evutil_socket_t fd, short /*event*/, void* vio)
***
**/
static void maybeSetCongestionAlgorithm(tr_socket_t socket, std::string const& algorithm)
{
if (!std::empty(algorithm))
{
tr_netSetCongestionControl(socket, algorithm.c_str());
}
}
#ifdef WITH_UTP
/* µTP callbacks */
@ -458,48 +450,54 @@ static uint64 utp_callback(utp_callback_arguments* args)
#endif /* #ifdef WITH_UTP */
std::shared_ptr<tr_peerIo> tr_peerIo::create(
tr_session* session,
tr_bandwidth* parent,
tr_peerIo::tr_peerIo(
tr_session* session_in,
tr_sha1_digest_t const* torrent_hash,
bool is_incoming,
bool is_seed,
tr_peer_socket socket)
tr_bandwidth* parent_bandwidth,
tr_peer_socket sock)
: socket{ std::move(sock) }
, session{ session_in }
, bandwidth_{ parent_bandwidth }
, torrent_hash_{ torrent_hash != nullptr ? *torrent_hash : tr_sha1_digest_t{} }
, is_seed_{ is_seed }
, is_incoming_{ is_incoming }
{
TR_ASSERT(session != nullptr);
auto lock = session->unique_lock();
TR_ASSERT(socket.is_valid());
TR_ASSERT(session->allowsTCP() || !socket.is_tcp());
if (socket.is_tcp())
{
session->setSocketTOS(socket.handle.tcp, socket.address().type);
maybeSetCongestionAlgorithm(socket.handle.tcp, session->peerCongestionAlgorithm());
}
auto io = std::make_shared<tr_peerIo>(session, torrent_hash, is_incoming, is_seed, parent, socket);
io->bandwidth().setPeer(io);
tr_logAddTraceIo(io, fmt::format("bandwidth is {}; its parent is {}", fmt::ptr(&io->bandwidth()), fmt::ptr(parent)));
if (socket.is_tcp())
{
tr_logAddTraceIo(io, fmt::format("socket (tcp) is {}", socket.handle.tcp));
io->event_read.reset(event_new(session->eventBase(), socket.handle.tcp, EV_READ, event_read_cb, io.get()));
io->event_write.reset(event_new(session->eventBase(), socket.handle.tcp, EV_WRITE, event_write_cb, io.get()));
event_read.reset(event_new(session->eventBase(), socket.handle.tcp, EV_READ, event_read_cb, this));
event_write.reset(event_new(session->eventBase(), socket.handle.tcp, EV_WRITE, event_write_cb, this));
}
#ifdef WITH_UTP
else if (socket.is_utp())
{
tr_logAddTraceIo(io, fmt::format("socket (µTP) is {}", fmt::ptr(socket.handle.utp)));
utp_set_userdata(socket.handle.utp, io.get());
utp_set_userdata(socket.handle.utp, this);
}
#endif
else
{
TR_ASSERT_MSG(false, "unsupported peer socket type");
}
}
std::shared_ptr<tr_peerIo> tr_peerIo::create(
tr_session* session,
tr_bandwidth* parent,
tr_sha1_digest_t const* torrent_hash,
bool is_incoming,
bool is_seed,
tr_peer_socket sock)
{
TR_ASSERT(session != nullptr);
auto lock = session->unique_lock();
TR_ASSERT(sock.is_valid());
TR_ASSERT(session->allowsTCP() || !sock.is_tcp());
auto io = std::make_shared<tr_peerIo>(session, torrent_hash, is_incoming, is_seed, parent, std::move(sock));
io->bandwidth().setPeer(io);
tr_logAddTraceIo(io, fmt::format("bandwidth is {}; its parent is {}", fmt::ptr(&io->bandwidth()), fmt::ptr(parent)));
return io;
}
@ -522,20 +520,20 @@ std::shared_ptr<tr_peerIo> tr_peerIo::newIncoming(tr_session* session, tr_bandwi
{
TR_ASSERT(session != nullptr);
return tr_peerIo::create(session, parent, nullptr, true, false, socket);
return tr_peerIo::create(session, parent, nullptr, true, false, std::move(socket));
}
std::shared_ptr<tr_peerIo> tr_peerIo::newOutgoing(
tr_session* session,
tr_bandwidth* parent,
tr_address const* addr,
tr_address const& addr,
tr_port port,
tr_sha1_digest_t const& torrent_hash,
bool is_seed,
bool utp)
{
TR_ASSERT(session != nullptr);
TR_ASSERT(tr_address_is_valid(addr));
TR_ASSERT(tr_address_is_valid(&addr));
TR_ASSERT(utp || session->allowsTCP());
auto socket = tr_peer_socket{};
@ -556,7 +554,7 @@ std::shared_ptr<tr_peerIo> tr_peerIo::newOutgoing(
return nullptr;
}
return create(session, parent, &torrent_hash, false, is_seed, socket);
return create(session, parent, &torrent_hash, false, is_seed, std::move(socket));
}
/***
@ -692,8 +690,8 @@ int tr_peerIo::reconnect()
io_close_socket(this);
auto const [addr, port] = this->socketAddress();
this->socket = tr_netOpenPeerSocket(session, &addr, port, this->isSeed());
auto const [addr, port] = socketAddress();
this->socket = tr_netOpenPeerSocket(session, addr, port, this->isSeed());
if (!this->socket.is_tcp())
{
@ -704,8 +702,6 @@ int tr_peerIo::reconnect()
this->event_write.reset(event_new(session->eventBase(), this->socket.handle.tcp, EV_WRITE, event_write_cb, this));
event_enable(this, pending_events);
this->session->setSocketTOS(this->socket.handle.tcp, addr.type);
maybeSetCongestionAlgorithm(this->socket.handle.tcp, session->peerCongestionAlgorithm());
return 0;
}

View File

@ -67,7 +67,7 @@ public:
static std::shared_ptr<tr_peerIo> newOutgoing(
tr_session* session,
tr_bandwidth* parent,
struct tr_address const* addr,
tr_address const& addr,
tr_port port,
tr_sha1_digest_t const& torrent_hash,
bool is_seed,
@ -222,7 +222,7 @@ public:
setCallbacks(nullptr, nullptr, nullptr, nullptr);
}
struct tr_peer_socket socket = {};
tr_peer_socket socket = {};
tr_session* const session;
@ -286,15 +286,7 @@ public:
bool is_incoming,
bool is_seed,
tr_bandwidth* parent_bandwidth,
tr_peer_socket sock)
: socket{ sock }
, session{ session_in }
, bandwidth_{ parent_bandwidth }
, torrent_hash_{ torrent_hash != nullptr ? *torrent_hash : tr_sha1_digest_t{} }
, is_seed_{ is_seed }
, is_incoming_{ is_incoming }
{
}
tr_peer_socket sock);
private:
friend class libtransmission::test::HandshakeTest;

View File

@ -1227,31 +1227,33 @@ static bool on_handshake_done(tr_handshake_result const& result)
return success;
}
void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const& addr, tr_port port, tr_peer_socket socket)
void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_peer_socket&& socket)
{
TR_ASSERT(manager->session != nullptr);
auto const lock = manager->unique_lock();
tr_session* session = manager->session;
if (session->addressIsBlocked(addr))
if (session->addressIsBlocked(socket.address()))
{
tr_logAddTrace(fmt::format("Banned IP address '{}' tried to connect to us", addr.readable(port)));
tr_logAddTrace(fmt::format("Banned IP address '{}' tried to connect to us", socket.readable()));
socket.close(session);
}
else if (manager->incoming_handshakes.contains(addr))
else if (manager->incoming_handshakes.contains(socket.address()))
{
socket.close(session);
}
else /* we don't have a connection to them yet... */
{
auto* const handshake = tr_handshakeNew(
manager->handshake_mediator_,
tr_peerIo::newIncoming(session, &session->top_bandwidth_, socket),
session->encryptionMode(),
on_handshake_done,
manager);
manager->incoming_handshakes.add(addr, handshake);
auto address = socket.address();
manager->incoming_handshakes.add(
address,
tr_handshakeNew(
manager->handshake_mediator_,
tr_peerIo::newIncoming(session, &session->top_bandwidth_, std::move(socket)),
session->encryptionMode(),
on_handshake_done,
manager));
}
}
@ -2796,7 +2798,7 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom)
auto io = tr_peerIo::newOutgoing(
mgr->session,
&mgr->session->top_bandwidth_,
&atom.addr,
atom.addr,
atom.port,
s->tor->infoHash(),
s->tor->completeness == TR_SEED,

View File

@ -170,7 +170,7 @@ void tr_peerMgrClientSentRequests(tr_torrent* torrent, tr_peer* peer, tr_block_s
[[nodiscard]] size_t tr_peerMgrCountActiveRequestsToPeer(tr_torrent const* torrent, tr_peer const* peer);
void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const& addr, tr_port port, struct tr_peer_socket const socket);
void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_peer_socket&& socket);
size_t tr_peerMgrAddPex(tr_torrent* tor, uint8_t from, tr_pex const* pex, size_t n_pex);

View File

@ -3,12 +3,49 @@
// or any future license endorsed by Mnemosyne LLC.
// License text can be found in the licenses/ folder.
#include <fmt/format.h>
#include <libutp/utp.h>
#include "transmission.h"
#include "peer-socket.h"
#include "net.h"
#include "session.h"
#define tr_logAddErrorIo(io, msg) tr_logAddError(msg, (io)->readable())
#define tr_logAddWarnIo(io, msg) tr_logAddWarn(msg, (io)->readable())
#define tr_logAddDebugIo(io, msg) tr_logAddDebug(msg, (io)->readable())
#define tr_logAddTraceIo(io, msg) tr_logAddTrace(msg, (io)->readable())
tr_peer_socket::tr_peer_socket(tr_session* session, tr_address const& address, tr_port port, tr_socket_t sock)
: handle{ sock }
, address_{ address }
, port_{ port }
, type_{ Type::TCP }
{
TR_ASSERT(sock != TR_BAD_SOCKET);
session->setSocketTOS(sock, address_.type);
if (auto const& algo = session->peerCongestionAlgorithm(); !std::empty(algo))
{
tr_netSetCongestionControl(sock, algo.c_str());
}
tr_logAddTraceIo(this, fmt::format("socket (tcp) is {}", handle.tcp));
}
tr_peer_socket::tr_peer_socket(tr_address const& address, tr_port port, struct UTPSocket* const sock)
: address_{ address }
, port_{ port }
, type_{ Type::UTP }
{
TR_ASSERT(sock != nullptr);
handle.utp = sock;
tr_logAddTraceIo(this, fmt::format("socket (µTP) is {}", fmt::ptr(handle.utp)));
}
void tr_peer_socket::close(tr_session* session)
{
@ -23,4 +60,7 @@ void tr_peer_socket::close(tr_session* session)
utp_close(handle.utp);
}
#endif
type_ = Type::None;
handle = {};
}

View File

@ -9,6 +9,10 @@
#error only libtransmission should #include this header.
#endif
#include <string>
#include <string_view>
#include <utility> // for std::make_pair()
#include "transmission.h"
#include "net.h"
@ -17,27 +21,17 @@
struct UTPSocket;
struct tr_session;
struct tr_peer_socket
class tr_peer_socket
{
public:
tr_peer_socket() = default;
tr_peer_socket(tr_address const& address, tr_port port, tr_socket_t sock)
: handle{ sock }
, address_{ address }
, port_{ port }
, type_{ Type::TCP }
{
TR_ASSERT(sock != TR_BAD_SOCKET);
}
tr_peer_socket(tr_address const& address, tr_port port, struct UTPSocket* const sock)
: address_{ address }
, port_{ port }
, type_{ Type::UTP }
{
TR_ASSERT(sock != nullptr);
handle.utp = sock;
}
tr_peer_socket(tr_session* session, tr_address const& address, tr_port port, tr_socket_t sock);
tr_peer_socket(tr_address const& address, tr_port port, struct UTPSocket* const sock);
tr_peer_socket(tr_peer_socket&&) = default;
tr_peer_socket(tr_peer_socket const&) = delete;
tr_peer_socket& operator=(tr_peer_socket&&) = default;
tr_peer_socket& operator=(tr_peer_socket const&) = delete;
~tr_peer_socket() = default;
void close(tr_session* session);
@ -111,5 +105,5 @@ private:
enum Type type_ = Type::None;
};
struct tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const* addr, tr_port port, bool client_is_seed);
struct tr_peer_socket tr_netOpenPeerUTPSocket(tr_session* session, tr_address const* addr, tr_port port, bool client_is_seed);
tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr, tr_port port, bool client_is_seed);
tr_peer_socket tr_netOpenPeerUTPSocket(tr_session* session, tr_address const& addr, tr_port port, bool client_is_seed);

View File

@ -297,7 +297,7 @@ void tr_session::onIncomingPeerConnection(tr_socket_t fd, void* vsession)
{
auto const& [addr, port, sock] = *incoming_info;
tr_logAddTrace(fmt::format("new incoming connection {} ({})", sock, addr.readable(port)));
session->addIncoming(addr, port, tr_peer_socket{ addr, port, sock });
session->addIncoming(tr_peer_socket{ session, addr, port, sock });
}
}
@ -2196,9 +2196,9 @@ tr_session::tr_session(std::string_view config_dir, tr_variant* settings_dict)
verifier_->addCallback(tr_torrentOnVerifyDone);
}
void tr_session::addIncoming(tr_address const& addr, tr_port port, struct tr_peer_socket const socket)
void tr_session::addIncoming(tr_peer_socket&& socket)
{
tr_peerMgrAddIncoming(peer_mgr_.get(), addr, port, socket);
tr_peerMgrAddIncoming(peer_mgr_.get(), std::move(socket));
}
void tr_session::addTorrent(tr_torrent* tor)

View File

@ -55,6 +55,7 @@ tr_peer_id_t tr_peerIdInit();
struct event_base;
class tr_lpd;
class tr_peer_socket;
class tr_port_forwarding;
class tr_rpc_server;
class tr_session_thread;
@ -881,7 +882,7 @@ public:
return bandwidth_groups_;
}
void addIncoming(tr_address const& addr, tr_port port, struct tr_peer_socket const socket);
void addIncoming(tr_peer_socket&& socket);
void addTorrent(tr_torrent* tor);

View File

@ -93,7 +93,7 @@ static void utp_on_accept(tr_session* const session, UTPSocket* const utp_sock)
if (auto addrport = tr_address::fromSockaddr(reinterpret_cast<struct sockaddr*>(&from_storage)); addrport)
{
auto const& [addr, port] = *addrport;
session->addIncoming(addr, port, tr_peer_socket{ addr, port, utp_sock });
session->addIncoming(tr_peer_socket{ addr, port, utp_sock });
}
else
{

View File

@ -161,7 +161,7 @@ public:
auto io = tr_peerIo::newIncoming(
session,
&session->top_bandwidth_,
tr_peer_socket(DefaultPeerAddr, DefaultPeerPort, sockpair[0]));
tr_peer_socket(session, DefaultPeerAddr, DefaultPeerPort, sockpair[0]));
return std::make_pair(io, sockpair[1]);
}
@ -175,7 +175,7 @@ public:
&info_hash,
false /*is_incoming*/,
false /*is_seed*/,
tr_peer_socket(DefaultPeerAddr, DefaultPeerPort, sockpair[0]));
tr_peer_socket(session, DefaultPeerAddr, DefaultPeerPort, sockpair[0]));
return std::make_pair(io, sockpair[1]);
}