refactor: make tr_session_id a class (#3598)

This commit is contained in:
Charles Kerr 2022-08-06 14:27:37 -05:00 committed by GitHub
parent 1f5c650f56
commit 31a733fab7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 224 additions and 240 deletions

View File

@ -111,7 +111,7 @@ static bool constexpr tr_rpc_address_is_valid(tr_rpc_address const& a)
static char const* get_current_session_id(tr_rpc_server* server)
{
return tr_session_id_get_current(server->session->session_id);
return server->session->session_id.c_str();
}
/**

View File

@ -2346,7 +2346,7 @@ static void addSessionField(tr_session const* s, tr_variant* d, tr_quark key)
break;
case TR_KEY_session_id:
tr_variantDictAddStr(d, key, tr_session_id_get_current(s->session_id));
tr_variantDictAddStr(d, key, s->session_id.sv());
break;
}
}

View File

@ -5,6 +5,7 @@
#include <ctime>
#include <string_view>
#include <iterator> // for std::back_inserter
#ifndef _WIN32
#include <sys/stat.h>
@ -13,81 +14,54 @@
#include <fmt/format.h>
#include "transmission.h"
#include "crypto-utils.h"
#include "error.h"
#include "session-id.h"
#include "crypto-utils.h" // for tr_rand_buf()
#include "error-types.h"
#include "error.h"
#include "file.h"
#include "log.h"
#include "platform.h"
#include "session-id.h"
#include "utils.h"
#include "tr-strbuf.h" // for tr_pathbuf
#include "utils.h" // for _()
using namespace std::literals;
static auto constexpr SessionIdSize = size_t{ 48 };
static auto constexpr SessionIdDurationSec = time_t{ 60 * 60 }; /* expire in an hour */
struct tr_session_id
namespace
{
char* current_value;
char* previous_value;
tr_sys_file_t current_lock_file;
tr_sys_file_t previous_lock_file;
time_t expires_at;
};
static char* generate_new_session_id_value()
void get_lockfile_path(std::string_view session_id, tr_pathbuf& path)
{
char const pool[] = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
size_t const pool_size = sizeof(pool) - 1;
auto* buf = tr_new(char, SessionIdSize + 1);
tr_rand_buffer(buf, SessionIdSize);
for (size_t i = 0; i < SessionIdSize; ++i)
{
buf[i] = pool[(unsigned char)buf[i] % pool_size];
}
buf[SessionIdSize] = '\0';
return buf;
fmt::format_to(std::back_inserter(path), FMT_STRING("{:s}/tr_session_id_{:s}"), tr_getSessionIdDir(), session_id);
}
static std::string get_session_id_lock_file_path(std::string_view session_id)
tr_sys_file_t create_lockfile(std::string_view session_id)
{
return fmt::format(FMT_STRING("{:s}/tr_session_id_{:s}"), tr_getSessionIdDir(), session_id);
}
static tr_sys_file_t create_session_id_lock_file(char const* session_id)
{
if (session_id == nullptr)
if (std::empty(session_id))
{
return TR_BAD_SYS_FILE;
}
auto const lock_file_path = get_session_id_lock_file_path(session_id);
tr_error* error = nullptr;
auto lock_file = tr_sys_file_open(
lock_file_path.c_str(),
TR_SYS_FILE_READ | TR_SYS_FILE_WRITE | TR_SYS_FILE_CREATE,
0600,
&error);
auto lockfile_path = tr_pathbuf{};
get_lockfile_path(session_id, lockfile_path);
if (lock_file != TR_BAD_SYS_FILE)
tr_error* error = nullptr;
auto lockfile_fd = tr_sys_file_open(lockfile_path, TR_SYS_FILE_READ | TR_SYS_FILE_WRITE | TR_SYS_FILE_CREATE, 0600, &error);
if (lockfile_fd != TR_BAD_SYS_FILE)
{
if (tr_sys_file_lock(lock_file, TR_SYS_FILE_LOCK_EX | TR_SYS_FILE_LOCK_NB, &error))
if (tr_sys_file_lock(lockfile_fd, TR_SYS_FILE_LOCK_EX | TR_SYS_FILE_LOCK_NB, &error))
{
#ifndef _WIN32
/* Allow any user to lock the file regardless of current umask */
fchmod(lock_file, 0644);
fchmod(lockfile_fd, 0644);
#endif
}
else
{
tr_sys_file_close(lock_file);
lock_file = TR_BAD_SYS_FILE;
tr_sys_file_close(lockfile_fd);
lockfile_fd = TR_BAD_SYS_FILE;
}
}
@ -95,119 +69,116 @@ static tr_sys_file_t create_session_id_lock_file(char const* session_id)
{
tr_logAddWarn(fmt::format(
_("Couldn't create '{path}': {error} ({error_code})"),
fmt::arg("path", lock_file_path),
fmt::arg("path", lockfile_path),
fmt::arg("error", error->message),
fmt::arg("error_code", error->code)));
tr_error_free(error);
}
return lock_file;
return lockfile_fd;
}
static void destroy_session_id_lock_file(tr_sys_file_t lock_file, char const* session_id)
void destroy_lockfile(tr_sys_file_t lockfile_fd, std::string_view session_id)
{
if (lock_file != TR_BAD_SYS_FILE)
if (lockfile_fd != TR_BAD_SYS_FILE)
{
tr_sys_file_close(lock_file);
tr_sys_file_close(lockfile_fd);
}
if (session_id != nullptr)
if (!std::empty(session_id))
{
auto const lock_file_path = get_session_id_lock_file_path(session_id);
tr_sys_path_remove(lock_file_path);
auto lockfile_path = tr_pathbuf{};
get_lockfile_path(session_id, lockfile_path);
tr_sys_path_remove(lockfile_path);
}
}
tr_session_id_t tr_session_id_new()
#ifndef _WIN32
auto constexpr WouldBlock = EWOULDBLOCK;
#else
auto constexpr WouldBlock = ERROR_LOCK_VIOLATION;
#endif
} // namespace
tr_session_id::session_id_t tr_session_id::make_session_id()
{
auto const session_id = tr_new0(struct tr_session_id, 1);
session_id->current_lock_file = TR_BAD_SYS_FILE;
session_id->previous_lock_file = TR_BAD_SYS_FILE;
auto session_id = session_id_t{};
tr_rand_buffer(std::data(session_id), std::size(session_id));
static auto constexpr Pool = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"sv;
for (auto& chr : session_id)
{
chr = Pool[static_cast<unsigned char>(chr) % std::size(Pool)];
}
session_id.back() = '\0';
return session_id;
}
void tr_session_id_free(tr_session_id_t session_id)
tr_session_id::~tr_session_id()
{
if (session_id == nullptr)
{
return;
}
destroy_session_id_lock_file(session_id->previous_lock_file, session_id->previous_value);
destroy_session_id_lock_file(session_id->current_lock_file, session_id->current_value);
tr_free(session_id->previous_value);
tr_free(session_id->current_value);
tr_free(session_id);
destroy_lockfile(current_lock_file_, std::data(current_value_));
destroy_lockfile(previous_lock_file_, std::data(previous_value_));
}
char const* tr_session_id_get_current(tr_session_id_t session_id)
bool tr_session_id::isLocal(std::string_view session_id) noexcept
{
time_t const now = tr_time();
if (session_id->current_value == nullptr || now >= session_id->expires_at)
if (std::empty(session_id))
{
destroy_session_id_lock_file(session_id->previous_lock_file, session_id->previous_value);
tr_free(session_id->previous_value);
session_id->previous_value = session_id->current_value;
session_id->current_value = generate_new_session_id_value();
session_id->previous_lock_file = session_id->current_lock_file;
session_id->current_lock_file = create_session_id_lock_file(session_id->current_value);
session_id->expires_at = now + SessionIdDurationSec;
return false;
}
return session_id->current_value;
}
bool tr_session_id_is_local(char const* session_id)
{
bool ret = false;
if (session_id != nullptr)
auto is_local = bool{ false };
auto lockfile_path = tr_pathbuf{};
get_lockfile_path(session_id, lockfile_path);
tr_error* error = nullptr;
if (auto lockfile_fd = tr_sys_file_open(lockfile_path, TR_SYS_FILE_READ, 0, &error); lockfile_fd == TR_BAD_SYS_FILE)
{
auto const lock_file_path = get_session_id_lock_file_path(session_id);
tr_error* error = nullptr;
auto lock_file = tr_sys_file_open(lock_file_path.c_str(), TR_SYS_FILE_READ, 0, &error);
if (lock_file == TR_BAD_SYS_FILE)
if (TR_ERROR_IS_ENOENT(error->code))
{
if (TR_ERROR_IS_ENOENT(error->code))
{
tr_error_clear(&error);
}
}
else
{
if (!tr_sys_file_lock(lock_file, TR_SYS_FILE_LOCK_SH | TR_SYS_FILE_LOCK_NB, &error) &&
#ifndef _WIN32
(error->code == EWOULDBLOCK))
#else
(error->code == ERROR_LOCK_VIOLATION))
#endif
{
ret = true;
tr_error_clear(&error);
}
tr_sys_file_close(lock_file);
}
if (error != nullptr)
{
tr_logAddWarn(fmt::format(
_("Couldn't open session lock file '{path}': {error} ({error_code})"),
fmt::arg("path", lock_file_path),
fmt::arg("error", error->message),
fmt::arg("error_code", error->code)));
tr_error_free(error);
tr_error_clear(&error);
}
}
else
{
if (!tr_sys_file_lock(lockfile_fd, TR_SYS_FILE_LOCK_SH | TR_SYS_FILE_LOCK_NB, &error) && (error->code == WouldBlock))
{
is_local = true;
tr_error_clear(&error);
}
return ret;
tr_sys_file_close(lockfile_fd);
}
if (error != nullptr)
{
tr_logAddWarn(fmt::format(
_("Couldn't open session lock file '{path}': {error} ({error_code})"),
fmt::arg("path", lockfile_path),
fmt::arg("error", error->message),
fmt::arg("error_code", error->code)));
tr_error_free(error);
}
return is_local;
}
std::string_view tr_session_id::sv() const noexcept
{
if (auto const now = get_current_time_(); now >= expires_at_)
{
destroy_lockfile(previous_lock_file_, std::data(previous_value_));
previous_value_ = current_value_;
previous_lock_file_ = current_lock_file_;
current_value_ = make_session_id();
current_lock_file_ = create_lockfile(std::data(current_value_));
expires_at_ = now + SessionIdDurationSec;
}
// -1 to strip the '\0'
return std::string_view{ std::data(current_value_), std::size(current_value_) - 1 };
}
char const* tr_session_id::c_str() const noexcept
{
return std::data(sv()); // current_value_ is zero-terminated
}

