diff --git a/libtransmission/handshake.cc b/libtransmission/handshake.cc index 867817d1b..ee6878484 100644 --- a/libtransmission/handshake.cc +++ b/libtransmission/handshake.cc @@ -49,7 +49,7 @@ ReadState tr_handshake::read_yb(tr_peerIo* peer_io) { if (peer_io->read_buffer_size() < std::size(HandshakeName)) { - return READ_LATER; + return ReadState::Later; } // Jump to plain handshake @@ -57,7 +57,7 @@ ReadState tr_handshake::read_yb(tr_peerIo* peer_io) { tr_logAddTraceHand(this, "in read_yb... got a plain incoming handshake"); set_state(tr_handshake::State::AwaitingHandshake); - return READ_NOW; + return ReadState::Now; } auto peer_public_key = key_bigend_t{}; @@ -66,7 +66,7 @@ ReadState tr_handshake::read_yb(tr_peerIo* peer_io) fmt::format("in read_yb... need {}, have {}", std::size(peer_public_key), peer_io->read_buffer_size())); if (peer_io->read_buffer_size() < std::size(peer_public_key)) { - return READ_LATER; + return ReadState::Later; } have_read_anything_from_peer_ = true; @@ -123,7 +123,7 @@ ReadState tr_handshake::read_yb(tr_peerIo* peer_io) /* send it */ set_state(State::AwaitingVc); peer_io->write(outbuf, false); - return READ_NOW; + return ReadState::Now; } // MSE spec: "Since the length of [PadB is] unknown, @@ -152,7 +152,7 @@ ReadState tr_handshake::read_vc(tr_peerIo* peer_io) tr_logAddTraceHand( this, fmt::format("in read_vc... need {}, read {}, have {}", Needlen, pad_b_recv_len_, peer_io->read_buffer_size())); - return READ_LATER; + return ReadState::Later; } if (peer_io->read_buffer_starts_with(*encrypted_vc_)) @@ -163,7 +163,7 @@ ReadState tr_handshake::read_vc(tr_peerIo* peer_io) peer_io->decrypt_init(peer_io->is_incoming(), get_dh(), info_hash); peer_io->read_buffer_discard(Needlen); set_state(tr_handshake::State::AwaitingCryptoSelect); - return READ_NOW; + return ReadState::Now; } peer_io->read_buffer_discard(1U); @@ -177,7 +177,7 @@ ReadState tr_handshake::read_crypto_select(tr_peerIo* peer_io) { if (static auto constexpr NeedLen = sizeof(crypto_select_) + sizeof(pad_d_len_); peer_io->read_buffer_size() < NeedLen) { - return READ_LATER; + return ReadState::Later; } peer_io->read_uint32(&crypto_select_); @@ -198,7 +198,7 @@ ReadState tr_handshake::read_crypto_select(tr_peerIo* peer_io) } set_state(tr_handshake::State::AwaitingPadD); - return READ_NOW; + return ReadState::Now; } ReadState tr_handshake::read_pad_d(tr_peerIo* peer_io) @@ -206,7 +206,7 @@ ReadState tr_handshake::read_pad_d(tr_peerIo* peer_io) tr_logAddTraceHand(this, fmt::format("PadD: need {}, got {}", pad_d_len_, peer_io->read_buffer_size())); if (peer_io->read_buffer_size() < pad_d_len_) { - return READ_LATER; + return ReadState::Later; } peer_io->read_buffer_discard(pad_d_len_); @@ -219,7 +219,7 @@ ReadState tr_handshake::read_pad_d(tr_peerIo* peer_io) } set_state(tr_handshake::State::AwaitingHandshake); - return READ_NOW; + return ReadState::Now; } // --- Incoming and Outgoing Connections @@ -230,7 +230,7 @@ ReadState tr_handshake::read_handshake(tr_peerIo* peer_io) tr_logAddTraceHand(this, fmt::format("read_handshake: need {}, got {}", Needlen, peer_io->read_buffer_size())); if (peer_io->read_buffer_size() < Needlen) { - return READ_LATER; + return ReadState::Later; } if (ia_len_ > 0U) @@ -309,7 +309,7 @@ ReadState tr_handshake::read_handshake(tr_peerIo* peer_io) } set_state(State::AwaitingPeerId); - return READ_NOW; + return ReadState::Now; } ReadState tr_handshake::read_peer_id(tr_peerIo* peer_io) @@ -320,7 +320,7 @@ ReadState tr_handshake::read_peer_id(tr_peerIo* peer_io) tr_logAddTraceHand(this, fmt::format("read_peer_id: need {}, got {}", Needlen, peer_io->read_buffer_size())); if (peer_io->read_buffer_size() < Needlen) { - return READ_LATER; + return ReadState::Later; } peer_io->read_bytes(std::data(peer_id), Needlen); set_peer_id(peer_id); @@ -343,7 +343,7 @@ ReadState tr_handshake::read_ya(tr_peerIo* peer_io) { if (peer_io->read_buffer_size() < std::size(HandshakeName)) { - return READ_LATER; + return ReadState::Later; } // Jump to plain handshake @@ -351,7 +351,7 @@ ReadState tr_handshake::read_ya(tr_peerIo* peer_io) { tr_logAddTraceHand(this, "in read_ya... got a plain incoming handshake"); set_state(tr_handshake::State::AwaitingHandshake); - return READ_NOW; + return ReadState::Now; } auto peer_public_key = key_bigend_t{}; @@ -360,7 +360,7 @@ ReadState tr_handshake::read_ya(tr_peerIo* peer_io) fmt::format("in read_ya... need {}, have {}", std::size(peer_public_key), peer_io->read_buffer_size())); if (peer_io->read_buffer_size() < std::size(peer_public_key)) { - return READ_LATER; + return ReadState::Later; } have_read_anything_from_peer_ = true; @@ -374,7 +374,7 @@ ReadState tr_handshake::read_ya(tr_peerIo* peer_io) send_public_key_and_pad(peer_io); set_state(State::AwaitingPadA); - return READ_NOW; + return ReadState::Now; } ReadState tr_handshake::read_pad_a(tr_peerIo* peer_io) @@ -394,7 +394,7 @@ ReadState tr_handshake::read_pad_a(tr_peerIo* peer_io) Needlen, pad_a_recv_len_, peer_io->read_buffer_size())); - return READ_LATER; + return ReadState::Later; } if (peer_io->read_buffer_starts_with(needle)) @@ -402,7 +402,7 @@ ReadState tr_handshake::read_pad_a(tr_peerIo* peer_io) tr_logAddTraceHand(this, "found HASH('req1', S)!"); peer_io->read_buffer_discard(Needlen); set_state(State::AwaitingCryptoProvide); - return READ_NOW; + return ReadState::Now; } peer_io->read_buffer_discard(1U); @@ -421,7 +421,7 @@ ReadState tr_handshake::read_crypto_provide(tr_peerIo* peer_io) if (peer_io->read_buffer_size() < Needlen) { - return READ_LATER; + return ReadState::Later; } /* This next piece is HASH('req2', SKEY) xor HASH('req3', S) ... @@ -473,14 +473,14 @@ ReadState tr_handshake::read_crypto_provide(tr_peerIo* peer_io) } set_state(State::AwaitingPadC); - return READ_NOW; + return ReadState::Now; } ReadState tr_handshake::read_pad_c(tr_peerIo* peer_io) { if (auto const needlen = pad_c_len_ + sizeof(ia_len_); peer_io->read_buffer_size() < needlen) { - return READ_LATER; + return ReadState::Later; } // read the throwaway padc @@ -490,7 +490,7 @@ ReadState tr_handshake::read_pad_c(tr_peerIo* peer_io) peer_io->read_uint16(&ia_len_); tr_logAddTraceHand(this, fmt::format("len(IA) is {}", ia_len_)); set_state(State::AwaitingIa); - return READ_NOW; + return ReadState::Now; } ReadState tr_handshake::read_ia(tr_peerIo* peer_io) @@ -501,7 +501,7 @@ ReadState tr_handshake::read_ia(tr_peerIo* peer_io) if (peer_io->read_buffer_size() < needlen) { - return READ_LATER; + return ReadState::Later; } // B->A: ENCRYPT(VC, crypto_select, len(padD), padD), ENCRYPT2(Payload Stream) @@ -551,7 +551,7 @@ ReadState tr_handshake::read_ia(tr_peerIo* peer_io) /* now await the handshake */ set_state(State::AwaitingHandshake); - return READ_NOW; + return ReadState::Now; } // --- @@ -565,8 +565,8 @@ ReadState tr_handshake::can_read(tr_peerIo* peer_io, void* vhandshake, size_t* p tr_logAddTraceHand(handshake, fmt::format("handling can_read; state is [{}]", handshake->state_string())); - ReadState ret = READ_NOW; - while (ret == READ_NOW) + auto ret = ReadState::Now; + while (ret == ReadState::Now) { switch (handshake->state()) { @@ -616,8 +616,7 @@ ReadState tr_handshake::can_read(tr_peerIo* peer_io, void* vhandshake, size_t* p default: TR_ASSERT_MSG(false, fmt::format("unhandled handshake state {:d}", static_cast(handshake->state()))); - ret = READ_ERR; - break; + return ReadState::Err; } } diff --git a/libtransmission/handshake.h b/libtransmission/handshake.h index 355c4000f..fcc8f9dc6 100644 --- a/libtransmission/handshake.h +++ b/libtransmission/handshake.h @@ -134,7 +134,10 @@ private: ReadState done(bool is_connected) { peer_io_->clear_callbacks(); - return fire_done(is_connected) ? READ_LATER : READ_ERR; + + // The responding client of a handshake usually starts sending BT messages immediately after + // the handshake, so we need to return ReadState::Break to ensure those messages are processed. + return fire_done(is_connected) ? ReadState::Break : ReadState::Err; } [[nodiscard]] auto is_incoming() const noexcept diff --git a/libtransmission/peer-io.cc b/libtransmission/peer-io.cc index 9af576221..9f909989e 100644 --- a/libtransmission/peer-io.cc +++ b/libtransmission/peer-io.cc @@ -389,7 +389,7 @@ void tr_peerIo::can_read_wrapper() { size_t piece = 0U; auto const old_len = read_buffer_size(); - auto const read_state = can_read_ != nullptr ? can_read_(this, user_data_, &piece) : READ_ERR; + auto const read_state = can_read_ != nullptr ? can_read_(this, user_data_, &piece) : ReadState::Err; auto const used = old_len - read_buffer_size(); auto const overhead = socket_.guess_packet_overhead(used); @@ -410,20 +410,19 @@ void tr_peerIo::can_read_wrapper() switch (read_state) { - case READ_NOW: - if (!std::empty(inbuf_)) + case ReadState::Now: + case ReadState::Break: + if (std::empty(inbuf_)) { - continue; + done = true; } + break; + case ReadState::Later: done = true; break; - case READ_LATER: - done = true; - break; - - case READ_ERR: + case ReadState::Err: err = true; break; } diff --git a/libtransmission/peer-io.h b/libtransmission/peer-io.h index f729624ac..c5b9ec2c8 100644 --- a/libtransmission/peer-io.h +++ b/libtransmission/peer-io.h @@ -39,11 +39,12 @@ namespace libtransmission::test class HandshakeTest; } // namespace libtransmission::test -enum ReadState +enum class ReadState : uint8_t { - READ_NOW, - READ_LATER, - READ_ERR + Now, + Later, + Break, + Err }; enum tr_preferred_transport : uint8_t diff --git a/libtransmission/peer-msgs.cc b/libtransmission/peer-msgs.cc index a7e7995b0..8d0ffb824 100644 --- a/libtransmission/peer-msgs.cc +++ b/libtransmission/peer-msgs.cc @@ -1377,7 +1377,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl static_cast(id), std::size(payload))); publish(tr_peer_event::GotError(EMSGSIZE)); - return { READ_ERR, {} }; + return { ReadState::Err, {} }; } switch (id) @@ -1420,7 +1420,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl if (tor_.has_metainfo() && ui32 >= tor_.piece_count()) { publish(tr_peer_event::GotError(ERANGE)); - return { READ_ERR, {} }; + return { ReadState::Err, {} }; } /* a peer can send the same HAVE message twice... */ @@ -1508,7 +1508,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl else { publish(tr_peer_event::GotError(EMSGSIZE)); - return { READ_ERR, {} }; + return { ReadState::Err, {} }; } break; @@ -1524,7 +1524,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl else { publish(tr_peer_event::GotError(EMSGSIZE)); - return { READ_ERR, {} }; + return { ReadState::Err, {} }; } break; @@ -1540,7 +1540,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl else { publish(tr_peer_event::GotError(EMSGSIZE)); - return { READ_ERR, {} }; + return { ReadState::Err, {} }; } break; @@ -1556,7 +1556,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl else { publish(tr_peer_event::GotError(EMSGSIZE)); - return { READ_ERR, {} }; + return { ReadState::Err, {} }; } break; @@ -1575,7 +1575,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl else { publish(tr_peer_event::GotError(EMSGSIZE)); - return { READ_ERR, {} }; + return { ReadState::Err, {} }; } break; @@ -1591,7 +1591,7 @@ ReadResult tr_peerMsgsImpl::process_peer_message(uint8_t id, MessageReader& payl break; } - return { READ_NOW, {} }; + return { ReadState::Now, {} }; } ReadResult tr_peerMsgsImpl::read_piece_data(MessageReader& payload) @@ -1610,13 +1610,13 @@ ReadResult tr_peerMsgsImpl::read_piece_data(MessageReader& payload) if (loc.block_offset + len > block_size) { logwarn(this, fmt::format("got unaligned piece {:d}:{:d}->{:d}", piece, offset, len)); - return { READ_ERR, len }; + return { ReadState::Err, len }; } if (!tr_peerMgrDidPeerRequest(&tor_, this, block)) { logwarn(this, fmt::format("got unrequested piece {:d}:{:d}->{:d}", piece, offset, len)); - return { READ_ERR, len }; + return { ReadState::Err, len }; } publish(tr_peer_event::GotPieceData(len)); @@ -1626,7 +1626,7 @@ ReadResult tr_peerMsgsImpl::read_piece_data(MessageReader& payload) auto buf = std::make_unique(block_size); payload.to_buf(std::data(*buf), len); auto const ok = client_got_block(std::move(buf), block) == 0; - return { ok ? READ_NOW : READ_ERR, len }; + return { ok ? ReadState::Now : ReadState::Err, len }; } auto& blocks = incoming_.blocks; @@ -1635,18 +1635,18 @@ ReadResult tr_peerMsgsImpl::read_piece_data(MessageReader& payload) if (!incoming_block.add_span(loc.block_offset, loc.block_offset + len)) { - return { READ_ERR, len }; // invalid span + return { ReadState::Err, len }; // invalid span } if (!incoming_block.has_all()) { - return { READ_LATER, len }; // we don't have the full block yet + return { ReadState::Later, len }; // we don't have the full block yet } auto block_buf = std::move(incoming_block.buf); blocks.erase(block); // note: invalidates `incoming_block` local auto const ok = client_got_block(std::move(block_buf), block) == 0; - return { ok ? READ_NOW : READ_ERR, len }; + return { ok ? ReadState::Now : ReadState::Err, len }; } // returns 0 on success, or an errno on failure @@ -1729,7 +1729,7 @@ ReadState tr_peerMsgsImpl::can_read(tr_peerIo* io, void* vmsgs, size_t* piece) auto message_len = uint32_t{}; if (io->read_buffer_size() < sizeof(message_len)) { - return READ_LATER; + return ReadState::Later; } io->read_uint32(&message_len); @@ -1740,7 +1740,7 @@ ReadState tr_peerMsgsImpl::can_read(tr_peerIo* io, void* vmsgs, size_t* piece) if (message_len == 0U) { logtrace(msgs, "got KeepAlive"); - return READ_NOW; + return ReadState::Now; } current_message_len = message_len; @@ -1753,7 +1753,7 @@ ReadState tr_peerMsgsImpl::can_read(tr_peerIo* io, void* vmsgs, size_t* piece) auto message_type = uint8_t{}; if (io->read_buffer_size() < sizeof(message_type)) { - return READ_LATER; + return ReadState::Later; } io->read_uint8(&message_type); @@ -1772,7 +1772,7 @@ ReadState tr_peerMsgsImpl::can_read(tr_peerIo* io, void* vmsgs, size_t* piece) if (n_left > 0U) { - return READ_LATER; + return ReadState::Later; } // The incoming message is now complete. After processing the message