refactor: add tr_peerIo::peek() (#3798)

This commit is contained in:
Charles Kerr 2022-09-09 13:12:47 -05:00 committed by GitHub
parent ad125edea9
commit 80d9d5a63b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 132 additions and 140 deletions

View File

@ -7,7 +7,6 @@
#include <array>
#include <cerrno>
#include <chrono>
#include <cstring>
#include <string_view>
#include <utility>
@ -38,10 +37,14 @@ using namespace std::literals;
****
***/
#define HANDSHAKE_NAME "\023BitTorrent protocol"
static auto constexpr HandshakeName = std::array<std::byte, 20>{
std::byte{ 19 }, std::byte{ 'B' }, std::byte{ 'i' }, std::byte{ 't' }, std::byte{ 'T' },
std::byte{ 'o' }, std::byte{ 'r' }, std::byte{ 'r' }, std::byte{ 'e' }, std::byte{ 'n' },
std::byte{ 't' }, std::byte{ ' ' }, std::byte{ 'p' }, std::byte{ 'r' }, std::byte{ 'o' },
std::byte{ 't' }, std::byte{ 'o' }, std::byte{ 'c' }, std::byte{ 'o' }, std::byte{ 'l' }
};
// bittorrent handshake constants
static auto constexpr HandshakeNameLen = int{ 20 };
static auto constexpr HandshakeFlagsLen = int{ 8 };
static auto constexpr HandshakeSize = int{ 68 };
static auto constexpr IncomingHandshakeLen = int{ 48 };
@ -224,9 +227,9 @@ static bool buildHandshakeMessage(tr_handshake const* const handshake, uint8_t*
uint8_t* walk = buf;
walk = std::copy_n(HANDSHAKE_NAME, HandshakeNameLen, walk);
walk = std::copy_n(reinterpret_cast<uint8_t const*>(std::data(HandshakeName)), std::size(HandshakeName), walk);
memset(walk, 0, HandshakeFlagsLen);
std::fill_n(walk, HandshakeFlagsLen, 0);
HANDSHAKE_SET_LTEP(walk);
HANDSHAKE_SET_FASTEXT(walk);
/* Note that this doesn't depend on whether the torrent is private.
@ -255,31 +258,31 @@ enum handshake_parse_err_t
HANDSHAKE_PEER_IS_SELF,
};
static handshake_parse_err_t parseHandshake(tr_handshake* handshake, struct evbuffer* inbuf)
static handshake_parse_err_t parseHandshake(tr_handshake* handshake, tr_peerIo* peer_io)
{
tr_logAddTraceHand(handshake, fmt::format("payload: need {}, got {}", HandshakeSize, evbuffer_get_length(inbuf)));
tr_logAddTraceHand(handshake, fmt::format("payload: need {}, got {}", HandshakeSize, peer_io->readBufferSize()));
if (evbuffer_get_length(inbuf) < HandshakeSize)
if (peer_io->readBufferSize() < HandshakeSize)
{
return HANDSHAKE_ENCRYPTION_WRONG;
}
/* confirm the protocol */
auto name = std::array<uint8_t, HandshakeNameLen>{};
handshake->io->readBytes(std::data(name), std::size(name));
if (memcmp(std::data(name), HANDSHAKE_NAME, std::size(name)) != 0)
auto name = decltype(HandshakeName){};
peer_io->readBytes(std::data(name), std::size(name));
if (name != HandshakeName)
{
return HANDSHAKE_ENCRYPTION_WRONG;
}
/* read the reserved bytes */
auto reserved = std::array<uint8_t, HandshakeFlagsLen>{};
handshake->io->readBytes(std::data(reserved), std::size(reserved));
peer_io->readBytes(std::data(reserved), std::size(reserved));
/* torrent hash */
auto hash = tr_sha1_digest_t{};
handshake->io->readBytes(std::data(hash), std::size(hash));
if (auto const torrent_hash = handshake->io->torrentHash(); !torrent_hash || *torrent_hash != hash)
peer_io->readBytes(std::data(hash), std::size(hash));
if (auto const torrent_hash = peer_io->torrentHash(); !torrent_hash || *torrent_hash != hash)
{
tr_logAddTraceHand(handshake, "peer returned the wrong hash. wtf?");
return HANDSHAKE_BAD_TORRENT;
@ -287,7 +290,7 @@ static handshake_parse_err_t parseHandshake(tr_handshake* handshake, struct evbu
// peer_id
auto peer_id = tr_peer_id_t{};
handshake->io->readBytes(std::data(peer_id), std::size(peer_id));
peer_io->readBytes(std::data(peer_id), std::size(peer_id));
handshake->peer_id = peer_id;
/* peer id */
@ -304,9 +307,9 @@ static handshake_parse_err_t parseHandshake(tr_handshake* handshake, struct evbu
*** Extensions
**/
handshake->io->enableDHT(HANDSHAKE_HAS_DHT(reserved));
handshake->io->enableLTEP(HANDSHAKE_HAS_LTEP(reserved));
handshake->io->enableFEXT(HANDSHAKE_HAS_FASTEXT(reserved));
peer_io->enableDHT(HANDSHAKE_HAS_DHT(reserved));
peer_io->enableLTEP(HANDSHAKE_HAS_LTEP(reserved));
peer_io->enableFEXT(HANDSHAKE_HAS_FASTEXT(reserved));
return HANDSHAKE_OK;
}
@ -369,26 +372,19 @@ static constexpr uint32_t getCryptoSelect(tr_encryption_mode encryption_mode, ui
return 0;
}
static ReadState readYb(tr_handshake* handshake, struct evbuffer* inbuf)
static ReadState readYb(tr_handshake* handshake, tr_peerIo* peer_io)
{
size_t needlen = HandshakeNameLen;
if (evbuffer_get_length(inbuf) < needlen)
auto const* const peek = peer_io->peek(std::size(HandshakeName));
if (peek == nullptr)
{
return READ_LATER;
}
bool const is_encrypted = memcmp(evbuffer_pullup(inbuf, HandshakeNameLen), HANDSHAKE_NAME, HandshakeNameLen) != 0;
bool const is_encrypted = !std::equal(std::begin(HandshakeName), std::end(HandshakeName), peek);
auto peer_public_key = DH::key_bigend_t{};
if (is_encrypted)
if (is_encrypted && (peer_io->readBufferSize() < std::size(peer_public_key)))
{
needlen = std::size(peer_public_key);
if (evbuffer_get_length(inbuf) < needlen)
{
return READ_LATER;
}
return READ_LATER;
}
tr_logAddTraceHand(handshake, is_encrypted ? "got an encrypted handshake" : "got a plain handshake");
@ -402,7 +398,7 @@ static ReadState readYb(tr_handshake* handshake, struct evbuffer* inbuf)
handshake->haveReadAnythingFromPeer = true;
// get the peer's public key
evbuffer_remove(inbuf, std::data(peer_public_key), std::size(peer_public_key));
peer_io->readBytes(std::data(peer_public_key), std::size(peer_public_key));
handshake->dh.setPeerPublicKey(peer_public_key);
/* now send these: HASH('req1', S), HASH('req2', SKEY) xor HASH('req3', S),
@ -413,7 +409,7 @@ static ReadState readYb(tr_handshake* handshake, struct evbuffer* inbuf)
auto const req1 = tr_sha1::digest("req1"sv, handshake->dh.secret());
evbuffer_add(outbuf, std::data(req1), std::size(req1));
auto const info_hash = handshake->io->torrentHash();
auto const info_hash = peer_io->torrentHash();
if (!info_hash)
{
tr_logAddTraceHand(handshake, "error while computing req2/req3 hash after Yb");
@ -436,8 +432,8 @@ static ReadState readYb(tr_handshake* handshake, struct evbuffer* inbuf)
/* ENCRYPT(VC, crypto_provide, len(PadC), PadC
* PadC is reserved for future extensions to the handshake...
* standard practice at this time is for it to be zero-length */
handshake->io->writeBuf(outbuf, false);
handshake->io->encryptInit(handshake->io->isIncoming(), handshake->dh, *info_hash);
peer_io->writeBuf(outbuf, false);
peer_io->encryptInit(peer_io->isIncoming(), handshake->dh, *info_hash);
evbuffer_add(outbuf, std::data(VC), std::size(VC));
evbuffer_add_uint32(outbuf, handshake->cryptoProvide());
evbuffer_add_uint16(outbuf, 0);
@ -455,9 +451,8 @@ static ReadState readYb(tr_handshake* handshake, struct evbuffer* inbuf)
}
/* send it */
handshake->io->decryptInit(handshake->io->isIncoming(), handshake->dh, *info_hash);
setReadState(handshake, AWAITING_VC);
handshake->io->writeBuf(outbuf, false);
peer_io->writeBuf(outbuf, false);
/* cleanup */
evbuffer_free(outbuf);
@ -466,51 +461,52 @@ static ReadState readYb(tr_handshake* handshake, struct evbuffer* inbuf)
// MSE spec: "Since the length of [PadB is] unknown,
// A will be able to resynchronize on ENCRYPT(VC)"
static ReadState readVC(tr_handshake* handshake, struct evbuffer* inbuf)
static ReadState readVC(tr_handshake* handshake, tr_peerIo* peer_io)
{
// find the end of PadB by looking for `ENCRYPT(VC)`
auto needle = VC;
auto filter = tr_message_stream_encryption::Filter{};
filter.encryptInit(true, handshake->dh, *handshake->io->torrentHash());
filter.encryptInit(true, handshake->dh, *peer_io->torrentHash());
filter.encrypt(std::size(needle), std::data(needle));
for (size_t i = 0; i < PadbMaxlen; ++i)
{
if (evbuffer_get_length(inbuf) < std::size(needle))
auto const* const peek = peer_io->peek(std::size(needle));
if (peek == nullptr)
{
tr_logAddTraceHand(handshake, "not enough bytes... returning read_more");
return READ_LATER;
}
auto const* peek = reinterpret_cast<std::byte const*>(evbuffer_pullup(inbuf, std::size(needle)));
if (std::equal(std::begin(needle), std::end(needle), peek))
{
tr_logAddTraceHand(handshake, "got it!");
// We already know it's a match; now we just need to
// consume it from the read buffer.
handshake->io->readBytes(std::data(needle), std::size(needle));
peer_io->decryptInit(peer_io->isIncoming(), handshake->dh, *peer_io->torrentHash());
peer_io->readBytes(std::data(needle), std::size(needle));
setState(handshake, AWAITING_CRYPTO_SELECT);
return READ_NOW;
}
evbuffer_drain(inbuf, 1);
peer_io->readBufferDrain(1);
}
tr_logAddTraceHand(handshake, "couldn't find ENCRYPT(VC)");
return tr_handshakeDone(handshake, false);
}
static ReadState readCryptoSelect(tr_handshake* handshake, struct evbuffer* inbuf)
static ReadState readCryptoSelect(tr_handshake* handshake, tr_peerIo* peer_io)
{
static size_t const needlen = sizeof(uint32_t) + sizeof(uint16_t);
if (evbuffer_get_length(inbuf) < needlen)
if (peer_io->readBufferSize() < needlen)
{
return READ_LATER;
}
uint32_t crypto_select = 0;
handshake->io->readUint32(&crypto_select);
peer_io->readUint32(&crypto_select);
handshake->crypto_select = crypto_select;
tr_logAddTraceHand(handshake, fmt::format("crypto select is {}", crypto_select));
@ -521,7 +517,7 @@ static ReadState readCryptoSelect(tr_handshake* handshake, struct evbuffer* inbu
}
uint16_t pad_d_len = 0;
handshake->io->readUint16(&pad_d_len);
peer_io->readUint16(&pad_d_len);
tr_logAddTraceHand(handshake, fmt::format("pad_d_len is {}", pad_d_len));
if (pad_d_len > 512)
@ -536,18 +532,18 @@ static ReadState readCryptoSelect(tr_handshake* handshake, struct evbuffer* inbu
return READ_NOW;
}
static ReadState readPadD(tr_handshake* handshake, struct evbuffer* inbuf)
static ReadState readPadD(tr_handshake* handshake, tr_peerIo* peer_io)
{
size_t const needlen = handshake->pad_d_len;
tr_logAddTraceHand(handshake, fmt::format("pad d: need {}, got {}", needlen, evbuffer_get_length(inbuf)));
tr_logAddTraceHand(handshake, fmt::format("pad d: need {}, got {}", needlen, peer_io->readBufferSize()));
if (evbuffer_get_length(inbuf) < needlen)
if (peer_io->readBufferSize() < needlen)
{
return READ_LATER;
}
handshake->io->readBufferDrain(needlen);
peer_io->readBufferDrain(needlen);
setState(handshake, AWAITING_HANDSHAKE);
return READ_NOW;
@ -559,20 +555,21 @@ static ReadState readPadD(tr_handshake* handshake, struct evbuffer* inbuf)
****
***/
static ReadState readHandshake(tr_handshake* handshake, struct evbuffer* inbuf)
static ReadState readHandshake(tr_handshake* handshake, tr_peerIo* peer_io)
{
tr_logAddTraceHand(handshake, fmt::format("payload: need {}, got {}", IncomingHandshakeLen, evbuffer_get_length(inbuf)));
tr_logAddTraceHand(handshake, fmt::format("payload: need {}, got {}", IncomingHandshakeLen, peer_io->readBufferSize()));
if (evbuffer_get_length(inbuf) < IncomingHandshakeLen)
auto const* const peek = peer_io->peek(IncomingHandshakeLen);
if (peek == nullptr)
{
return READ_LATER;
}
handshake->haveReadAnythingFromPeer = true;
uint8_t pstrlen = evbuffer_pullup(inbuf, 1)[0]; /* peek, don't read. We may be handing inbuf to AWAITING_YA */
if (pstrlen == 19) /* unencrypted */
// peek instead of reading, because if we decide the handshake is
// encrypted we'll pass the unconsumed buffer to AWAITING_YA
if (std::equal(std::begin(HandshakeName), std::end(HandshakeName), peek)) // unencrypted
{
if (handshake->encryption_mode == TR_ENCRYPTION_REQUIRED)
{
@ -580,7 +577,7 @@ static ReadState readHandshake(tr_handshake* handshake, struct evbuffer* inbuf)
return tr_handshakeDone(handshake, false);
}
}
else /* encrypted or corrupt */
else // either encrypted or corrupt
{
if (handshake->isIncoming())
{
@ -588,44 +585,30 @@ static ReadState readHandshake(tr_handshake* handshake, struct evbuffer* inbuf)
setState(handshake, AWAITING_YA);
return READ_NOW;
}
handshake->io->decrypt(1, &pstrlen);
if (pstrlen != 19)
{
tr_logAddTraceHand(handshake, "I think peer has sent us a corrupt handshake...");
return tr_handshakeDone(handshake, false);
}
}
evbuffer_drain(inbuf, 1);
/* pstr (BitTorrent) */
TR_ASSERT(pstrlen == 19);
auto pstr = std::array<uint8_t, 20>{};
handshake->io->readBytes(std::data(pstr), pstrlen);
pstr[pstrlen] = '\0';
if (strncmp(reinterpret_cast<char const*>(std::data(pstr)), "BitTorrent protocol", 19) != 0)
auto name = decltype(HandshakeName){};
peer_io->readBytes(std::data(name), std::size(name));
if (name != HandshakeName)
{
return tr_handshakeDone(handshake, false);
}
/* reserved bytes */
auto reserved = std::array<uint8_t, HandshakeFlagsLen>{};
handshake->io->readBytes(std::data(reserved), std::size(reserved));
peer_io->readBytes(std::data(reserved), std::size(reserved));
/**
*** Extensions
**/
handshake->io->enableDHT(HANDSHAKE_HAS_DHT(reserved));
handshake->io->enableLTEP(HANDSHAKE_HAS_LTEP(reserved));
handshake->io->enableFEXT(HANDSHAKE_HAS_FASTEXT(reserved));
peer_io->enableDHT(HANDSHAKE_HAS_DHT(reserved));
peer_io->enableLTEP(HANDSHAKE_HAS_LTEP(reserved));
peer_io->enableFEXT(HANDSHAKE_HAS_FASTEXT(reserved));
/* torrent hash */
auto hash = tr_sha1_digest_t{};
handshake->io->readBytes(std::data(hash), std::size(hash));
peer_io->readBytes(std::data(hash), std::size(hash));
if (handshake->isIncoming())
{
@ -635,11 +618,11 @@ static ReadState readHandshake(tr_handshake* handshake, struct evbuffer* inbuf)
return tr_handshakeDone(handshake, false);
}
handshake->io->setTorrentHash(hash);
peer_io->setTorrentHash(hash);
}
else /* outgoing */
{
auto const torrent_hash = handshake->io->torrentHash();
auto const torrent_hash = peer_io->torrentHash();
if (!torrent_hash || *torrent_hash != hash)
{
@ -661,7 +644,7 @@ static ReadState readHandshake(tr_handshake* handshake, struct evbuffer* inbuf)
return tr_handshakeDone(handshake, false);
}
handshake->io->writeBytes(std::data(msg), std::size(msg), false);
peer_io->writeBytes(std::data(msg), std::size(msg), false);
handshake->haveSentBitTorrentHandshake = true;
}
@ -669,15 +652,15 @@ static ReadState readHandshake(tr_handshake* handshake, struct evbuffer* inbuf)
return READ_NOW;
}
static ReadState readPeerId(tr_handshake* handshake, struct evbuffer* inbuf)
static ReadState readPeerId(tr_handshake* handshake, tr_peerIo* peer_io)
{
// read the peer_id
auto peer_id = tr_peer_id_t{};
if (evbuffer_get_length(inbuf) < std::size(peer_id))
if (peer_io->readBufferSize() < std::size(peer_id))
{
return READ_LATER;
}
handshake->io->readBytes(std::data(peer_id), std::size(peer_id));
peer_io->readBytes(std::data(peer_id), std::size(peer_id));
handshake->peer_id = peer_id;
auto client = std::array<char, 128>{};
@ -687,27 +670,27 @@ static ReadState readPeerId(tr_handshake* handshake, struct evbuffer* inbuf)
fmt::format("peer-id is '{}' ... isIncoming is {}", std::data(client), handshake->isIncoming()));
// if we've somehow connected to ourselves, don't keep the connection
auto const hash = handshake->io->torrentHash();
auto const hash = peer_io->torrentHash();
auto const info = hash ? handshake->mediator->torrentInfo(*hash) : std::nullopt;
auto const connected_to_self = info && info->client_peer_id == peer_id;
return tr_handshakeDone(handshake, !connected_to_self);
}
static ReadState readYa(tr_handshake* handshake, struct evbuffer* inbuf)
static ReadState readYa(tr_handshake* handshake, tr_peerIo* peer_io)
{
auto peer_public_key = DH::key_bigend_t{};
tr_logAddTraceHand(
handshake,
fmt::format("in readYa... need {}, have {}", std::size(peer_public_key), evbuffer_get_length(inbuf)));
fmt::format("in readYa... need {}, have {}", std::size(peer_public_key), peer_io->readBufferSize()));
if (evbuffer_get_length(inbuf) < std::size(peer_public_key))
if (peer_io->readBufferSize() < std::size(peer_public_key))
{
return READ_LATER;
}
/* read the incoming peer's public key */
evbuffer_remove(inbuf, std::data(peer_public_key), std::size(peer_public_key));
peer_io->readBytes(std::data(peer_public_key), std::size(peer_public_key));
handshake->dh.setPeerPublicKey(peer_public_key);
// send our public key to the peer
@ -718,36 +701,36 @@ static ReadState readYa(tr_handshake* handshake, struct evbuffer* inbuf)
return READ_NOW;
}
static ReadState readPadA(tr_handshake* handshake, struct evbuffer* inbuf)
static ReadState readPadA(tr_handshake* handshake, tr_peerIo* peer_io)
{
// find the end of PadA by looking for HASH('req1', S)
auto const needle = tr_sha1::digest("req1"sv, handshake->dh.secret());
for (size_t i = 0; i < PadaMaxlen; ++i)
{
if (evbuffer_get_length(inbuf) < std::size(needle))
auto const* const peek = peer_io->peek(std::size(needle));
if (peek == nullptr)
{
tr_logAddTraceHand(handshake, "not enough bytes... returning read_more");
return READ_LATER;
}
auto const* peek = reinterpret_cast<std::byte const*>(evbuffer_pullup(inbuf, std::size(needle)));
if (std::equal(std::begin(needle), std::end(needle), peek))
{
tr_logAddTraceHand(handshake, "found it... looking setting to awaiting_crypto_provide");
evbuffer_drain(inbuf, std::size(needle));
peer_io->readBufferDrain(std::size(needle));
setState(handshake, AWAITING_CRYPTO_PROVIDE);
return READ_NOW;
}
evbuffer_drain(inbuf, 1);
peer_io->readBufferDrain(1U);
}
tr_logAddTraceHand(handshake, "couldn't find HASH('req', S)");
return tr_handshakeDone(handshake, false);
}
static ReadState readCryptoProvide(tr_handshake* handshake, struct evbuffer* inbuf)
static ReadState readCryptoProvide(tr_handshake* handshake, tr_peerIo* peer_io)
{
/* HASH('req2', SKEY) xor HASH('req3', S), ENCRYPT(VC, crypto_provide, len(PadC)) */
@ -757,7 +740,7 @@ static ReadState readCryptoProvide(tr_handshake* handshake, struct evbuffer* inb
size_t const needlen = sizeof(obfuscated_hash) + /* HASH('req2', SKEY) xor HASH('req3', S) */
std::size(VC) + sizeof(crypto_provide) + sizeof(padc_len);
if (evbuffer_get_length(inbuf) < needlen)
if (peer_io->readBufferSize() < needlen)
{
return READ_LATER;
}
@ -767,7 +750,7 @@ static ReadState readCryptoProvide(tr_handshake* handshake, struct evbuffer* inb
* by building the latter and xor'ing it with what the peer sent us */
tr_logAddTraceHand(handshake, "reading obfuscated torrent hash...");
auto req2 = tr_sha1_digest_t{};
evbuffer_remove(inbuf, std::data(req2), std::size(req2));
peer_io->readBytes(std::data(req2), std::size(req2));
auto const req3 = tr_sha1::digest("req3"sv, handshake->dh.secret());
for (size_t i = 0; i < std::size(obfuscated_hash); ++i)
@ -778,9 +761,9 @@ static ReadState readCryptoProvide(tr_handshake* handshake, struct evbuffer* inb
if (auto const info = handshake->mediator->torrentInfoFromObfuscated(obfuscated_hash); info)
{
bool const client_is_seed = info->is_done;
bool const peer_is_seed = handshake->mediator->isPeerKnownSeed(info->id, handshake->io->address());
bool const peer_is_seed = handshake->mediator->isPeerKnownSeed(info->id, peer_io->address());
tr_logAddTraceHand(handshake, fmt::format("got INCOMING connection's encrypted handshake for torrent [{}]", info->id));
handshake->io->setTorrentHash(info->info_hash);
peer_io->setTorrentHash(info->info_hash);
if (client_is_seed && peer_is_seed)
{
@ -796,16 +779,16 @@ static ReadState readCryptoProvide(tr_handshake* handshake, struct evbuffer* inb
/* next part: ENCRYPT(VC, crypto_provide, len(PadC), */
handshake->io->decryptInit(handshake->io->isIncoming(), handshake->dh, *handshake->io->torrentHash());
peer_io->decryptInit(peer_io->isIncoming(), handshake->dh, *peer_io->torrentHash());
auto vc_in = vc_t{};
handshake->io->readBytes(std::data(vc_in), std::size(vc_in));
peer_io->readBytes(std::data(vc_in), std::size(vc_in));
handshake->io->readUint32(&crypto_provide);
peer_io->readUint32(&crypto_provide);
handshake->crypto_provide = crypto_provide;
tr_logAddTraceHand(handshake, fmt::format("crypto_provide is {}", crypto_provide));
handshake->io->readUint16(&padc_len);
peer_io->readUint16(&padc_len);
tr_logAddTraceHand(handshake, fmt::format("padc is {}", padc_len));
if (padc_len > PadcMaxlen)
{
@ -818,33 +801,33 @@ static ReadState readCryptoProvide(tr_handshake* handshake, struct evbuffer* inb
return READ_NOW;
}
static ReadState readPadC(tr_handshake* handshake, struct evbuffer* inbuf)
static ReadState readPadC(tr_handshake* handshake, tr_peerIo* peer_io)
{
if (auto const needlen = handshake->pad_c_len + sizeof(uint16_t); evbuffer_get_length(inbuf) < needlen)
if (auto const needlen = handshake->pad_c_len + sizeof(uint16_t); peer_io->readBufferSize() < needlen)
{
return READ_LATER;
}
// read the throwaway padc
auto pad_c = std::array<char, PadcMaxlen>{};
handshake->io->readBytes(std::data(pad_c), handshake->pad_c_len);
peer_io->readBytes(std::data(pad_c), handshake->pad_c_len);
/* read ia_len */
uint16_t ia_len = 0;
handshake->io->readUint16(&ia_len);
peer_io->readUint16(&ia_len);
tr_logAddTraceHand(handshake, fmt::format("ia_len is {}", ia_len));
handshake->ia_len = ia_len;
setState(handshake, AWAITING_IA);
return READ_NOW;
}
static ReadState readIA(tr_handshake* handshake, struct evbuffer const* inbuf)
static ReadState readIA(tr_handshake* handshake, tr_peerIo* peer_io)
{
size_t const needlen = handshake->ia_len;
tr_logAddTraceHand(handshake, fmt::format("reading IA... have {}, need {}", evbuffer_get_length(inbuf), needlen));
tr_logAddTraceHand(handshake, fmt::format("reading IA... have {}, need {}", peer_io->readBufferSize(), needlen));
if (evbuffer_get_length(inbuf) < needlen)
if (peer_io->readBufferSize() < needlen)
{
return READ_LATER;
}
@ -853,7 +836,7 @@ static ReadState readIA(tr_handshake* handshake, struct evbuffer const* inbuf)
*** B->A: ENCRYPT(VC, crypto_select, len(padD), padD), ENCRYPT2(Payload Stream)
**/
handshake->io->encryptInit(handshake->io->isIncoming(), handshake->dh, *handshake->io->torrentHash());
peer_io->encryptInit(peer_io->isIncoming(), handshake->dh, *peer_io->torrentHash());
evbuffer* const outbuf = evbuffer_new();
// send VC
@ -888,7 +871,7 @@ static ReadState readIA(tr_handshake* handshake, struct evbuffer const* inbuf)
/* maybe de-encrypt our connection */
if (crypto_select == CryptoProvidePlaintext)
{
handshake->io->writeBuf(outbuf, false);
peer_io->writeBuf(outbuf, false);
}
tr_logAddTraceHand(handshake, "sending handshake");
@ -905,7 +888,7 @@ static ReadState readIA(tr_handshake* handshake, struct evbuffer const* inbuf)
}
/* send it out */
handshake->io->writeBuf(outbuf, false);
peer_io->writeBuf(outbuf, false);
evbuffer_free(outbuf);
/* now await the handshake */
@ -913,21 +896,21 @@ static ReadState readIA(tr_handshake* handshake, struct evbuffer const* inbuf)
return READ_NOW;
}
static ReadState readPayloadStream(tr_handshake* handshake, struct evbuffer* inbuf)
static ReadState readPayloadStream(tr_handshake* handshake, tr_peerIo* peer_io)
{
size_t const needlen = HandshakeSize;
tr_logAddTraceHand(
handshake,
fmt::format("reading payload stream... have {}, need {}", evbuffer_get_length(inbuf), needlen));
fmt::format("reading payload stream... have {}, need {}", peer_io->readBufferSize(), needlen));
if (evbuffer_get_length(inbuf) < needlen)
if (peer_io->readBufferSize() < needlen)
{
return READ_LATER;
}
/* parse the handshake ... */
handshake_parse_err_t const i = parseHandshake(handshake, inbuf);
handshake_parse_err_t const i = parseHandshake(handshake, peer_io);
tr_logAddTraceHand(handshake, fmt::format("parseHandshake returned {}", i));
if (i != HANDSHAKE_OK)
@ -945,13 +928,12 @@ static ReadState readPayloadStream(tr_handshake* handshake, struct evbuffer* inb
****
***/
static ReadState canRead(tr_peerIo* io, void* vhandshake, size_t* piece)
static ReadState canRead(tr_peerIo* peer_io, void* vhandshake, size_t* piece)
{
TR_ASSERT(tr_isPeerIo(io));
TR_ASSERT(tr_isPeerIo(peer_io));
auto* handshake = static_cast<tr_handshake*>(vhandshake);
auto* const inbuf = io->readBuffer();
bool ready_for_more = true;
/* no piece data in handshake */
@ -965,51 +947,51 @@ static ReadState canRead(tr_peerIo* io, void* vhandshake, size_t* piece)
switch (handshake->state)
{
case AWAITING_HANDSHAKE:
ret = readHandshake(handshake, inbuf);
ret = readHandshake(handshake, peer_io);
break;
case AWAITING_PEER_ID:
ret = readPeerId(handshake, inbuf);
ret = readPeerId(handshake, peer_io);
break;
case AWAITING_YA:
ret = readYa(handshake, inbuf);
ret = readYa(handshake, peer_io);
break;
case AWAITING_PAD_A:
ret = readPadA(handshake, inbuf);
ret = readPadA(handshake, peer_io);
break;
case AWAITING_CRYPTO_PROVIDE:
ret = readCryptoProvide(handshake, inbuf);
ret = readCryptoProvide(handshake, peer_io);
break;
case AWAITING_PAD_C:
ret = readPadC(handshake, inbuf);
ret = readPadC(handshake, peer_io);
break;
case AWAITING_IA:
ret = readIA(handshake, inbuf);
ret = readIA(handshake, peer_io);
break;
case AWAITING_PAYLOAD_STREAM:
ret = readPayloadStream(handshake, inbuf);
ret = readPayloadStream(handshake, peer_io);
break;
case AWAITING_YB:
ret = readYb(handshake, inbuf);
ret = readYb(handshake, peer_io);
break;
case AWAITING_VC:
ret = readVC(handshake, inbuf);
ret = readVC(handshake, peer_io);
break;
case AWAITING_CRYPTO_SELECT:
ret = readCryptoSelect(handshake, inbuf);
ret = readCryptoSelect(handshake, peer_io);
break;
case AWAITING_PAD_D:
ret = readPadD(handshake, inbuf);
ret = readPadD(handshake, peer_io);
break;
default:
@ -1027,15 +1009,15 @@ static ReadState canRead(tr_peerIo* io, void* vhandshake, size_t* piece)
}
else if (handshake->state == AWAITING_PAD_C)
{
ready_for_more = evbuffer_get_length(inbuf) >= handshake->pad_c_len;
ready_for_more = peer_io->readBufferSize() >= handshake->pad_c_len;
}
else if (handshake->state == AWAITING_PAD_D)
{
ready_for_more = evbuffer_get_length(inbuf) >= handshake->pad_d_len;
ready_for_more = peer_io->readBufferSize() >= handshake->pad_d_len;
}
else if (handshake->state == AWAITING_IA)
{
ready_for_more = evbuffer_get_length(inbuf) >= handshake->ia_len;
ready_for_more = peer_io->readBufferSize() >= handshake->ia_len;
}
}

View File

@ -140,6 +140,16 @@ public:
return evbuffer_get_length(inbuf.get());
}
[[nodiscard]] std::byte const* peek(size_t n_bytes) const noexcept
{
if (readBufferSize() < n_bytes)
{
return nullptr;
}
return reinterpret_cast<std::byte const*>(evbuffer_pullup(inbuf.get(), n_bytes));
}
void readBufferAdd(void const* data, size_t n_bytes);
int flushOutgoingProtocolMsgs();

View File

@ -331,13 +331,13 @@ TEST_F(FileTest, readFile)
// read from closed file
n_read = 0;
EXPECT_FALSE(tr_sys_file_read(fd, std::data(buf), std::size(buf), &n_read, &err)); // coverity USE_AFTER_FREE
EXPECT_FALSE(tr_sys_file_read(fd, std::data(buf), std::size(buf), &n_read, &err)); // coverity[USE_AFTER_FREE]
EXPECT_EQ(0, n_read);
EXPECT_NE(nullptr, err);
tr_error_clear(&err);
// read_at from closed file
EXPECT_FALSE(tr_sys_file_read_at(fd, std::data(buf), std::size(buf), offset, &n_read, &err)); // coverity USE_AFTER_FREE
EXPECT_FALSE(tr_sys_file_read_at(fd, std::data(buf), std::size(buf), offset, &n_read, &err)); // coverity[USE_AFTER_FREE]
EXPECT_EQ(0, n_read);
EXPECT_NE(nullptr, err);
tr_error_clear(&err);
@ -1211,7 +1211,7 @@ TEST_F(FileTest, fileTruncate)
EXPECT_EQ(25U, info->size);
// try to truncate a closed file
EXPECT_FALSE(tr_sys_file_truncate(fd, 10, &err)); // coverity USE_AFTER_FREE
EXPECT_FALSE(tr_sys_file_truncate(fd, 10, &err)); // coverity[USE_AFTER_FREE]
EXPECT_NE(nullptr, err);
tr_error_clear(&err);