View File

@ -5,41 +5,55 @@
#pragma once
using tr_session_id_t = struct tr_session_id*;
#include <array>
#include <cstddef> // for size_t
#include <ctime> // for time_t
#include <string_view>
/**
* Create new session identifier object.
*
* @return New session identifier object.
*/
tr_session_id_t tr_session_id_new(void);
#include "file.h" // tr_sys_file_t
/**
* Free session identifier object.
*
* @param[in] session_id Session identifier object.
*/
void tr_session_id_free(tr_session_id_t session_id);
class tr_session_id
{
public:
using current_time_func_t = time_t (*)();
/**
* Get current session identifier as string.
*
* @param[in] session_id Session identifier object.
*
* @return String representation of current session identifier.
*/
char const* tr_session_id_get_current(tr_session_id_t session_id);
tr_session_id(current_time_func_t get_current_time)
: get_current_time_{ get_current_time }
{
}
/**
* Check if session ID corresponds to session running on the same machine as
* the caller.
*
* This is useful for various behavior alterations, such as transforming
* relative paths to absolute before passing through RPC, or presenting
* different UI for local and remote sessions.
*
* @param[in] session_id String representation of session identifier object.
*
* @return `True` if session is valid and local, `false` otherwise.
*/
bool tr_session_id_is_local(char const* session_id);
tr_session_id(tr_session_id&&) = delete;
tr_session_id(tr_session_id const&) = delete;
tr_session_id& operator=(tr_session_id&&) = delete;
tr_session_id& operator=(tr_session_id const&) = delete;
~tr_session_id();
/**
* Check if session ID corresponds to session running on the same machine as
* the caller.
*
* This is useful for various behavior alterations, such as transforming
* relative paths to absolute before passing through RPC, or presenting
* different UI for local and remote sessions.
*/
[[nodiscard]] static bool isLocal(std::string_view) noexcept;
// current session identifier
[[nodiscard]] std::string_view sv() const noexcept;
[[nodiscard]] char const* c_str() const noexcept;
private:
static auto constexpr SessionIdSize = size_t{ 48 };
static auto constexpr SessionIdDurationSec = time_t{ 60 * 60 }; /* expire in an hour */
using session_id_t = std::array<char, SessionIdSize + 1>; // +1 for '\0';
static session_id_t make_session_id();
current_time_func_t const get_current_time_;
mutable session_id_t current_value_;
mutable session_id_t previous_value_;
mutable tr_sys_file_t current_lock_file_ = TR_BAD_SYS_FILE;
mutable tr_sys_file_t previous_lock_file_ = TR_BAD_SYS_FILE;
mutable time_t expires_at_ = 0;
};

