diff --git a/darc/keychain.py b/darc/keychain.py index fbc604f12..a30d451ab 100644 --- a/darc/keychain.py +++ b/darc/keychain.py @@ -9,7 +9,7 @@ from Crypto.Hash import SHA256, HMAC from Crypto.PublicKey import RSA from Crypto.Util import Counter -from Crypto.Util.number import bytes_to_long +from Crypto.Util.number import bytes_to_long, long_to_bytes from .helpers import IntegrityError, zero_pad from .oaep import OAEP @@ -25,6 +25,7 @@ def __init__(self, path=None): self._key_cache = {} self.read_key = os.urandom(32) self.create_key = os.urandom(32) + self.counter = Counter.new(64, prefix='\0' * 8) self.aes_id = self.rsa_read = self.rsa_create = None self.path = path if path: @@ -58,7 +59,7 @@ def open(self, path): self.create_encrypted = OAEP(256, hash=SHA256).encode(self.create_key, os.urandom(32)) self.create_encrypted = zero_pad(self.rsa_create.encrypt(self.create_encrypted, '')[0], 256) - def encrypt(self, data, password): + def encrypt_keychain(self, data, password): salt = os.urandom(32) iterations = 2000 key = pbkdf2(password, salt, 32, iterations, hashlib.sha256) @@ -91,7 +92,7 @@ def save(self, path, password): 'rsa_read': self.rsa_read.exportKey('PEM'), 'rsa_create': self.rsa_create.exportKey('PEM'), } - data = self.encrypt(msgpack.packb(chain), password) + data = self.encrypt_keychain(msgpack.packb(chain), password) with open(path, 'wb') as fd: fd.write(self.FILE_ID) fd.write(data) @@ -135,23 +136,36 @@ def generate(path): return 0 def id_hash(self, data): + """Return HMAC hash using the "id" AES key + """ return HMAC.new(self.aes_id, data, SHA256).digest() + def _encrypt(self, id, rsa_key, key, data): + """Helper function used by `encrypt_read` and `encrypt_create` + """ + data = zlib.compress(data) + nonce = long_to_bytes(self.counter.next_value(), 8) + data = nonce + rsa_key + AES.new(key, AES.MODE_CTR, '', counter=self.counter).encrypt(data) + hash = self.id_hash(data) + return ''.join((id, hash, data)), hash + def encrypt_read(self, data): - data = zlib.compress(data) - hash = self.id_hash(data) - counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True) - data = AES.new(self.read_key, AES.MODE_CTR, '', counter=counter).encrypt(data) - return ''.join((self.READ, self.read_encrypted, hash, data)), hash + """Encrypt `data` using the AES "read" key - def encrypt_create(self, data): - data = zlib.compress(data) - hash = self.id_hash(data) - counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True) - data = AES.new(self.create_key, AES.MODE_CTR, '', counter=counter).encrypt(data) - return ''.join((self.CREATE, self.create_encrypted, hash, data)), hash + An RSA encrypted version of the AES key is included in the header + """ + return self._encrypt(self.READ, self.read_encrypted, self.read_key, data) - def decrypt_key(self, data, rsa_key): + def encrypt_create(self, data, iv=None): + """Encrypt `data` using the AES "create" key + + An RSA encrypted version of the AES key is included in the header + """ + return self._encrypt(self.CREATE, self.create_encrypted, self.create_key, data) + + def _decrypt_key(self, data, rsa_key): + """Helper function used by `decrypt` + """ try: return self._key_cache[data] except KeyError: @@ -159,25 +173,22 @@ def decrypt_key(self, data, rsa_key): return self._key_cache[data] def decrypt(self, data): + """Decrypt `data` previously encrypted by `encrypt_create` or `encrypt_read` + """ type = data[0] + hash = data[1:33] + if self.id_hash(data[33:]) != hash: + raise IntegrityError('Encryption integrity error') + nonce = bytes_to_long(data[33:41]) + counter = Counter.new(64, prefix='\0' * 8, initial_value=nonce) if type == self.READ: - key = self.decrypt_key(data[1:257], self.rsa_read) - hash = data[257:289] - counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True) - data = AES.new(key, AES.MODE_CTR, counter=counter).decrypt(data[289:]) - if self.id_hash(data) != hash: - raise IntegrityError('decryption failed') - return zlib.decompress(data), hash + key = self._decrypt_key(data[41:297], self.rsa_read) elif type == self.CREATE: - key = self.decrypt_key(data[1:257], self.rsa_create) - hash = data[257:289] - counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True) - data = AES.new(key, AES.MODE_CTR, '', counter=counter).decrypt(data[289:]) - if self.id_hash(data) != hash: - raise IntegrityError('decryption failed') - return zlib.decompress(data), hash + key = self.decrypt_key(data[41:297], self.rsa_create) else: raise Exception('Unknown pack type %d found' % ord(type)) + data = AES.new(key, AES.MODE_CTR, counter=counter).decrypt(data[297:]) + return zlib.decompress(data), hash