cleanup crypto.pyx, make it easier to adapt to other modes

There were some small issues:

 a) it never called EVP_EncryptFinal_ex.
For CTR mode, this had no visible consequences as EVP_EncryptUpdate already yielded all ciphertext.
For cleanliness and to have correctness even in other modes, the missing call was added.

b) decrypt = encrypt hack
This is a nice hack to abbreviate, but it only works for modes without padding and without authentication.
For cleanliness and to have correctness even in other modes, the missing usage of the decrypt api was added.

c) outl == inl assumption
Again, True for CTR mode, but not for padding or authenticating modes.
Fixed so it computes the ciphertext / plaintext length based on api return values.

Other changes:
As encrypt and decrypt API calls are different even for initialization/reset, added a is_encrypt flag.

Defensive output buffer allocation. Added the length of one extra AES block (16bytes) so it would
work even with padding modes. 16bytes are needed because a full block of padding might get
added when the plaintext was a multiple of aes block size.

These changes are based on some experimental code I did for aes-cbc and aes-gcm.
While we likely won't ever want aes-cbc in attic (maybe gcm though?), I think it is cleaner
to not make too many mode specific assumptions and hacks, but just use the API as it
was meant to be used.
This commit is contained in:
Thomas Waldmann 2015-03-03 19:19:28 +01:00
parent 4ab4ecc7af
commit 6c7c2e2e40
3 changed files with 64 additions and 17 deletions

View File