View File

@ -610,7 +610,6 @@ tr_session* tr_sessionInit(char const* config_dir, bool message_queueing_enabled
session->udp6_socket = TR_BAD_SOCKET;
session->cache = std::make_unique<Cache>(session->torrents(), 1024 * 1024 * 2);
session->magicNumber = SESSION_MAGIC_NUMBER;
session->session_id = tr_session_id_new();
bandwidthGroupRead(session, config_dir);
/* nice to start logging at the very beginning */
@ -2030,7 +2029,6 @@ void tr_sessionClose(tr_session* session)
/* free the session memory */
delete session->turtle.minutes;
tr_session_id_free(session->session_id);
delete session;
}
@ -3017,7 +3015,8 @@ auto makeTorrentDir(std::string_view config_dir)
} // namespace
tr_session::tr_session(std::string_view config_dir)
: config_dir_{ config_dir }
: session_id{ tr_time }
, config_dir_{ config_dir }
, resume_dir_{ makeResumeDir(config_dir) }
, torrent_dir_{ makeTorrentDir(config_dir) }
, session_stats_{ config_dir, time(nullptr) }

View File

@ -32,6 +32,7 @@
#include "net.h" // tr_socket_t
#include "open-files.h"
#include "quark.h"
#include "session-id.h"
#include "stats.h"
#include "torrents.h"
#include "web.h"
@ -563,7 +564,7 @@ public:
WebMediator web_mediator{ this };
std::unique_ptr<tr_web> web;
struct tr_session_id* session_id = nullptr;
tr_session_id session_id;
tr_rpc_func rpc_func = nullptr;
void* rpc_func_user_data = nullptr;

