refactor: blocklists (#6189)

This commit is contained in:
Charles Kerr 2023-10-31 19:20:01 -04:00 committed by GitHub
parent 1c18737e67
commit e54b17d92e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 233 additions and 174 deletions

View File

@ -22,6 +22,8 @@
#include <fmt/core.h>
#include "libtransmission/transmission.h"
#include "libtransmission/blocklist.h"
#include "libtransmission/error.h"
#include "libtransmission/file.h"
@ -314,7 +316,7 @@ auto getFilenamesInDir(std::string_view folder)
} // namespace
void Blocklist::ensureLoaded() const
void Blocklists::Blocklist::ensureLoaded() const
{
if (!std::empty(rules_))
{
@ -398,42 +400,7 @@ void Blocklist::ensureLoaded() const
fmt::arg("count", std::size(rules_))));
}
std::vector<Blocklist> Blocklist::loadBlocklists(std::string_view const blocklist_dir, bool const is_enabled)
{
// check for files that need to be updated
for (auto const& src_file : getFilenamesInDir(blocklist_dir))
{
if (tr_strv_ends_with(src_file, BinFileSuffix))
{
continue;
}
// ensure this src_file has an up-to-date corresponding bin_file
auto const src_info = tr_sys_path_get_info(src_file);
auto const bin_file = tr_pathbuf{ src_file, BinFileSuffix };
auto const bin_info = tr_sys_path_get_info(bin_file);
auto const bin_needs_update = src_info && (!bin_info || bin_info->last_modified_at <= src_info->last_modified_at);
if (bin_needs_update)
{
if (auto const ranges = parseFile(src_file); !std::empty(ranges))
{
save(bin_file, std::data(ranges), std::size(ranges));
}
}
}
auto ret = std::vector<Blocklist>{};
for (auto const& bin_file : getFilenamesInDir(blocklist_dir))
{
if (tr_strv_ends_with(bin_file, BinFileSuffix))
{
ret.emplace_back(bin_file, is_enabled);
}
}
return ret;
}
bool Blocklist::contains(tr_address const& addr) const
bool Blocklists::Blocklist::contains(tr_address const& addr) const
{
TR_ASSERT(addr.is_valid());
@ -478,7 +445,10 @@ bool Blocklist::contains(tr_address const& addr) const
return std::binary_search(std::begin(rules_), std::end(rules_), addr, Compare{});
}
std::optional<Blocklist> Blocklist::saveNew(std::string_view external_file, std::string_view bin_file, bool is_enabled)
std::optional<Blocklists::Blocklist> Blocklists::Blocklist::saveNew(
std::string_view external_file,
std::string_view bin_file,
bool is_enabled)
{
// if we can't parse the file, do nothing
auto rules = parseFile(external_file);
@ -514,4 +484,94 @@ std::optional<Blocklist> Blocklist::saveNew(std::string_view external_file, std:
return ret;
}
// ---
void Blocklists::set_enabled(bool is_enabled)
{
for (auto& blocklist : blocklists_)
{
blocklist.setEnabled(is_enabled);
}
changed_.emit();
}
void Blocklists::load(std::string_view folder, bool is_enabled)
{
folder_ = folder;
blocklists_ = load_folder(folder, is_enabled);
changed_.emit();
}
// static
std::vector<Blocklists::Blocklist> Blocklists::Blocklists::load_folder(std::string_view const folder, bool const is_enabled)
{
// check for files that need to be updated
for (auto const& src_file : getFilenamesInDir(folder))
{
if (tr_strv_ends_with(src_file, BinFileSuffix))
{
continue;
}
// ensure this src_file has an up-to-date corresponding bin_file
auto const src_info = tr_sys_path_get_info(src_file);
auto const bin_file = tr_pathbuf{ src_file, BinFileSuffix };
auto const bin_info = tr_sys_path_get_info(bin_file);
auto const bin_needs_update = src_info && (!bin_info || bin_info->last_modified_at <= src_info->last_modified_at);
if (bin_needs_update)
{
if (auto const ranges = parseFile(src_file); !std::empty(ranges))
{
save(bin_file, std::data(ranges), std::size(ranges));
}
}
}
auto ret = std::vector<Blocklist>{};
for (auto const& bin_file : getFilenamesInDir(folder))
{
if (tr_strv_ends_with(bin_file, BinFileSuffix))
{
ret.emplace_back(bin_file, is_enabled);
}
}
return ret;
}
size_t Blocklists::update_primary_blocklist(std::string_view external_file, bool is_enabled)
{
// These rules will replace the default blocklist.
// Build the path of the default blocklist .bin file where we'll save these rules.
auto const bin_file = tr_pathbuf{ folder_, '/', DEFAULT_BLOCKLIST_FILENAME };
// Try to save it
auto added = Blocklist::saveNew(external_file, bin_file, is_enabled);
if (!added)
{
return 0U;
}
auto const n_rules = std::size(*added);
// Add (or replace) it in our blocklists_ vector
if (auto iter = std::find_if(
std::begin(blocklists_),
std::end(blocklists_),
[&bin_file](auto const& candidate) { return bin_file == candidate.binFile(); });
iter != std::end(blocklists_))
{
*iter = std::move(*added);
}
else
{
blocklists_.emplace_back(std::move(*added));
}
changed_.emit();
return n_rules;
}
} // namespace libtransmission

