refactor: make DHT unblocking (#4122)

This commit is contained in:
Charles Kerr 2022-11-11 10:09:24 -06:00 committed by GitHub
parent 92b74fee74
commit 9e06cf8f2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1413 additions and 851 deletions

View File

@ -108,9 +108,44 @@ tr_peer_id_t tr_peerIdInit()
return peer_id;
}
/***
****
***/
///
std::vector<tr_torrent_id_t> tr_session::DhtMediator::torrentsAllowingDHT() const
{
auto ids = std::vector<tr_torrent_id_t>{};
auto const& torrents = session_.torrents();
ids.reserve(std::size(torrents));
for (auto const* const tor : torrents)
{
if (tor->isRunning && tor->allowsDht())
{
ids.push_back(tor->id());
}
}
return ids;
}
tr_sha1_digest_t tr_session::DhtMediator::torrentInfoHash(tr_torrent_id_t id) const
{
if (auto const* const tor = session_.torrents().get(id); tor != nullptr)
{
return tor->infoHash();
}
return {};
}
void tr_session::DhtMediator::addPex(tr_sha1_digest_t const& info_hash, tr_pex const* pex, size_t n_pex)
{
if (auto* const tor = session_.torrents().get(info_hash); tor != nullptr)
{
tr_peerMgrAddPex(tor, TR_PEER_FROM_DHT, pex, n_pex);
}
}
///
bool tr_session::LpdMediator::onPeerFound(std::string_view info_hash_str, tr_address address, tr_port port)
{
@ -468,7 +503,6 @@ void tr_session::onNowTimer()
// tr_session upkeep tasks to perform once per second
tr_timeUpdate(time(nullptr));
udp_core_->dhtUpkeep();
alt_speeds_.checkScheduler();
// TODO: this seems a little silly. Why do we increment this
@ -607,24 +641,28 @@ void tr_session::setSettings(tr_session_settings settings_in, bool force)
port_changed = true;
}
bool addr_changed = false;
if (new_settings.tcp_enabled)
{
if (auto const& val = new_settings.bind_address_ipv4; force || port_changed || val != old_settings.bind_address_ipv4)
{
auto const [addr, is_default] = publicAddress(TR_AF_INET);
bound_ipv4_.emplace(eventBase(), addr, local_peer_port_, &tr_session::onIncomingPeerConnection, this);
addr_changed = true;
}
if (auto const& val = new_settings.bind_address_ipv6; force || port_changed || val != old_settings.bind_address_ipv6)
{
auto const [addr, is_default] = publicAddress(TR_AF_INET6);
bound_ipv6_.emplace(eventBase(), addr, local_peer_port_, &tr_session::onIncomingPeerConnection, this);
addr_changed = true;
}
}
else
{
bound_ipv4_.reset();
bound_ipv6_.reset();
addr_changed = true;
}
if (port_changed)
@ -653,6 +691,15 @@ void tr_session::setSettings(tr_session_settings settings_in, bool force)
}
}
if (!allowsDHT())
{
dht_.reset();
}
else if (force || !dht_ || port_changed || addr_changed || dht_changed)
{
dht_ = tr_dht::create(dht_mediator_, localPeerPort(), udp_core_->socket4(), udp_core_->socket6());
}
// We need to update bandwidth if speed settings changed.
// It's a harmless call, so just call it instead of checking for settings changes
updateBandwidth(this, TR_UP);
@ -801,6 +848,14 @@ tr_port_forwarding_state tr_sessionGetPortForwarding(tr_session const* session)
return session->port_forwarding_->state();
}
void tr_session::onAdvertisedPeerPortChanged()
{
for (auto* const tor : torrents())
{
tr_torrentChangeMyPort(tor);
}
}
/***
****
***/
@ -1161,13 +1216,14 @@ void tr_session::closeImplPart1(std::promise<void>* closed_promise)
save_timer_.reset();
now_timer_.reset();
rpc_server_.reset();
dht_.reset();
lpd_.reset();
port_forwarding_.reset();
bound_ipv6_.reset();
bound_ipv4_.reset();
// tell other items to start shutting down
udp_core_->startShutdown();
announcer_udp_->startShutdown();
// Close the torrents in order of most active to least active

View File

