Add decompress arg to Key.decrypt

This commit is contained in:
Marian Beermann 2016-07-31 21:47:31 +02:00
parent 774609cd9f
commit a80b371d09
2 changed files with 25 additions and 5 deletions

View File

@ -105,7 +105,7 @@ class KeyBase:
def encrypt(self, chunk): def encrypt(self, chunk):
pass pass
def decrypt(self, id, data): def decrypt(self, id, data, decompress=True):
pass pass
@ -130,10 +130,13 @@ class PlaintextKey(KeyBase):
chunk = self.compress(chunk) chunk = self.compress(chunk)
return b''.join([self.TYPE_STR, chunk.data]) return b''.join([self.TYPE_STR, chunk.data])
def decrypt(self, id, data): def decrypt(self, id, data, decompress=True):
if data[0] != self.TYPE: if data[0] != self.TYPE:
raise IntegrityError('Invalid encryption envelope') raise IntegrityError('Invalid encryption envelope')
data = self.compressor.decompress(memoryview(data)[1:]) payload = memoryview(data)[1:]
if not decompress:
return Chunk(payload)
data = self.compressor.decompress(payload)
if id and sha256(data).digest() != id: if id and sha256(data).digest() != id:
raise IntegrityError('Chunk id verification failed') raise IntegrityError('Chunk id verification failed')
return Chunk(data) return Chunk(data)
@ -166,7 +169,7 @@ class AESKeyBase(KeyBase):
hmac = hmac_sha256(self.enc_hmac_key, data) hmac = hmac_sha256(self.enc_hmac_key, data)
return b''.join((self.TYPE_STR, hmac, data)) return b''.join((self.TYPE_STR, hmac, data))
def decrypt(self, id, data): def decrypt(self, id, data, decompress=True):
if not (data[0] == self.TYPE or if not (data[0] == self.TYPE or
data[0] == PassphraseKey.TYPE and isinstance(self, RepoKey)): data[0] == PassphraseKey.TYPE and isinstance(self, RepoKey)):
raise IntegrityError('Invalid encryption envelope') raise IntegrityError('Invalid encryption envelope')
@ -176,7 +179,10 @@ class AESKeyBase(KeyBase):
if not compare_digest(hmac_computed, hmac_given): if not compare_digest(hmac_computed, hmac_given):
raise IntegrityError('Encryption envelope checksum mismatch') raise IntegrityError('Encryption envelope checksum mismatch')
self.dec_cipher.reset(iv=PREFIX + data[33:41]) self.dec_cipher.reset(iv=PREFIX + data[33:41])
data = self.compressor.decompress(self.dec_cipher.decrypt(data_view[41:])) payload = self.dec_cipher.decrypt(data_view[41:])
if not decompress:
return Chunk(payload)
data = self.compressor.decompress(payload)
if id: if id:
hmac_given = id hmac_given = id
hmac_computed = hmac_sha256(self.id_key, data) hmac_computed = hmac_sha256(self.id_key, data)

View File

@ -43,6 +43,14 @@ class TestKey:
monkeypatch.setenv('BORG_KEYS_DIR', tmpdir) monkeypatch.setenv('BORG_KEYS_DIR', tmpdir)
return tmpdir return tmpdir
@pytest.fixture(params=(
KeyfileKey,
PlaintextKey
))
def key(self, request, monkeypatch):
monkeypatch.setenv('BORG_PASSPHRASE', 'test')
return request.param.create(self.MockRepository(), self.MockArgs())
class MockRepository: class MockRepository:
class _Location: class _Location:
orig = '/some/place' orig = '/some/place'
@ -155,6 +163,12 @@ class TestKey:
id[12] = 0 id[12] = 0
key.decrypt(id, data) key.decrypt(id, data)
def test_decrypt_decompress(self, key):
plaintext = Chunk(b'123456789')
encrypted = key.encrypt(plaintext)
assert key.decrypt(None, encrypted, decompress=False) != plaintext
assert key.decrypt(None, encrypted) == plaintext
class TestPassphrase: class TestPassphrase:
def test_passphrase_new_verification(self, capsys, monkeypatch): def test_passphrase_new_verification(self, capsys, monkeypatch):