1
0
Fork 0
mirror of https://github.com/transmission/transmission synced 2024-12-21 23:32:35 +00:00

fix: process BT messages that immediately follows handshake (#6913)

* refactor: don't loop in `tr_handshake::can_read()`

* fix: return `READ_NOW` after handshake success

* code review: more accurate comment wording

* Revert "refactor: don't loop in `tr_handshake::can_read()`"

This reverts commit 4f33520cba6a38171ed203a071158aa37ddcd325.

* refactor: convert `ReadState` to enum class

* refactor: use new `ReadState` value to break out of loop
This commit is contained in:
Yat Ho 2024-08-25 06:04:28 +08:00 committed by GitHub
parent 34dbaaad7e
commit 1b57c294be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 63 additions and 61 deletions

View file

@ -49,7 +49,7 @@ ReadState tr_handshake::read_yb(tr_peerIo* peer_io)
{
if (peer_io->read_buffer_size() < std::size(HandshakeName))
{
return READ_LATER;
return ReadState::Later;
}
// Jump to plain handshake
@ -57,7 +57,7 @@ ReadState tr_handshake::read_yb(tr_peerIo* peer_io)
{
tr_logAddTraceHand(this, "in read_yb... got a plain incoming handshake");
set_state(tr_handshake::State::AwaitingHandshake);
return READ_NOW;
return ReadState::Now;
}
auto peer_public_key = key_bigend_t{};
@ -66,7 +66,7 @@ ReadState tr_handshake::read_yb(tr_peerIo* peer_io)
fmt::format("in read_yb... need {}, have {}", std::size(peer_public_key), peer_io->read_buffer_size()));
if (peer_io->read_buffer_size() < std::size(peer_public_key))
{
return READ_LATER;
return ReadState::Later;
}
have_read_anything_from_peer_ = true;
@ -123,7 +123,7 @@ ReadState tr_handshake::read_yb(tr_peerIo* peer_io)
/* send it */
set_state(State::AwaitingVc);
peer_io->write(outbuf, false);
return READ_NOW;
return ReadState::Now;
}
// MSE spec: "Since the length of [PadB is] unknown,
@ -152,7 +152,7 @@ ReadState tr_handshake::read_vc(tr_peerIo* peer_io)
tr_logAddTraceHand(
this,
fmt::format("in read_vc... need {}, read {}, have {}", Needlen, pad_b_recv_len_, peer_io->read_buffer_size()));
return READ_LATER;
return ReadState::Later;
}
if (peer_io->read_buffer_starts_with(*encrypted_vc_))
@ -163,7 +163,7 @@ ReadState tr_handshake::read_vc(tr_peerIo* peer_io)
peer_io->decrypt_init(peer_io->is_incoming(), get_dh(), info_hash);
peer_io->read_buffer_discard(Needlen);
set_state(tr_handshake::State::AwaitingCryptoSelect);
return READ_NOW;
return ReadState::Now;
}
peer_io->read_buffer_discard(1U);
@ -177,7 +177,7 @@ ReadState tr_handshake::read_crypto_select(tr_peerIo* peer_io)
{
if (static auto constexpr NeedLen = sizeof(crypto_select_) + sizeof(pad_d_len_); peer_io->read_buffer_size() < NeedLen)
{
return READ_LATER;
return ReadState::Later;
}
peer_io->read_uint32(&crypto_select_);
@ -198,7 +198,7 @@ ReadState tr_handshake::read_crypto_select(tr_peerIo* peer_io)
}
set_state(tr_handshake::State::AwaitingPadD);
return READ_NOW;
return ReadState::Now;
}
ReadState tr_handshake::read_pad_d(tr_peerIo* peer_io)
@ -206,7 +206,7 @@ ReadState tr_handshake::read_pad_d(tr_peerIo* peer_io)
tr_logAddTraceHand(this, fmt::format("PadD: need {}, got {}", pad_d_len_, peer_io->read_buffer_size()));
if (peer_io->read_buffer_size() < pad_d_len_)
{
return READ_LATER;
return ReadState::Later;
}
peer_io->read_buffer_discard(pad_d_len_);
@ -219,7 +219,7 @@ ReadState tr_handshake::read_pad_d(tr_peerIo* peer_io)
}
set_state(tr_handshake::State::AwaitingHandshake);
return READ_NOW;
return ReadState::Now;
}
// --- Incoming and Outgoing Connections
@ -230,7 +230,7 @@ ReadState tr_handshake::read_handshake(tr_peerIo* peer_io)
tr_logAddTraceHand(this, fmt::format("read_handshake: need {}, got {}", Needlen, peer_io->read_buffer_size()));
if (peer_io->read_buffer_size() < Needlen)
{
return READ_LATER;
return ReadState::Later;
}
if (ia_len_ > 0U)
@ -309,7 +309,7 @@ ReadState tr_handshake::read_handshake(tr_peerIo* peer_io)
}
set_state(State::AwaitingPeerId);
return READ_NOW;
return ReadState::Now;
}
ReadState tr_handshake::read_peer_id(tr_peerIo* peer_io)
@ -320,7 +320,7 @@ ReadState tr_handshake::read_peer_id(tr_peerIo* peer_io)
tr_logAddTraceHand(this, fmt::format("read_peer_id: need {}, got {}", Needlen, peer_io->read_buffer_size()));
if (peer_io->read_buffer_size() < Needlen)
{
return READ_LATER;
return ReadState::Later;
}
peer_io->read_bytes(std::data(peer_id), Needlen);
set_peer_id(peer_id);
@ -343,7 +343,7 @@ ReadState tr_handshake::read_ya(tr_peerIo* peer_io)
{
if (peer_io->read_buffer_size() < std::size(HandshakeName))
{
return READ_LATER;
return ReadState::Later;
}
// Jump to plain handshake
@ -351,7 +351,7 @@ ReadState tr_handshake::read_ya(tr_peerIo* peer_io)
{
tr_logAddTraceHand(this, "in read_ya... got a plain incoming handshake");
set_state(tr_handshake::State::AwaitingHandshake);
return READ_NOW;
return ReadState::Now;
}
auto peer_public_key = key_bigend_t{};
@ -360,7 +360,7 @@ ReadState tr_handshake::read_ya(tr_peerIo* peer_io)
fmt::format("in read_ya... need {}, have {}", std::size(peer_public_key), peer_io->read_buffer_size()));
if (peer_io->read_buffer_size() < std::size(peer_public_key))
{
return READ_LATER;
return ReadState::Later;
}
have_read_anything_from_peer_ = true;
@ -374,7 +374,7 @@ ReadState tr_handshake::read_ya(tr_peerIo* peer_io)
send_public_key_and_pad<PadbMaxlen>(peer_io);
set_state(State::AwaitingPadA);
return READ_NOW;
return ReadState::Now;
}
ReadState tr_handshake::read_pad_a(tr_peerIo* peer_io)
@ -394,7 +394,7 @@ ReadState tr_handshake::read_pad_a(tr_peerIo* peer_io)
Needlen,
pad_a_recv_len_,
peer_io->read_buffer_size()));
return READ_LATER;
return ReadState::Later;
}
if (peer_io->read_buffer_starts_with(needle))
@ -402,7 +402,7 @@ ReadState tr_handshake::read_pad_a(tr_peerIo* peer_io)
tr_logAddTraceHand(this, "found HASH('req1', S)!");
peer_io->read_buffer_discard(Needlen);
set_state(State::AwaitingCryptoProvide);
return READ_NOW;
return ReadState::Now;
}
peer_io->read_buffer_discard(1U);
@ -421,7 +421,7 @@ ReadState tr_handshake::read_crypto_provide(tr_peerIo* peer_io)
if (peer_io->read_buffer_size() < Needlen)
{
return READ_LATER;
return ReadState::Later;
}
/* This next piece is HASH('req2', SKEY) xor HASH('req3', S) ...
@ -473,14 +473,14 @@ ReadState tr_handshake::read_crypto_provide(tr_peerIo* peer_io)
}
set_state(State::AwaitingPadC);
return READ_NOW;
return ReadState::Now;
}
ReadState tr_handshake::read_pad_c(tr_peerIo* peer_io)
{
if (auto const needlen = pad_c_len_ + sizeof(ia_len_); peer_io->read_buffer_size() < needlen)
{
return READ_LATER;
return ReadState::Later;
}
// read the throwaway padc
@ -490,7 +490,7 @@ ReadState tr_handshake::read_pad_c(tr_peerIo* peer_io)
peer_io->read_uint16(&ia_len_);
tr_logAddTraceHand(this, fmt::format("len(IA) is {}", ia_len_));
set_state(State::AwaitingIa);
return READ_NOW;
return ReadState::Now;
}
ReadState tr_handshake::read_ia(tr_peerIo* peer_io)
@ -501,7 +501,7 @@ ReadState tr_handshake::read_ia(tr_peerIo* peer_io)
if (peer_io->read_buffer_size() < needlen)
{
return READ_LATER;
return ReadState::Later;
}
// B->A: ENCRYPT(VC, crypto_select, len(padD), padD), ENCRYPT2(Payload Stream)
@ -551,7 +551,7 @@ ReadState tr_handshake::read_ia(tr_peerIo* peer_io)
/* now await the handshake */
set_state(State::AwaitingHandshake);
return READ_NOW;
return ReadState::Now;
}
// ---
@ -565,8 +565,8 @@ ReadState tr_handshake::can_read(tr_peerIo* peer_io, void* vhandshake, size_t* p
tr_logAddTraceHand(handshake, fmt::format("handling can_read; state is [{}]", handshake->state_string()));
ReadState ret = READ_NOW;
while (ret == READ_NOW)
auto ret = ReadState::Now;
while (ret == ReadState::Now)
{
switch (handshake->state())
{
@ -616,8 +616,7 @@ ReadState tr_handshake::can_read(tr_peerIo* peer_io, void* vhandshake, size_t* p
default:
TR_ASSERT_MSG(false, fmt::format("unhandled handshake state {:d}", static_cast<int>(handshake->state())));
ret = READ_ERR;
break;
return ReadState::Err;
}
}

View file

@ -134,7 +134,10 @@ private:
ReadState done(bool is_connected)
{
peer_io_->clear_callbacks();
return fire_done(is_connected) ? READ_LATER : READ_ERR;
// The responding client of a handshake usually starts sending BT messages immediately after
// the handshake, so we need to return ReadState::Break to ensure those messages are processed.
return fire_done(is_connected) ? ReadState::Break : ReadState::Err;
}
[[nodiscard]] auto is_incoming() const noexcept

View file

@ -389,7 +389,7 @@ void tr_peerIo::can_read_wrapper()
{
size_t piece = 0U;
auto const old_len = read_buffer_size();
auto const read_state = can_read_ != nullptr ? can_read_(this, user_data_, &piece) : READ_ERR;
auto const read_state = can_read_ != nullptr ? can_read_(this, user_data_, &piece) : ReadState::Err;
auto const used = old_len - read_buffer_size();
auto const overhead = socket_.guess_packet_overhead(used);
@ -410,20 +410,19 @@ void tr_peerIo::can_read_wrapper()
switch (read_state)
{
case READ_NOW:
if (!std::empty(inbuf_))
case ReadState::Now:
case ReadState::Break:
if (std::empty(inbuf_))
{
continue;
done = true;
}
break;
case ReadState::Later:
done = true;
break;
case READ_LATER:
done = true;
break;
case READ_ERR:
case ReadState::Err:
err = true;
break;
}

View file

@ -39,11 +39,12 @@ namespace libtransmission::test
class HandshakeTest;
} // namespace libtransmission::test
enum ReadState
enum class ReadState : uint8_t
{
READ_NOW,
READ_LATER,
READ_ERR
Now,
Later,
Break,
Err
};
enum tr_preferred_transport : uint8_t

View file

@ -1377,7 +1377,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl
static_cast<int>(id),
std::size(payload)));
publish(tr_peer_event::GotError(EMSGSIZE));
return { READ_ERR, {} };
return { ReadState::Err, {} };
}
switch (id)
@ -1420,7 +1420,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl
if (tor_.has_metainfo() && ui32 >= tor_.piece_count())
{
publish(tr_peer_event::GotError(ERANGE));
return { READ_ERR, {} };
return { ReadState::Err, {} };
}
/* a peer can send the same HAVE message twice... */
@ -1508,7 +1508,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl
else
{
publish(tr_peer_event::GotError(EMSGSIZE));
return { READ_ERR, {} };
return { ReadState::Err, {} };
}
break;
@ -1524,7 +1524,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl
else
{
publish(tr_peer_event::GotError(EMSGSIZE));
return { READ_ERR, {} };
return { ReadState::Err, {} };
}
break;
@ -1540,7 +1540,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl
else
{
publish(tr_peer_event::GotError(EMSGSIZE));
return { READ_ERR, {} };
return { ReadState::Err, {} };
}
break;
@ -1556,7 +1556,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl
else
{
publish(tr_peer_event::GotError(EMSGSIZE));
return { READ_ERR, {} };
return { ReadState::Err, {} };
}
break;
@ -1575,7 +1575,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl
else
{
publish(tr_peer_event::GotError(EMSGSIZE));
return { READ_ERR, {} };
return { ReadState::Err, {} };
}
break;
@ -1591,7 +1591,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl
break;
}
return { READ_NOW, {} };
return { ReadState::Now, {} };
}
ReadResult tr_peerMsgsImpl::read_piece_data(MessageReader& payload)
@ -1610,13 +1610,13 @@ ReadResult tr_peerMsgsImpl::read_piece_data(MessageReader& payload)
if (loc.block_offset + len > block_size)
{
logwarn(this, fmt::format("got unaligned piece {:d}:{:d}->{:d}", piece, offset, len));
return { READ_ERR, len };
return { ReadState::Err, len };
}
if (!tr_peerMgrDidPeerRequest(&tor_, this, block))
{
logwarn(this, fmt::format("got unrequested piece {:d}:{:d}->{:d}", piece, offset, len));
return { READ_ERR, len };
return { ReadState::Err, len };
}
publish(tr_peer_event::GotPieceData(len));
@ -1626,7 +1626,7 @@ ReadResult tr_peerMsgsImpl::read_piece_data(MessageReader& payload)
auto buf = std::make_unique<Cache::BlockData>(block_size);
payload.to_buf(std::data(*buf), len);
auto const ok = client_got_block(std::move(buf), block) == 0;
return { ok ? READ_NOW : READ_ERR, len };
return { ok ? ReadState::Now : ReadState::Err, len };
}
auto& blocks = incoming_.blocks;
@ -1635,18 +1635,18 @@ ReadResult tr_peerMsgsImpl::read_piece_data(MessageReader& payload)
if (!incoming_block.add_span(loc.block_offset, loc.block_offset + len))
{
return { READ_ERR, len }; // invalid span
return { ReadState::Err, len }; // invalid span
}
if (!incoming_block.has_all())
{
return { READ_LATER, len }; // we don't have the full block yet
return { ReadState::Later, len }; // we don't have the full block yet
}
auto block_buf = std::move(incoming_block.buf);
blocks.erase(block); // note: invalidates `incoming_block` local
auto const ok = client_got_block(std::move(block_buf), block) == 0;
return { ok ? READ_NOW : READ_ERR, len };
return { ok ? ReadState::Now : ReadState::Err, len };
}
// returns 0 on success, or an errno on failure
@ -1729,7 +1729,7 @@ ReadState tr_peerMsgsImpl::can_read(tr_peerIo* io, void* vmsgs, size_t* piece)
auto message_len = uint32_t{};
if (io->read_buffer_size() < sizeof(message_len))
{
return READ_LATER;
return ReadState::Later;
}
io->read_uint32(&message_len);
@ -1740,7 +1740,7 @@ ReadState tr_peerMsgsImpl::can_read(tr_peerIo* io, void* vmsgs, size_t* piece)
if (message_len == 0U)
{
logtrace(msgs, "got KeepAlive");
return READ_NOW;
return ReadState::Now;
}
current_message_len = message_len;
@ -1753,7 +1753,7 @@ ReadState tr_peerMsgsImpl::can_read(tr_peerIo* io, void* vmsgs, size_t* piece)
auto message_type = uint8_t{};
if (io->read_buffer_size() < sizeof(message_type))
{
return READ_LATER;
return ReadState::Later;
}
io->read_uint8(&message_type);
@ -1772,7 +1772,7 @@ ReadState tr_peerMsgsImpl::can_read(tr_peerIo* io, void* vmsgs, size_t* piece)
if (n_left > 0U)
{
return READ_LATER;
return ReadState::Later;
}
// The incoming message is now complete. After processing the message