refactor: use reserve_space() in peer-io (#5532)

This commit is contained in:
Charles Kerr 2023-05-17 13:57:27 -05:00 committed by GitHub
parent b9698210ef
commit 4fd5f3a490
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 59 additions and 83 deletions

View File

@ -244,10 +244,11 @@ ReadState tr_handshake::read_vc(tr_peerIo* peer_io)
// so calculate and cache the value of `ENCRYPT(VC)`.
if (!encrypted_vc_)
{
auto needle = VC;
auto filter = tr_message_stream_encryption::Filter{};
filter.encryptInit(true, dh_, info_hash);
filter.encrypt(std::size(needle), std::data(needle));
filter.encrypt_init(true, dh_, info_hash);
auto needle = decltype(VC){};
filter.encrypt(std::data(VC), std::size(VC), std::data(needle));
encrypted_vc_ = needle;
}

View File

@ -575,44 +575,8 @@ size_t tr_peerIo::get_write_buffer_space(uint64_t now) const noexcept
return desired_len > current_len ? desired_len - current_len : 0U;
}
void tr_peerIo::write(libtransmission::Buffer& buf, bool is_piece_data)
{
auto [bytes, len] = buf.pullup();
encrypt(len, bytes);
outbuf_info_.emplace_back(std::size(buf), is_piece_data);
outbuf_.add(buf);
buf.clear();
}
void tr_peerIo::write_bytes(void const* bytes, size_t n_bytes, bool is_piece_data)
{
auto const old_size = std::size(outbuf_);
outbuf_.reserve(old_size + n_bytes);
outbuf_.add(bytes, n_bytes);
for (auto iter = std::begin(outbuf_) + old_size, end = std::end(outbuf_); iter != end; ++iter)
{
encrypt(1, &*iter);
}
outbuf_info_.emplace_back(n_bytes, is_piece_data);
}
// ---
void tr_peerIo::read_bytes(void* bytes, size_t byte_count)
{
TR_ASSERT(read_buffer_size() >= byte_count);
inbuf_.to_buf(bytes, byte_count);
if (is_encrypted())
{
decrypt(byte_count, bytes);
}
}
void tr_peerIo::read_uint16(uint16_t* setme)
{
auto tmp = uint16_t{};

View File

@ -108,7 +108,12 @@ public:
void read_buffer_drain(size_t byte_count);
void read_bytes(void* bytes, size_t byte_count);
void read_bytes(void* bytes, size_t n_bytes)
{
n_bytes = std::min(n_bytes, std::size(inbuf_));
filter_.decrypt(std::data(inbuf_), n_bytes, reinterpret_cast<std::byte*>(bytes));
inbuf_.drain(n_bytes);
}
void read_uint8(uint8_t* setme)
{
@ -123,11 +128,23 @@ public:
[[nodiscard]] size_t get_write_buffer_space(uint64_t now) const noexcept;
void write_bytes(void const* bytes, size_t n_bytes, bool is_piece_data);
void write_bytes(void const* bytes, size_t n_bytes, bool is_piece_data)
{
outbuf_info_.emplace_back(n_bytes, is_piece_data);
auto [resbuf, reslen] = outbuf_.reserve_space(n_bytes);
filter_.encrypt(reinterpret_cast<std::byte const*>(bytes), n_bytes, resbuf);
outbuf_.commit_space(n_bytes);
}
// Write all the data from `buf`.
// This is a destructive add: `buf` is empty after this call.
void write(libtransmission::Buffer& buf, bool is_piece_data);
void write(libtransmission::Buffer& buf, bool is_piece_data)
{
auto const n_bytes = std::size(buf);
write_bytes(std::data(buf), n_bytes, is_piece_data);
buf.drain(n_bytes);
}
size_t flush_outgoing_protocol_msgs();
@ -258,12 +275,12 @@ public:
void decrypt_init(bool is_incoming, DH const& dh, tr_sha1_digest_t const& info_hash)
{
filter_.decryptInit(is_incoming, dh, info_hash);
filter_.decrypt_init(is_incoming, dh, info_hash);
}
void encrypt_init(bool is_incoming, DH const& dh, tr_sha1_digest_t const& info_hash)
{
filter_.encryptInit(is_incoming, dh, info_hash);
filter_.encrypt_init(is_incoming, dh, info_hash);
}
///
@ -288,16 +305,6 @@ private:
}
}
void decrypt(size_t buflen, void* buf)
{
filter_.decrypt(buflen, buf);
}
void encrypt(size_t buflen, void* buf)
{
filter_.encrypt(buflen, buf);
}
void on_utp_state_change(int new_state);
void on_utp_error(int errcode);

View File