View File

@ -9,63 +9,110 @@
#error only libtransmission should #include this header.
#endif
#include <numeric>
#include <optional>
#include <string>
#include <string_view>
#include <utility> // for std::pair
#include <vector>
#include "net.h" // for tr_address
#include "libtransmission/tr-macros.h" // for TR_CONSTEXPR20
#include "libtransmission/net.h" // for tr_address
#include "libtransmission/observable.h"
namespace libtransmission
{
class Blocklist
class Blocklists
{
public:
[[nodiscard]] static std::vector<Blocklist> loadBlocklists(std::string_view const blocklist_dir, bool const is_enabled);
Blocklists() = default;
static std::optional<Blocklist> saveNew(std::string_view external_file, std::string_view bin_file, bool is_enabled);
Blocklist() = default;
Blocklist(std::string_view bin_file, bool is_enabled)
: bin_file_{ bin_file }
, is_enabled_{ is_enabled }
[[nodiscard]] bool contains(tr_address const& addr) const noexcept
{
return std::any_of(
std::begin(blocklists_),
std::end(blocklists_),
[&addr](auto const& blocklist) { return blocklist.enabled() && blocklist.contains(addr); });
}
[[nodiscard]] bool contains(tr_address const& addr) const;
[[nodiscard]] auto size() const
[[nodiscard]] TR_CONSTEXPR20 auto num_lists() const noexcept
{
ensureLoaded();
return std::size(rules_);
return std::size(blocklists_);
}
[[nodiscard]] constexpr bool enabled() const noexcept
[[nodiscard]] TR_CONSTEXPR20 auto num_rules() const noexcept
{
return is_enabled_;
return std::accumulate(
std::begin(blocklists_),
std::end(blocklists_),
size_t{},
[](int sum, auto& cur) { return sum + std::size(cur); });
}
constexpr void setEnabled(bool is_enabled) noexcept
{
is_enabled_ = is_enabled;
}
void load(std::string_view folder, bool is_enabled);
void set_enabled(bool is_enabled);
size_t update_primary_blocklist(std::string_view external_file, bool is_enabled);
[[nodiscard]] constexpr auto const& binFile() const noexcept
template<typename Observer>
[[nodiscard]] auto observe_changes(Observer observer)
{
return bin_file_;
return changed_.observe(std::move(observer));
}
private:
void ensureLoaded() const;
class Blocklist
{
public:
static std::optional<Blocklist> saveNew(std::string_view external_file, std::string_view bin_file, bool is_enabled);
mutable std::vector<std::pair<tr_address, tr_address>> rules_;
Blocklist() = default;
std::string bin_file_;
bool is_enabled_ = false;
Blocklist(std::string_view bin_file, bool is_enabled)
: bin_file_{ bin_file }
, is_enabled_{ is_enabled }
{
}
[[nodiscard]] bool contains(tr_address const& addr) const;
[[nodiscard]] auto size() const
{
ensureLoaded();
return std::size(rules_);
}
[[nodiscard]] constexpr bool enabled() const noexcept
{
return is_enabled_;
}
constexpr void setEnabled(bool is_enabled) noexcept
{
is_enabled_ = is_enabled;
}
[[nodiscard]] constexpr auto const& binFile() const noexcept
{
return bin_file_;
}
private:
void ensureLoaded() const;
mutable std::vector<std::pair<tr_address, tr_address>> rules_;
std::string bin_file_;
bool is_enabled_ = false;
};
std::vector<Blocklist> blocklists_;
std::string folder_;
libtransmission::SimpleObservable<> changed_;
[[nodiscard]] static std::vector<Blocklist> load_folder(std::string_view folder, bool is_enabled);
};
} // namespace libtransmission