View File

@ -516,7 +516,8 @@ std::string tr_win32_native_to_utf8(std::wstring_view in)
{
auto out = std::string{};
out.resize(WideCharToMultiByte(CP_UTF8, 0, std::data(in), std::size(in), nullptr, 0, nullptr, nullptr));
auto len = WideCharToMultiByte(CP_UTF8, 0, std::data(in), std::size(in), std::data(out), std::size(out), nullptr, nullptr);
[[maybe_unused]] auto
len = WideCharToMultiByte(CP_UTF8, 0, std::data(in), std::size(in), std::data(out), std::size(out), nullptr, nullptr);
TR_ASSERT(len == std::size(out));
return out;
}
@ -540,7 +541,7 @@ std::wstring tr_win32_utf8_to_native(std::string_view in)
{
auto out = std::wstring{};
out.resize(MultiByteToWideChar(CP_UTF8, 0, std::data(in), std::size(in), nullptr, 0));
auto len = MultiByteToWideChar(CP_UTF8, 0, std::data(in), std::size(in), std::data(out), std::size(out));
[[maybe_unused]] auto len = MultiByteToWideChar(CP_UTF8, 0, std::data(in), std::size(in), std::data(out), std::size(out));
TR_ASSERT(len == std::size(out));
return out;
}