@ -28,8 +28,16 @@ cdef extern from "openssl/evp.h":
int EVP_EncryptInit_ex(EVP_CIPHER_CTX *ctx,const EVP_CIPHER *cipher, ENGINE *impl,
const unsigned char *key, const unsigned char *iv)
int EVP_DecryptInit_ex(EVP_CIPHER_CTX *ctx,const EVP_CIPHER *cipher, ENGINE *impl,
const unsigned char *key, const unsigned char *iv)
int EVP_EncryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out,
int *outl, const unsigned char *in_, int inl)
int EVP_DecryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out,
int *outl, const unsigned char *in_, int inl)
int EVP_EncryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *out,
int *outl)
int EVP_DecryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *out,
int *outl)
int PKCS5_PBKDF2_HMAC(const char *password, int passwordlen,
const unsigned char *salt, int saltlen, int iter,
@ -85,11 +93,19 @@ cdef class AES:
"""A thin wrapper around the OpenSSL EVP cipher API
"""
cdef EVP_CIPHER_CTX ctx
cdef int is_encrypt
def __cinit__(self, key, iv=None):
def __cinit__(self, is_encrypt, key, iv=None):
EVP_CIPHER_CTX_init(&self.ctx)
if not EVP_EncryptInit_ex(&self.ctx, EVP_aes_256_ctr(), NULL, NULL, NULL):
raise Exception('EVP_EncryptInit_ex failed')
self.is_encrypt = is_encrypt
# Set cipher type and mode
cipher_mode = EVP_aes_256_ctr()
if self.is_encrypt:
if not EVP_EncryptInit_ex(&self.ctx, cipher_mode, NULL, NULL, NULL):
raise Exception('EVP_EncryptInit_ex failed')
else: # decrypt
if not EVP_DecryptInit_ex(&self.ctx, cipher_mode, NULL, NULL, NULL):
raise Exception('EVP_DecryptInit_ex failed')
self.reset(key, iv)
def __dealloc__(self):
@ -102,8 +118,13 @@ cdef class AES:
key2 = key
if iv:
iv2 = iv
if not EVP_EncryptInit_ex(&self.ctx, NULL, NULL, key2, iv2):
raise Exception('EVP_EncryptInit_ex failed')
# Initialise key and IV
if self.is_encrypt:
if not EVP_EncryptInit_ex(&self.ctx, NULL, NULL, key2, iv2):
raise Exception('EVP_EncryptInit_ex failed')
else: # decrypt
if not EVP_DecryptInit_ex(&self.ctx, NULL, NULL, key2, iv2):
raise Exception('EVP_DecryptInit_ex failed')
@property
def iv(self):
@ -111,15 +132,37 @@ cdef class AES:
def encrypt(self, data):
cdef int inl = len(data)
cdef int outl
cdef unsigned char *out = <unsigned char *>malloc(inl)
cdef int ctl = 0
cdef int outl = 0
# note: modes that use padding, need up to one extra AES block (16b)
cdef unsigned char *out = <unsigned char *>malloc(inl+16)
if not out:
raise MemoryError
try:
if not EVP_EncryptUpdate(&self.ctx, out, &outl, data, inl):
raise Exception('EVP_EncryptUpdate failed')
return out[:inl]
ctl = outl
if not EVP_EncryptFinal_ex(&self.ctx, out+ctl, &outl):
raise Exception('EVP_EncryptFinal failed')
ctl += outl
return out[:ctl]
finally:
free(out)
decrypt = encrypt
def decrypt(self, data):
cdef int inl = len(data)
cdef int ptl = 0
cdef int outl = 0
cdef unsigned char *out = <unsigned char *>malloc(inl)
if not out:
raise MemoryError
try:
if not EVP_DecryptUpdate(&self.ctx, out, &outl, data, inl):
raise Exception('EVP_DecryptUpdate failed')
ptl = outl
if EVP_DecryptFinal_ex(&self.ctx, out+outl, &outl) <= 0:
raise Exception('EVP_DecryptFinal failed')
ptl += outl
return out[:ptl]
finally:
free(out)

View File

@ -144,8 +144,8 @@ class AESKeyBase(KeyBase):
self.chunk_seed = self.chunk_seed - 0xffffffff - 1
def init_ciphers(self, enc_iv=b''):
self.enc_cipher = AES(self.enc_key, enc_iv)
self.dec_cipher = AES(self.enc_key)
self.enc_cipher = AES(is_encrypt=True, key=self.enc_key, iv=enc_iv)
self.dec_cipher = AES(is_encrypt=False, key=self.enc_key)
class PassphraseKey(AESKeyBase):
@ -244,7 +244,7 @@ class KeyfileKey(AESKeyBase):
assert d[b'version'] == 1
assert d[b'algorithm'] == b'sha256'
key = pbkdf2_sha256(passphrase.encode('utf-8'), d[b'salt'], d[b'iterations'], 32)
data = AES(key).decrypt(d[b'data'])
data = AES(is_encrypt=False, key=key).decrypt(d[b'data'])
if HMAC(key, data, sha256).digest() != d[b'hash']:
return None
return data
@ -254,7 +254,7 @@ class KeyfileKey(AESKeyBase):
iterations = 100000
key = pbkdf2_sha256(passphrase.encode('utf-8'), salt, iterations, 32)
hash = HMAC(key, data, sha256).digest()
cdata = AES(key).encrypt(data)
cdata = AES(is_encrypt=True, key=key).encrypt(data)
d = {
'version': 1,
'salt': salt,

View File

@ -30,11 +30,15 @@ class CryptoTestCase(AtticTestCase):
def test_aes(self):
key = b'X' * 32
data = b'foo' * 10
aes = AES(key)
# encrypt
aes = AES(is_encrypt=True, key=key)
self.assert_equal(bytes_to_long(aes.iv, 8), 0)
cdata = aes.encrypt(data)
self.assert_equal(hexlify(cdata), b'c6efb702de12498f34a2c2bbc8149e759996d08bf6dc5c610aefc0c3a466')
self.assert_equal(bytes_to_long(aes.iv, 8), 2)
self.assert_not_equal(data, aes.decrypt(cdata))
aes.reset(iv=b'\0' * 16)
self.assert_equal(data, aes.decrypt(cdata))
# decrypt
aes = AES(is_encrypt=False, key=key)
self.assert_equal(bytes_to_long(aes.iv, 8), 0)
pdata = aes.decrypt(cdata)
self.assert_equal(data, pdata)
self.assert_equal(bytes_to_long(aes.iv, 8), 2)