diff --git a/libtransmission/net.cc b/libtransmission/net.cc index 174c84cb3..090911cbb 100644 --- a/libtransmission/net.cc +++ b/libtransmission/net.cc @@ -142,10 +142,8 @@ void tr_netSetCongestionControl([[maybe_unused]] tr_socket_t s, [[maybe_unused]] #endif } -static tr_socket_t createSocket(tr_session* session, int domain, int type) +static tr_socket_t createSocket(int domain, int type) { - TR_ASSERT(session != nullptr); - auto const sockfd = socket(domain, type, 0); if (sockfd == TR_BAD_SOCKET) { @@ -160,9 +158,9 @@ static tr_socket_t createSocket(tr_session* session, int domain, int type) return TR_BAD_SOCKET; } - if ((evutil_make_socket_nonblocking(sockfd) == -1) || !session->incPeerCount()) + if (evutil_make_socket_nonblocking(sockfd) == -1) { - tr_netClose(session, sockfd); + tr_net_close_socket(sockfd); return TR_BAD_SOCKET; } @@ -193,19 +191,15 @@ static tr_socket_t createSocket(tr_session* session, int domain, int type) tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr, tr_port port, bool client_is_seed) { TR_ASSERT(addr.is_valid()); + TR_ASSERT(!tr_peer_socket::limit_reached(session)); - if (!session->allowsTCP()) - { - return {}; - } - - if (!addr.is_valid_for_peers(port)) + if (tr_peer_socket::limit_reached(session) || !session->allowsTCP() || !addr.is_valid_for_peers(port)) { return {}; } static auto constexpr Domains = std::array{ AF_INET, AF_INET6 }; - auto const s = createSocket(session, Domains[addr.type], SOCK_STREAM); + auto const s = createSocket(Domains[addr.type], SOCK_STREAM); if (s == TR_BAD_SOCKET) { return {}; @@ -236,7 +230,7 @@ tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr, fmt::arg("socket", s), fmt::arg("error", tr_net_strerror(sockerrno)), fmt::arg("error_code", sockerrno))); - tr_netClose(session, s); + tr_net_close_socket(s); return {}; } @@ -258,7 +252,7 @@ tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr, fmt::arg("error_code", tmperrno))); } - tr_netClose(session, s); + tr_net_close_socket(s); } else { @@ -286,7 +280,7 @@ static tr_socket_t tr_netBindTCPImpl(tr_address const& addr, tr_port port, bool if (evutil_make_socket_nonblocking(fd) == -1) { *err_out = sockerrno; - tr_netCloseSocket(fd); + tr_net_close_socket(fd); return TR_BAD_SOCKET; } @@ -301,7 +295,7 @@ static tr_socket_t tr_netBindTCPImpl(tr_address const& addr, tr_port port, bool (sockerrno != ENOPROTOOPT)) // if the kernel doesn't support it, ignore it { *err_out = sockerrno; - tr_netCloseSocket(fd); + tr_net_close_socket(fd); return TR_BAD_SOCKET; } @@ -325,7 +319,7 @@ static tr_socket_t tr_netBindTCPImpl(tr_address const& addr, tr_port port, bool fmt::arg("error_code", err))); } - tr_netCloseSocket(fd); + tr_net_close_socket(fd); *err_out = err; return TR_BAD_SOCKET; } @@ -354,7 +348,7 @@ static tr_socket_t tr_netBindTCPImpl(tr_address const& addr, tr_port port, bool #endif /* _WIN32 */ { *err_out = sockerrno; - tr_netCloseSocket(fd); + tr_net_close_socket(fd); return TR_BAD_SOCKET; } @@ -384,7 +378,7 @@ bool tr_net_hasIPv6(tr_port port) if (fd != TR_BAD_SOCKET) { - tr_netCloseSocket(fd); + tr_net_close_socket(fd); } already_done = true; @@ -410,26 +404,20 @@ std::optional> tr_netAccept(tr_sess // make the socket unblocking, // and confirm we don't have too many peers auto const addrport = tr_address::from_sockaddr(reinterpret_cast(&sock)); - if (!addrport || evutil_make_socket_nonblocking(sockfd) == -1 || !session->incPeerCount()) + if (!addrport || evutil_make_socket_nonblocking(sockfd) == -1 || tr_peer_socket::limit_reached(session)) { - tr_netCloseSocket(sockfd); + tr_net_close_socket(sockfd); return {}; } return std::make_tuple(addrport->first, addrport->second, sockfd); } -void tr_netCloseSocket(tr_socket_t sockfd) +void tr_net_close_socket(tr_socket_t sockfd) { evutil_closesocket(sockfd); } -void tr_netClose(tr_session* session, tr_socket_t sockfd) -{ - tr_netCloseSocket(sockfd); - session->decPeerCount(); -} - // code in global_ipv6_herlpers is written by Juliusz Chroboczek // and is covered under the same license as dht.cc. // Please feel free to copy them into your software if it can help diff --git a/libtransmission/net.h b/libtransmission/net.h index 5f1a37d9e..14d446137 100644 --- a/libtransmission/net.h +++ b/libtransmission/net.h @@ -322,9 +322,7 @@ tr_socket_t tr_netBindTCP(tr_address const& addr, tr_port port, bool suppress_ms void tr_netSetCongestionControl(tr_socket_t s, char const* algorithm); -void tr_netClose(tr_session* session, tr_socket_t s); - -void tr_netCloseSocket(tr_socket_t fd); +void tr_net_close_socket(tr_socket_t fd); bool tr_net_hasIPv6(tr_port); diff --git a/libtransmission/peer-io.cc b/libtransmission/peer-io.cc index 9af367aba..9a60c8d95 100644 --- a/libtransmission/peer-io.cc +++ b/libtransmission/peer-io.cc @@ -91,6 +91,7 @@ std::shared_ptr tr_peerIo::new_outgoing( bool is_seed, bool utp) { + TR_ASSERT(!tr_peer_socket::limit_reached(session)); TR_ASSERT(session != nullptr); TR_ASSERT(addr.is_valid()); TR_ASSERT(utp || session->allowsTCP()); @@ -166,7 +167,7 @@ void tr_peerIo::set_socket(tr_peer_socket socket_in) void tr_peerIo::close() { - socket_.close(session_); + socket_.close(); event_write_.reset(); event_read_.reset(); } @@ -189,6 +190,11 @@ bool tr_peerIo::reconnect() close(); + if (tr_peer_socket::limit_reached(session_)) + { + return false; + } + auto const [addr, port] = socket_address(); socket_ = tr_netOpenPeerSocket(session_, addr, port, is_seed()); diff --git a/libtransmission/peer-mgr.cc b/libtransmission/peer-mgr.cc index 717fb4822..443724e35 100644 --- a/libtransmission/peer-mgr.cc +++ b/libtransmission/peer-mgr.cc @@ -1198,11 +1198,11 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_peer_socket&& socket) if (session->addressIsBlocked(socket.address())) { tr_logAddTrace(fmt::format("Banned IP address '{}' tried to connect to us", socket.display_name())); - socket.close(session); + socket.close(); } else if (manager->incoming_handshakes.count(socket.address()) != 0U) { - socket.close(session); + socket.close(); } else /* we don't have a connection to them yet... */ { @@ -2726,7 +2726,9 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom) utp = utp && (atom.flags & ADDED_F_UTP_FLAGS) != 0; } - if (!utp && !mgr->session->allowsTCP()) + auto* const session = mgr->session; + + if (tr_peer_socket::limit_reached(session) || (!utp && !session->allowsTCP())) { return; } @@ -2736,8 +2738,8 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom) fmt::format("Starting an OUTGOING {} connection with {}", utp ? " µTP" : "TCP", atom.display_name())); auto peer_io = tr_peerIo::new_outgoing( - mgr->session, - &mgr->session->top_bandwidth_, + session, + &session->top_bandwidth_, atom.addr, atom.port, s->tor->infoHash(), @@ -2756,7 +2758,7 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom) atom.addr, &mgr->handshake_mediator_, peer_io, - mgr->session->encryptionMode(), + session->encryptionMode(), [mgr](tr_handshake::Result const& result) { return on_handshake_done(mgr, result); }); } diff --git a/libtransmission/peer-socket.cc b/libtransmission/peer-socket.cc index 8f14c46e3..912cc4984 100644 --- a/libtransmission/peer-socket.cc +++ b/libtransmission/peer-socket.cc @@ -26,6 +26,7 @@ tr_peer_socket::tr_peer_socket(tr_session const* session, tr_address const& addr { TR_ASSERT(sock != TR_BAD_SOCKET); + ++n_open_sockets_; session->setSocketTOS(sock, address_.type); if (auto const& algo = session->peerCongestionAlgorithm(); !std::empty(algo)) @@ -42,20 +43,24 @@ tr_peer_socket::tr_peer_socket(tr_address const& address, tr_port port, struct U , type_{ Type::UTP } { TR_ASSERT(sock != nullptr); + + ++n_open_sockets_; handle.utp = sock; tr_logAddTraceIo(this, fmt::format("socket (µTP) is {}", fmt::ptr(handle.utp))); } -void tr_peer_socket::close(tr_session* session) +void tr_peer_socket::close() { if (is_tcp() && (handle.tcp != TR_BAD_SOCKET)) { - tr_netClose(session, handle.tcp); + --n_open_sockets_; + tr_net_close_socket(handle.tcp); } #ifdef WITH_UTP else if (is_utp()) { + --n_open_sockets_; utp_set_userdata(handle.utp, nullptr); utp_close(handle.utp); } @@ -126,3 +131,8 @@ size_t tr_peer_socket::try_read(Buffer& buf, size_t max, tr_error** error) const return {}; } + +bool tr_peer_socket::limit_reached(tr_session* const session) noexcept +{ + return n_open_sockets_.load() >= session->peerLimit(); +} diff --git a/libtransmission/peer-socket.h b/libtransmission/peer-socket.h index 2f5d1cf13..b0feb3558 100644 --- a/libtransmission/peer-socket.h +++ b/libtransmission/peer-socket.h @@ -9,6 +9,7 @@ #error only libtransmission should #include this header. #endif +#include #include #include #include // for std::make_pair() @@ -37,7 +38,7 @@ public: tr_peer_socket& operator=(tr_peer_socket const&) = delete; ~tr_peer_socket() = default; - void close(tr_session* session); + void close(); size_t try_write(Buffer& buf, size_t max, tr_error** error) const; size_t try_read(Buffer& buf, size_t max, tr_error** error) const; @@ -124,6 +125,8 @@ public: struct UTPSocket* utp; } handle = {}; + [[nodiscard]] static bool limit_reached(tr_session* const session) noexcept; + private: enum class Type { @@ -136,6 +139,8 @@ private: tr_port port_; enum Type type_ = Type::None; + + static inline std::atomic n_open_sockets_ = {}; }; tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const& addr, tr_port port, bool client_is_seed); diff --git a/libtransmission/session.cc b/libtransmission/session.cc index 6b6f675fc..abcb417f5 100644 --- a/libtransmission/session.cc +++ b/libtransmission/session.cc @@ -331,7 +331,7 @@ tr_session::BoundSocket::~BoundSocket() if (socket_ != TR_BAD_SOCKET) { - tr_netCloseSocket(socket_); + tr_net_close_socket(socket_); socket_ = TR_BAD_SOCKET; } } diff --git a/libtransmission/session.h b/libtransmission/session.h index a48b6e77e..8c0acc23a 100644 --- a/libtransmission/session.h +++ b/libtransmission/session.h @@ -496,25 +496,6 @@ public: return settings_.peer_limit_per_torrent; } - [[nodiscard]] constexpr bool incPeerCount() noexcept - { - if (this->peer_count_ >= this->peerLimit()) - { - return false; - } - - ++this->peer_count_; - return true; - } - - constexpr void decPeerCount() noexcept - { - if (this->peer_count_ > 0) - { - --this->peer_count_; - } - } - // bandwidth [[nodiscard]] tr_bandwidth& getBandwidthGroup(std::string_view name); @@ -1059,8 +1040,6 @@ private: // port than the one requested by Transmission. tr_port advertised_peer_port_; - uint16_t peer_count_ = 0; - bool is_closing_ = false; /// fields that aren't trivial, diff --git a/libtransmission/tr-udp.cc b/libtransmission/tr-udp.cc index c713b24e3..eb9e6dba7 100644 --- a/libtransmission/tr-udp.cc +++ b/libtransmission/tr-udp.cc @@ -159,7 +159,7 @@ tr_session::tr_udp_core::tr_udp_core(tr_session& session, tr_port udp_port) fmt::arg("error", tr_strerror(error_code)), fmt::arg("error_code", error_code))); - tr_netCloseSocket(sock); + tr_net_close_socket(sock); } else { @@ -193,7 +193,7 @@ tr_session::tr_udp_core::tr_udp_core(tr_session& session, tr_port udp_port) fmt::arg("error", tr_strerror(error_code)), fmt::arg("error_code", error_code))); - tr_netCloseSocket(sock); + tr_net_close_socket(sock); } else { @@ -220,7 +220,7 @@ tr_session::tr_udp_core::~tr_udp_core() if (udp6_socket_ != TR_BAD_SOCKET) { - tr_netCloseSocket(udp6_socket_); + tr_net_close_socket(udp6_socket_); udp6_socket_ = TR_BAD_SOCKET; } @@ -228,7 +228,7 @@ tr_session::tr_udp_core::~tr_udp_core() if (udp4_socket_ != TR_BAD_SOCKET) { - tr_netCloseSocket(udp4_socket_); + tr_net_close_socket(udp4_socket_); udp4_socket_ = TR_BAD_SOCKET; } } diff --git a/libtransmission/tr-utp.cc b/libtransmission/tr-utp.cc index af91c5286..dca91e9be 100644 --- a/libtransmission/tr-utp.cc +++ b/libtransmission/tr-utp.cc @@ -82,7 +82,7 @@ static void utp_on_accept(tr_session* const session, UTPSocket* const utp_sock) auto* const from = (struct sockaddr*)&from_storage; socklen_t fromlen = sizeof(from_storage); - if (!session->allowsUTP()) + if (!session->allowsUTP() || tr_peer_socket::limit_reached(session)) { utp_close(utp_sock); return;