diff --git a/libtransmission/handshake.cc b/libtransmission/handshake.cc index 759da0452..6f6ad28c1 100644 --- a/libtransmission/handshake.cc +++ b/libtransmission/handshake.cc @@ -244,10 +244,11 @@ ReadState tr_handshake::read_vc(tr_peerIo* peer_io) // so calculate and cache the value of `ENCRYPT(VC)`. if (!encrypted_vc_) { - auto needle = VC; auto filter = tr_message_stream_encryption::Filter{}; - filter.encryptInit(true, dh_, info_hash); - filter.encrypt(std::size(needle), std::data(needle)); + filter.encrypt_init(true, dh_, info_hash); + + auto needle = decltype(VC){}; + filter.encrypt(std::data(VC), std::size(VC), std::data(needle)); encrypted_vc_ = needle; } diff --git a/libtransmission/peer-io.cc b/libtransmission/peer-io.cc index aa95744e9..f5cd7e0b6 100644 --- a/libtransmission/peer-io.cc +++ b/libtransmission/peer-io.cc @@ -575,44 +575,8 @@ size_t tr_peerIo::get_write_buffer_space(uint64_t now) const noexcept return desired_len > current_len ? desired_len - current_len : 0U; } -void tr_peerIo::write(libtransmission::Buffer& buf, bool is_piece_data) -{ - auto [bytes, len] = buf.pullup(); - encrypt(len, bytes); - outbuf_info_.emplace_back(std::size(buf), is_piece_data); - outbuf_.add(buf); - buf.clear(); -} - -void tr_peerIo::write_bytes(void const* bytes, size_t n_bytes, bool is_piece_data) -{ - auto const old_size = std::size(outbuf_); - - outbuf_.reserve(old_size + n_bytes); - outbuf_.add(bytes, n_bytes); - - for (auto iter = std::begin(outbuf_) + old_size, end = std::end(outbuf_); iter != end; ++iter) - { - encrypt(1, &*iter); - } - - outbuf_info_.emplace_back(n_bytes, is_piece_data); -} - // --- -void tr_peerIo::read_bytes(void* bytes, size_t byte_count) -{ - TR_ASSERT(read_buffer_size() >= byte_count); - - inbuf_.to_buf(bytes, byte_count); - - if (is_encrypted()) - { - decrypt(byte_count, bytes); - } -} - void tr_peerIo::read_uint16(uint16_t* setme) { auto tmp = uint16_t{}; diff --git a/libtransmission/peer-io.h b/libtransmission/peer-io.h index 07c36a75d..683826316 100644 --- a/libtransmission/peer-io.h +++ b/libtransmission/peer-io.h @@ -108,7 +108,12 @@ public: void read_buffer_drain(size_t byte_count); - void read_bytes(void* bytes, size_t byte_count); + void read_bytes(void* bytes, size_t n_bytes) + { + n_bytes = std::min(n_bytes, std::size(inbuf_)); + filter_.decrypt(std::data(inbuf_), n_bytes, reinterpret_cast(bytes)); + inbuf_.drain(n_bytes); + } void read_uint8(uint8_t* setme) { @@ -123,11 +128,23 @@ public: [[nodiscard]] size_t get_write_buffer_space(uint64_t now) const noexcept; - void write_bytes(void const* bytes, size_t n_bytes, bool is_piece_data); + void write_bytes(void const* bytes, size_t n_bytes, bool is_piece_data) + { + outbuf_info_.emplace_back(n_bytes, is_piece_data); + + auto [resbuf, reslen] = outbuf_.reserve_space(n_bytes); + filter_.encrypt(reinterpret_cast(bytes), n_bytes, resbuf); + outbuf_.commit_space(n_bytes); + } // Write all the data from `buf`. // This is a destructive add: `buf` is empty after this call. - void write(libtransmission::Buffer& buf, bool is_piece_data); + void write(libtransmission::Buffer& buf, bool is_piece_data) + { + auto const n_bytes = std::size(buf); + write_bytes(std::data(buf), n_bytes, is_piece_data); + buf.drain(n_bytes); + } size_t flush_outgoing_protocol_msgs(); @@ -258,12 +275,12 @@ public: void decrypt_init(bool is_incoming, DH const& dh, tr_sha1_digest_t const& info_hash) { - filter_.decryptInit(is_incoming, dh, info_hash); + filter_.decrypt_init(is_incoming, dh, info_hash); } void encrypt_init(bool is_incoming, DH const& dh, tr_sha1_digest_t const& info_hash) { - filter_.encryptInit(is_incoming, dh, info_hash); + filter_.encrypt_init(is_incoming, dh, info_hash); } /// @@ -288,16 +305,6 @@ private: } } - void decrypt(size_t buflen, void* buf) - { - filter_.decrypt(buflen, buf); - } - - void encrypt(size_t buflen, void* buf) - { - filter_.encrypt(buflen, buf); - } - void on_utp_state_change(int new_state); void on_utp_error(int errcode); diff --git a/libtransmission/peer-mse.cc b/libtransmission/peer-mse.cc index cdf168069..5e117d663 100644 --- a/libtransmission/peer-mse.cc +++ b/libtransmission/peer-mse.cc @@ -103,7 +103,7 @@ void DH::setPeerPublicKey(key_bigend_t const& peer_public_key) // --- Filter -void Filter::decryptInit(bool is_incoming, DH const& dh, tr_sha1_digest_t const& info_hash) +void Filter::decrypt_init(bool is_incoming, DH const& dh, tr_sha1_digest_t const& info_hash) { auto const key = is_incoming ? "keyA"sv : "keyB"sv; auto const buf = tr_sha1::digest(key, dh.secret(), info_hash); @@ -112,7 +112,7 @@ void Filter::decryptInit(bool is_incoming, DH const& dh, tr_sha1_digest_t const& dec_key_.discard(1024); } -void Filter::encryptInit(bool is_incoming, DH const& dh, tr_sha1_digest_t const& info_hash) +void Filter::encrypt_init(bool is_incoming, DH const& dh, tr_sha1_digest_t const& info_hash) { auto const key = is_incoming ? "keyB"sv : "keyA"sv; auto const buf = tr_sha1::digest(key, dh.secret(), info_hash); diff --git a/libtransmission/peer-mse.h b/libtransmission/peer-mse.h index d8c4b0a6c..5928d04bb 100644 --- a/libtransmission/peer-mse.h +++ b/libtransmission/peer-mse.h @@ -11,6 +11,7 @@ #error only libtransmission should #include this header. #endif +#include // for std::copy_n() #include #include // size_t, std::byte #include @@ -77,26 +78,20 @@ private: class Filter { public: - void decryptInit(bool is_incoming, DH const&, tr_sha1_digest_t const& info_hash); + void decrypt_init(bool is_incoming, DH const&, tr_sha1_digest_t const& info_hash); template - constexpr void decrypt(size_t buf_len, T* buf) + constexpr void decrypt(T const* buf_in, size_t buf_len, T* buf_out) noexcept { - if (dec_active_) - { - dec_key_.process(buf, buf, buf_len); - } + process(buf_in, buf_len, buf_out, dec_active_, dec_key_); } - void encryptInit(bool is_incoming, DH const&, tr_sha1_digest_t const& info_hash); + void encrypt_init(bool is_incoming, DH const&, tr_sha1_digest_t const& info_hash); template - constexpr void encrypt(size_t buf_len, T* buf) + constexpr void encrypt(T const* buf_in, size_t buf_len, T* buf_out) noexcept { - if (enc_active_) - { - enc_key_.process(buf, buf, buf_len); - } + process(buf_in, buf_len, buf_out, enc_active_, enc_key_); } [[nodiscard]] constexpr auto is_active() const noexcept @@ -105,6 +100,19 @@ public: } private: + template + static constexpr void process(T const* buf_in, size_t buf_len, T* buf_out, bool active, tr_arc4& arc4) noexcept + { + if (active) + { + arc4.process(reinterpret_cast(buf_in), buf_len, reinterpret_cast(buf_out)); + } + else + { + std::copy_n(buf_in, buf_len, buf_out); + } + } + tr_arc4 dec_key_ = {}; tr_arc4 enc_key_ = {}; bool dec_active_ = false; diff --git a/libtransmission/tr-arc4.h b/libtransmission/tr-arc4.h index 9ef14f8bc..52f58f949 100644 --- a/libtransmission/tr-arc4.h +++ b/libtransmission/tr-arc4.h @@ -47,11 +47,11 @@ public: } } - constexpr void process(void const* src_data, void* dst_data, size_t data_length) + constexpr void process(uint8_t const* const src, size_t n_bytes, uint8_t* const tgt) { - for (size_t i = 0; i < data_length; ++i) + for (size_t i = 0; i != n_bytes; ++i) { - static_cast(dst_data)[i] = static_cast(src_data)[i] ^ arc4_next(); + tgt[i] = src[i] ^ arc4_next(); } } diff --git a/tests/libtransmission/crypto-test.cc b/tests/libtransmission/crypto-test.cc index 8691e0bbc..fe3296c16 100644 --- a/tests/libtransmission/crypto-test.cc +++ b/tests/libtransmission/crypto-test.cc @@ -79,25 +79,21 @@ TEST(Crypto, encryptDecrypt) auto decrypted1 = std::array{}; auto a = tr_message_stream_encryption::Filter{}; - a.encryptInit(false, a_dh, SomeHash); - std::copy_n(std::begin(Input1), std::size(Input1), std::begin(encrypted1)); - a.encrypt(std::size(Input1), std::data(encrypted1)); + a.encrypt_init(false, a_dh, SomeHash); + a.encrypt(std::data(Input1), std::size(Input1), std::data(encrypted1)); auto b = tr_message_stream_encryption::Filter{}; - b.decryptInit(true, b_dh, SomeHash); - std::copy_n(std::begin(encrypted1), std::size(Input1), std::begin(decrypted1)); - b.decrypt(std::size(Input1), std::data(decrypted1)); + b.decrypt_init(true, b_dh, SomeHash); + b.decrypt(std::data(encrypted1), std::size(Input1), std::data(decrypted1)); EXPECT_EQ(Input1, std::data(decrypted1)) << "Input1 " << Input1 << " decrypted1 " << std::data(decrypted1); auto constexpr Input2 = "@#)C$@)#(*%bvkdjfhwbc039bc4603756VB3)"sv; auto encrypted2 = std::array{}; auto decrypted2 = std::array{}; - b.encryptInit(true, b_dh, SomeHash); - std::copy_n(std::begin(Input2), std::size(Input2), std::begin(encrypted2)); - b.encrypt(std::size(Input2), std::data(encrypted2)); - a.decryptInit(false, a_dh, SomeHash); - std::copy_n(std::begin(encrypted2), std::size(Input2), std::begin(decrypted2)); - a.decrypt(std::size(Input2), std::data(decrypted2)); + b.encrypt_init(true, b_dh, SomeHash); + b.encrypt(std::data(Input2), std::size(Input2), std::data(encrypted2)); + a.decrypt_init(false, a_dh, SomeHash); + a.decrypt(std::data(encrypted2), std::size(Input2), std::data(decrypted2)); EXPECT_EQ(Input2, std::data(decrypted2)) << "Input2 " << Input2 << " decrypted2 " << std::data(decrypted2); }