View File

@ -23,6 +23,7 @@
#include <libtransmission/variant.h>
#include "CustomVariantType.h"
#include "Filters.h"
#include "Prefs.h"
#include "VariantHelpers.h"

View File

@ -15,8 +15,6 @@
#include <libtransmission/quark.h>
#include <libtransmission/tr-macros.h>
#include "Filters.h"
class QDateTime;
extern "C"

View File

@ -3,8 +3,6 @@
// or any future license endorsed by Mnemosyne LLC.
// License text can be found in the licenses/ folder.
#include "Session.h"
#include <algorithm>
#include <array>
#include <cassert>
@ -29,6 +27,8 @@
#include <libtransmission/utils.h> // tr_free
#include <libtransmission/variant.h>
#include "Session.h"
#include "AddData.h"
#include "CustomVariantType.h"
#include "Prefs.h"
@ -969,7 +969,7 @@ void Session::updateInfo(tr_variant* d)
if (auto const str = dictFind<QString>(d, TR_KEY_session_id); str)
{
session_id_ = *str;
is_definitely_local_session_ = tr_session_id_is_local(session_id_.toUtf8().constData());
is_definitely_local_session_ = tr_session_id::isLocal(session_id_.toUtf8().constData());
}
else
{

View File

@ -14,6 +14,7 @@
#include <array>
#include <cstdlib>
#include <cstring>
#include <memory>
#include <string>
#include <string_view>
@ -211,6 +212,25 @@ TEST_F(SessionTest, peerId)
}
}
namespace current_time_mock
{
namespace
{
auto value = time_t{};
}
time_t get()
{
return value;
}
void set(time_t now)
{
value = now;
}
} // namespace current_time_mock
TEST_F(SessionTest, sessionId)
{
#ifdef __sun
@ -218,75 +238,54 @@ TEST_F(SessionTest, sessionId)
GTEST_SKIP();
#endif
EXPECT_FALSE(tr_session_id_is_local(nullptr));
EXPECT_FALSE(tr_session_id_is_local(""));
EXPECT_FALSE(tr_session_id_is_local("test"));
EXPECT_FALSE(tr_session_id::isLocal(""));
EXPECT_FALSE(tr_session_id::isLocal("test"));
auto session_id = tr_session_id_new();
EXPECT_NE(nullptr, session_id);
current_time_mock::set(0U);
auto session_id = std::make_unique<tr_session_id>(current_time_mock::get);
tr_timeUpdate(0);
EXPECT_NE(""sv, session_id->sv());
EXPECT_EQ(session_id->sv(), session_id->c_str()) << session_id->sv() << ", " << session_id->c_str();
EXPECT_EQ(48U, strlen(session_id->c_str()));
auto session_id_str_1 = std::string{ session_id->sv() };
EXPECT_TRUE(tr_session_id::isLocal(session_id_str_1));
auto const* session_id_str_1 = tr_session_id_get_current(session_id);
EXPECT_NE(nullptr, session_id_str_1);
EXPECT_EQ(48U, strlen(session_id_str_1));
session_id_str_1 = tr_strdup(session_id_str_1);
current_time_mock::set(current_time_mock::get() + (3600U - 1U));
EXPECT_TRUE(tr_session_id::isLocal(session_id_str_1));
auto session_id_str_2 = std::string{ session_id->sv() };
EXPECT_EQ(session_id_str_1, session_id_str_2);
EXPECT_TRUE(tr_session_id_is_local(session_id_str_1));
current_time_mock::set(3600U);
EXPECT_TRUE(tr_session_id::isLocal(session_id_str_1));
session_id_str_2 = std::string{ session_id->sv() };
EXPECT_NE(session_id_str_1, session_id_str_2);
EXPECT_EQ(session_id_str_2, session_id->c_str());
EXPECT_EQ(48U, strlen(session_id->c_str()));
tr_timeUpdate(60 * 60 - 1);
EXPECT_TRUE(tr_session_id::isLocal(session_id_str_2));
EXPECT_TRUE(tr_session_id::isLocal(session_id_str_1));
current_time_mock::set(3600U * 2U);
EXPECT_TRUE(tr_session_id::isLocal(session_id_str_2));
EXPECT_TRUE(tr_session_id::isLocal(session_id_str_1));
EXPECT_TRUE(tr_session_id_is_local(session_id_str_1));
auto const session_id_str_3 = std::string{ session_id->sv() };
EXPECT_EQ(48U, std::size(session_id_str_3));
EXPECT_NE(session_id_str_2, session_id_str_3);
EXPECT_NE(session_id_str_1, session_id_str_3);
auto const* session_id_str_2 = tr_session_id_get_current(session_id);
EXPECT_NE(nullptr, session_id_str_2);
EXPECT_EQ(48U, strlen(session_id_str_2));
EXPECT_STREQ(session_id_str_1, session_id_str_2);
EXPECT_TRUE(tr_session_id::isLocal(session_id_str_3));
EXPECT_TRUE(tr_session_id::isLocal(session_id_str_2));
EXPECT_FALSE(tr_session_id::isLocal(session_id_str_1));
tr_timeUpdate(60 * 60);
current_time_mock::set(60U * 60U * 10U);
EXPECT_TRUE(tr_session_id::isLocal(session_id_str_3));
EXPECT_TRUE(tr_session_id::isLocal(session_id_str_2));
EXPECT_FALSE(tr_session_id::isLocal(session_id_str_1));
EXPECT_TRUE(tr_session_id_is_local(session_id_str_1));
session_id_str_2 = tr_session_id_get_current(session_id);
EXPECT_NE(nullptr, session_id_str_2);
EXPECT_EQ(48U, strlen(session_id_str_2));
EXPECT_STRNE(session_id_str_1, session_id_str_2);
session_id_str_2 = tr_strdup(session_id_str_2);
EXPECT_TRUE(tr_session_id_is_local(session_id_str_2));
EXPECT_TRUE(tr_session_id_is_local(session_id_str_1));
tr_timeUpdate(60 * 60 * 2);
EXPECT_TRUE(tr_session_id_is_local(session_id_str_2));
EXPECT_TRUE(tr_session_id_is_local(session_id_str_1));
auto const* session_id_str_3 = tr_session_id_get_current(session_id);
EXPECT_NE(nullptr, session_id_str_3);
EXPECT_EQ(48U, strlen(session_id_str_3));
EXPECT_STRNE(session_id_str_2, session_id_str_3);
EXPECT_STRNE(session_id_str_1, session_id_str_3);
session_id_str_3 = tr_strdup(session_id_str_3);
EXPECT_TRUE(tr_session_id_is_local(session_id_str_3));
EXPECT_TRUE(tr_session_id_is_local(session_id_str_2));
EXPECT_FALSE(tr_session_id_is_local(session_id_str_1));
tr_timeUpdate(60 * 60 * 10);
EXPECT_TRUE(tr_session_id_is_local(session_id_str_3));
EXPECT_TRUE(tr_session_id_is_local(session_id_str_2));
EXPECT_FALSE(tr_session_id_is_local(session_id_str_1));
tr_session_id_free(session_id);
EXPECT_FALSE(tr_session_id_is_local(session_id_str_3));
EXPECT_FALSE(tr_session_id_is_local(session_id_str_2));
EXPECT_FALSE(tr_session_id_is_local(session_id_str_1));
tr_free(const_cast<char*>(session_id_str_3));
tr_free(const_cast<char*>(session_id_str_2));
tr_free(const_cast<char*>(session_id_str_1));
session_id.reset();
EXPECT_FALSE(tr_session_id::isLocal(session_id_str_3));
EXPECT_FALSE(tr_session_id::isLocal(session_id_str_2));
EXPECT_FALSE(tr_session_id::isLocal(session_id_str_1));
}
} // namespace test