@ -42,6 +42,7 @@
#include "session-thread.h"
#include "stats.h"
#include "torrents.h"
#include "tr-dht.h"
#include "tr-lpd.h"
#include "utils-ev.h"
#include "verify.h"
@ -150,6 +151,36 @@ private:
tr_session& session_;
};
class DhtMediator : public tr_dht::Mediator
{
public:
DhtMediator(tr_session& session) noexcept
: session_{ session }
{
}
~DhtMediator() noexcept override = default;
[[nodiscard]] std::vector<tr_torrent_id_t> torrentsAllowingDHT() const override;
[[nodiscard]] tr_sha1_digest_t torrentInfoHash(tr_torrent_id_t id) const override;
[[nodiscard]] std::string_view configDir() const override
{
return session_.config_dir_;
}
[[nodiscard]] libtransmission::TimerMaker& timerMaker() override
{
return session_.timerMaker();
}
void addPex(tr_sha1_digest_t const&, tr_pex const* pex, size_t n_pex) override;
private:
tr_session& session_;
};
class PortForwardingMediator final : public tr_port_forwarding::Mediator
{
public:
@ -175,7 +206,11 @@ private:
void onPortForwarded(tr_port public_port) override
{
session_.advertised_peer_port_ = public_port;
if (session_.advertised_peer_port_ != public_port)
{
session_.advertised_peer_port_ = public_port;
session_.onAdvertisedPeerPortChanged();
}
}
private:
@ -241,12 +276,21 @@ private:
{
public:
tr_udp_core(tr_session& session, tr_port udp_port);
~tr_udp_core();
static void startShutdown();
static void dhtUpkeep();
void sendto(void const* buf, size_t buflen, struct sockaddr const* to, socklen_t const tolen) const;
[[nodiscard]] constexpr auto socket4() const noexcept
{
return udp_socket_;
}
[[nodiscard]] constexpr auto socket6() const noexcept
{
return udp6_socket_;
}
private:
void set_socket_buffers();
void set_socket_tos()
@ -255,11 +299,6 @@ private:
session_.setSocketTOS(udp6_socket_, TR_AF_INET6);
}
void sendto(void const* buf, size_t buflen, struct sockaddr const* to, socklen_t const tolen) const;
void addDhtNode(tr_address const& addr, tr_port port);
private:
tr_port const udp_port_;
tr_session& session_;
tr_socket_t udp_socket_ = TR_BAD_SOCKET;
@ -847,9 +886,9 @@ public:
void addDhtNode(tr_address const& addr, tr_port port)
{
if (udp_core_)
if (dht_)
{
udp_core_->addDhtNode(addr, port);
dht_->addNode(addr, port);
}
}
@ -886,7 +925,7 @@ private:
[[nodiscard]] tr_port randomPort() const;
void setPeerPort(tr_port port);
void onAdvertisedPeerPortChanged();
struct init_data;
void initImpl(init_data&);
@ -1061,7 +1100,7 @@ private:
std::optional<BoundSocket> bound_ipv6_;
public:
// depends-on: announcer_udp_
// depends-on: settings_, announcer_udp_
// FIXME(ckerr): circular dependency udp_core -> announcer_udp -> announcer_udp_mediator -> udp_core
std::unique_ptr<tr_udp_core> udp_core_;
@ -1104,6 +1143,9 @@ private:
// depends-on: udp_core_
AnnouncerUdpMediator announcer_udp_mediator_{ *this };
// depends-on: timer_maker_, torrents_, peer_mgr_
DhtMediator dht_mediator_{ *this };
public:
// depends-on: announcer_udp_mediator_
std::unique_ptr<tr_announcer_udp> announcer_udp_ = tr_announcer_udp::create(announcer_udp_mediator_);
@ -1111,6 +1153,9 @@ public:
// depends-on: settings_, torrents_, web_, announcer_udp_
struct tr_announcer* announcer = nullptr;
// depends-on: public_peer_port_, udp_core_, dht_mediator_
std::unique_ptr<tr_dht> dht_;
private:
// depends-on: session_thread_, timer_maker_, settings_, torrents_, web_
std::unique_ptr<tr_rpc_server> rpc_server_;

View File

@ -1312,8 +1312,6 @@ static void torrentStartImpl(tr_torrent* const tor)
tr_torrentResetTransferStats(tor);
tr_announcerTorrentStarted(tor);
tor->dhtAnnounceAt = now + tr_rand_int_weak(20);
tor->dhtAnnounce6At = now + tr_rand_int_weak(20);
tor->lpdAnnounceAt = now;
tr_peerMgrStartTorrent(tor);
}

View File

@ -748,9 +748,6 @@ public:
time_t peer_id_creation_time_ = 0;
time_t dhtAnnounceAt = 0;
time_t dhtAnnounce6At = 0;
time_t lpdAnnounceAt = 0;
time_t activityDate = 0;

File diff suppressed because it is too large Load Diff

View File

@ -8,17 +8,103 @@
#error only libtransmission should #include this header.
#endif
#include <memory>
#include <string_view>
#include <vector>
#include <dht/dht.h>
#include "transmission.h"
#include "net.h" // tr_port
int tr_dhtInit(tr_session*, tr_socket_t udp4_socket, tr_socket_t udp6_socket);
void tr_dhtUninit();
struct tr_pex;
bool tr_dhtEnabled();
namespace libtransmission
{
class TimerMaker;
} // namespace libtransmission
bool tr_dhtAddNode(tr_address, tr_port, bool bootstrap);
void tr_dhtUpkeep();
void tr_dhtCallback(unsigned char* buf, int buflen, struct sockaddr* from, socklen_t fromlen);
class tr_dht
{
public:
// Wrapper around DHT library.
// This calls `jech/dht` in production, but makes it possible for tests to inject a mock.
struct API
{
virtual ~API() = default;
virtual int get_nodes(struct sockaddr_in* sin, int* num, struct sockaddr_in6* sin6, int* num6)
{
return ::dht_get_nodes(sin, num, sin6, num6);
}
virtual int nodes(int af, int* good_return, int* dubious_return, int* cached_return, int* incoming_return)
{
return ::dht_nodes(af, good_return, dubious_return, cached_return, incoming_return);
}
virtual int periodic(
void const* buf,
size_t buflen,
struct sockaddr const* from,
int fromlen,
time_t* tosleep,
dht_callback_t callback,
void* closure)
{
return ::dht_periodic(buf, buflen, from, fromlen, tosleep, callback, closure);
}
virtual int ping_node(struct sockaddr const* sa, int salen)
{
return ::dht_ping_node(sa, salen);
}
virtual int search(unsigned char const* id, int port, int af, dht_callback_t callback, void* closure)
{
return ::dht_search(id, port, af, callback, closure);
}
virtual int init(int s, int s6, unsigned const char* id, unsigned const char* v)
{
return ::dht_init(s, s6, id, v);
}
virtual int uninit()
{
return ::dht_uninit();
}
};
class Mediator
{
public:
virtual ~Mediator() = default;
[[nodiscard]] virtual std::vector<tr_torrent_id_t> torrentsAllowingDHT() const = 0;
[[nodiscard]] virtual tr_sha1_digest_t torrentInfoHash(tr_torrent_id_t) const = 0;
[[nodiscard]] virtual std::string_view configDir() const = 0;
[[nodiscard]] virtual libtransmission::TimerMaker& timerMaker() = 0;
[[nodiscard]] virtual API& api()
{
return api_;
}
virtual void addPex(tr_sha1_digest_t const&, tr_pex const* pex, size_t n_pex) = 0;
private:
API api_;
};
[[nodiscard]] static std::unique_ptr<tr_dht> create(
Mediator& mediator,
tr_port peer_port,
tr_socket_t udp4_socket,
tr_socket_t udp6_socket);
virtual ~tr_dht() = default;
virtual void addNode(tr_address const& address, tr_port port) = 0;
virtual void handleMessage(unsigned char const* msg, size_t msglen, struct sockaddr* from, socklen_t fromlen) = 0;
};

