pymumble/pymumble_py3/crypto.py

384 lines
12 KiB
Python

'''
OCB2 crypto, broadly following the implementation from Mumble
'''
from typing import Tuple
import struct
import time
from math import ceil
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
AES_BLOCK_SIZE = 128 // 8 # Number of bytes in a block
AES_KEY_SIZE_BITS = 128
AES_KEY_SIZE_BYTES = AES_KEY_SIZE_BITS // 8
SHIFTBITS = 63 # Shift size for S2 operation
MAX64 = (1 << 64) - 1 # Maximum value of uint64
class EncryptFailedException(Exception):
pass
class DecryptFailedException(Exception):
pass
class CryptStateOCB2:
"""
State tracker for AES-OCB2 crypto.
All encryption/decryption should be done through this class
and not the `ocb_*` functions.
A random key and IVs are chosen upon initialization; these can be
replaced using `set_key`.
Attributes intended for external access:
raw_key
encrypt_iv
decrypt_iv
decrypt_history
uiGood
uiLate
uiLost
tLastGood
"""
_raw_key: bytes # AES key; access through `raw_key` property
_aes: object # pycrypto AES cipher object, replaced when `raw_key` is changed
_encrypt_iv: bytearray # IV for encryption, access through `encrypt_iv` property
_decrypt_iv: bytearray # IV for decryption, access through `decrypt_iv` property
decrypt_history: bytearray # History of previous decrypt_iv values
# Statistics:
uiGood: int # Number of packets successfully decrypted
uiLate: int # Number of packets which arrived out of order
uiLost: int # Number of packets which did not arrive in order (may arrive late)
tLastGood: float # time.perf_counter() value for latest good packet
def __init__(self):
self.uiGood = 0
self.uiLate = 0
self.uiLost = 0
self.tLastGood = 0
self._raw_key = get_random_bytes(AES_KEY_SIZE_BYTES)
self._encrypt_iv = get_random_bytes(AES_BLOCK_SIZE)
self._decrypt_iv = get_random_bytes(AES_BLOCK_SIZE)
self._aes = None
self.decrypt_history = bytearray(0x100)
@property
def raw_key(self) -> bytes:
return self._raw_key
@raw_key.setter
def raw_key(self, rkey: bytes):
if len(rkey) != AES_KEY_SIZE_BYTES:
raise Exception('raw_key has wrong length')
self._raw_key = bytes(rkey)
self._aes = AES.new(key=self.raw_key, mode=AES.MODE_ECB)
@property
def encrypt_iv(self) -> bytearray:
return self._encrypt_iv
@encrypt_iv.setter
def encrypt_iv(self, eiv: bytearray):
if len(eiv) != AES_BLOCK_SIZE:
raise Exception('encrypt_iv wrong length')
self._encrypt_iv = bytearray(eiv)
@property
def decrypt_iv(self) -> bytearray:
return self._decrypt_iv
@decrypt_iv.setter
def decrypt_iv(self, div: bytearray):
if len(div) != AES_BLOCK_SIZE:
raise Exception('decrypt_iv has wrong length')
self._decrypt_iv = bytearray(div)
def gen_key(self):
"""
Randomly generate new keys
"""
self.raw_key = get_random_bytes(AES_KEY_SIZE_BYTES)
self.encrypt_iv = get_random_bytes(AES_BLOCK_SIZE)
self.decrypt_iv = get_random_bytes(AES_BLOCK_SIZE)
def set_key(self, raw_key: bytes, encrypt_iv: bytearray, decrypt_iv: bytearray):
"""
Set new keys
Args:
raw_key: AES key
encrypt_iv: IV for encryption
decrypt_iv: IV for decrpytion
"""
self.raw_key = raw_key
self.encrypt_iv = encrypt_iv
self.decrypt_iv = decrypt_iv
def encrypt(self, source: bytes) -> bytes:
"""
Encrypt a message
Args:
source: The plaintext bytes to be encrypted
Returns:
Encrypted (ciphertext) bytes
Raises:
EncryptFailedException if `source` would result in a vulnerable packet
"""
eiv = increment_iv(self.encrypt_iv)
self.encrypt_iv = eiv
dst, tag = ocb_encrypt(self._aes, source, bytes(eiv))
head = bytes((eiv[0], *tag[:3]))
return head + dst
def decrypt(self, source: bytes, len_plain: int) -> bytes:
"""
Decrypt a message
Args:
source: The ciphertext bytes to be decrypted
len_plain: The length of the plaintext
Returns:
Decrypted (plaintext) bytes
Raises:
DecryptFailedException:
- if `source` is too short
- packet is out of order or duplicate
- packet was could have been tampered with
"""
if len(source) < 4:
raise DecryptFailedException('Source <4 bytes long!')
div = self.decrypt_iv.copy()
ivbyte = source[0]
late = False
lost = 0
if (div[0] + 1) & 0xFF == ivbyte:
# In order as expected.
if ivbyte > div[0]:
div[0] = ivbyte
elif ivbyte < div[0]:
div[0] = ivbyte
div = increment_iv(div, 1)
else:
raise DecryptFailedException('ivbyte == decrypt_iv[0]')
else:
# This is either out of order or a repeat.
diff = ivbyte - div[0]
if diff > 128:
diff -= 256
elif diff < -128:
diff += 256
if ivbyte < div[0] and -30 < diff < 0:
# Late packet, but no wraparound.
late = True
lost = -1
div[0] = ivbyte
elif ivbyte > div[0] and -30 < diff < 0:
# Last was 0x02, here comes 0xff from last round
late = True
lost = -1
div[0] = ivbyte
div = decrement_iv(div, 1)
elif ivbyte > div[0] and diff > 0:
# Lost a few packets, but beyond that we're good.
lost = ivbyte - div[0] - 1
div[0] = ivbyte
elif ivbyte < div[0] and diff > 0:
# Lost a few packets, and wrapped around
lost = 0x100 - div[0] + ivbyte - 1
div[0] = ivbyte
div = increment_iv(div, 1)
else:
raise DecryptFailedException('Lost too many packets?')
if self.decrypt_history[div[0]] == div[1]:
raise DecryptFailedException('decrypt_iv in history')
dst, tag = ocb_decrypt(self._aes, source[4:], bytes(div), len_plain)
if tag[:3] != source[1:4]:
raise DecryptFailedException('Tag did not match!')
self.decrypt_history[div[0]] = div[1]
if not late:
self.decrypt_iv = div
else:
self.uiLate += 1
self.uiGood += 1
self.uiLost += lost
self.tLastGood = time.perf_counter()
return dst
def ocb_encrypt(aes: object,
plain: bytes,
nonce: bytes,
*,
insecure=False,
) -> Tuple[bytes, bytes]:
"""
Encrypt a message.
This should be called from CryptStateOCB2.encrypt() and not independently.
Args:
aes: AES-ECB cipher object
plain: The plaintext bytes to be encrypted
nonce: The encryption IV
Returns:
Encrypted (ciphertext) bytes and tag
Raises:
EncryptFailedException if `source` would result in a vulnerable packet
"""
delta = aes.encrypt(nonce)
checksum = bytes(AES_BLOCK_SIZE)
plain_block = b''
pos = 0
encrypted = bytearray(ceil(len(plain) / AES_BLOCK_SIZE) * AES_BLOCK_SIZE)
while len(plain) - pos > AES_BLOCK_SIZE:
plain_block = plain[pos:pos + AES_BLOCK_SIZE]
delta = S2(delta)
encrypted_block = xor(delta, aes.encrypt(xor(delta, plain_block)))
checksum = xor(checksum, plain_block)
encrypted[pos:pos + AES_BLOCK_SIZE] = encrypted_block
pos += AES_BLOCK_SIZE
# Counter-cryptanalysis described in section 9 of https://eprint.iacr.org/2019/311
# For an attack, the second to last block (i.e. the last iteration of this loop)
# must be all 0 except for the last byte (which may be 0 - 128).
if not insecure and bytes(plain_block[:-1]) == bytes(AES_BLOCK_SIZE - 1):
raise EncryptFailedException('Insecure input block: ' +
'see section 9 of https://eprint.iacr.org/2019/311')
len_remaining = len(plain) - pos
delta = S2(delta)
pad_in = struct.pack('>QQ', 0, len_remaining * 8)
pad = aes.encrypt(xor(pad_in, delta))
plain_block = plain[pos:] + pad[len_remaining - AES_BLOCK_SIZE:]
checksum = xor(checksum, plain_block)
encrypted_block = xor(pad, plain_block)
encrypted[pos:] = encrypted_block
delta = xor(delta, S2(delta))
tag = aes.encrypt(xor(delta, checksum))
return encrypted, tag
def ocb_decrypt(aes: object,
encrypted: bytes,
nonce: bytes,
len_plain: int,
*,
insecure=False,
) -> Tuple[bytes, bytes]:
"""
Decrypt a message.
This should be called from CryptStateOCB2.decrypt() and not independently.
Args:
aes: AES-ECB cipher object
encrypted: The ciphertext bytes to be decrypted
nonce: The decryption IV
len_plain: The length of the desired plaintext
Returns:
Decrypted (plaintext) bytes and tag
Raises:
DecryptFailedException:
- if `source` is too short
- packet is out of order or duplicate
- packet was could have been tampered with
"""
delta = aes.encrypt(nonce)
checksum = bytes(AES_BLOCK_SIZE)
plain = bytearray(len_plain)
pos = 0
while len_plain - pos > AES_BLOCK_SIZE:
encrypted_block = encrypted[pos:pos + AES_BLOCK_SIZE]
delta = S2(delta)
tmp = aes.decrypt(xor(delta, encrypted_block))
plain_block = xor(delta, tmp)
checksum = xor(checksum, plain_block)
plain[pos:pos + AES_BLOCK_SIZE] = plain_block
pos += AES_BLOCK_SIZE
len_remaining = len_plain - pos
delta = S2(delta)
pad_in = struct.pack('>QQ', 0, len_remaining * 8)
pad = aes.encrypt(xor(pad_in, delta))
encrypted_zeropad = encrypted[pos:] + bytes(AES_BLOCK_SIZE - len_remaining)
plain_block = xor(encrypted_zeropad, pad)
checksum = xor(checksum, plain_block)
plain[pos:] = plain_block[:len_remaining]
# Counter-cryptanalysis described in section 9 of https://eprint.iacr.org/2019/311
# In an attack, the decrypted last block would need to equal `delta ^ len(128)`.
# With a bit of luck (or many packets), smaller values than 128 (i.e. non-full blocks) are also
# feasible, so we check `plain_block` instead of `plain`.
# Since our `len` only ever modifies the last byte, we simply check all remaining ones.
if not insecure and plain_block[:-1] == delta[:-1]:
raise DecryptFailedException('Possibly tampered/able block, discarding.')
delta = xor(delta, S2(delta))
tag = aes.encrypt(xor(delta, checksum))
return plain, tag
def increment_iv(iv: bytearray, start: int = 0) -> bytearray:
for i in range(start, AES_BLOCK_SIZE):
iv[i] = (iv[i] + 1) % 0x100
if iv[i] != 0:
break
return iv
def decrement_iv(iv: bytearray, start: int = 0) -> bytearray:
for i in range(start, AES_BLOCK_SIZE):
iv[i] = (iv[i] - 1) % 0x100
if iv[i] != 0xFF:
break
return iv
def xor(a: bytes, b: bytes) -> bytes:
return bytes(aa ^ bb for aa, bb in zip(a, b))
def S2(block: bytes) -> bytes:
ll, uu = struct.unpack('>QQ', block)
carry = ll >> 63
block = struct.pack('>QQ',
((ll << 1) | (uu >> 63)) & MAX64,
((uu << 1) ^ (carry * 0x87)) & MAX64)
return block