refactor: tr_peer_socket keeps track of peer count (#4534)

This commit is contained in:
Charles Kerr 2023-01-04 15:37:55 -06:00 committed by GitHub
parent c95891ec60
commit b47c34726b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 56 additions and 68 deletions

View File

@ -142,10 +142,8 @@ void tr_netSetCongestionControl([[maybe_unused]] tr_socket_t s, [[maybe_unused]]
#endif
}
static tr_socket_t createSocket(tr_session* session, int domain, int type)
static tr_socket_t createSocket(int domain, int type)
{
TR_ASSERT(session != nullptr);
auto const sockfd = socket(domain, type, 0);
if (sockfd == TR_BAD_SOCKET)
{
@ -160,9 +158,9 @@ static tr_socket_t createSocket(tr_session* session, int domain, int type)
return TR_BAD_SOCKET;
}
if ((evutil_make_socket_nonblocking(sockfd) == -1) || !session->incPeerCount())
if (evutil_make_socket_nonblocking(sockfd) == -1)
{
tr_netClose(session, sockfd);
tr_net_close_socket(sockfd);
return TR_BAD_SOCKET;
}
@ -193,19 +191,15 @@ static tr_socket_t createSocket(tr_session* session, int domain, int type)
tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr, tr_port port, bool client_is_seed)
{
TR_ASSERT(addr.is_valid());
TR_ASSERT(!tr_peer_socket::limit_reached(session));
if (!session->allowsTCP())
{
return {};
}
if (!addr.is_valid_for_peers(port))
if (tr_peer_socket::limit_reached(session) || !session->allowsTCP() || !addr.is_valid_for_peers(port))
{
return {};
}
static auto constexpr Domains = std::array<int, NUM_TR_AF_INET_TYPES>{ AF_INET, AF_INET6 };
auto const s = createSocket(session, Domains[addr.type], SOCK_STREAM);
auto const s = createSocket(Domains[addr.type], SOCK_STREAM);
if (s == TR_BAD_SOCKET)
{
return {};
@ -236,7 +230,7 @@ tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr,
fmt::arg("socket", s),
fmt::arg("error", tr_net_strerror(sockerrno)),
fmt::arg("error_code", sockerrno)));
tr_netClose(session, s);
tr_net_close_socket(s);
return {};
}
@ -258,7 +252,7 @@ tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr,
fmt::arg("error_code", tmperrno)));
}
tr_netClose(session, s);
tr_net_close_socket(s);
}
else
{
@ -286,7 +280,7 @@ static tr_socket_t tr_netBindTCPImpl(tr_address const& addr, tr_port port, bool
if (evutil_make_socket_nonblocking(fd) == -1)
{
*err_out = sockerrno;
tr_netCloseSocket(fd);
tr_net_close_socket(fd);
return TR_BAD_SOCKET;
}
@ -301,7 +295,7 @@ static tr_socket_t tr_netBindTCPImpl(tr_address const& addr, tr_port port, bool
(sockerrno != ENOPROTOOPT)) // if the kernel doesn't support it, ignore it
{
*err_out = sockerrno;
tr_netCloseSocket(fd);
tr_net_close_socket(fd);
return TR_BAD_SOCKET;
}
@ -325,7 +319,7 @@ static tr_socket_t tr_netBindTCPImpl(tr_address const& addr, tr_port port, bool
fmt::arg("error_code", err)));
}
tr_netCloseSocket(fd);
tr_net_close_socket(fd);
*err_out = err;
return TR_BAD_SOCKET;
}
@ -354,7 +348,7 @@ static tr_socket_t tr_netBindTCPImpl(tr_address const& addr, tr_port port, bool
#endif /* _WIN32 */
{
*err_out = sockerrno;
tr_netCloseSocket(fd);
tr_net_close_socket(fd);
return TR_BAD_SOCKET;
}
@ -384,7 +378,7 @@ bool tr_net_hasIPv6(tr_port port)
if (fd != TR_BAD_SOCKET)
{
tr_netCloseSocket(fd);
tr_net_close_socket(fd);
}
already_done = true;
@ -410,26 +404,20 @@ std::optional<std::tuple<tr_address, tr_port, tr_socket_t>> tr_netAccept(tr_sess
// make the socket unblocking,
// and confirm we don't have too many peers
auto const addrport = tr_address::from_sockaddr(reinterpret_cast<struct sockaddr*>(&sock));
if (!addrport || evutil_make_socket_nonblocking(sockfd) == -1 || !session->incPeerCount())
if (!addrport || evutil_make_socket_nonblocking(sockfd) == -1 || tr_peer_socket::limit_reached(session))
{
tr_netCloseSocket(sockfd);
tr_net_close_socket(sockfd);
return {};
}
return std::make_tuple(addrport->first, addrport->second, sockfd);
}
void tr_netCloseSocket(tr_socket_t sockfd)
void tr_net_close_socket(tr_socket_t sockfd)
{
evutil_closesocket(sockfd);
}
void tr_netClose(tr_session* session, tr_socket_t sockfd)
{
tr_netCloseSocket(sockfd);
session->decPeerCount();
}
// code in global_ipv6_herlpers is written by Juliusz Chroboczek
// and is covered under the same license as dht.cc.
// Please feel free to copy them into your software if it can help

View File

@ -322,9 +322,7 @@ tr_socket_t tr_netBindTCP(tr_address const& addr, tr_port port, bool suppress_ms
void tr_netSetCongestionControl(tr_socket_t s, char const* algorithm);
void tr_netClose(tr_session* session, tr_socket_t s);
void tr_netCloseSocket(tr_socket_t fd);
void tr_net_close_socket(tr_socket_t fd);
bool tr_net_hasIPv6(tr_port);

View File

@ -91,6 +91,7 @@ std::shared_ptr<tr_peerIo> tr_peerIo::new_outgoing(
bool is_seed,
bool utp)
{
TR_ASSERT(!tr_peer_socket::limit_reached(session));
TR_ASSERT(session != nullptr);
TR_ASSERT(addr.is_valid());
TR_ASSERT(utp || session->allowsTCP());
@ -166,7 +167,7 @@ void tr_peerIo::set_socket(tr_peer_socket socket_in)
void tr_peerIo::close()
{
socket_.close(session_);
socket_.close();
event_write_.reset();
event_read_.reset();
}
@ -189,6 +190,11 @@ bool tr_peerIo::reconnect()
close();
if (tr_peer_socket::limit_reached(session_))
{
return false;
}
auto const [addr, port] = socket_address();
socket_ = tr_netOpenPeerSocket(session_, addr, port, is_seed());

View File

@ -1198,11 +1198,11 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_peer_socket&& socket)
if (session->addressIsBlocked(socket.address()))
{
tr_logAddTrace(fmt::format("Banned IP address '{}' tried to connect to us", socket.display_name()));
socket.close(session);
socket.close();
}
else if (manager->incoming_handshakes.count(socket.address()) != 0U)
{
socket.close(session);
socket.close();
}
else /* we don't have a connection to them yet... */
{
@ -2726,7 +2726,9 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom)
utp = utp && (atom.flags & ADDED_F_UTP_FLAGS) != 0;
}
if (!utp && !mgr->session->allowsTCP())
auto* const session = mgr->session;
if (tr_peer_socket::limit_reached(session) || (!utp && !session->allowsTCP()))
{
return;
}
@ -2736,8 +2738,8 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom)
fmt::format("Starting an OUTGOING {} connection with {}", utp ? " µTP" : "TCP", atom.display_name()));
auto peer_io = tr_peerIo::new_outgoing(
mgr->session,
&mgr->session->top_bandwidth_,
session,
&session->top_bandwidth_,
atom.addr,
atom.port,
s->tor->infoHash(),
@ -2756,7 +2758,7 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom)
atom.addr,
&mgr->handshake_mediator_,
peer_io,
mgr->session->encryptionMode(),
session->encryptionMode(),
[mgr](tr_handshake::Result const& result) { return on_handshake_done(mgr, result); });
}

