from math import inf, ceil from typing import Callable from celeste.math.util import clamp_max SPACER = b'\xff' # arbitrary spacing character class DecryptFailed(Exception): pass def round_to_blocks(n: int, block_size: int) -> int: return ceil(n / block_size) def crack_secret_len(cipher: Callable[[str], str], max_iters: int = inf) -> int | None: # calculate length for i = 1 init_len = len(cipher(SPACER)) secret_len = None i = 2 while True: if i - 2 > max_iters: break elif len(cipher(SPACER*i)) != init_len: secret_len = init_len - i break i += 1 return secret_len ''' NOTE: pad_if_perfect exists for PKCS#7 which will add a full block NOTE: of padding if the input is perfectly alligned to the blocks already. ''' def crack(cipher: Callable[[str], str], padfn: Callable[[bytes, int], bytes], charset: list, block_size: int, max_secret_iters: int = inf, batch_size: int = 1, pad_if_perfect: bool = True, debug: bool = False) -> str | None: if len(charset) % batch_size: raise ValueError(f'batch_size={batch_size} does not divide len(charset)={len(charset)}') # calculate the secret length secret_len = crack_secret_len(cipher, max_iters=max_secret_iters) if debug: print(f'[+] Found secret length: {secret_len}') # calculate how many blocks the secret stretches over # NOTE: secret_block_len - secret_len represents the number of # NOTE: bytes required to make the secret fill the blocks with no padding secret_block_len = round_to_blocks(secret_len, block_size) * block_size default_push = (secret_block_len - secret_len) known = b'' while True: # the "full tail" is all characters we know + the 1 we're cracking (target) # the "current tail" is the characters in the same block as the target full_tail_bytes = len(known) + 1 tail_bytes = clamp_max(full_tail_bytes, block_size - 1) # generate ALL possible tails (avoid padding if no padding required) tails = [c + known[:tail_bytes] for c in charset] if len(tails[0]) != block_size: tails = [padfn(tail, 16) for tail in tails] # calculate the "push" applied to the secret push_size = (default_push + full_tail_bytes) % block_size matched = False NUM_BATCHES = len(tails) // batch_size for i in range(NUM_BATCHES): if debug: print(f'{int(i/NUM_BATCHES*100)}%', end='\r') batch = tails[i*batch_size : (i+1)*batch_size] batch = b''.join(batch) + (SPACER * push_size) # apply spacing # encrypt batch and split the ciphertext into blocks ciphertext = cipher(batch) num_blocks = len(ciphertext)//block_size blocks = [ciphertext[i*block_size : (i+1)*block_size] for i in range(num_blocks)] oracle_pos = round_to_blocks(full_tail_bytes, block_size) if pad_if_perfect and (push_size + secret_len) % block_size == 0: oracle_pos += 1 for j, cipher_block in enumerate(blocks[:batch_size]): if cipher_block == blocks[-oracle_pos]: char = charset[i*batch_size + j] known = char + known if debug: print(f'[*] Found Tail: {known}') matched = True break if matched: break if not matched: break elif len(known) == secret_len: if debug: print('[+] SUCCESS') return known # if we reached the end (no return) # then the attack failed err_msg = 'Padding oracle attack failed' if not debug: raise DecryptFailed(err_msg) print(f'\n[!] {err_msg}')