@ -103,7 +103,7 @@ void DH::setPeerPublicKey(key_bigend_t const& peer_public_key)
// --- Filter
void Filter::decryptInit(bool is_incoming, DH const& dh, tr_sha1_digest_t const& info_hash)
void Filter::decrypt_init(bool is_incoming, DH const& dh, tr_sha1_digest_t const& info_hash)
{
auto const key = is_incoming ? "keyA"sv : "keyB"sv;
auto const buf = tr_sha1::digest(key, dh.secret(), info_hash);
@ -112,7 +112,7 @@ void Filter::decryptInit(bool is_incoming, DH const& dh, tr_sha1_digest_t const&
dec_key_.discard(1024);
}
void Filter::encryptInit(bool is_incoming, DH const& dh, tr_sha1_digest_t const& info_hash)
void Filter::encrypt_init(bool is_incoming, DH const& dh, tr_sha1_digest_t const& info_hash)
{
auto const key = is_incoming ? "keyB"sv : "keyA"sv;
auto const buf = tr_sha1::digest(key, dh.secret(), info_hash);

View File

@ -11,6 +11,7 @@
#error only libtransmission should #include this header.
#endif
#include <algorithm> // for std::copy_n()
#include <array>
#include <cstddef> // size_t, std::byte
#include <memory>
@ -77,26 +78,20 @@ private:
class Filter
{
public:
void decryptInit(bool is_incoming, DH const&, tr_sha1_digest_t const& info_hash);
void decrypt_init(bool is_incoming, DH const&, tr_sha1_digest_t const& info_hash);
template<typename T>
constexpr void decrypt(size_t buf_len, T* buf)
constexpr void decrypt(T const* buf_in, size_t buf_len, T* buf_out) noexcept
{
if (dec_active_)
{
dec_key_.process(buf, buf, buf_len);
}
process(buf_in, buf_len, buf_out, dec_active_, dec_key_);
}
void encryptInit(bool is_incoming, DH const&, tr_sha1_digest_t const& info_hash);
void encrypt_init(bool is_incoming, DH const&, tr_sha1_digest_t const& info_hash);
template<typename T>
constexpr void encrypt(size_t buf_len, T* buf)
constexpr void encrypt(T const* buf_in, size_t buf_len, T* buf_out) noexcept
{
if (enc_active_)
{
enc_key_.process(buf, buf, buf_len);
}
process(buf_in, buf_len, buf_out, enc_active_, enc_key_);
}
[[nodiscard]] constexpr auto is_active() const noexcept
@ -105,6 +100,19 @@ public:
}
private:
template<typename T>
static constexpr void process(T const* buf_in, size_t buf_len, T* buf_out, bool active, tr_arc4& arc4) noexcept
{
if (active)
{
arc4.process(reinterpret_cast<uint8_t const*>(buf_in), buf_len, reinterpret_cast<uint8_t*>(buf_out));
}
else
{
std::copy_n(buf_in, buf_len, buf_out);
}
}
tr_arc4 dec_key_ = {};
tr_arc4 enc_key_ = {};
bool dec_active_ = false;

View File

@ -47,11 +47,11 @@ public:
}
}
constexpr void process(void const* src_data, void* dst_data, size_t data_length)
constexpr void process(uint8_t const* const src, size_t n_bytes, uint8_t* const tgt)
{
for (size_t i = 0; i < data_length; ++i)
for (size_t i = 0; i != n_bytes; ++i)
{
static_cast<uint8_t*>(dst_data)[i] = static_cast<uint8_t const*>(src_data)[i] ^ arc4_next();
tgt[i] = src[i] ^ arc4_next();
}
}

View File

@ -79,25 +79,21 @@ TEST(Crypto, encryptDecrypt)
auto decrypted1 = std::array<char, 128>{};
auto a = tr_message_stream_encryption::Filter{};
a.encryptInit(false, a_dh, SomeHash);
std::copy_n(std::begin(Input1), std::size(Input1), std::begin(encrypted1));
a.encrypt(std::size(Input1), std::data(encrypted1));
a.encrypt_init(false, a_dh, SomeHash);
a.encrypt(std::data(Input1), std::size(Input1), std::data(encrypted1));
auto b = tr_message_stream_encryption::Filter{};
b.decryptInit(true, b_dh, SomeHash);
std::copy_n(std::begin(encrypted1), std::size(Input1), std::begin(decrypted1));
b.decrypt(std::size(Input1), std::data(decrypted1));
b.decrypt_init(true, b_dh, SomeHash);
b.decrypt(std::data(encrypted1), std::size(Input1), std::data(decrypted1));
EXPECT_EQ(Input1, std::data(decrypted1)) << "Input1 " << Input1 << " decrypted1 " << std::data(decrypted1);
auto constexpr Input2 = "@#)C$@)#(*%bvkdjfhwbc039bc4603756VB3)"sv;
auto encrypted2 = std::array<char, 128>{};
auto decrypted2 = std::array<char, 128>{};
b.encryptInit(true, b_dh, SomeHash);
std::copy_n(std::begin(Input2), std::size(Input2), std::begin(encrypted2));
b.encrypt(std::size(Input2), std::data(encrypted2));
a.decryptInit(false, a_dh, SomeHash);
std::copy_n(std::begin(encrypted2), std::size(Input2), std::begin(decrypted2));
a.decrypt(std::size(Input2), std::data(decrypted2));
b.encrypt_init(true, b_dh, SomeHash);
b.encrypt(std::data(Input2), std::size(Input2), std::data(encrypted2));
a.decrypt_init(false, a_dh, SomeHash);
a.decrypt(std::data(encrypted2), std::size(Input2), std::data(decrypted2));
EXPECT_EQ(Input2, std::data(decrypted2)) << "Input2 " << Input2 << " decrypted2 " << std::data(decrypted2);
}