View File

@ -26,6 +26,7 @@ tr_peer_socket::tr_peer_socket(tr_session const* session, tr_address const& addr
{
TR_ASSERT(sock != TR_BAD_SOCKET);
++n_open_sockets_;
session->setSocketTOS(sock, address_.type);
if (auto const& algo = session->peerCongestionAlgorithm(); !std::empty(algo))
@ -42,20 +43,24 @@ tr_peer_socket::tr_peer_socket(tr_address const& address, tr_port port, struct U
, type_{ Type::UTP }
{
TR_ASSERT(sock != nullptr);
++n_open_sockets_;
handle.utp = sock;
tr_logAddTraceIo(this, fmt::format("socket (µTP) is {}", fmt::ptr(handle.utp)));
}
void tr_peer_socket::close(tr_session* session)
void tr_peer_socket::close()
{
if (is_tcp() && (handle.tcp != TR_BAD_SOCKET))
{
tr_netClose(session, handle.tcp);
--n_open_sockets_;
tr_net_close_socket(handle.tcp);
}
#ifdef WITH_UTP
else if (is_utp())
{
--n_open_sockets_;
utp_set_userdata(handle.utp, nullptr);
utp_close(handle.utp);
}
@ -126,3 +131,8 @@ size_t tr_peer_socket::try_read(Buffer& buf, size_t max, tr_error** error) const
return {};
}
bool tr_peer_socket::limit_reached(tr_session* const session) noexcept
{
return n_open_sockets_.load() >= session->peerLimit();
}

View File

@ -9,6 +9,7 @@
#error only libtransmission should #include this header.
#endif
#include <atomic>
#include <string>
#include <string_view>
#include <utility> // for std::make_pair()
@ -37,7 +38,7 @@ public:
tr_peer_socket& operator=(tr_peer_socket const&) = delete;
~tr_peer_socket() = default;
void close(tr_session* session);
void close();
size_t try_write(Buffer& buf, size_t max, tr_error** error) const;
size_t try_read(Buffer& buf, size_t max, tr_error** error) const;
@ -124,6 +125,8 @@ public:
struct UTPSocket* utp;
} handle = {};
[[nodiscard]] static bool limit_reached(tr_session* const session) noexcept;
private:
enum class Type
{
@ -136,6 +139,8 @@ private:
tr_port port_;
enum Type type_ = Type::None;
static inline std::atomic<size_t> n_open_sockets_ = {};
};
tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr, tr_port port, bool client_is_seed);

View File

@ -331,7 +331,7 @@ tr_session::BoundSocket::~BoundSocket()
if (socket_ != TR_BAD_SOCKET)
{
tr_netCloseSocket(socket_);
tr_net_close_socket(socket_);
socket_ = TR_BAD_SOCKET;
}
}

