From 9a5d9a0ba2203bee54ec82212c5077210881a6fb Mon Sep 17 00:00:00 2001 From: Charles Kerr Date: Mon, 5 Dec 2022 18:53:31 -0600 Subject: [PATCH] refactor: tr_peer_socket (#4325) * refactor: make tr_peer_socket.type private * refactor: reimplement tr_peerIo::address() as a wrapper around tr_peer_socket::address() * refactor: remove tr_address, tr_port from tr_peerIo * refactor: replace tr_netClosePeerSocket() with tr_peer_socket::close() --- Transmission.xcodeproj/project.pbxproj | 12 ++- libtransmission/CMakeLists.txt | 1 + libtransmission/handshake.cc | 2 +- libtransmission/net.cc | 49 +--------- libtransmission/peer-io.cc | 115 +++++++---------------- libtransmission/peer-io.h | 62 ++++++------ libtransmission/peer-mgr.cc | 10 +- libtransmission/peer-msgs.cc | 2 +- libtransmission/peer-socket.cc | 26 ++++++ libtransmission/peer-socket.h | 119 +++++++++++++++++++----- libtransmission/session.cc | 2 +- libtransmission/tr-utp.cc | 2 +- tests/libtransmission/handshake-test.cc | 11 +-- 13 files changed, 207 insertions(+), 206 deletions(-) create mode 100644 libtransmission/peer-socket.cc diff --git a/Transmission.xcodeproj/project.pbxproj b/Transmission.xcodeproj/project.pbxproj index a4a6cd793..a77a1f1c6 100644 --- a/Transmission.xcodeproj/project.pbxproj +++ b/Transmission.xcodeproj/project.pbxproj @@ -331,7 +331,8 @@ C1305EBE186A13B100F03351 /* file.cc in Sources */ = {isa = PBXBuildFile; fileRef = C1305EB8186A134000F03351 /* file.cc */; }; C1425B361EE9C605001DB85F /* tr-assert.h in Headers */ = {isa = PBXBuildFile; fileRef = C1425B331EE9C5EA001DB85F /* tr-assert.h */; }; C1425B371EE9C705001DB85F /* tr-macros.h in Headers */ = {isa = PBXBuildFile; fileRef = C1425B341EE9C5EA001DB85F /* tr-macros.h */; }; - C1425B381EE9C805001DB85F /* peer-socket.h in Headers */ = {isa = PBXBuildFile; fileRef = C1425B351EE9C5EA001DB85F /* peer-socket.h */; }; + C1425B381EE9C805001DB850 /* peer-socket.h in Headers */ = {isa = PBXBuildFile; fileRef = C1425B381EE9C805001DB851 /* peer-socket.h */; }; + C1425B381EE9C805001DB852 /* peer-socket.cc in Sources */ = {isa = PBXBuildFile; fileRef = C1425B381EE9C805001DB853 /* peer-socket.cc */; }; C16089EF1F092A1E00CEFC36 /* utp_api.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C16089E41F092A1E00CEFC36 /* utp_api.cpp */; }; C16089F01F092A1E00CEFC36 /* utp_callbacks.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C16089E51F092A1E00CEFC36 /* utp_callbacks.cpp */; }; C16089F11F092A1E00CEFC36 /* utp_callbacks.h in Headers */ = {isa = PBXBuildFile; fileRef = C16089E61F092A1E00CEFC36 /* utp_callbacks.h */; }; @@ -1094,6 +1095,7 @@ BEFC1E1D0C07861A00B0BB3C /* completion.cc */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = completion.cc; sourceTree = ""; }; BEFC1E1E0C07861A00B0BB3C /* clients.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = clients.h; sourceTree = ""; }; BEFC1E1F0C07861A00B0BB3C /* clients.cc */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = clients.cc; sourceTree = ""; }; + C1425B381EE9C805001DB853 /* peer-socket.cc */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = "peer-socket.cc"; sourceTree = ""; }; C1033E031A3279B800EF44D8 /* crypto-utils-fallback.cc */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = "crypto-utils-fallback.cc"; sourceTree = ""; }; C1033E041A3279B800EF44D8 /* crypto-utils-ccrypto.cc */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = "crypto-utils-ccrypto.cc"; sourceTree = ""; }; C1033E051A3279B800EF44D8 /* crypto-utils.cc */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = "crypto-utils.cc"; sourceTree = ""; }; @@ -1112,7 +1114,7 @@ C1425B321EE9C5EA001DB85F /* tr-assert.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = "tr-assert.cc"; sourceTree = ""; }; C1425B331EE9C5EA001DB85F /* tr-assert.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "tr-assert.h"; sourceTree = ""; }; C1425B341EE9C5EA001DB85F /* tr-macros.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "tr-macros.h"; sourceTree = ""; }; - C1425B351EE9C5EA001DB85F /* peer-socket.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "peer-socket.h"; sourceTree = ""; }; + C1425B381EE9C805001DB851 /* peer-socket.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "peer-socket.h"; sourceTree = ""; }; C16089E41F092A1E00CEFC36 /* utp_api.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = utp_api.cpp; sourceTree = ""; }; C16089E51F092A1E00CEFC36 /* utp_callbacks.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = utp_callbacks.cpp; sourceTree = ""; }; C16089E61F092A1E00CEFC36 /* utp_callbacks.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = utp_callbacks.h; sourceTree = ""; }; @@ -1683,7 +1685,8 @@ 4D36BA610CA2F00800A63CA5 /* peer-mse.h */, 4D36BA6A0CA2F00800A63CA5 /* peer-msgs.cc */, 4D36BA6B0CA2F00800A63CA5 /* peer-msgs.h */, - C1425B351EE9C5EA001DB85F /* peer-socket.h */, + C1425B381EE9C805001DB851 /* peer-socket.h */, + C1425B381EE9C805001DB853 /* peer-socket.cc */, A23FAE52178BC2950053DC5B /* platform-quota.cc */, A23FAE53178BC2950053DC5B /* platform-quota.h */, BEFC1E030C07861A00B0BB3C /* platform.cc */, @@ -2183,7 +2186,7 @@ C1425B361EE9C605001DB85F /* tr-assert.h in Headers */, C1425B371EE9C705001DB85F /* tr-macros.h in Headers */, 888A256631B3DE536FEB8B00 /* tr-strbuf.h in Headers */, - C1425B381EE9C805001DB85F /* peer-socket.h in Headers */, + C1425B381EE9C805001DB850 /* peer-socket.h in Headers */, BEFC1E450C07861A00B0BB3C /* net.h in Headers */, BEFC1E4D0C07861A00B0BB3C /* session.h in Headers */, CCEBA596277340F6DF9F4482 /* session-alt-speeds.h in Headers */, @@ -2921,6 +2924,7 @@ C1FEE5781C3223CC00D62832 /* watchdir-generic.cc in Sources */, BEFC1E560C07861A00B0BB3C /* completion.cc in Sources */, BEFC1E580C07861A00B0BB3C /* clients.cc in Sources */, + C1425B381EE9C805001DB852 /* peer-socket.cc in Sources */, A2BE9C520C1E4AF5002D16E6 /* makemeta.cc in Sources */, A24621420C769D0900088E81 /* session-thread.cc in Sources */, C11DEA161FCD31C0009E22B9 /* subprocess-posix.cc in Sources */, diff --git a/libtransmission/CMakeLists.txt b/libtransmission/CMakeLists.txt index ddaffc2b2..2aba15792 100644 --- a/libtransmission/CMakeLists.txt +++ b/libtransmission/CMakeLists.txt @@ -41,6 +41,7 @@ set(PROJECT_FILES peer-mgr.cc peer-mse.cc peer-msgs.cc + peer-socket.cc platform-quota.cc platform.cc port-forwarding-natpmp.cc diff --git a/libtransmission/handshake.cc b/libtransmission/handshake.cc index ad465553c..3acdc6aac 100644 --- a/libtransmission/handshake.cc +++ b/libtransmission/handshake.cc @@ -1048,7 +1048,7 @@ static void gotError(tr_peerIo* io, short what, void* vhandshake) int const errcode = errno; auto* handshake = static_cast(vhandshake); - if (io->socket.type == TR_PEER_SOCKET_TYPE_UTP && !io->isIncoming() && handshake->state == AWAITING_YB) + if (io->socket.is_utp() && !io->isIncoming() && handshake->state == AWAITING_YB) { // the peer probably doesn't speak µTP. diff --git a/libtransmission/net.cc b/libtransmission/net.cc index 1db68af48..e6730e39d 100644 --- a/libtransmission/net.cc +++ b/libtransmission/net.cc @@ -328,7 +328,7 @@ struct tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const } else { - ret = tr_peer_socket_tcp_create(s); + ret = tr_peer_socket{ *addr, port, s }; } tr_logAddTrace(fmt::format("New OUTGOING connection {} ({})", s, addr->readable(port))); @@ -348,15 +348,15 @@ struct tr_peer_socket tr_netOpenPeerUTPSocket( { auto const [ss, sslen] = addr->toSockaddr(port); - if (auto* const socket = utp_create_socket(session->utp_context); socket != nullptr) + if (auto* const sock = utp_create_socket(session->utp_context); sock != nullptr) { - if (utp_connect(socket, reinterpret_cast(&ss), sslen) != -1) + if (utp_connect(sock, reinterpret_cast(&ss), sslen) != -1) { - ret = tr_peer_socket_utp_create(socket); + ret = tr_peer_socket{ *addr, port, sock }; } else { - utp_close(socket); + utp_close(sock); } } } @@ -364,29 +364,6 @@ struct tr_peer_socket tr_netOpenPeerUTPSocket( return ret; } -void tr_netClosePeerSocket(tr_session* session, tr_peer_socket socket) -{ - switch (socket.type) - { - case TR_PEER_SOCKET_TYPE_NONE: - break; - - case TR_PEER_SOCKET_TYPE_TCP: - tr_netClose(session, socket.handle.tcp); - break; - -#ifdef WITH_UTP - case TR_PEER_SOCKET_TYPE_UTP: - utp_set_userdata(socket.handle.utp, nullptr); - utp_close(socket.handle.utp); - break; -#endif - - default: - TR_ASSERT_MSG(false, fmt::format(FMT_STRING("unsupported peer socket type {:d}"), static_cast(socket.type))); - } -} - static tr_socket_t tr_netBindTCPImpl(tr_address const& addr, tr_port port, bool suppress_msgs, int* err_out) { TR_ASSERT(tr_address_is_valid(&addr)); @@ -766,22 +743,6 @@ bool tr_address_is_valid_for_peers(tr_address const* addr, tr_port port) !isMartianAddr(addr); } -struct tr_peer_socket tr_peer_socket_tcp_create(tr_socket_t const handle) -{ - TR_ASSERT(handle != TR_BAD_SOCKET); - - return { TR_PEER_SOCKET_TYPE_TCP, { handle } }; -} - -struct tr_peer_socket tr_peer_socket_utp_create(struct UTPSocket* const handle) -{ - TR_ASSERT(handle != nullptr); - - auto ret = tr_peer_socket{ TR_PEER_SOCKET_TYPE_UTP, {} }; - ret.handle.utp = handle; - return ret; -} - /// tr_port std::pair tr_port::fromCompact(std::byte const* compact) noexcept diff --git a/libtransmission/peer-io.cc b/libtransmission/peer-io.cc index 66cc3a2fe..cf433401d 100644 --- a/libtransmission/peer-io.cc +++ b/libtransmission/peer-io.cc @@ -47,17 +47,6 @@ static constexpr auto UtpReadBufferSize = 256 * 1024; #define tr_logAddDebugIo(io, msg) tr_logAddDebug(msg, (io)->addrStr()) #define tr_logAddTraceIo(io, msg) tr_logAddTrace(msg, (io)->addrStr()) -#ifdef TR_ENABLE_ASSERTS -[[nodiscard]] static constexpr auto isSupportedSocket(tr_peer_socket const& sock) -{ -#ifdef WITH_UTP - return sock.type == TR_PEER_SOCKET_TYPE_TCP || sock.type == TR_PEER_SOCKET_TYPE_UTP; -#else - return sock.type == TR_PEER_SOCKET_TYPE_TCP; -#endif -} -#endif // TR_ENABLE_ASSERTS - static constexpr size_t guessPacketOverhead(size_t d) { /** @@ -93,7 +82,7 @@ static void didWriteWrapper(tr_peerIo* io, size_t bytes_transferred) size_t const payload = std::min(uint64_t{ n_bytes_left }, uint64_t{ bytes_transferred }); /* For µTP sockets, the overhead is computed in utp_on_overhead. */ - size_t const overhead = io->socket.type == TR_PEER_SOCKET_TYPE_TCP ? guessPacketOverhead(payload) : 0; + size_t const overhead = io->socket.is_tcp() ? guessPacketOverhead(payload) : 0; uint64_t const now = tr_time_msec(); io->bandwidth().notifyBandwidthConsumed(TR_UP, payload, is_piece_data, now); @@ -194,7 +183,7 @@ static void event_read_cb(evutil_socket_t fd, short /*event*/, void* vio) auto* io = static_cast(vio); TR_ASSERT(tr_isPeerIo(io)); - TR_ASSERT(io->socket.type == TR_PEER_SOCKET_TYPE_TCP); + TR_ASSERT(io->socket.is_tcp()); /* Limit the input buffer to 256K, so it doesn't grow too large */ tr_direction const dir = TR_DOWN; @@ -264,7 +253,7 @@ static void event_write_cb(evutil_socket_t fd, short /*event*/, void* vio) auto* io = static_cast(vio); TR_ASSERT(tr_isPeerIo(io)); - TR_ASSERT(io->socket.type == TR_PEER_SOCKET_TYPE_TCP); + TR_ASSERT(io->socket.is_tcp()); io->pendingEvents &= ~EV_WRITE; @@ -472,49 +461,43 @@ static uint64 utp_callback(utp_callback_arguments* args) std::shared_ptr tr_peerIo::create( tr_session* session, tr_bandwidth* parent, - tr_address const* addr, - tr_port port, tr_sha1_digest_t const* torrent_hash, bool is_incoming, bool is_seed, - struct tr_peer_socket const socket) + tr_peer_socket socket) { TR_ASSERT(session != nullptr); auto lock = session->unique_lock(); - TR_ASSERT(isSupportedSocket(socket)); - TR_ASSERT(session->allowsTCP() || socket.type != TR_PEER_SOCKET_TYPE_TCP); + TR_ASSERT(socket.is_valid()); + TR_ASSERT(session->allowsTCP() || !socket.is_tcp()); - if (socket.type == TR_PEER_SOCKET_TYPE_TCP) + if (socket.is_tcp()) { - session->setSocketTOS(socket.handle.tcp, addr->type); + session->setSocketTOS(socket.handle.tcp, socket.address().type); maybeSetCongestionAlgorithm(socket.handle.tcp, session->peerCongestionAlgorithm()); } - auto io = std::shared_ptr{ new tr_peerIo{ session, torrent_hash, is_incoming, *addr, port, is_seed, parent } }; - io->socket = socket; + auto io = std::make_shared(session, torrent_hash, is_incoming, is_seed, parent, socket); io->bandwidth().setPeer(io); tr_logAddTraceIo(io, fmt::format("bandwidth is {}; its parent is {}", fmt::ptr(&io->bandwidth()), fmt::ptr(parent))); - switch (socket.type) + if (socket.is_tcp()) { - case TR_PEER_SOCKET_TYPE_TCP: tr_logAddTraceIo(io, fmt::format("socket (tcp) is {}", socket.handle.tcp)); io->event_read.reset(event_new(session->eventBase(), socket.handle.tcp, EV_READ, event_read_cb, io.get())); io->event_write.reset(event_new(session->eventBase(), socket.handle.tcp, EV_WRITE, event_write_cb, io.get())); - break; - + } #ifdef WITH_UTP - - case TR_PEER_SOCKET_TYPE_UTP: + else if (socket.is_utp()) + { tr_logAddTraceIo(io, fmt::format("socket (µTP) is {}", fmt::ptr(socket.handle.utp))); utp_set_userdata(socket.handle.utp, io.get()); - break; - + } #endif - - default: - TR_ASSERT_MSG(false, fmt::format("unsupported peer socket type {:d}", static_cast(socket.type))); + else + { + TR_ASSERT_MSG(false, "unsupported peer socket type"); } return io; @@ -535,17 +518,11 @@ void tr_peerIo::utpInit([[maybe_unused]] struct_utp_context* ctx) #endif } -std::shared_ptr tr_peerIo::newIncoming( - tr_session* session, - tr_bandwidth* parent, - tr_address const* addr, - tr_port port, - struct tr_peer_socket const socket) +std::shared_ptr tr_peerIo::newIncoming(tr_session* session, tr_bandwidth* parent, tr_peer_socket socket) { TR_ASSERT(session != nullptr); - TR_ASSERT(tr_address_is_valid(addr)); - return tr_peerIo::create(session, parent, addr, port, nullptr, true, false, socket); + return tr_peerIo::create(session, parent, nullptr, true, false, socket); } std::shared_ptr tr_peerIo::newOutgoing( @@ -568,20 +545,18 @@ std::shared_ptr tr_peerIo::newOutgoing( socket = tr_netOpenPeerUTPSocket(session, addr, port, is_seed); } - if (socket.type == TR_PEER_SOCKET_TYPE_NONE) + if (!socket.is_valid()) { socket = tr_netOpenPeerSocket(session, addr, port, is_seed); - tr_logAddDebug(fmt::format( - "tr_netOpenPeerSocket returned {}", - socket.type != TR_PEER_SOCKET_TYPE_NONE ? socket.handle.tcp : TR_BAD_SOCKET)); + tr_logAddDebug(fmt::format("tr_netOpenPeerSocket returned {}", socket.is_tcp() ? socket.handle.tcp : TR_BAD_SOCKET)); } - if (socket.type == TR_PEER_SOCKET_TYPE_NONE) + if (!socket.is_valid()) { return nullptr; } - return create(session, parent, addr, port, &torrent_hash, false, is_seed, socket); + return create(session, parent, &torrent_hash, false, is_seed, socket); } /*** @@ -592,7 +567,7 @@ static void event_enable(tr_peerIo* io, short event) { TR_ASSERT(io->session != nullptr); - bool const need_events = io->socket.type == TR_PEER_SOCKET_TYPE_TCP; + bool const need_events = io->socket.is_tcp(); TR_ASSERT(!need_events || io->event_read); TR_ASSERT(!need_events || io->event_write); @@ -623,7 +598,7 @@ static void event_enable(tr_peerIo* io, short event) static void event_disable(tr_peerIo* io, short event) { - bool const need_events = io->socket.type == TR_PEER_SOCKET_TYPE_TCP; + bool const need_events = io->socket.is_tcp(); TR_ASSERT(!need_events || io->event_read); TR_ASSERT(!need_events || io->event_write); @@ -674,28 +649,7 @@ void tr_peerIo::setEnabled(tr_direction dir, bool is_enabled) static void io_close_socket(tr_peerIo* io) { - switch (io->socket.type) - { - case TR_PEER_SOCKET_TYPE_NONE: - break; - - case TR_PEER_SOCKET_TYPE_TCP: - tr_netClose(io->session, io->socket.handle.tcp); - break; - -#ifdef WITH_UTP - - case TR_PEER_SOCKET_TYPE_UTP: - utp_set_userdata(io->socket.handle.utp, nullptr); - utp_close(io->socket.handle.utp); - break; - -#endif - - default: - tr_logAddDebugIo(io, fmt::format("unsupported peer socket type {}", static_cast(io->socket.type))); - } - + io->socket.close(io->session); io->event_write.reset(); io->event_read.reset(); io->socket = {}; @@ -711,11 +665,6 @@ tr_peerIo::~tr_peerIo() io_close_socket(this); } -std::string tr_peerIo::addrStr() const -{ - return tr_isPeerIo(this) ? this->addr_.readable(this->port_) : "error"; -} - void tr_peerIo::setCallbacks(tr_can_read_cb readcb, tr_did_write_cb writecb, tr_net_error_cb errcb, void* user_data) { this->canRead = readcb; @@ -746,7 +695,7 @@ int tr_peerIo::reconnect() auto const [addr, port] = this->socketAddress(); this->socket = tr_netOpenPeerSocket(session, &addr, port, this->isSeed()); - if (this->socket.type != TR_PEER_SOCKET_TYPE_TCP) + if (!this->socket.is_tcp()) { return -1; } @@ -868,8 +817,8 @@ static size_t tr_peerIoTryRead(tr_peerIo* io, size_t howmuch, tr_error** error) return n_read; } - TR_ASSERT(isSupportedSocket(io->socket)); - if (io->socket.type == TR_PEER_SOCKET_TYPE_TCP) + TR_ASSERT(io->socket.is_valid()); + if (io->socket.is_tcp()) { tr_error* my_error = nullptr; n_read = io->inbuf.addSocket(io->socket.handle.tcp, howmuch, &my_error); @@ -902,7 +851,7 @@ static size_t tr_peerIoTryRead(tr_peerIo* io, size_t howmuch, tr_error** error) } } #ifdef WITH_UTP - else if (io->socket.type == TR_PEER_SOCKET_TYPE_UTP) + else if (io->socket.is_utp()) { // UTP_RBDrained notifies libutp that your read buffer is empty. // It opens up the congestion window by sending an ACK (soonish) @@ -930,7 +879,7 @@ static size_t tr_peerIoTryWrite(tr_peerIo* io, size_t howmuch, tr_error** error) return n_written; } - if (io->socket.type == TR_PEER_SOCKET_TYPE_TCP) + if (io->socket.is_tcp()) { tr_error* my_error = nullptr; n_written = io->outbuf.toSocket(io->socket.handle.tcp, howmuch, &my_error); @@ -965,7 +914,7 @@ static size_t tr_peerIoTryWrite(tr_peerIo* io, size_t howmuch, tr_error** error) } } #ifdef WITH_UTP - else if (io->socket.type == TR_PEER_SOCKET_TYPE_UTP) + else if (io->socket.is_utp()) { auto iov = io->outbuf.vecs(howmuch); errno = 0; diff --git a/libtransmission/peer-io.h b/libtransmission/peer-io.h index 25a2c3a7d..8d18dd2e5 100644 --- a/libtransmission/peer-io.h +++ b/libtransmission/peer-io.h @@ -64,7 +64,6 @@ class tr_peerIo final : public std::enable_shared_from_this public: ~tr_peerIo(); - // TODO: 8 constructor args is too many; maybe a builder object? static std::shared_ptr newOutgoing( tr_session* session, tr_bandwidth* parent, @@ -74,12 +73,7 @@ public: bool is_seed, bool utp); - static std::shared_ptr newIncoming( - tr_session* session, - tr_bandwidth* parent, - struct tr_address const* addr, - tr_port port, - struct tr_peer_socket const socket); + static std::shared_ptr newIncoming(tr_session* session, tr_bandwidth* parent, tr_peer_socket socket); void clear(); @@ -97,17 +91,20 @@ public: void setEnabled(tr_direction dir, bool is_enabled); - [[nodiscard]] constexpr tr_address const& address() const noexcept + [[nodiscard]] constexpr auto const& address() const noexcept { - return addr_; + return socket.address(); } - [[nodiscard]] constexpr std::pair socketAddress() const noexcept + [[nodiscard]] constexpr auto socketAddress() const noexcept { - return std::make_pair(addr_, port_); + return socket.socketAddress(); } - std::string addrStr() const; + [[nodiscard]] auto addrStr() const + { + return socket.readable(); + } void readBufferDrain(size_t byte_count); @@ -283,6 +280,22 @@ public: static void utpInit(struct_utp_context* ctx); + tr_peerIo( + tr_session* session_in, + tr_sha1_digest_t const* torrent_hash, + bool is_incoming, + bool is_seed, + tr_bandwidth* parent_bandwidth, + tr_peer_socket sock) + : socket{ sock } + , session{ session_in } + , bandwidth_{ parent_bandwidth } + , torrent_hash_{ torrent_hash != nullptr ? *torrent_hash : tr_sha1_digest_t{} } + , is_seed_{ is_seed } + , is_incoming_{ is_incoming } + { + } + private: friend class libtransmission::test::HandshakeTest; @@ -291,30 +304,10 @@ private: static std::shared_ptr create( tr_session* session, tr_bandwidth* parent, - tr_address const* addr, - tr_port port, tr_sha1_digest_t const* torrent_hash, bool is_incoming, bool is_seed, - struct tr_peer_socket const socket); - - tr_peerIo( - tr_session* session_in, - tr_sha1_digest_t const* torrent_hash, - bool is_incoming, - tr_address const& addr, - tr_port port, - bool is_seed, - tr_bandwidth* parent_bandwidth) - : session{ session_in } - , bandwidth_{ parent_bandwidth } - , torrent_hash_{ torrent_hash != nullptr ? *torrent_hash : tr_sha1_digest_t{} } - , addr_{ addr } - , port_{ port } - , is_seed_{ is_seed } - , is_incoming_{ is_incoming } - { - } + tr_peer_socket socket); tr_bandwidth bandwidth_; @@ -322,9 +315,6 @@ private: tr_sha1_digest_t torrent_hash_; - tr_address const addr_; - tr_port const port_; - bool const is_seed_; bool const is_incoming_; diff --git a/libtransmission/peer-mgr.cc b/libtransmission/peer-mgr.cc index d4a8fe87d..ce2b25149 100644 --- a/libtransmission/peer-mgr.cc +++ b/libtransmission/peer-mgr.cc @@ -1190,7 +1190,7 @@ static bool on_handshake_done(tr_handshake_result const& result) /* In principle, this flag specifies whether the peer groks µTP, not whether it's currently connected over µTP. */ - if (result.io->socket.type == TR_PEER_SOCKET_TYPE_UTP) + if (result.io->socket.is_utp()) { atom->flags |= ADDED_F_UTP_FLAGS; } @@ -1227,7 +1227,7 @@ static bool on_handshake_done(tr_handshake_result const& result) return success; } -void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const& addr, tr_port port, struct tr_peer_socket const socket) +void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const& addr, tr_port port, tr_peer_socket socket) { TR_ASSERT(manager->session != nullptr); auto const lock = manager->unique_lock(); @@ -1237,17 +1237,17 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const& addr, tr_port if (session->addressIsBlocked(addr)) { tr_logAddTrace(fmt::format("Banned IP address '{}' tried to connect to us", addr.readable(port))); - tr_netClosePeerSocket(session, socket); + socket.close(session); } else if (manager->incoming_handshakes.contains(addr)) { - tr_netClosePeerSocket(session, socket); + socket.close(session); } else /* we don't have a connection to them yet... */ { auto* const handshake = tr_handshakeNew( manager->handshake_mediator_, - tr_peerIo::newIncoming(session, &session->top_bandwidth_, &addr, port, socket), + tr_peerIo::newIncoming(session, &session->top_bandwidth_, socket), session->encryptionMode(), on_handshake_done, manager); diff --git a/libtransmission/peer-msgs.cc b/libtransmission/peer-msgs.cc index 7d8638d89..c8d6bdf53 100644 --- a/libtransmission/peer-msgs.cc +++ b/libtransmission/peer-msgs.cc @@ -374,7 +374,7 @@ public: [[nodiscard]] bool is_utp_connection() const noexcept override { - return io->socket.type == TR_PEER_SOCKET_TYPE_UTP; + return io->socket.is_utp(); } [[nodiscard]] bool is_encrypted() const override diff --git a/libtransmission/peer-socket.cc b/libtransmission/peer-socket.cc new file mode 100644 index 000000000..fd29b2db5 --- /dev/null +++ b/libtransmission/peer-socket.cc @@ -0,0 +1,26 @@ +// This file Copyright © 2017-2022 Mnemosyne LLC. +// It may be used under GPLv2 (SPDX: GPL-2.0-only), GPLv3 (SPDX: GPL-3.0-only), +// or any future license endorsed by Mnemosyne LLC. +// License text can be found in the licenses/ folder. + +#include + +#include "transmission.h" + +#include "peer-socket.h" +#include "net.h" + +void tr_peer_socket::close(tr_session* session) +{ + if (is_tcp()) + { + tr_netClose(session, handle.tcp); + } +#ifdef WITH_UTP + else if (is_utp()) + { + utp_set_userdata(handle.utp, nullptr); + utp_close(handle.utp); + } +#endif +} diff --git a/libtransmission/peer-socket.h b/libtransmission/peer-socket.h index b88e7dce2..2544b9e49 100644 --- a/libtransmission/peer-socket.h +++ b/libtransmission/peer-socket.h @@ -9,36 +9,107 @@ #error only libtransmission should #include this header. #endif +#include "transmission.h" + #include "net.h" +#include "tr-assert.h" -enum tr_peer_socket_type -{ - TR_PEER_SOCKET_TYPE_NONE, - TR_PEER_SOCKET_TYPE_TCP, - TR_PEER_SOCKET_TYPE_UTP -}; - -union tr_peer_socket_handle -{ - tr_socket_t tcp; - struct UTPSocket* utp; -}; +struct UTPSocket; +struct tr_session; struct tr_peer_socket { - enum tr_peer_socket_type type = TR_PEER_SOCKET_TYPE_NONE; - union tr_peer_socket_handle handle; + tr_peer_socket() = default; + + tr_peer_socket(tr_address const& address, tr_port port, tr_socket_t sock) + : handle{ sock } + , address_{ address } + , port_{ port } + , type_{ Type::TCP } + { + TR_ASSERT(sock != TR_BAD_SOCKET); + } + + tr_peer_socket(tr_address const& address, tr_port port, struct UTPSocket* const sock) + : address_{ address } + , port_{ port } + , type_{ Type::UTP } + { + TR_ASSERT(sock != nullptr); + handle.utp = sock; + } + + void close(tr_session* session); + + [[nodiscard]] constexpr std::pair socketAddress() const noexcept + { + return std::make_pair(address_, port_); + } + + [[nodiscard]] constexpr auto const& address() const noexcept + { + return address_; + } + + [[nodiscard]] constexpr auto const& port() const noexcept + { + return port_; + } + + template + OutputIt readable(OutputIt out) + { + return address_.readable(out, port_); + } + + [[nodiscard]] std::string_view readable(char* out, size_t outlen) const + { + return address_.readable(out, outlen, port_); + } + + [[nodiscard]] std::string readable() const + { + return address_.readable(port_); + } + + [[nodiscard]] constexpr auto is_utp() const noexcept + { + return type_ == Type::UTP; + } + + [[nodiscard]] constexpr auto is_tcp() const noexcept + { + return type_ == Type::TCP; + } + + [[nodiscard]] constexpr auto is_valid() const noexcept + { +#ifdef WITH_UTP + return is_tcp() || is_utp(); +#else + return is_tcp(); +#endif + } + + union + { + tr_socket_t tcp; + struct UTPSocket* utp; + } handle = {}; + +private: + enum class Type + { + None, + TCP, + UTP + }; + + tr_address address_; + tr_port port_; + + enum Type type_ = Type::None; }; -struct tr_peer_socket tr_peer_socket_tcp_create(tr_socket_t const handle); - -struct tr_peer_socket tr_peer_socket_utp_create(struct UTPSocket* const handle); - -struct tr_session; -struct tr_address; - struct tr_peer_socket tr_netOpenPeerSocket(tr_session* session, tr_address const* addr, tr_port port, bool client_is_seed); - struct tr_peer_socket tr_netOpenPeerUTPSocket(tr_session* session, tr_address const* addr, tr_port port, bool client_is_seed); - -void tr_netClosePeerSocket(tr_session* session, tr_peer_socket socket); diff --git a/libtransmission/session.cc b/libtransmission/session.cc index 5a0bb1102..6b3f923de 100644 --- a/libtransmission/session.cc +++ b/libtransmission/session.cc @@ -297,7 +297,7 @@ void tr_session::onIncomingPeerConnection(tr_socket_t fd, void* vsession) { auto const& [addr, port, sock] = *incoming_info; tr_logAddTrace(fmt::format("new incoming connection {} ({})", sock, addr.readable(port))); - session->addIncoming(addr, port, tr_peer_socket_tcp_create(sock)); + session->addIncoming(addr, port, tr_peer_socket{ addr, port, sock }); } } diff --git a/libtransmission/tr-utp.cc b/libtransmission/tr-utp.cc index b8aca0ae3..d197a0854 100644 --- a/libtransmission/tr-utp.cc +++ b/libtransmission/tr-utp.cc @@ -93,7 +93,7 @@ static void utp_on_accept(tr_session* const session, UTPSocket* const utp_sock) if (auto addrport = tr_address::fromSockaddr(reinterpret_cast(&from_storage)); addrport) { auto const& [addr, port] = *addrport; - session->addIncoming(addr, port, tr_peer_socket_utp_create(utp_sock)); + session->addIncoming(addr, port, tr_peer_socket{ addr, port, utp_sock }); } else { diff --git a/tests/libtransmission/handshake-test.cc b/tests/libtransmission/handshake-test.cc index 5bd28c905..5a2b6fd67 100644 --- a/tests/libtransmission/handshake-test.cc +++ b/tests/libtransmission/handshake-test.cc @@ -158,8 +158,10 @@ public: { auto sockpair = std::array{ -1, -1 }; EXPECT_EQ(0, evutil_socketpair(LOCAL_SOCKETPAIR_AF, SOCK_STREAM, 0, std::data(sockpair))) << tr_strerror(errno); - auto const peer_socket = tr_peer_socket_tcp_create(sockpair[0]); - auto io = tr_peerIo::newIncoming(session, &session->top_bandwidth_, &DefaultPeerAddr, DefaultPeerPort, peer_socket); + auto io = tr_peerIo::newIncoming( + session, + &session->top_bandwidth_, + tr_peer_socket(DefaultPeerAddr, DefaultPeerPort, sockpair[0])); return std::make_pair(io, sockpair[1]); } @@ -167,16 +169,13 @@ public: { auto sockpair = std::array{ -1, -1 }; EXPECT_EQ(0, evutil_socketpair(LOCAL_SOCKETPAIR_AF, SOCK_STREAM, 0, std::data(sockpair))) << tr_strerror(errno); - auto const peer_socket = tr_peer_socket_tcp_create(sockpair[0]); auto io = tr_peerIo::create( session, &session->top_bandwidth_, - &DefaultPeerAddr, - DefaultPeerPort, &info_hash, false /*is_incoming*/, false /*is_seed*/, - peer_socket); + tr_peer_socket(DefaultPeerAddr, DefaultPeerPort, sockpair[0])); return std::make_pair(io, sockpair[1]); }