fix: circular dependency in udp core and dht init (#3862)

This commit is contained in:
Charles Kerr 2022-10-02 13:18:23 -05:00 committed by GitHub
parent 257d98545b
commit 0a0c15d17c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 117 additions and 113 deletions

View File

@ -104,7 +104,7 @@ public:
[[nodiscard]] bool isDHTEnabled() const override
{
return tr_dhtEnabled(&session_);
return tr_dhtEnabled();
}
[[nodiscard]] bool allowsTCP() const override

View File

@ -282,7 +282,7 @@ public:
tellPeerWhatWeHave(this);
if (auto const port = tr_dhtPort(torrent->session); io->supportsDHT() && port.has_value())
if (auto const port = tr_dhtPort(); io->supportsDHT() && port.has_value())
{
// only send PORT over IPv6 iff IPv6 DHT is running (BEP-32).
if (io->address().isIPv4() || tr_globalIPv6(nullptr).has_value())
@ -1693,7 +1693,7 @@ static ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen)
if (auto const dht_port = tr_port::fromNetwork(nport); !std::empty(dht_port))
{
msgs->dht_port = dht_port;
tr_dhtAddNode(msgs->session, msgs->io->address(), msgs->dht_port, false);
tr_dhtAddNode(msgs->io->address(), msgs->dht_port, false);
}
}
break;

View File

