minor key.encrypt api change/cleanup

we already have .decrypt(id, data, ...).
i changed .encrypt(chunk) to .encrypt(id, data).

the old borg crypto won't really need or use the id,
but the new AEAD crypto will authenticate the id in future.
This commit is contained in:
Thomas Waldmann 2022-03-21 12:33:11 +01:00
parent 41b8a04d82
commit d3b78a6cf5
7 changed files with 46 additions and 36 deletions

View File

@ -1789,7 +1789,7 @@ class ArchiveChecker:
def add_callback(chunk): def add_callback(chunk):
id_ = self.key.id_hash(chunk) id_ = self.key.id_hash(chunk)
cdata = self.key.encrypt(chunk) cdata = self.key.encrypt(id_, chunk)
add_reference(id_, len(chunk), len(cdata), cdata) add_reference(id_, len(chunk), len(cdata), cdata)
return id_ return id_
@ -1811,7 +1811,7 @@ class ArchiveChecker:
def replacement_chunk(size): def replacement_chunk(size):
chunk = Chunk(None, allocation=CH_ALLOC, size=size) chunk = Chunk(None, allocation=CH_ALLOC, size=size)
chunk_id, data = cached_hash(chunk, self.key.id_hash) chunk_id, data = cached_hash(chunk, self.key.id_hash)
cdata = self.key.encrypt(data) cdata = self.key.encrypt(chunk_id, data)
csize = len(cdata) csize = len(cdata)
return chunk_id, size, csize, cdata return chunk_id, size, csize, cdata
@ -1998,7 +1998,7 @@ class ArchiveChecker:
archive.items = items_buffer.chunks archive.items = items_buffer.chunks
data = msgpack.packb(archive.as_dict()) data = msgpack.packb(archive.as_dict())
new_archive_id = self.key.id_hash(data) new_archive_id = self.key.id_hash(data)
cdata = self.key.encrypt(data) cdata = self.key.encrypt(new_archive_id, data)
add_reference(new_archive_id, len(data), len(cdata), cdata) add_reference(new_archive_id, len(data), len(cdata), cdata)
self.manifest.archives[info.name] = (new_archive_id, info.ts) self.manifest.archives[info.name] = (new_archive_id, info.ts)
pi.finish() pi.finish()

View File