View File

@ -134,18 +134,6 @@ using Handshakes = std::unordered_map<tr_socket_address, tr_handshake>;
} // anonymous namespace
bool tr_peer_info::is_blocklisted(tr_session const* session) const
{
if (blocklisted_)
{
return *blocklisted_;
}
auto const value = session->addressIsBlocked(listen_address());
blocklisted_ = value;
return value;
}
void tr_peer_info::merge(tr_peer_info& that) noexcept
{
TR_ASSERT(is_connectable_.value_or(true) || !is_connected());
@ -935,14 +923,19 @@ public:
using OutboundCandidates = small::
max_size_vector<std::pair<tr_torrent_id_t, tr_socket_address>, OutboundCandidateListCapacity>;
explicit tr_peerMgr(tr_session* session_in, libtransmission::TimerMaker& timer_maker, tr_torrents& torrents)
explicit tr_peerMgr(
tr_session* session_in,
libtransmission::TimerMaker& timer_maker,
tr_torrents& torrents,
libtransmission::Blocklists& blocklist)
: session{ session_in }
, torrents_{ torrents }
, blocklists_{ blocklist }
, handshake_mediator_{ *session, timer_maker, torrents }
, bandwidth_timer_{ timer_maker.create([this]() { bandwidth_pulse(); }) }
, rechoke_timer_{ timer_maker.create([this]() { rechoke_pulse_marshall(); }) }
, refill_upkeep_timer_{ timer_maker.create([this]() { refill_upkeep(); }) }
, blocklist_tag_{ session->blocklist_changed_.observe([this]() { on_blocklist_changed(); }) }
, blocklists_tag_{ blocklist.observe_changes([this]() { on_blocklists_changed(); }) }
{
bandwidth_timer_->start_repeating(BandwidthTimerPeriod);
rechoke_timer_->start_repeating(RechokePeriod);
@ -978,6 +971,7 @@ public:
tr_session* const session;
tr_torrents& torrents_;
libtransmission::Blocklists const& blocklists_;
Handshakes incoming_handshakes;
HandshakeMediator handshake_mediator_;
@ -995,7 +989,7 @@ private:
rechoke_timer_->set_interval(RechokePeriod);
}
void on_blocklist_changed() const
void on_blocklists_changed() const
{
/* we cache whether or not a peer is blocklisted...
since the blocklist has changed, erase that cached value */
@ -1017,7 +1011,7 @@ private:
std::unique_ptr<libtransmission::Timer> const rechoke_timer_;
std::unique_ptr<libtransmission::Timer> const refill_upkeep_timer_;
libtransmission::ObserverTag const blocklist_tag_;
libtransmission::ObserverTag const blocklists_tag_;
};
// --- tr_peer virtual functions
@ -1041,7 +1035,7 @@ tr_peer::~tr_peer()
tr_peerMgr* tr_peerMgrNew(tr_session* session)
{
return new tr_peerMgr{ session, session->timerMaker(), session->torrents() };
return new tr_peerMgr{ session, session->timerMaker(), session->torrents(), session->blocklist() };
}
void tr_peerMgrFree(tr_peerMgr* manager)
@ -1297,12 +1291,9 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_peer_socket&& socket)
{
using namespace handshake_helpers;
TR_ASSERT(manager->session != nullptr);
auto const lock = manager->unique_lock();
auto* const session = manager->session;
if (session->addressIsBlocked(socket.address()))
if (manager->blocklists_.contains(socket.address()))
{
tr_logAddTrace(fmt::format("Banned IP address '{}' tried to connect to us", socket.display_name()));
socket.close();
@ -1311,9 +1302,10 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_peer_socket&& socket)
{
socket.close();
}
else /* we don't have a connection to them yet... */
else // we don't have a connection to them yet...
{
auto socket_address = socket.socket_address();
auto const socket_address = socket.socket_address();
auto* const session = manager->session;
manager->incoming_handshakes.try_emplace(
socket_address,
&manager->handshake_mediator_,
@ -1332,7 +1324,7 @@ size_t tr_peerMgrAddPex(tr_torrent* tor, tr_peer_from from, tr_pex const* pex, s
for (tr_pex const* const end = pex + n_pex; pex != end; ++pex)
{
if (tr_isPex(pex) && /* safeguard against corrupt data */
!s->manager->session->addressIsBlocked(pex->socket_address.address()) && pex->is_valid_for_peers() &&
!s->manager->blocklists_.contains(pex->socket_address.address()) && pex->is_valid_for_peers() &&
from != TR_PEER_FROM_INCOMING && (from != TR_PEER_FROM_PEX || (pex->flags & ADDED_F_CONNECTABLE) != 0))
{
// we store this peer since it is supposedly connectable (socket address should be the peer's listening address)
@ -1410,7 +1402,7 @@ namespace get_peers_helpers
return true;
}
if (info.is_blocklisted(tor->session))
if (info.is_blocklisted(tor->session->blocklist()))
{
return false;
}
@ -2188,8 +2180,7 @@ void tr_peerMgr::reconnect_pulse()
// remove crappy peers
auto bad_peers_buf = bad_peers_t{};
auto& torrents = torrents_;
for (auto* const tor : torrents)
for (auto* const tor : torrents_)
{
auto* const swarm = tor->swarm;
@ -2204,7 +2195,7 @@ void tr_peerMgr::reconnect_pulse()
}
// if we're over the per-torrent peer limits, cull some peers
for (auto* const tor : torrents)
for (auto* const tor : torrents_)
{
if (tor->is_running())
{
@ -2213,7 +2204,7 @@ void tr_peerMgr::reconnect_pulse()
}
// if we're over the per-session peer limits, cull some peers
enforceSessionPeerLimit(session->peerLimit(), torrents);
enforceSessionPeerLimit(session->peerLimit(), torrents_);
// try to make new peer connections
make_new_peer_connections();
@ -2294,7 +2285,7 @@ namespace connect_helpers
}
// not if they're blocklisted
if (peer_info.is_blocklisted(tor->session))
if (peer_info.is_blocklisted(tor->session->blocklist()))
{
return false;
}

View File

@ -19,6 +19,7 @@
#include "libtransmission/transmission.h" // tr_block_span_t (ptr only)
#include "libtransmission/blocklist.h"
#include "libtransmission/handshake.h"
#include "libtransmission/net.h" /* tr_address */
#include "libtransmission/tr-assert.h"
@ -243,7 +244,15 @@ public:
// ---
[[nodiscard]] bool is_blocklisted(tr_session const* session) const;
[[nodiscard]] bool is_blocklisted(libtransmission::Blocklists const& blocklist) const
{
if (!blocklisted_.has_value())
{
blocklisted_ = blocklist.contains(listen_address());
}
return *blocklisted_;
}
void set_blocklisted_dirty()
{

View File

@ -1715,7 +1715,7 @@ char const* sessionSet(tr_session* session, tr_variant* args_in, tr_variant* /*a
if (auto val = bool{}; tr_variantDictFindBool(args_in, TR_KEY_blocklist_enabled, &val))
{
session->useBlocklist(val);
session->set_blocklist_enabled(val);
}
if (tr_variantDictFindStrView(args_in, TR_KEY_blocklist_url, &sv))
@ -1999,7 +1999,7 @@ void addSessionField(tr_session const* s, tr_variant* d, tr_quark key)
break;
case TR_KEY_blocklist_enabled:
tr_variantDictAddBool(d, key, s->useBlocklist());
tr_variantDictAddBool(d, key, s->blocklist_enabled());
break;
case TR_KEY_blocklist_url:

View File

@ -13,7 +13,6 @@
#include <iterator> // for std::back_inserter
#include <limits> // std::numeric_limits
#include <memory>
#include <numeric> // for std::accumulate()
#include <string>
#include <string_view>
#include <utility>
@ -727,7 +726,7 @@ void tr_session::initImpl(init_data& data)
tr_logSetQueueEnabled(data.message_queuing_enabled);
this->blocklists_ = libtransmission::Blocklist::loadBlocklists(blocklist_dir_, useBlocklist());
blocklists_.load(blocklist_dir_, blocklist_enabled());
tr_logAddInfo(fmt::format(_("Transmission version {version} starting"), fmt::arg("version", LONG_VERSION_STRING)));
@ -794,7 +793,7 @@ void tr_session::setSettings(tr_session_settings&& settings_in, bool force)
bool const utp_changed = new_settings.utp_enabled != old_settings.utp_enabled;
useBlocklist(new_settings.blocklist_enabled);
set_blocklist_enabled(new_settings.blocklist_enabled);
auto local_peer_port = force && settings_.peer_port_random_on_start ? randomPort() : new_settings.peer_port;
bool port_changed = false;
@ -1678,93 +1677,43 @@ bool tr_sessionIsPortForwardingEnabled(tr_session const* session)
// ---
void tr_session::useBlocklist(bool enabled)
{
settings_.blocklist_enabled = enabled;
std::for_each(
std::begin(blocklists_),
std::end(blocklists_),
[enabled](auto& blocklist) { blocklist.setEnabled(enabled); });
}
bool tr_session::addressIsBlocked(tr_address const& addr) const noexcept
{
return std::any_of(
std::begin(blocklists_),
std::end(blocklists_),
[&addr](auto& blocklist) { return blocklist.contains(addr); });
}
void tr_sessionReloadBlocklists(tr_session* session)
{
session->blocklists_ = libtransmission::Blocklist::loadBlocklists(session->blocklist_dir_, session->useBlocklist());
session->blocklist_changed_.emit();
session->blocklists_.load(session->blocklist_dir_, session->blocklist_enabled());
}
size_t tr_blocklistGetRuleCount(tr_session const* session)
{
TR_ASSERT(session != nullptr);
auto& src = session->blocklists_;
return std::accumulate(std::begin(src), std::end(src), 0, [](int sum, auto& cur) { return sum + std::size(cur); });
return session->blocklists_.num_rules();
}
bool tr_blocklistIsEnabled(tr_session const* session)
{
TR_ASSERT(session != nullptr);
return session->useBlocklist();
return session->blocklist_enabled();
}
void tr_blocklistSetEnabled(tr_session* session, bool enabled)
{
TR_ASSERT(session != nullptr);
session->useBlocklist(enabled);
session->set_blocklist_enabled(enabled);
}
bool tr_blocklistExists(tr_session const* session)
{
TR_ASSERT(session != nullptr);
return !std::empty(session->blocklists_);
return session->blocklists_.num_lists() > 0U;
}
size_t tr_blocklistSetContent(tr_session* session, char const* content_filename)
{
auto const lock = session->unique_lock();
// These rules will replace the default blocklist.
// Build the path of the default blocklist .bin file where we'll save these rules.
auto const bin_file = tr_pathbuf{ session->blocklist_dir_, '/', DEFAULT_BLOCKLIST_FILENAME };
// Try to save it
auto added = libtransmission::Blocklist::saveNew(content_filename, bin_file, session->useBlocklist());
if (!added)
{
return 0U;
}
auto const n_rules = std::size(*added);
// Add (or replace) it in our blocklists_ vector
auto& src = session->blocklists_;
if (auto iter = std::find_if(
std::begin(src),
std::end(src),
[&bin_file](auto const& candidate) { return bin_file == candidate.binFile(); });
iter != std::end(src))
{
*iter = std::move(*added);
}
else
{
src.emplace_back(std::move(*added));
}
return n_rules;
return session->blocklists_.update_primary_blocklist(content_filename, session->blocklist_enabled());
}
void tr_blocklistSetURL(tr_session* session, char const* url)

View File

@ -46,7 +46,6 @@
#include "libtransmission/global-ip-cache.h"
#include "libtransmission/interned-string.h"
#include "libtransmission/net.h" // tr_socket_t
#include "libtransmission/observable.h"
#include "libtransmission/open-files.h"
#include "libtransmission/port-forwarding.h"
#include "libtransmission/quark.h"
@ -488,13 +487,22 @@ public:
// blocklist
[[nodiscard]] constexpr auto useBlocklist() const noexcept
[[nodiscard]] constexpr auto& blocklist() noexcept
{
return blocklists_;
}
void set_blocklist_enabled(bool is_enabled)
{
settings_.blocklist_enabled = is_enabled;
blocklist().set_enabled(is_enabled);
}
[[nodiscard]] auto blocklist_enabled() const noexcept
{
return settings_.blocklist_enabled;
}
void useBlocklist(bool enabled);
[[nodiscard]] constexpr auto const& blocklistUrl() const noexcept
{
return settings_.blocklist_url;
@ -809,8 +817,6 @@ public:
[[nodiscard]] size_t count_queue_free_slots(tr_direction dir) const noexcept;
[[nodiscard]] bool addressIsBlocked(tr_address const& addr) const noexcept;
[[nodiscard]] bool has_ip_protocol(tr_address_type type) const noexcept
{
TR_ASSERT(tr_address::is_valid(type));
@ -1100,10 +1106,7 @@ private:
tr_open_files open_files_;
std::vector<libtransmission::Blocklist> blocklists_;
public:
libtransmission::SimpleObservable<> blocklist_changed_;
libtransmission::Blocklists blocklists_;
private:
/// other fields
@ -1155,7 +1158,7 @@ public:
std::unique_ptr<Cache> cache = std::make_unique<Cache>(torrents_, 1024 * 1024 * 2);
private:
// depends-on: timer_maker_, top_bandwidth_, utp_context, torrents_, web_, blocklist_changed_
// depends-on: timer_maker_, blocklists_, top_bandwidth_, utp_context, torrents_, web_
std::unique_ptr<struct tr_peerMgr, void (*)(struct tr_peerMgr*)> peer_mgr_;
// depends-on: peer_mgr_, advertised_peer_port_, torrents_

View File

@ -41,7 +41,7 @@ protected:
bool addressIsBlocked(char const* address_str)
{
auto const addr = tr_address::from_string(address_str);
return !addr || session_->addressIsBlocked(*addr);
return !addr || session_->blocklist().contains(*addr);
}
};

View File

@ -168,12 +168,12 @@ TEST_F(SessionTest, propertiesApi)
for (auto const value : { true, false })
{
session->useBlocklist(value);
EXPECT_EQ(value, session->useBlocklist());
session->set_blocklist_enabled(value);
EXPECT_EQ(value, session->blocklist_enabled());
EXPECT_EQ(value, tr_blocklistIsEnabled(session));
tr_sessionSetIncompleteDirEnabled(session, value);
EXPECT_EQ(value, session->useBlocklist());
EXPECT_EQ(value, session->blocklist_enabled());
EXPECT_EQ(value, tr_blocklistIsEnabled(session));
}
}