View File

@ -7,12 +7,6 @@
#include <cstdint>
#include <cstring> /* memcmp(), memset() */
#ifdef _WIN32
#include <io.h> /* dup2() */
#else
#include <unistd.h> /* dup2() */
#endif
#include <event2/event.h>
#include <fmt/core.h>
@ -22,7 +16,6 @@
#include "net.h"
#include "session.h"
#include "tr-assert.h"
#include "tr-dht.h"
#include "tr-utp.h"
#include "utils.h"
@ -90,11 +83,6 @@ static void set_socket_buffers(tr_socket_t fd, bool large)
}
}
void tr_session::tr_udp_core::addDhtNode(tr_address const& addr, tr_port port)
{
tr_dhtAddNode(addr, port, false);
}
void tr_session::tr_udp_core::set_socket_buffers()
{
bool const utp = session_.allowsUTP();
@ -188,13 +176,16 @@ static void event_callback(evutil_socket_t s, [[maybe_unused]] short type, void*
TR_ASSERT(vsession != nullptr);
TR_ASSERT(type == EV_READ);
auto buf = std::array<unsigned char, 4096>{};
auto buf = std::array<unsigned char, 8192>{};
auto from = sockaddr_storage{};
auto* session = static_cast<tr_session*>(vsession);
socklen_t fromlen = sizeof(from);
auto const
rc = recvfrom(s, reinterpret_cast<char*>(std::data(buf)), std::size(buf) - 1, 0, (struct sockaddr*)&from, &fromlen);
auto fromlen = socklen_t{ sizeof(from) };
auto const rc = recvfrom(
s,
reinterpret_cast<char*>(std::data(buf)),
std::size(buf) - 1,
0,
reinterpret_cast<sockaddr*>(&from),
&fromlen);
/* Since most packets we receive here are µTP, make quick inline
checks for the other protocols. The logic is as follows:
@ -203,14 +194,15 @@ static void event_callback(evutil_socket_t s, [[maybe_unused]] short type, void*
is between 0 and 3
- the above cannot be µTP packets, since these start with a 4-bit
version number (1). */
auto* session = static_cast<tr_session*>(vsession);
if (rc > 0)
{
if (buf[0] == 'd')
{
if (session->allowsDHT())
if (session->dht_)
{
buf[rc] = '\0'; /* required by the DHT code */
tr_dhtCallback(std::data(buf), rc, (struct sockaddr*)&from, fromlen);
buf[rc] = '\0'; // libdht requires zero-terminated messages
session->dht_->handleMessage(std::data(buf), rc, reinterpret_cast<sockaddr*>(&from), fromlen);
}
}
else if (rc >= 8 && buf[0] == 0 && buf[1] == 0 && buf[2] == 0 && buf[3] <= 3)
@ -294,12 +286,7 @@ tr_session::tr_udp_core::tr_udp_core(tr_session& session, tr_port udp_port)
set_socket_buffers();
set_socket_tos();
if (session_.allowsDHT())
{
tr_dhtInit(&session_, udp_socket_, udp6_socket_);
}
if (udp4_event_)
if (udp4_event_ != nullptr)
{
event_add(udp4_event_.get(), nullptr);
}
@ -309,26 +296,8 @@ tr_session::tr_udp_core::tr_udp_core(tr_session& session, tr_port udp_port)
}
}
void tr_session::tr_udp_core::dhtUpkeep()
{
if (tr_dhtEnabled())
{
tr_dhtUpkeep();
}
}
void tr_session::tr_udp_core::startShutdown()
{
if (tr_dhtEnabled())
{
tr_dhtUninit();
}
}
tr_session::tr_udp_core::~tr_udp_core()
{
startShutdown();
udp6_event_.reset();
if (udp6_socket_ != TR_BAD_SOCKET)

View File

@ -13,6 +13,7 @@ add_executable(libtransmission-test
crypto-test-ref.h
crypto-test.cc
error-test.cc
dht-test.cc
file-piece-map-test.cc
file-test.cc
getopt-test.cc
@ -65,6 +66,7 @@ target_include_directories(libtransmission-test SYSTEM
${WIDE_INTEGER_INCLUDE_DIRS}
${B64_INCLUDE_DIRS}
${CURL_INCLUDE_DIRS}
${DHT_INCLUDE_DIRS}
${EVENT2_INCLUDE_DIRS})
target_compile_options(libtransmission-test

View File

@ -0,0 +1,643 @@
// This file Copyright (C) 2022 Mnemosyne LLC.
// It may be used under GPLv2 (SPDX: GPL-2.0-only), GPLv3 (SPDX: GPL-3.0-only),
// or any future license endorsed by Mnemosyne LLC.
// License text can be found in the licenses/ folder.
#include <algorithm>
#include <chrono>
#include <fstream>
#include <memory>
#include <utility>
#include <event2/event.h>
#include "transmission.h"
#include "file.h"
#include "timer-ev.h"
#include "session-thread.h" // for tr_evthread_init();
#include "gtest/gtest.h"
#include "test-fixtures.h"
#ifdef _WIN32
#undef gai_strerror
#define gai_strerror gai_strerrorA
#endif
using namespace std::literals;
namespace libtransmission::test
{
bool waitFor(struct event_base* event_base, std::chrono::milliseconds msec)
{
return waitFor(
event_base,
[]() { return false; },
msec);
}
namespace
{
auto constexpr IdLength = size_t{ 20U };
auto constexpr MockTimerInterval = 40ms;
} // namespace
class DhtTest : public SandboxedTest
{
protected:
// Helper for creating a mock dht.dat state file
struct MockStateFile
{
// Fake data to be written to the test state file
std::array<char, IdLength> id_ = tr_randObj<std::array<char, IdLength>>();
std::vector<std::pair<tr_address, tr_port>> ipv4_nodes_ = {
std::make_pair(*tr_address::fromString("10.10.10.1"), tr_port::fromHost(128)),
std::make_pair(*tr_address::fromString("10.10.10.2"), tr_port::fromHost(129)),
std::make_pair(*tr_address::fromString("10.10.10.3"), tr_port::fromHost(130)),
std::make_pair(*tr_address::fromString("10.10.10.4"), tr_port::fromHost(131)),
std::make_pair(*tr_address::fromString("10.10.10.5"), tr_port::fromHost(132))
};
std::vector<std::pair<tr_address, tr_port>> ipv6_nodes_ = {
std::make_pair(*tr_address::fromString("1002:1035:4527:3546:7854:1237:3247:3217"), tr_port::fromHost(6881)),
std::make_pair(*tr_address::fromString("1002:1035:4527:3546:7854:1237:3247:3218"), tr_port::fromHost(6882)),
std::make_pair(*tr_address::fromString("1002:1035:4527:3546:7854:1237:3247:3219"), tr_port::fromHost(6883)),
std::make_pair(*tr_address::fromString("1002:1035:4527:3546:7854:1237:3247:3220"), tr_port::fromHost(6884)),
std::make_pair(*tr_address::fromString("1002:1035:4527:3546:7854:1237:3247:3221"), tr_port::fromHost(6885))
};
[[nodiscard]] auto nodesString() const
{
auto str = std::string{};
for (auto const& [addr, port] : ipv4_nodes_)
{
str += addr.readable(port);
str += ',';
}
for (auto const& [addr, port] : ipv6_nodes_)
{
str += addr.readable(port);
str += ',';
}
return str;
}
[[nodiscard]] static auto filename(std::string_view dirname)
{
return std::string{ dirname } + "/dht.dat";
}
void save(std::string_view path) const
{
auto const dat_file = MockStateFile::filename(path);
auto dict = tr_variant{};
tr_variantInitDict(&dict, 3U);
tr_variantDictAddRaw(&dict, TR_KEY_id, std::data(id_), std::size(id_));
auto compact = std::vector<std::byte>{};
for (auto const& [addr, port] : ipv4_nodes_)
{
addr.toCompact4(std::back_inserter(compact), port);
}
tr_variantDictAddRaw(&dict, TR_KEY_nodes, std::data(compact), std::size(compact));
compact.clear();
for (auto const& [addr, port] : ipv6_nodes_)
{
addr.toCompact6(std::back_inserter(compact), port);
}
tr_variantDictAddRaw(&dict, TR_KEY_nodes6, std::data(compact), std::size(compact));
tr_variantToFile(&dict, TR_VARIANT_FMT_BENC, dat_file);
tr_variantClear(&dict);
}
};
// A fake libdht for the tests to call
class MockDht final : public tr_dht::API
{
public:
int get_nodes(struct sockaddr_in* /*sin*/, int* /*max*/, struct sockaddr_in6* /*sin6*/, int* /*max6*/) override
{
return 0;
}
int nodes(int /*af*/, int* good, int* dubious, int* cached, int* incoming) override
{
if (good != nullptr)
{
*good = good_;
}
if (dubious != nullptr)
{
*dubious = dubious_;
}
if (cached != nullptr)
{
*cached = cached_;
}
if (incoming != nullptr)
{
*incoming = incoming_;
}
return 0;
}
int periodic(
void const* /*buf*/,
size_t /*buflen*/,
sockaddr const /*from*/*,
int /*fromlen*/,
time_t* /*tosleep*/,
dht_callback_t /*callback*/,
void* /*closure*/) override
{
++n_periodic_calls_;
return 0;
}
int ping_node(struct sockaddr const* sa, int /*salen*/) override
{
auto addrport = tr_address::fromSockaddr(sa);
auto const [addr, port] = *addrport;
pinged_.push_back(Pinged{ addr, port, tr_time() });
return 0;
}
int search(unsigned char const* id, int port, int af, dht_callback_t /*callback*/, void* /*closure*/) override
{
auto info_hash = tr_sha1_digest_t{};
std::copy_n(reinterpret_cast<std::byte const*>(id), std::size(info_hash), std::data(info_hash));
searched_.push_back(Searched{ info_hash, tr_port::fromHost(port), af });
return 0;
}
int init(int dht_socket, int dht_socket6, unsigned const char* id, unsigned const char* /*v*/) override
{
inited_ = true;
dht_socket_ = dht_socket;
dht_socket6_ = dht_socket6;
std::copy_n(id, std::size(id_), std::begin(id_));
return 0;
}
int uninit() override
{
inited_ = false;
return 0;
}
constexpr void setHealthySwarm()
{
good_ = 50;
incoming_ = 10;
}
constexpr void setFirewalledSwarm()
{
good_ = 50;
incoming_ = 0;
}
constexpr void setPoorSwarm()
{
good_ = 10;
incoming_ = 1;
}
struct Searched
{
tr_sha1_digest_t info_hash;
tr_port port;
int af;
};
struct Pinged
{
tr_address address;
tr_port port;
time_t timestamp;
};
int good_ = 0;
int dubious_ = 0;
int cached_ = 0;
int incoming_ = 0;
size_t n_periodic_calls_ = 0;
bool inited_ = false;
std::vector<Pinged> pinged_;
std::vector<Searched> searched_;
std::array<char, IdLength> id_ = {};
int dht_socket_ = TR_BAD_SOCKET;
int dht_socket6_ = TR_BAD_SOCKET;
};
// Creates real timers, but with shortened intervals so that tests can run faster
class MockTimer final : public libtransmission::Timer
{
public:
explicit MockTimer(std::unique_ptr<Timer> real_timer)
: real_timer_{ std::move(real_timer) }
{
}
void stop() override
{
real_timer_->stop();
}
void setCallback(std::function<void()> callback) override
{
real_timer_->setCallback(std::move(callback));
}
void setRepeating(bool repeating = true) override
{
real_timer_->setRepeating(repeating);
}
void setInterval(std::chrono::milliseconds /*interval*/) override
{
real_timer_->setInterval(MockTimerInterval);
}
void start() override
{
real_timer_->start();
}
[[nodiscard]] std::chrono::milliseconds interval() const noexcept override
{
return real_timer_->interval();
}
[[nodiscard]] bool isRepeating() const noexcept override
{
return real_timer_->isRepeating();
}
private:
std::unique_ptr<Timer> const real_timer_;
};
// Creates MockTimers
class MockTimerMaker final : public libtransmission::TimerMaker
{
public:
explicit MockTimerMaker(struct event_base* evb)
: real_timer_maker_{ evb }
{
}
[[nodiscard]] std::unique_ptr<Timer> create() override
{
return std::make_unique<MockTimer>(real_timer_maker_.create());
}
EvTimerMaker real_timer_maker_;
};
class MockMediator final : public tr_dht::Mediator
{
public:
explicit MockMediator(struct event_base* event_base)
: mock_timer_maker_{ event_base }
{
}
[[nodiscard]] std::vector<tr_torrent_id_t> torrentsAllowingDHT() const override
{
return torrents_allowing_dht_;
}
[[nodiscard]] tr_sha1_digest_t torrentInfoHash(tr_torrent_id_t id) const override
{
if (auto const iter = info_hashes_.find(id); iter != std::end(info_hashes_))
{
return iter->second;
}
return {};
}
[[nodiscard]] std::string_view configDir() const override
{
return config_dir_;
}
[[nodiscard]] libtransmission::TimerMaker& timerMaker() override
{
return mock_timer_maker_;
}
[[nodiscard]] tr_dht::API& api() override
{
return mock_dht_;
}
void addPex(tr_sha1_digest_t const& /*info_hash*/, tr_pex const* /*pex*/, size_t /*n_pex*/) override
{
}
std::string config_dir_;
std::vector<tr_torrent_id_t> torrents_allowing_dht_;
std::map<tr_torrent_id_t, tr_sha1_digest_t> info_hashes_;
MockDht mock_dht_;
MockTimerMaker mock_timer_maker_;
};
[[nodiscard]] static std::pair<tr_address, tr_port> getSockaddr(std::string_view name, tr_port port)
{
auto hints = addrinfo{};
hints.ai_socktype = SOCK_DGRAM;
hints.ai_family = AF_UNSPEC;
auto const szname = tr_urlbuf{ name };
auto const port_str = std::to_string(port.host());
addrinfo* info = nullptr;
if (int const rc = getaddrinfo(szname.c_str(), std::data(port_str), &hints, &info); rc != 0)
{
tr_logAddWarn(fmt::format(
_("Couldn't look up '{address}:{port}': {error} ({error_code})"),
fmt::arg("address", name),
fmt::arg("port", port.host()),
fmt::arg("error", gai_strerror(rc)),
fmt::arg("error_code", rc)));
return {};
}
auto opt = tr_address::fromSockaddr(info->ai_addr);
freeaddrinfo(info);
if (opt)
{
return *opt;
}
return {};
}
void SetUp() override
{
SandboxedTest::SetUp();
tr_session_thread::tr_evthread_init();
event_base_ = event_base_new();
}
void TearDown() override
{
event_base_free(event_base_);
event_base_ = nullptr;
SandboxedTest::TearDown();
}
struct event_base* event_base_ = nullptr;
// Arbitrary values. Several tests requires socket/port values
// to be provided but they aren't central to the tests, so they're
// declared here with "Arbitrary" in the name to make that clear.
static auto constexpr ArbitrarySock4 = tr_socket_t{ 404 };
static auto constexpr ArbitrarySock6 = tr_socket_t{ 418 };
static auto constexpr ArbitraryPeerPort = tr_port::fromHost(909);
};
TEST_F(DhtTest, initsWithCorrectSockets)
{
static auto constexpr Sock4 = tr_socket_t{ 1000 };
static auto constexpr Sock6 = tr_socket_t{ 2000 };
// Make the DHT
auto mediator = MockMediator{ event_base_ };
mediator.config_dir_ = sandboxDir();
auto dht = tr_dht::create(mediator, ArbitraryPeerPort, Sock4, Sock6);
// Confirm that dht_init() was called with the right sockets
EXPECT_EQ(Sock4, mediator.mock_dht_.dht_socket_);
EXPECT_EQ(Sock6, mediator.mock_dht_.dht_socket6_);
}
TEST_F(DhtTest, callsUninitOnDestruct)
{
auto mediator = MockMediator{ event_base_ };
mediator.config_dir_ = sandboxDir();
EXPECT_FALSE(mediator.mock_dht_.inited_);
{
auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6);
EXPECT_TRUE(mediator.mock_dht_.inited_);
// dht goes out-of-scope here
}
EXPECT_FALSE(mediator.mock_dht_.inited_);
}
TEST_F(DhtTest, loadsStateFromStateFile)
{
auto const state_file = MockStateFile{};
state_file.save(sandboxDir());
// Make the DHT
auto mediator = MockMediator{ event_base_ };
mediator.config_dir_ = sandboxDir();
auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6);
// Wait for all the state nodes to be pinged
auto& pinged = mediator.mock_dht_.pinged_;
auto const n_expected_nodes = std::size(state_file.ipv4_nodes_) + std::size(state_file.ipv6_nodes_);
waitFor(event_base_, [&pinged, n_expected_nodes]() { return std::size(pinged) >= n_expected_nodes; });
auto actual_nodes_str = std::string{};
for (auto const& [addr, port, timestamp] : pinged)
{
actual_nodes_str += addr.readable(port);
actual_nodes_str += ',';
}
/// Confirm that the state was loaded
// dht_init() should have been called with the state file's id
EXPECT_EQ(state_file.id_, mediator.mock_dht_.id_);
// dht_ping_nodedht_init() should have been called with state file's nodes
EXPECT_EQ(state_file.nodesString(), actual_nodes_str);
}
TEST_F(DhtTest, stopsBootstrappingWhenSwarmHealthIsGoodEnough)
{
auto const state_file = MockStateFile{};
state_file.save(sandboxDir());
// Make the DHT
auto mediator = MockMediator{ event_base_ };
mediator.config_dir_ = sandboxDir();
auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6);
// Wait for N pings to occur...
auto& mock_dht = mediator.mock_dht_;
static auto constexpr TurnGoodAfterNthPing = size_t{ 3 };
waitFor(event_base_, [&mock_dht]() { return std::size(mock_dht.pinged_) == TurnGoodAfterNthPing; });
EXPECT_EQ(TurnGoodAfterNthPing, std::size(mock_dht.pinged_));
// Now fake that libdht says the swarm is healthy.
// This should cause bootstrapping to end.
mock_dht.setHealthySwarm();
// Now test to see if bootstrapping is done.
// There's not public API for `isBootstrapping()`,
// so to test this we just a moment to confirm that no more bootstrap nodes are pinged.
waitFor(event_base_, MockTimerInterval * 10);
// Confirm that the number of nodes pinged is unchanged,
// indicating that boostrapping is done
EXPECT_EQ(TurnGoodAfterNthPing, std::size(mock_dht.pinged_));
}
TEST_F(DhtTest, savesStateIfSwarmIsGood)
{
auto const state_file = MockStateFile{};
auto const dat_file = MockStateFile::filename(sandboxDir());
EXPECT_FALSE(tr_sys_path_exists(dat_file.c_str()));
{
auto mediator = MockMediator{ event_base_ };
mediator.config_dir_ = sandboxDir();
mediator.mock_dht_.setHealthySwarm();
auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6);
// as dht goes out of scope,
// it should save its state if the swarm is healthy
EXPECT_FALSE(tr_sys_path_exists(dat_file.c_str()));
}
EXPECT_TRUE(tr_sys_path_exists(dat_file.c_str()));
}
TEST_F(DhtTest, doesNotSaveStateIfSwarmIsBad)
{
auto const state_file = MockStateFile{};
auto const dat_file = MockStateFile::filename(sandboxDir());
EXPECT_FALSE(tr_sys_path_exists(dat_file.c_str()));
{
auto mediator = MockMediator{ event_base_ };
mediator.config_dir_ = sandboxDir();
mediator.mock_dht_.setPoorSwarm();
auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6);
// as dht goes out of scope,
// it should save its state if the swarm is healthy
EXPECT_FALSE(tr_sys_path_exists(dat_file.c_str()));
}
EXPECT_FALSE(tr_sys_path_exists(dat_file.c_str()));
}
TEST_F(DhtTest, usesBootstrapFile)
{
// Make the 'dht.bootstrap' file.
// This a file with each line holding `${host} ${port}`
// which tr-dht will try to ping as nodes
static auto constexpr BootstrapNodeName = "example.com"sv;
static auto constexpr BootstrapNodePort = tr_port::fromHost(8080);
if (auto ofs = std::ofstream{ tr_pathbuf{ sandboxDir(), "/dht.bootstrap" } }; ofs)
{
ofs << BootstrapNodeName << ' ' << BootstrapNodePort.host() << std::endl;
ofs.close();
}
// Make the DHT
auto mediator = MockMediator{ event_base_ };
mediator.config_dir_ = sandboxDir();
auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6);
// We didn't create a 'dht.dat' file to load state from,
// so 'dht.bootstrap' should be the first nodes in the bootstrap list.
// Confirm that BootstrapNodeName gets pinged first.
auto const expected = getSockaddr(BootstrapNodeName, BootstrapNodePort);
auto& pinged = mediator.mock_dht_.pinged_;
waitFor(
event_base_,
[&pinged]() { return !std::empty(pinged); },
5s);
ASSERT_EQ(1U, std::size(pinged));
auto const actual = pinged.front();
EXPECT_EQ(expected.first, actual.address);
EXPECT_EQ(expected.second, actual.port);
EXPECT_EQ(expected.first.readable(expected.second), actual.address.readable(actual.port));
}
TEST_F(DhtTest, pingsAddedNodes)
{
auto mediator = MockMediator{ event_base_ };
mediator.config_dir_ = sandboxDir();
auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6);
EXPECT_EQ(0U, std::size(mediator.mock_dht_.pinged_));
auto const addr = *tr_address::fromString("10.10.10.1");
auto constexpr Port = tr_port::fromHost(128);
dht->addNode(addr, Port);
ASSERT_EQ(1U, std::size(mediator.mock_dht_.pinged_));
EXPECT_EQ(addr, mediator.mock_dht_.pinged_.front().address);
EXPECT_EQ(Port, mediator.mock_dht_.pinged_.front().port);
}
TEST_F(DhtTest, announcesTorrents)
{
auto constexpr Id = tr_torrent_id_t{ 1 };
auto constexpr PeerPort = tr_port::fromHost(999);
auto const info_hash = tr_randObj<tr_sha1_digest_t>();
tr_timeUpdate(time(nullptr));
auto mediator = MockMediator{ event_base_ };
mediator.info_hashes_[Id] = info_hash;
mediator.torrents_allowing_dht_ = { Id };
mediator.config_dir_ = sandboxDir();
// Since we're mocking a swarm that's magically healthy out-of-the-box,
// the DHT object we create can skip bootstrapping and proceed straight
// to announces
auto& mock_dht = mediator.mock_dht_;
mock_dht.setHealthySwarm();
auto dht = tr_dht::create(mediator, PeerPort, ArbitrarySock4, ArbitrarySock6);
waitFor(event_base_, MockTimerInterval * 10);
ASSERT_EQ(2U, std::size(mock_dht.searched_));
EXPECT_EQ(info_hash, mock_dht.searched_[0].info_hash);
EXPECT_EQ(PeerPort, mock_dht.searched_[0].port);
EXPECT_EQ(AF_INET, mock_dht.searched_[0].af);
EXPECT_EQ(info_hash, mock_dht.searched_[1].info_hash);
EXPECT_EQ(PeerPort, mock_dht.searched_[1].port);
EXPECT_EQ(AF_INET6, mock_dht.searched_[1].af);
}
TEST_F(DhtTest, callsPeriodicPeriodically)
{
auto mediator = MockMediator{ event_base_ };
mediator.config_dir_ = sandboxDir();
auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6);
auto& mock_dht = mediator.mock_dht_;
auto const baseline = mock_dht.n_periodic_calls_;
static auto constexpr Periods = 10;
waitFor(event_base_, std::chrono::duration_cast<std::chrono::milliseconds>(MockTimerInterval * Periods));
EXPECT_NEAR(mock_dht.n_periodic_calls_, baseline + Periods, Periods / 2);
}
} // namespace libtransmission::test