@ -942,7 +942,7 @@ class LocalCache(CacheStatsMixin):
refcount = self.seen_chunk(id, size) refcount = self.seen_chunk(id, size)
if refcount and not overwrite: if refcount and not overwrite:
return self.chunk_incref(id, stats) return self.chunk_incref(id, stats)
data = self.key.encrypt(chunk) data = self.key.encrypt(id, chunk)
csize = len(data) csize = len(data)
self.repository.put(id, data, wait=wait) self.repository.put(id, data, wait=wait)
self.chunks.add(id, 1, size, csize) self.chunks.add(id, 1, size, csize)
@ -1107,7 +1107,7 @@ Chunk index: {0.total_unique_chunks:20d} unknown"""
refcount = self.seen_chunk(id, size) refcount = self.seen_chunk(id, size)
if refcount: if refcount:
return self.chunk_incref(id, stats, size=size) return self.chunk_incref(id, stats, size=size)
data = self.key.encrypt(chunk) data = self.key.encrypt(id, chunk)
csize = len(data) csize = len(data)
self.repository.put(id, data, wait=wait) self.repository.put(id, data, wait=wait)
self.chunks.add(id, 1, size, csize) self.chunks.add(id, 1, size, csize)

View File

@ -158,7 +158,7 @@ class KeyBase:
""" """
raise NotImplementedError raise NotImplementedError
def encrypt(self, chunk): def encrypt(self, id, data):
pass pass
def decrypt(self, id, data, decompress=True): def decrypt(self, id, data, decompress=True):
@ -264,8 +264,8 @@ class PlaintextKey(KeyBase):
def id_hash(self, data): def id_hash(self, data):
return sha256(data).digest() return sha256(data).digest()
def encrypt(self, chunk): def encrypt(self, id, data):
data = self.compressor.compress(chunk) data = self.compressor.compress(data)
return b''.join([self.TYPE_STR, data]) return b''.join([self.TYPE_STR, data])
def decrypt(self, id, data, decompress=True): def decrypt(self, id, data, decompress=True):
@ -340,8 +340,8 @@ class AESKeyBase(KeyBase):
logically_encrypted = True logically_encrypted = True
def encrypt(self, chunk): def encrypt(self, id, data):
data = self.compressor.compress(chunk) data = self.compressor.compress(data)
next_iv = self.nonce_manager.ensure_reservation(self.cipher.next_iv(), next_iv = self.nonce_manager.ensure_reservation(self.cipher.next_iv(),
self.cipher.block_count(len(data))) self.cipher.block_count(len(data)))
return self.cipher.encrypt(data, header=self.TYPE_STR, iv=next_iv) return self.cipher.encrypt(data, header=self.TYPE_STR, iv=next_iv)
@ -678,8 +678,8 @@ class AuthenticatedKeyBase(AESKeyBase, FlexiKey):
if manifest_data is not None: if manifest_data is not None:
self.assert_type(manifest_data[0]) self.assert_type(manifest_data[0])
def encrypt(self, chunk): def encrypt(self, id, data):
data = self.compressor.compress(chunk) data = self.compressor.compress(data)
return b''.join([self.TYPE_STR, data]) return b''.join([self.TYPE_STR, data])
def decrypt(self, id, data, decompress=True): def decrypt(self, id, data, decompress=True):
@ -732,9 +732,9 @@ class AEADKeyBase(KeyBase):
logically_encrypted = True logically_encrypted = True
def encrypt(self, chunk): def encrypt(self, id, data):
# to encrypt new data in this session we use always self.cipher and self.sessionid # to encrypt new data in this session we use always self.cipher and self.sessionid
data = self.compressor.compress(chunk) data = self.compressor.compress(data)
reserved = b'\0' reserved = b'\0'
iv = self.cipher.next_iv() iv = self.cipher.next_iv()
iv_48bit = iv.to_bytes(6, 'big') iv_48bit = iv.to_bytes(6, 'big')

View File

@ -261,4 +261,4 @@ class Manifest:
self.tam_verified = True self.tam_verified = True
data = self.key.pack_and_authenticate_metadata(manifest.as_dict()) data = self.key.pack_and_authenticate_metadata(manifest.as_dict())
self.id = self.key.id_hash(data) self.id = self.key.id_hash(data)
self.repository.put(self.MANIFEST_ID, self.key.encrypt(data)) self.repository.put(self.MANIFEST_ID, self.key.encrypt(self.MANIFEST_ID, data))

View File

@ -3806,7 +3806,7 @@ class ArchiverCheckTestCase(ArchiverTestCaseBase):
'version': 1, 'version': 1,
}) })
archive_id = key.id_hash(archive) archive_id = key.id_hash(archive)
repository.put(archive_id, key.encrypt(archive)) repository.put(archive_id, key.encrypt(archive_id, archive))
repository.commit(compact=False) repository.commit(compact=False)
self.cmd('check', self.repository_location, exit_code=1) self.cmd('check', self.repository_location, exit_code=1)
self.cmd('check', '--repair', self.repository_location, exit_code=0) self.cmd('check', '--repair', self.repository_location, exit_code=0)
@ -3894,7 +3894,7 @@ class ManifestAuthenticationTest(ArchiverTestCaseBase):
def spoof_manifest(self, repository): def spoof_manifest(self, repository):
with repository: with repository:
_, key = Manifest.load(repository, Manifest.NO_OPERATION_CHECK) _, key = Manifest.load(repository, Manifest.NO_OPERATION_CHECK)
repository.put(Manifest.MANIFEST_ID, key.encrypt(msgpack.packb({ repository.put(Manifest.MANIFEST_ID, key.encrypt(Manifest.MANIFEST_ID, msgpack.packb({
'version': 1, 'version': 1,
'archives': {}, 'archives': {},
'config': {}, 'config': {},
@ -3907,7 +3907,7 @@ class ManifestAuthenticationTest(ArchiverTestCaseBase):
repository = Repository(self.repository_path, exclusive=True) repository = Repository(self.repository_path, exclusive=True)
with repository: with repository:
manifest, key = Manifest.load(repository, Manifest.NO_OPERATION_CHECK) manifest, key = Manifest.load(repository, Manifest.NO_OPERATION_CHECK)
repository.put(Manifest.MANIFEST_ID, key.encrypt(msgpack.packb({ repository.put(Manifest.MANIFEST_ID, key.encrypt(Manifest.MANIFEST_ID, msgpack.packb({
'version': 1, 'version': 1,
'archives': {}, 'archives': {},
'timestamp': (datetime.utcnow() + timedelta(days=1)).strftime(ISO_FORMAT), 'timestamp': (datetime.utcnow() + timedelta(days=1)).strftime(ISO_FORMAT),
@ -3929,7 +3929,7 @@ class ManifestAuthenticationTest(ArchiverTestCaseBase):
manifest = msgpack.unpackb(key.decrypt(None, repository.get(Manifest.MANIFEST_ID))) manifest = msgpack.unpackb(key.decrypt(None, repository.get(Manifest.MANIFEST_ID)))
del manifest[b'tam'] del manifest[b'tam']
repository.put(Manifest.MANIFEST_ID, key.encrypt(msgpack.packb(manifest))) repository.put(Manifest.MANIFEST_ID, key.encrypt(Manifest.MANIFEST_ID, msgpack.packb(manifest)))
repository.commit(compact=False) repository.commit(compact=False)
output = self.cmd('list', '--debug', self.repository_location) output = self.cmd('list', '--debug', self.repository_location)
assert 'archive1234' in output assert 'archive1234' in output

View File

@ -114,18 +114,21 @@ class TestKey:
def test_plaintext(self): def test_plaintext(self):
key = PlaintextKey.create(None, None) key = PlaintextKey.create(None, None)
chunk = b'foo' chunk = b'foo'
assert hexlify(key.id_hash(chunk)) == b'2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae' id = key.id_hash(chunk)
assert chunk == key.decrypt(key.id_hash(chunk), key.encrypt(chunk)) assert hexlify(id) == b'2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae'
assert chunk == key.decrypt(id, key.encrypt(id, chunk))
def test_keyfile(self, monkeypatch, keys_dir): def test_keyfile(self, monkeypatch, keys_dir):
monkeypatch.setenv('BORG_PASSPHRASE', 'test') monkeypatch.setenv('BORG_PASSPHRASE', 'test')
key = KeyfileKey.create(self.MockRepository(), self.MockArgs()) key = KeyfileKey.create(self.MockRepository(), self.MockArgs())
assert key.cipher.next_iv() == 0 assert key.cipher.next_iv() == 0
manifest = key.encrypt(b'ABC') chunk = b'ABC'
id = key.id_hash(chunk)
manifest = key.encrypt(id, chunk)
assert key.cipher.extract_iv(manifest) == 0 assert key.cipher.extract_iv(manifest) == 0
manifest2 = key.encrypt(b'ABC') manifest2 = key.encrypt(id, chunk)
assert manifest != manifest2 assert manifest != manifest2
assert key.decrypt(None, manifest) == key.decrypt(None, manifest2) assert key.decrypt(id, manifest) == key.decrypt(id, manifest2)
assert key.cipher.extract_iv(manifest2) == 1 assert key.cipher.extract_iv(manifest2) == 1
iv = key.cipher.extract_iv(manifest) iv = key.cipher.extract_iv(manifest)
key2 = KeyfileKey.detect(self.MockRepository(), manifest) key2 = KeyfileKey.detect(self.MockRepository(), manifest)
@ -134,7 +137,8 @@ class TestKey:
assert len({key2.id_key, key2.enc_key, key2.enc_hmac_key}) == 3 assert len({key2.id_key, key2.enc_key, key2.enc_hmac_key}) == 3
assert key2.chunk_seed != 0 assert key2.chunk_seed != 0
chunk = b'foo' chunk = b'foo'
assert chunk == key2.decrypt(key.id_hash(chunk), key.encrypt(chunk)) id = key.id_hash(chunk)
assert chunk == key2.decrypt(id, key.encrypt(id, chunk))
def test_keyfile_nonce_rollback_protection(self, monkeypatch, keys_dir): def test_keyfile_nonce_rollback_protection(self, monkeypatch, keys_dir):
monkeypatch.setenv('BORG_PASSPHRASE', 'test') monkeypatch.setenv('BORG_PASSPHRASE', 'test')
@ -142,9 +146,11 @@ class TestKey:
with open(os.path.join(get_security_dir(repository.id_str), 'nonce'), "w") as fd: with open(os.path.join(get_security_dir(repository.id_str), 'nonce'), "w") as fd:
fd.write("0000000000002000") fd.write("0000000000002000")
key = KeyfileKey.create(repository, self.MockArgs()) key = KeyfileKey.create(repository, self.MockArgs())
data = key.encrypt(b'ABC') chunk = b'ABC'
id = key.id_hash(chunk)
data = key.encrypt(id, chunk)
assert key.cipher.extract_iv(data) == 0x2000 assert key.cipher.extract_iv(data) == 0x2000
assert key.decrypt(None, data) == b'ABC' assert key.decrypt(id, data) == chunk
def test_keyfile_kfenv(self, tmpdir, monkeypatch): def test_keyfile_kfenv(self, tmpdir, monkeypatch):
keyfile = tmpdir.join('keyfile') keyfile = tmpdir.join('keyfile')
@ -155,7 +161,7 @@ class TestKey:
assert keyfile.exists() assert keyfile.exists()
chunk = b'ABC' chunk = b'ABC'
chunk_id = key.id_hash(chunk) chunk_id = key.id_hash(chunk)
chunk_cdata = key.encrypt(chunk) chunk_cdata = key.encrypt(chunk_id, chunk)
key = KeyfileKey.detect(self.MockRepository(), chunk_cdata) key = KeyfileKey.detect(self.MockRepository(), chunk_cdata)
assert chunk == key.decrypt(chunk_id, chunk_cdata) assert chunk == key.decrypt(chunk_id, chunk_cdata)
keyfile.remove() keyfile.remove()
@ -212,18 +218,20 @@ class TestKey:
def test_roundtrip(self, key): def test_roundtrip(self, key):
repository = key.repository repository = key.repository
plaintext = b'foo' plaintext = b'foo'
encrypted = key.encrypt(plaintext) id = key.id_hash(plaintext)
encrypted = key.encrypt(id, plaintext)
identified_key_class = identify_key(encrypted) identified_key_class = identify_key(encrypted)
assert identified_key_class == key.__class__ assert identified_key_class == key.__class__
loaded_key = identified_key_class.detect(repository, encrypted) loaded_key = identified_key_class.detect(repository, encrypted)
decrypted = loaded_key.decrypt(None, encrypted) decrypted = loaded_key.decrypt(id, encrypted)
assert decrypted == plaintext assert decrypted == plaintext
def test_decrypt_decompress(self, key): def test_decrypt_decompress(self, key):
plaintext = b'123456789' plaintext = b'123456789'
encrypted = key.encrypt(plaintext) id = key.id_hash(plaintext)
assert key.decrypt(None, encrypted, decompress=False) != plaintext encrypted = key.encrypt(id, plaintext)
assert key.decrypt(None, encrypted) == plaintext assert key.decrypt(id, encrypted, decompress=False) != plaintext
assert key.decrypt(id, encrypted) == plaintext
def test_assert_id(self, key): def test_assert_id(self, key):
plaintext = b'123456789' plaintext = b'123456789'
@ -243,7 +251,8 @@ class TestKey:
assert AuthenticatedKey.id_hash is ID_HMAC_SHA_256.id_hash assert AuthenticatedKey.id_hash is ID_HMAC_SHA_256.id_hash
assert len(key.id_key) == 32 assert len(key.id_key) == 32
plaintext = b'123456789' plaintext = b'123456789'
authenticated = key.encrypt(plaintext) id = key.id_hash(plaintext)
authenticated = key.encrypt(id, plaintext)
# 0x07 is the key TYPE, \x0000 identifies no compression. # 0x07 is the key TYPE, \x0000 identifies no compression.
assert authenticated == b'\x07\x00\x00' + plaintext assert authenticated == b'\x07\x00\x00' + plaintext
@ -253,7 +262,8 @@ class TestKey:
assert Blake2AuthenticatedKey.id_hash is ID_BLAKE2b_256.id_hash assert Blake2AuthenticatedKey.id_hash is ID_BLAKE2b_256.id_hash
assert len(key.id_key) == 128 assert len(key.id_key) == 128
plaintext = b'123456789' plaintext = b'123456789'
authenticated = key.encrypt(plaintext) id = key.id_hash(plaintext)
authenticated = key.encrypt(id, plaintext)
# 0x06 is the key TYPE, 0x0000 identifies no compression. # 0x06 is the key TYPE, 0x0000 identifies no compression.
assert authenticated == b'\x06\x00\x00' + plaintext assert authenticated == b'\x06\x00\x00' + plaintext

View File

@ -165,7 +165,7 @@ class TestRepositoryCache:
def _put_encrypted_object(self, key, repository, data): def _put_encrypted_object(self, key, repository, data):
id_ = key.id_hash(data) id_ = key.id_hash(data)
repository.put(id_, key.encrypt(data)) repository.put(id_, key.encrypt(id_, data))
return id_ return id_
@pytest.fixture @pytest.fixture