@ -53,9 +53,19 @@
using namespace std::literals;
static std::unique_ptr<libtransmission::Timer> dht_timer;
static std::array<unsigned char, 20> myid;
static tr_session* my_session = nullptr;
namespace
{
struct Impl
{
std::unique_ptr<libtransmission::Timer> timer;
std::array<unsigned char, 20> id = {};
tr_socket_t udp4_socket = TR_BAD_SOCKET;
tr_socket_t udp6_socket = TR_BAD_SOCKET;
tr_session* session = nullptr;
};
Impl impl = {};
} // namespace
// mutex-locked wrapper around libdht's API
namespace locked_dht
@ -155,29 +165,29 @@ static constexpr std::string_view printableStatus(Status status)
}
}
bool tr_dhtEnabled(tr_session const* session)
bool tr_dhtEnabled()
{
return session != nullptr && session == my_session;
return impl.session != nullptr && (impl.udp4_socket != TR_BAD_SOCKET || impl.udp4_socket != TR_BAD_SOCKET);
}
static auto getUdpSocket(tr_session const* const session, int af)
static constexpr auto getUdpSocket(int af)
{
switch (af)
{
case AF_INET:
return session->udp_core_->udp_socket();
return impl.udp4_socket;
case AF_INET6:
return session->udp_core_->udp6_socket();
return impl.udp6_socket;
default:
return TR_BAD_SOCKET;
}
}
static auto getStatus(tr_session const* const session, int af, int* const setme_node_count = nullptr)
static auto getStatus(int af, int* const setme_node_count = nullptr)
{
if (!tr_dhtEnabled(session) || (getUdpSocket(session, af) == TR_BAD_SOCKET))
if (getUdpSocket(af) == TR_BAD_SOCKET)
{
if (setme_node_count != nullptr)
{
@ -220,19 +230,19 @@ static constexpr auto isReady(Status const status)
return status >= Status::Firewalled;
}
static auto isReady(tr_session const* const session, int af)
static auto isReady(int af)
{
return isReady(getStatus(session, af));
return isReady(getStatus(af));
}
static bool isBootstrapDone(tr_session const* const session, int af)
static bool isBootstrapDone(int af = 0)
{
if (af == 0)
{
return isBootstrapDone(session, AF_INET) && isBootstrapDone(session, AF_INET6);
return isBootstrapDone(AF_INET) && isBootstrapDone(AF_INET6);
}
auto const status = getStatus(session, af, nullptr);
auto const status = getStatus(af, nullptr);
return status == Status::Stopped || isReady(status);
}
@ -243,14 +253,14 @@ static void nap(int roughly_sec)
tr_wait_msec(msec);
}
static int getBootstrappedAF(tr_session const* const session)
static int getBootstrappedAF()
{
if (isBootstrapDone(session, AF_INET6))
if (isBootstrapDone(AF_INET6))
{
return AF_INET;
}
if (isBootstrapDone(session, AF_INET))
if (isBootstrapDone(AF_INET))
{
return AF_INET6;
}
@ -258,7 +268,7 @@ static int getBootstrappedAF(tr_session const* const session)
return 0;
}
static void bootstrapFromName(tr_session const* const session, char const* name, tr_port port, int af)
static void bootstrapFromName(char const* name, tr_port port, int af)
{
auto hints = addrinfo{};
hints.ai_socktype = SOCK_DGRAM;
@ -286,7 +296,7 @@ static void bootstrapFromName(tr_session const* const session, char const* name,
nap(15);
if (isBootstrapDone(session, af))
if (isBootstrapDone(af))
{
break;
}
@ -297,15 +307,15 @@ static void bootstrapFromName(tr_session const* const session, char const* name,
freeaddrinfo(info);
}
static void bootstrapFromFile(tr_session const* const session)
static void bootstrapFromFile(std::string_view config_dir)
{
if (isBootstrapDone(session, 0))
if (isBootstrapDone())
{
return;
}
// check for a manual bootstrap file.
auto in = std::ifstream{ tr_pathbuf{ session->configDir(), "/dht.bootstrap"sv } };
auto in = std::ifstream{ tr_pathbuf{ config_dir, "/dht.bootstrap"sv } };
if (!in.is_open())
{
return;
@ -314,7 +324,7 @@ static void bootstrapFromFile(tr_session const* const session)
// format is each line has address, a space char, and port number
tr_logAddTrace("Attempting manual bootstrap");
auto line = std::string{};
while (!isBootstrapDone(session, 0) && std::getline(in, line))
while (!isBootstrapDone() && std::getline(in, line))
{
auto line_stream = std::istringstream{ line };
auto addrstr = std::string{};
@ -327,14 +337,14 @@ static void bootstrapFromFile(tr_session const* const session)
}
else
{
bootstrapFromName(session, addrstr.c_str(), tr_port::fromHost(hport), getBootstrappedAF(session));
bootstrapFromName(addrstr.c_str(), tr_port::fromHost(hport), getBootstrappedAF());
}
}
}
static void bootstrapStart(tr_session* const session, std::vector<uint8_t> nodes4, std::vector<uint8_t> nodes6)
static void bootstrapStart(std::string config_dir, std::vector<uint8_t> nodes4, std::vector<uint8_t> nodes6)
{
TR_ASSERT(tr_dhtEnabled(session));
TR_ASSERT(tr_dhtEnabled());
auto const num4 = std::size(nodes4) / 6;
if (num4 > 0)
@ -352,22 +362,22 @@ static void bootstrapStart(tr_session* const session, std::vector<uint8_t> nodes
auto const* walk6 = std::data(nodes6);
for (size_t i = 0; i < std::max(num4, num6); ++i)
{
if (i < num4 && !isBootstrapDone(session, AF_INET))
if (i < num4 && !isBootstrapDone(AF_INET))
{
auto addr = tr_address{};
auto port = tr_port{};
std::tie(addr, walk4) = tr_address::fromCompact4(walk4);
std::tie(port, walk4) = tr_port::fromCompact(walk4);
tr_dhtAddNode(session, addr, port, true);
tr_dhtAddNode(addr, port, true);
}
if (i < num6 && !isBootstrapDone(session, AF_INET6))
if (i < num6 && !isBootstrapDone(AF_INET6))
{
auto addr = tr_address{};
auto port = tr_port{};
std::tie(addr, walk6) = tr_address::fromCompact6(walk6);
std::tie(port, walk6) = tr_port::fromCompact(walk6);
tr_dhtAddNode(session, addr, port, true);
tr_dhtAddNode(addr, port, true);
}
/* Our DHT code is able to take up to 9 nodes in a row without
@ -382,18 +392,18 @@ static void bootstrapStart(tr_session* const session, std::vector<uint8_t> nodes
nap(15);
}
if (isBootstrapDone(session, 0))
if (isBootstrapDone())
{
break;
}
}
if (!isBootstrapDone(session, 0))
if (!isBootstrapDone())
{
bootstrapFromFile(session);
bootstrapFromFile(config_dir);
}
if (!isBootstrapDone(session, 0))
if (!isBootstrapDone())
{
for (int i = 0; i < 6; ++i)
{
@ -403,7 +413,7 @@ static void bootstrapStart(tr_session* const session, std::vector<uint8_t> nodes
node, for example because we've just been restarted. */
nap(40);
if (isBootstrapDone(session, 0))
if (isBootstrapDone())
{
break;
}
@ -413,16 +423,16 @@ static void bootstrapStart(tr_session* const session, std::vector<uint8_t> nodes
tr_logAddDebug("Attempting bootstrap from dht.transmissionbt.com");
}
bootstrapFromName(session, "dht.transmissionbt.com", tr_port::fromHost(6881), getBootstrappedAF(session));
bootstrapFromName("dht.transmissionbt.com", tr_port::fromHost(6881), getBootstrappedAF());
}
}
tr_logAddTrace("Finished bootstrapping");
}
int tr_dhtInit(tr_session* session)
int tr_dhtInit(tr_session* session, tr_socket_t udp4_socket, tr_socket_t udp6_socket)
{
if (my_session != nullptr) /* already initialized */
if (impl.session != nullptr) /* already initialized */
{
return -1;
}
@ -446,9 +456,9 @@ int tr_dhtInit(tr_session* session)
{
auto sv = std::string_view{};
have_id = tr_variantDictFindStrView(&benc, TR_KEY_id, &sv);
if (have_id && std::size(sv) == 20)
if (have_id && std::size(sv) == std::size(impl.id))
{
std::copy(std::begin(sv), std::end(sv), std::data(myid));
std::copy(std::begin(sv), std::end(sv), std::data(impl.id));
}
size_t raw_len = 0U;
@ -476,44 +486,46 @@ int tr_dhtInit(tr_session* session)
/* Note that DHT ids need to be distributed uniformly,
* so it should be something truly random. */
tr_logAddTrace("Generating new id");
tr_rand_buffer(std::data(myid), std::size(myid));
tr_rand_buffer(std::data(impl.id), std::size(impl.id));
}
if (locked_dht::init(getUdpSocket(session, AF_INET), getUdpSocket(session, AF_INET6), std::data(myid), nullptr) < 0)
if (locked_dht::init(udp4_socket, udp6_socket, std::data(impl.id), nullptr) < 0)
{
auto const errcode = errno;
tr_logAddDebug(fmt::format("DHT initialization failed: {} ({})", tr_strerror(errcode), errcode));
my_session = nullptr;
impl = {};
return -1;
}
my_session = session;
impl.session = session;
impl.udp4_socket = udp4_socket;
impl.udp6_socket = udp4_socket;
std::thread(bootstrapStart, session, nodes, nodes6).detach();
std::thread(bootstrapStart, std::string{ session->configDir() }, nodes, nodes6).detach();
dht_timer = session->timerMaker().create([session]() { tr_dhtCallback(session, nullptr, 0, nullptr, 0); });
auto const random_percent = tr_rand_int_weak(1000) / 1000.0;
static auto constexpr MinInterval = 10ms;
static auto constexpr MaxInterval = 1s;
auto const random_percent = tr_rand_int_weak(1000) / 1000.0;
auto interval = MinInterval + random_percent * (MaxInterval - MinInterval);
dht_timer->startSingleShot(std::chrono::duration_cast<std::chrono::milliseconds>(interval));
impl.timer = session->timerMaker().create([]() { tr_dhtCallback(nullptr, 0, nullptr, 0); });
impl.timer->startSingleShot(std::chrono::duration_cast<std::chrono::milliseconds>(interval));
tr_logAddDebug("DHT initialized");
return 1;
}
void tr_dhtUninit(tr_session const* session)
void tr_dhtUninit()
{
TR_ASSERT(tr_dhtEnabled(session));
TR_ASSERT(tr_dhtEnabled());
tr_logAddTrace("Uninitializing DHT");
dht_timer.reset();
impl.timer.reset();
/* Since we only save known good nodes,
* avoid erasing older data if we don't know enough nodes. */
if (!isReady(session, AF_INET) && !isReady(session, AF_INET6))
if (!isReady(AF_INET) && !isReady(AF_INET6))
{
tr_logAddTrace("Not saving nodes, DHT not ready");
}
@ -535,7 +547,7 @@ void tr_dhtUninit(tr_session const* session)
tr_variant benc;
tr_variantInitDict(&benc, 3);
tr_variantDictAddRaw(&benc, TR_KEY_id, std::data(myid), std::size(myid));
tr_variantDictAddRaw(&benc, TR_KEY_id, std::data(impl.id), std::size(impl.id));
if (num > 0)
{
@ -567,8 +579,7 @@ void tr_dhtUninit(tr_session const* session)
tr_variantDictAddRaw(&benc, TR_KEY_nodes6, std::data(compact6), out6 - std::data(compact6));
}
auto const dat_file = tr_pathbuf{ session->configDir(), "/dht.dat"sv };
tr_variantToFile(&benc, TR_VARIANT_FMT_BENC, dat_file.sv());
tr_variantToFile(&benc, TR_VARIANT_FMT_BENC, tr_pathbuf{ impl.session->configDir(), "/dht.dat"sv });
tr_variantClear(&benc);
}
@ -576,24 +587,22 @@ void tr_dhtUninit(tr_session const* session)
tr_logAddTrace("Done uninitializing DHT");
my_session = nullptr;
impl = {};
}
std::optional<tr_port> tr_dhtPort(tr_session const* session)
std::optional<tr_port> tr_dhtPort()
{
if (!tr_dhtEnabled(session))
if (impl.session == nullptr)
{
return {};
}
return session->udp_core_->port();
return impl.session->udp_core_->port();
}
bool tr_dhtAddNode(tr_session* ss, tr_address const& addr, tr_port port, bool bootstrap)
bool tr_dhtAddNode(tr_address addr, tr_port port, bool bootstrap)
{
int const af = addr.isIPv4() ? AF_INET : AF_INET6;
if (!tr_dhtEnabled(ss))
if (!tr_dhtEnabled())
{
return false;
}
@ -601,7 +610,7 @@ bool tr_dhtAddNode(tr_session* ss, tr_address const& addr, tr_port port, bool bo
/* Since we don't want to abuse our bootstrap nodes,
* we don't ping them if the DHT is in a good state. */
if (bootstrap && isReady(ss, af))
if (bootstrap && isReady(addr.isIPv4() ? AF_INET : AF_INET6))
{
return false;
}
@ -665,23 +674,16 @@ static void callback(void* vsession, int event, unsigned char const* info_hash,
}
}
enum class AnnounceResult
{
INVALID,
OK,
FAILED
};
static AnnounceResult announceTorrent(tr_session const* const session, tr_torrent const* const tor, int af, bool announce)
static bool announceTorrent(tr_torrent const* const tor, int af, bool announce, tr_port incoming_peer_port)
{
TR_ASSERT(tor->allowsDht());
int numnodes = 0;
auto const status = getStatus(session, af, &numnodes);
auto const status = getStatus(af, &numnodes);
if (status == Status::Stopped)
{
// let the caller believe everything is all right.
return AnnounceResult::OK;
return true;
}
if (status < Status::Poor)
@ -693,12 +695,12 @@ static AnnounceResult announceTorrent(tr_session const* const session, tr_torren
af == AF_INET6 ? "IPv6" : "IPv4",
printableStatus(status),
numnodes));
return AnnounceResult::FAILED;
return false;
}
auto const* dht_hash = reinterpret_cast<unsigned char const*>(std::data(tor->infoHash()));
auto const hport = announce ? session->peerPort().host() : 0;
int const rc = locked_dht::search(dht_hash, hport, af, callback, nullptr);
auto const hport = announce ? incoming_peer_port.host() : 0;
int const rc = locked_dht::search(dht_hash, hport, af, callback, impl.session);
if (rc < 0)
{
auto const error_code = errno;
@ -710,7 +712,7 @@ static AnnounceResult announceTorrent(tr_session const* const session, tr_torren
fmt::arg("state", printableStatus(status)),
fmt::arg("error_code", error_code),
fmt::arg("error", tr_strerror(error_code))));
return AnnounceResult::FAILED;
return false;
}
tr_logAddTraceTor(
@ -721,17 +723,18 @@ static AnnounceResult announceTorrent(tr_session const* const session, tr_torren
printableStatus(status),
numnodes));
return AnnounceResult::OK;
return true;
}
void tr_dhtUpkeep(tr_session* session)
void tr_dhtUpkeep()
{
TR_ASSERT(tr_dhtEnabled(session));
TR_ASSERT(impl.session != nullptr);
auto lock = session->unique_lock();
auto lock = impl.session->unique_lock();
auto const now = tr_time();
auto const incoming_peer_port = impl.session->peerPort();
for (auto* const tor : session->torrents())
for (auto* const tor : impl.session->torrents())
{
if (!tor->isRunning || !tor->allowsDht())
{
@ -740,31 +743,29 @@ void tr_dhtUpkeep(tr_session* session)
if (tor->dhtAnnounceAt <= now)
{
auto const rc = announceTorrent(session, tor, AF_INET, true);
tor->dhtAnnounceAt = now +
((rc == AnnounceResult::FAILED) ? 5 + tr_rand_int_weak(5) : 25 * 60 + tr_rand_int_weak(3 * 60));
auto const ok = announceTorrent(tor, AF_INET, true, incoming_peer_port);
auto const interval = ok ? 25 * 60 + tr_rand_int_weak(3 * 60) : 5 + tr_rand_int_weak(5);
tor->dhtAnnounceAt = now + interval;
}
if (tor->dhtAnnounce6At <= now)
{
auto const rc = announceTorrent(session, tor, AF_INET6, true);
tor->dhtAnnounce6At = now +
((rc == AnnounceResult::FAILED) ? 5 + tr_rand_int_weak(5) : 25 * 60 + tr_rand_int_weak(3 * 60));
auto const ok = announceTorrent(tor, AF_INET6, true, incoming_peer_port);
auto const interval = ok ? 25 * 60 + tr_rand_int_weak(3 * 60) : 5 + tr_rand_int_weak(5);
tor->dhtAnnounce6At = now + interval;
}
}
}
void tr_dhtCallback(tr_session* session, unsigned char* buf, int buflen, struct sockaddr* from, socklen_t fromlen)
void tr_dhtCallback(unsigned char* buf, int buflen, struct sockaddr* from, socklen_t fromlen)
{
if (!tr_dhtEnabled(session))
if (!tr_dhtEnabled())
{
return;
}
time_t tosleep = 0;
int const rc = locked_dht::periodic(buf, buflen, from, fromlen, &tosleep, callback, nullptr);
int const rc = locked_dht::periodic(buf, buflen, from, fromlen, &tosleep, callback, impl.session);
if (rc < 0)
{
@ -792,12 +793,11 @@ void tr_dhtCallback(tr_session* session, unsigned char* buf, int buflen, struct
auto const min_interval = std::chrono::seconds{ tosleep };
auto const max_interval = std::chrono::seconds{ tosleep + 1 };
auto const interval = min_interval + random_percent * (max_interval - min_interval);
dht_timer->startSingleShot(std::chrono::duration_cast<std::chrono::milliseconds>(interval));
impl.timer->startSingleShot(std::chrono::duration_cast<std::chrono::milliseconds>(interval));
}
extern "C"
{
// This function should return true when a node is blacklisted.
// We don't support using a blacklist with the DHT in Transmission,
// since massive (ab)use of this feature could harm the DHT. However,

View File

@ -9,15 +9,19 @@
#endif
#include <optional>
#include <string_view>
#include "transmission.h"
#include "net.h" // tr_port
int tr_dhtInit(tr_session*);
void tr_dhtUninit(tr_session const*);
bool tr_dhtEnabled(tr_session const*);
std::optional<tr_port> tr_dhtPort(tr_session const*);
bool tr_dhtAddNode(tr_session*, tr_address const&, tr_port, bool bootstrap);
void tr_dhtUpkeep(tr_session*);
void tr_dhtCallback(tr_session*, unsigned char* buf, int buflen, struct sockaddr* from, socklen_t fromlen);
int tr_dhtInit(tr_session*, tr_socket_t udp4_socket, tr_socket_t udp6_socket);
void tr_dhtUninit();
bool tr_dhtEnabled();
std::optional<tr_port> tr_dhtPort();
bool tr_dhtAddNode(tr_address, tr_port, bool bootstrap);
void tr_dhtUpkeep();
void tr_dhtCallback(unsigned char* buf, int buflen, struct sockaddr* from, socklen_t fromlen);

View File

@ -205,7 +205,7 @@ static void event_callback(evutil_socket_t s, [[maybe_unused]] short type, void*
if (session->allowsDHT())
{
buf[rc] = '\0'; /* required by the DHT code */
tr_dhtCallback(session, std::data(buf), rc, (struct sockaddr*)&from, fromlen);
tr_dhtCallback(std::data(buf), rc, (struct sockaddr*)&from, fromlen);
}
}
else if (rc >= 8 && buf[0] == 0 && buf[1] == 0 && buf[2] == 0 && buf[3] <= 3)
@ -301,7 +301,7 @@ tr_session::tr_udp_core::tr_udp_core(tr_session& session)
if (session_.allowsDHT())
{
tr_dhtInit(&session_);
tr_dhtInit(&session_, udp_socket_, udp6_socket_);
}
if (udp_event_ != nullptr)
@ -316,17 +316,17 @@ tr_session::tr_udp_core::tr_udp_core(tr_session& session)
void tr_session::tr_udp_core::dhtUpkeep()
{
if (tr_dhtEnabled(&session_))
if (tr_dhtEnabled())
{
tr_dhtUpkeep(&session_);
tr_dhtUpkeep();
}
}
void tr_session::tr_udp_core::dhtUninit()
{
if (tr_dhtEnabled(&session_))
if (tr_dhtEnabled())
{
tr_dhtUninit(&session_);
tr_dhtUninit();
}
}