refactor: base64 utils (#2381)

base64 encode/decode now take std::string_views and return std::strings
This commit is contained in:
Charles Kerr 2022-01-08 06:46:25 -06:00 committed by GitHub
parent 385a119fb1
commit 0c16c454ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 78 additions and 255 deletions

View File

@ -8,8 +8,9 @@
#include <algorithm>
#include <array>
#include <cstring> /* memmove(), memset(), strlen() */
#include <random> /* random_device, mt19937, uniform_int_distribution*/
#include <cstring> // memmove(), memset()
#include <iterator>
#include <random>
#include <string>
#include <string_view>
@ -137,105 +138,41 @@ bool tr_ssha1_matches(std::string_view ssha1, std::string_view plaintext)
****
***/
void* tr_base64_encode(void const* input, size_t input_length, size_t* output_length)
static size_t base64_alloc_size(std::string_view input)
{
char* ret = nullptr;
if (input != nullptr)
{
if (input_length != 0)
{
size_t ret_length = 4 * ((input_length + 2) / 3);
base64_encodestate state;
size_t ret_length = 4 * ((std::size(input) + 2) / 3);
#ifdef USE_SYSTEM_B64
/* Additional space is needed for newlines if we're using unpatched libb64 */
ret_length += ret_length / 72 + 1;
// Additional space is needed for newlines if we're using unpatched libb64
ret_length += ret_length / 72 + 1;
#endif
ret = tr_new(char, ret_length + 8);
base64_init_encodestate(&state);
ret_length = base64_encode_block(static_cast<char const*>(input), input_length, ret, &state);
ret_length += base64_encode_blockend(ret + ret_length, &state);
if (output_length != nullptr)
{
*output_length = ret_length;
}
ret[ret_length] = '\0';
return ret;
}
ret = tr_strdup("");
}
if (output_length != nullptr)
{
*output_length = 0;
}
return ret;
return ret_length * 8;
}
void* tr_base64_encode_str(char const* input, size_t* output_length)
std::string tr_base64_encode(std::string_view input)
{
return tr_base64_encode(input, input == nullptr ? 0 : strlen(input), output_length);
}
void* tr_base64_decode(void const* input, size_t input_length, size_t* output_length)
{
char* ret = nullptr;
if (input != nullptr)
{
if (input_length != 0)
{
size_t ret_length = input_length / 4 * 3;
base64_decodestate state;
ret = tr_new(char, ret_length + 8);
base64_init_decodestate(&state);
ret_length = base64_decode_block(static_cast<char const*>(input), input_length, ret, &state);
if (output_length != nullptr)
{
*output_length = ret_length;
}
ret[ret_length] = '\0';
return ret;
}
ret = tr_strdup("");
}
if (output_length != nullptr)
{
*output_length = 0;
}
return ret;
}
void* tr_base64_decode_str(char const* input, size_t* output_length)
{
return tr_base64_decode(input, input == nullptr ? 0 : strlen(input), output_length);
}
std::string tr_base64_decode_str(std::string_view input)
{
auto len = size_t{};
auto* buf = tr_base64_decode(std::data(input), std::size(input), &len);
auto str = std::string{ reinterpret_cast<char const*>(buf), len };
tr_free(buf);
auto buf = std::vector<char>(base64_alloc_size(input));
auto state = base64_encodestate{};
base64_init_encodestate(&state);
size_t len = base64_encode_block(std::data(input), std::size(input), std::data(buf), &state);
len += base64_encode_blockend(std::data(buf) + len, &state);
auto str = std::string{};
std::copy_if(
std::data(buf),
std::data(buf) + len,
std::back_inserter(str),
[](auto ch) { return !tr_strvContains("\r\n"sv, ch); });
return str;
}
std::string tr_base64_decode(std::string_view input)
{
auto buf = std::vector<char>(std::size(input) + 8);
auto state = base64_decodestate{};
base64_init_decodestate(&state);
size_t const len = base64_decode_block(std::data(input), std::size(input), std::data(buf), &state);
return std::string{ std::data(buf), len };
}
/***
****
***/

View File

@ -171,35 +171,17 @@ bool tr_ssha1_test(std::string_view text);
*/
bool tr_ssha1_matches(std::string_view ssha1, std::string_view plain_text);
/**
* @brief Translate a block of bytes into base64.
* @return a newly-allocated null-terminated string that can be freed with tr_free()
*/
void* tr_base64_encode(void const* input, size_t input_length, size_t* output_length) TR_GNUC_MALLOC;
/**
* @brief Translate null-terminated string into base64.
* @return a newly-allocated null-terminated string that can be freed with tr_free()
* @return a new std::string with the encoded contents
*/
void* tr_base64_encode_str(char const* input, size_t* output_length) TR_GNUC_MALLOC;
/**
* @brief Translate a block of bytes from base64 into raw form.
* @return a newly-allocated null-terminated string that can be freed with tr_free()
*/
void* tr_base64_decode(void const* input, size_t input_length, size_t* output_length) TR_GNUC_MALLOC;
/**
* @brief Translate null-terminated string from base64 into raw form.
* @return a newly-allocated null-terminated string that can be freed with tr_free()
*/
void* tr_base64_decode_str(char const* input, size_t* output_length) TR_GNUC_MALLOC;
std::string tr_base64_encode(std::string_view input);
/**
* @brief Translate a character range from base64 into raw form.
* @return a new std::string with the decoded contents.
*/
std::string tr_base64_decode_str(std::string_view input);
std::string tr_base64_decode(std::string_view input);
/**
* @brief Generate an ascii hex string for a sha1 digest.

View File

@ -177,11 +177,10 @@ static void handle_upload(struct evhttp_request* req, tr_rpc_server* server)
{
for (auto const& p : parts)
{
auto const& body = p.body;
auto body_len = std::size(body);
if (body_len >= 2 && memcmp(&body[body_len - 2], "\r\n", 2) == 0)
auto body = std::string_view{ p.body };
if (tr_strvEndsWith(body, "\r\n"sv))
{
body_len -= 2;
body.remove_suffix(2);
}
auto top = tr_variant{};
@ -199,9 +198,7 @@ static void handle_upload(struct evhttp_request* req, tr_rpc_server* server)
}
else if (tr_variantFromBuf(&test, TR_VARIANT_PARSE_BENC | TR_VARIANT_PARSE_INPLACE, body))
{
auto* b64 = static_cast<char*>(tr_base64_encode(body.c_str(), body_len, nullptr));
tr_variantDictAddStr(args, TR_KEY_metainfo, b64);
tr_free(b64);
tr_variantDictAddStrView(args, TR_KEY_metainfo, tr_base64_encode(body));
have_source = true;
}
@ -585,7 +582,7 @@ static bool isAuthorized(tr_rpc_server const* server, char const* auth_header)
}
auth.remove_prefix(std::size(Prefix));
auto const decoded_str = tr_base64_decode_str(auth);
auto const decoded_str = tr_base64_decode(auth);
auto decoded = std::string_view{ decoded_str };
auto const username = tr_strvSep(&decoded, ':');
auto const password = decoded;

View File

@ -681,9 +681,8 @@ static void initField(
if (tor->hasMetadata())
{
auto const bytes = tor->createPieceBitfield();
auto* enc = static_cast<char*>(tr_base64_encode(bytes.data(), std::size(bytes), nullptr));
tr_variantInitStr(initme, enc != nullptr ? std::string_view{ enc } : ""sv);
tr_free(enc);
auto const enc = tr_base64_encode({ reinterpret_cast<char const*>(std::data(bytes)), std::size(bytes) });
tr_variantInitStrView(initme, enc);
}
else
{
@ -1708,7 +1707,7 @@ static char const* torrentAdd(tr_session* session, tr_variant* args_in, tr_varia
{
if (std::empty(filename))
{
std::string const metainfo = tr_base64_decode_str(metainfo_base64);
auto const metainfo = tr_base64_decode(metainfo_base64);
tr_ctorSetMetainfo(ctor, std::data(metainfo), std::size(metainfo), nullptr);
}
else

View File

@ -11,7 +11,6 @@
#include <libtransmission/transmission.h>
#include <libtransmission/crypto-utils.h> // tr_base64_encode()
#include <libtransmission/torrent-metainfo.h>
#include <libtransmission/utils.h>
#include <libtransmission/error.h>
@ -64,13 +63,10 @@ int AddData::set(QString const& key)
}
else
{
size_t len;
void* raw = tr_base64_decode(key.toUtf8().constData(), key.toUtf8().size(), &len);
if (raw != nullptr)
auto raw = QByteArray::fromBase64(key.toUtf8());
if (!raw.isEmpty())
{
metainfo.append(static_cast<char const*>(raw), int(len));
tr_free(raw);
metainfo.append(raw);
type = METAINFO;
}
else
@ -84,17 +80,7 @@ int AddData::set(QString const& key)
QByteArray AddData::toBase64() const
{
QByteArray ret;
if (!metainfo.isEmpty())
{
size_t len;
void* b64 = tr_base64_encode(metainfo.constData(), metainfo.size(), &len);
ret = QByteArray(static_cast<char const*>(b64), int(len));
tr_free(b64);
}
return ret;
return metainfo.toBase64();
}
QString AddData::readableName() const

View File

@ -56,10 +56,8 @@
#define tr_ssha1_matches tr_ssha1_matches_
#define tr_ssha1_test tr_ssha1_test_
#define tr_base64_encode tr_base64_encode_
#define tr_base64_encode_str tr_base64_encode_str_
#define tr_base64_encode_impl tr_base64_encode_impl_
#define tr_base64_decode tr_base64_decode_
#define tr_base64_decode_str tr_base64_decode_str_
#define tr_base64_decode_impl tr_base64_decode_impl_
#define tr_sha1_to_string tr_sha1_to_string_
#define tr_sha1_from_string tr_sha1_from_string_
@ -116,10 +114,8 @@
#undef tr_ssha1_matches
#undef tr_ssha1_test
#undef tr_base64_encode
#undef tr_base64_encode_str
#undef tr_base64_encode_impl
#undef tr_base64_decode
#undef tr_base64_decode_str
#undef tr_base64_decode_impl
#undef tr_sha1_to_string
#undef tr_sha1_from_string
@ -169,10 +165,8 @@
#define tr_ssha1_matches_ tr_ssha1_matches
#define tr_ssha1_test_ tr_ssha1_test
#define tr_base64_encode_ tr_base64_encode
#define tr_base64_encode_str_ tr_base64_encode_str
#define tr_base64_encode_impl_ tr_base64_encode_impl
#define tr_base64_decode_ tr_base64_decode
#define tr_base64_decode_str_ tr_base64_decode_str
#define tr_base64_decode_impl_ tr_base64_decode_impl
#define tr_sha1_to_string_ tr_sha1_to_string
#define tr_sha1_from_string_ tr_sha1_from_string

View File

@ -203,88 +203,31 @@ TEST(Crypto, random)
}
}
static bool base64Eq(char const* a, char const* b)
{
for (;; ++a, ++b)
{
while (*a == '\r' || *a == '\n')
{
++a;
}
while (*b == '\r' || *b == '\n')
{
++b;
}
if (*a == '\0' || *b == '\0' || *a != *b)
{
break;
}
}
return *a == *b;
}
TEST(Crypto, base64)
{
auto len = size_t{};
auto* out = static_cast<char*>(tr_base64_encode_str("YOYO!", &len));
EXPECT_EQ(strlen(out), len);
EXPECT_TRUE(base64Eq("WU9ZTyE=", out));
auto* in = static_cast<char*>(tr_base64_decode_str(out, &len));
EXPECT_EQ(decltype(len){ 5 }, len);
EXPECT_STREQ("YOYO!", in);
tr_free(in);
tr_free(out);
auto raw = std::string_view{ "YOYO!"sv };
auto encoded = tr_base64_encode(raw);
EXPECT_EQ("WU9ZTyE="sv, encoded);
EXPECT_EQ(raw, tr_base64_decode(encoded));
out = static_cast<char*>(tr_base64_encode("", 0, &len));
EXPECT_EQ(size_t{}, len);
EXPECT_STREQ("", out);
tr_free(out);
out = static_cast<char*>(tr_base64_decode("", 0, &len));
EXPECT_EQ(0, len);
EXPECT_STREQ("", out);
tr_free(out);
out = static_cast<char*>(tr_base64_encode(nullptr, 0, &len));
EXPECT_EQ(0, len);
EXPECT_EQ(nullptr, out);
out = static_cast<char*>(tr_base64_decode(nullptr, 0, &len));
EXPECT_EQ(0, len);
EXPECT_EQ(nullptr, out);
EXPECT_EQ(""sv, tr_base64_encode(""sv));
EXPECT_EQ(""sv, tr_base64_decode(""sv));
static auto constexpr MaxBufSize = size_t{ 1024 };
for (size_t i = 1; i <= MaxBufSize; ++i)
{
auto buf = std::array<char, MaxBufSize + 1>{};
auto buf = std::string{};
for (size_t j = 0; j < i; ++j)
{
buf[j] = char(tr_rand_int_weak(256));
buf += char(tr_rand_int_weak(256));
}
EXPECT_EQ(buf, tr_base64_decode(tr_base64_encode(buf)));
out = static_cast<char*>(tr_base64_encode(buf.data(), i, &len));
EXPECT_EQ(strlen(out), len);
in = static_cast<char*>(tr_base64_decode(out, len, &len));
EXPECT_EQ(i, len);
EXPECT_EQ(0, memcmp(in, buf.data(), len));
tr_free(in);
tr_free(out);
buf = std::string{};
for (size_t j = 0; j < i; ++j)
{
buf[j] = char(1 + tr_rand_int_weak(255));
buf += char(1 + tr_rand_int_weak(255));
}
buf[i] = '\0';
out = static_cast<char*>(tr_base64_encode_str(buf.data(), &len));
EXPECT_EQ(strlen(out), len);
in = static_cast<char*>(tr_base64_decode_str(out, &len));
EXPECT_EQ(i, len);
EXPECT_STREQ(buf.data(), in);
tr_free(in);
tr_free(out);
EXPECT_EQ(buf, tr_base64_decode(tr_base64_encode(buf)));
}
}

View File

@ -69,12 +69,10 @@ protected:
tr_torrent* createTorrentFromBase64Metainfo(tr_ctor* ctor, char const* benc_base64)
{
// create the torrent ctor
size_t benc_len;
auto* benc = static_cast<char*>(tr_base64_decode_str(benc_base64, &benc_len));
EXPECT_NE(nullptr, benc);
EXPECT_LT(size_t(0), benc_len);
auto const benc = tr_base64_decode(benc_base64);
EXPECT_LT(0, std::size(benc));
tr_error* error = nullptr;
EXPECT_TRUE(tr_ctorSetMetainfo(ctor, benc, benc_len, &error));
EXPECT_TRUE(tr_ctorSetMetainfo(ctor, std::data(benc), std::size(benc), &error));
EXPECT_EQ(nullptr, error);
tr_ctorSetPaused(ctor, TR_FORCE, true);
@ -83,7 +81,6 @@ protected:
EXPECT_NE(nullptr, tor);
// cleanup
tr_free(benc);
return tor;
}

View File

@ -16,7 +16,7 @@
#include <string>
#include <thread>
#include "crypto-utils.h" // tr_base64_decode_str()
#include "crypto-utils.h" // tr_base64_decode()
#include "error.h"
#include "file.h" // tr_sys_file_*()
#include "platform.h" // TR_PATH_DELIMITER
@ -374,16 +374,13 @@ protected:
"OnByaXZhdGVpMGVlZQ==";
// create the torrent ctor
auto metainfo_len = size_t{};
auto* const metainfo = tr_base64_decode_str(metainfo_base64, &metainfo_len);
EXPECT_NE(nullptr, metainfo);
EXPECT_LT(size_t{ 0 }, metainfo_len);
auto const metainfo = tr_base64_decode(metainfo_base64);
EXPECT_LT(0, std::size(metainfo));
auto* ctor = tr_ctorNew(session_);
tr_error* error = nullptr;
EXPECT_TRUE(tr_ctorSetMetainfo(ctor, static_cast<char const*>(metainfo), metainfo_len, &error));
EXPECT_TRUE(tr_ctorSetMetainfo(ctor, std::data(metainfo), std::size(metainfo), &error));
EXPECT_EQ(nullptr, error);
tr_ctorSetPaused(ctor, TR_FORCE, true);
tr_free(metainfo);
// create the torrent
auto* const tor = tr_torrentNew(ctor, nullptr);

View File

@ -544,19 +544,15 @@ static char* netrc = nullptr;
static char* session_id = nullptr;
static bool UseSSL = false;
static char* getEncodedMetainfo(char const* filename)
static std::string getEncodedMetainfo(char const* filename)
{
size_t len = 0;
char* b64 = nullptr;
uint8_t* buf = tr_loadFile(filename, &len, nullptr);
if (buf != nullptr)
auto contents = std::vector<char>{};
if (tr_loadFile(contents, filename))
{
b64 = static_cast<char*>(tr_base64_encode(buf, len, nullptr));
tr_free(buf);
return tr_base64_encode({ std::data(contents), std::size(contents) });
}
return b64;
return {};
}
static void addIdArg(tr_variant* args, char const* id_str, char const* fallback)
@ -1362,19 +1358,18 @@ static void printPeers(tr_variant* top)
}
}
static void printPiecesImpl(uint8_t const* raw, size_t raw_len, size_t piece_count)
static void printPiecesImpl(std::string_view raw, size_t piece_count)
{
size_t len = 0;
auto* const str = static_cast<char*>(tr_base64_decode(raw, raw_len, &len));
auto const str = tr_base64_decode(raw);
printf(" ");
size_t piece = 0;
size_t const col_width = 64;
for (char const *it = str, *end = it + len; it != end; ++it)
for (auto const ch : str)
{
for (int bit = 0; piece < piece_count && bit < 8; ++bit, ++piece)
{
printf("%c", (*it & (1 << (7 - bit))) != 0 ? '1' : '0');
printf("%c", (ch & (1 << (7 - bit))) != 0 ? '1' : '0');
}
printf(" ");
@ -1386,7 +1381,6 @@ static void printPiecesImpl(uint8_t const* raw, size_t raw_len, size_t piece_cou
}
printf("\n");
tr_free(str);
}
static void printPieces(tr_variant* top)
@ -1399,15 +1393,14 @@ static void printPieces(tr_variant* top)
for (int i = 0, n = tr_variantListSize(torrents); i < n; ++i)
{
int64_t j;
uint8_t const* raw;
size_t rawlen;
auto raw = std::string_view{};
tr_variant* torrent = tr_variantListChild(torrents, i);
if (tr_variantDictFindRaw(torrent, TR_KEY_pieces, &raw, &rawlen) &&
if (tr_variantDictFindStrView(torrent, TR_KEY_pieces, &raw) &&
tr_variantDictFindInt(torrent, TR_KEY_pieceCount, &j))
{
assert(j >= 0);
printPiecesImpl(raw, rawlen, (size_t)j);
printPiecesImpl(raw, (size_t)j);
if (i + 1 < n)
{
@ -2359,18 +2352,16 @@ static int processArgs(char const* rpcurl, int argc, char const* const* argv)
if (tadd != nullptr)
{
tr_variant* args = tr_variantDictFind(tadd, Arguments);
char* tmp = getEncodedMetainfo(optarg);
std::string const tmp = getEncodedMetainfo(optarg);
if (tmp != nullptr)
if (!std::empty(tmp))
{
tr_variantDictAddStr(args, TR_KEY_metainfo, tmp);
tr_variantDictAddStrView(args, TR_KEY_metainfo, tmp);
}
else
{
tr_variantDictAddStr(args, TR_KEY_filename, optarg);
}
tr_free(tmp);
}
else
{