View File

@ -496,25 +496,6 @@ public:
return settings_.peer_limit_per_torrent;
}
[[nodiscard]] constexpr bool incPeerCount() noexcept
{
if (this->peer_count_ >= this->peerLimit())
{
return false;
}
++this->peer_count_;
return true;
}
constexpr void decPeerCount() noexcept
{
if (this->peer_count_ > 0)
{
--this->peer_count_;
}
}
// bandwidth
[[nodiscard]] tr_bandwidth& getBandwidthGroup(std::string_view name);
@ -1059,8 +1040,6 @@ private:
// port than the one requested by Transmission.
tr_port advertised_peer_port_;
uint16_t peer_count_ = 0;
bool is_closing_ = false;
/// fields that aren't trivial,

View File

@ -159,7 +159,7 @@ tr_session::tr_udp_core::tr_udp_core(tr_session& session, tr_port udp_port)
fmt::arg("error", tr_strerror(error_code)),
fmt::arg("error_code", error_code)));
tr_netCloseSocket(sock);
tr_net_close_socket(sock);
}
else
{
@ -193,7 +193,7 @@ tr_session::tr_udp_core::tr_udp_core(tr_session& session, tr_port udp_port)
fmt::arg("error", tr_strerror(error_code)),
fmt::arg("error_code", error_code)));
tr_netCloseSocket(sock);
tr_net_close_socket(sock);
}
else
{
@ -220,7 +220,7 @@ tr_session::tr_udp_core::~tr_udp_core()
if (udp6_socket_ != TR_BAD_SOCKET)
{
tr_netCloseSocket(udp6_socket_);
tr_net_close_socket(udp6_socket_);
udp6_socket_ = TR_BAD_SOCKET;
}
@ -228,7 +228,7 @@ tr_session::tr_udp_core::~tr_udp_core()
if (udp4_socket_ != TR_BAD_SOCKET)
{
tr_netCloseSocket(udp4_socket_);
tr_net_close_socket(udp4_socket_);
udp4_socket_ = TR_BAD_SOCKET;
}
}

View File

@ -82,7 +82,7 @@ static void utp_on_accept(tr_session* const session, UTPSocket* const utp_sock)
auto* const from = (struct sockaddr*)&from_storage;
socklen_t fromlen = sizeof(from_storage);
if (!session->allowsUTP())
if (!session->allowsUTP() || tr_peer_socket::limit_reached(session))
{
utp_close(utp_sock);
return;