View File

@ -43,6 +43,14 @@ namespace libtransmission
namespace test
{
template<typename T>
[[nodiscard]] static auto tr_randObj()
{
auto ret = T{};
tr_rand_buffer(&ret, sizeof(ret));
return ret;
}
using file_func_t = std::function<void(char const* filename)>;
static void depthFirstWalk(char const* path, file_func_t func)

View File

@ -37,7 +37,7 @@ protected:
return std::chrono::duration_cast<std::chrono::milliseconds>(val);
};
void sleep_msec(std::chrono::milliseconds msec)
void sleepMsec(std::chrono::milliseconds msec)
{
EXPECT_FALSE(waitFor(
evbase_.get(),
@ -45,7 +45,7 @@ protected:
msec));
}
static void EXPECT_TIME(
static void expectTime(
std::chrono::milliseconds expected,
std::chrono::milliseconds actual,
std::chrono::milliseconds allowed_deviation)
@ -59,12 +59,12 @@ protected:
// This checks that `actual` is in the bounds of [expected/2 ... expected*1.5]
// to confirm that the timer didn't kick too close to the previous or next interval.
static void EXPECT_INTERVAL(std::chrono::milliseconds expected, std::chrono::milliseconds actual)
static void expectInterval(std::chrono::milliseconds expected, std::chrono::milliseconds actual)
{
EXPECT_TIME(expected, actual, expected / 2);
expectTime(expected, actual, expected / 2);
}
[[nodiscard]] static auto current_time()
[[nodiscard]] static auto currentTime()
{
return std::chrono::steady_clock::now();
}
@ -133,17 +133,17 @@ TEST_F(TimerTest, singleShotHonorsInterval)
timer->setCallback(callback);
// run a single-shot timer
auto const begin_time = current_time();
auto const begin_time = currentTime();
static auto constexpr Interval = 100ms;
timer->startSingleShot(Interval);
EXPECT_FALSE(timer->isRepeating());
EXPECT_EQ(Interval, timer->interval());
waitFor(evbase_.get(), [&called] { return called; });
auto const end_time = current_time();
auto const end_time = currentTime();
// confirm that it kicked at the right interval
EXPECT_TRUE(called);
EXPECT_INTERVAL(Interval, AsMSec(end_time - begin_time));
expectInterval(Interval, AsMSec(end_time - begin_time));
}
TEST_F(TimerTest, repeatingHonorsInterval)
@ -160,17 +160,17 @@ TEST_F(TimerTest, repeatingHonorsInterval)
timer->setCallback(callback);
// start a repeating timer
auto const begin_time = current_time();
auto const begin_time = currentTime();
static auto constexpr Interval = 100ms;
static auto constexpr DesiredLoops = 3;
timer->startRepeating(Interval);
EXPECT_TRUE(timer->isRepeating());
EXPECT_EQ(Interval, timer->interval());
waitFor(evbase_.get(), [&n_calls] { return n_calls >= DesiredLoops; });
auto const end_time = current_time();
auto const end_time = currentTime();
// confirm that it kicked the right number of times
EXPECT_INTERVAL(Interval * DesiredLoops, AsMSec(end_time - begin_time));
expectInterval(Interval * DesiredLoops, AsMSec(end_time - begin_time));
EXPECT_EQ(DesiredLoops, n_calls);
}
@ -190,12 +190,12 @@ TEST_F(TimerTest, restartWithDifferentInterval)
auto const test = [this, &n_calls, &timer](auto interval)
{
auto const next = n_calls + 1;
auto const begin_time = current_time();
auto const begin_time = currentTime();
timer->startSingleShot(interval);
waitFor(evbase_.get(), [&n_calls, next]() { return n_calls >= next; });
auto const end_time = current_time();
auto const end_time = currentTime();
EXPECT_INTERVAL(interval, AsMSec(end_time - begin_time));
expectInterval(interval, AsMSec(end_time - begin_time));
};
test(100ms);
@ -219,12 +219,12 @@ TEST_F(TimerTest, restartWithSameInterval)
auto const test = [this, &n_calls, &timer](auto interval)
{
auto const next = n_calls + 1;
auto const begin_time = current_time();
auto const begin_time = currentTime();
timer->startSingleShot(interval);
waitFor(evbase_.get(), [&n_calls, next]() { return n_calls >= next; });
auto const end_time = current_time();
auto const end_time = currentTime();
EXPECT_INTERVAL(interval, AsMSec(end_time - begin_time));
expectInterval(interval, AsMSec(end_time - begin_time));
};
test(timer->interval());
@ -246,31 +246,31 @@ TEST_F(TimerTest, repeatingThenSingleShot)
timer->setCallback(callback);
// start a repeating timer and confirm that it's running
auto begin_time = current_time();
auto begin_time = currentTime();
static auto constexpr RepeatingInterval = 100ms;
static auto constexpr DesiredLoops = 2;
timer->startRepeating(RepeatingInterval);
EXPECT_EQ(RepeatingInterval, timer->interval());
EXPECT_TRUE(timer->isRepeating());
waitFor(evbase_.get(), [&n_calls]() { return n_calls >= DesiredLoops; });
auto end_time = current_time();
EXPECT_TIME(RepeatingInterval * DesiredLoops, AsMSec(end_time - begin_time), RepeatingInterval / 2);
auto end_time = currentTime();
expectTime(RepeatingInterval * DesiredLoops, AsMSec(end_time - begin_time), RepeatingInterval / 2);
// now restart it as a single shot
auto const baseline = n_calls;
begin_time = current_time();
begin_time = currentTime();
static auto constexpr SingleShotInterval = 25ms;
timer->startSingleShot(SingleShotInterval);
EXPECT_EQ(SingleShotInterval, timer->interval());
EXPECT_FALSE(timer->isRepeating());
waitFor(evbase_.get(), [&n_calls]() { return n_calls >= DesiredLoops + 1; });
end_time = current_time();
end_time = currentTime();
// confirm that the single shot interval was honored
EXPECT_INTERVAL(SingleShotInterval, AsMSec(end_time - begin_time));
expectInterval(SingleShotInterval, AsMSec(end_time - begin_time));
// confirm that the timer only kicks once, since it was converted into single-shot
sleep_msec(SingleShotInterval * 3);
sleepMsec(SingleShotInterval * 3);
EXPECT_EQ(baseline + 1, n_calls);
}
@ -294,13 +294,13 @@ TEST_F(TimerTest, singleShotStop)
EXPECT_FALSE(timer->isRepeating());
// wait half the interval, then stop the timer
sleep_msec(Interval / 2);
sleepMsec(Interval / 2);
EXPECT_EQ(0U, n_calls);
timer->stop();
// wait until the timer has gone past.
// since we stopped it, callback should not have been called.
sleep_msec(Interval);
sleepMsec(Interval);
EXPECT_EQ(0U, n_calls);
}
@ -324,13 +324,13 @@ TEST_F(TimerTest, repeatingStop)
EXPECT_TRUE(timer->isRepeating());
// wait half the interval, then stop the timer
sleep_msec(Interval / 2);
sleepMsec(Interval / 2);
EXPECT_EQ(0U, n_calls);
timer->stop();
// wait until the timer has gone past.
// since we stopped it, callback should not have been called.
sleep_msec(Interval);
sleepMsec(Interval);
EXPECT_EQ(0U, n_calls);
}
@ -354,13 +354,13 @@ TEST_F(TimerTest, destroyedTimersStop)
EXPECT_TRUE(timer->isRepeating());
// wait half the interval, then destroy the timer
sleep_msec(Interval / 2);
sleepMsec(Interval / 2);
EXPECT_EQ(0U, n_calls);
timer.reset();
// wait until the timer has gone past.
// since we destroyed it, callback should not have been called.
sleep_msec(Interval);
sleepMsec(Interval);
EXPECT_EQ(0U, n_calls);
}