diff --git a/2020/gctf/sharky/README.md b/2020/gctf/sharky/README.md new file mode 100644 index 0000000..cab9f51 --- /dev/null +++ b/2020/gctf/sharky/README.md @@ -0,0 +1,246 @@ +# sharky + +writeup by [haskal](https://awoo.systems) for [BLÅHAJ](https://blahaj.awoo.systems) + +**Crypto** +**231 points** +**39 solves** + +>Can you find the round keys? + +provided files: `challenge.py`, `sha256.py` + +## tldr + +for this i used [Z3](https://github.com/Z3Prover/z3) except with angr's +[claripy](https://docs.angr.io/advanced-topics/claripy) frontend instead of the direct frontend +because i've never used the direct frontend. using claripy/z3 we can construct a symbolic emulation +of the provided sha256 algorithm with the unknowns being symbolic variables. by applying constraints +for the hash output you can have z3 solve for the unknown hash constants + +## writeup + +initial analysis shows this is a modified sha256 algorithm with the first 8 round constants changed +to random values. the objective is to figure out what these values are given a known hash of a known +plaintext. from `challenge.py`: + +```python +NUM_KEYS = 8 +MSG = b'Encoded with random keys' +``` + +first, i checked to see if anything else was funky with the hand coded sha256 implementation by +initializing the "unknown" round constants to the standard constants and checking if the resulting +hash matched the expected "standard" hash for this message. it does, so the only difference is the +first 8 round constants + +[sha256](https://en.wikipedia.org/wiki/SHA-2) (this links to wikipedia because i literally studied +sha256 from the wikipedia article during the CTF) starts by padding the message to a multiple of 512 +bits. then, for each chunk of 512 bits it expands the chunk to 2048 bits by bit rotations, xors, and +arithmetic. the result is called `w`. `w` is compressed into a 256 bit state `h'` by taking the +previous state and doing 64 rounds of more bit rotations, bitwise ops, and arithmetic combining the +previous state with the nth word of `w` and the nth word of `k`, the sha256 constants table. `h` +also starts with known constants. also sha256 basically operates on 32-bit words, always. so all the +symbolic variables should be 32-bit + +in this case we see the message is short enough so there is only one 512-bit chunk of padded input. +this makes things easier. i started by precomputing the `w` array for the known input, since this +doesn't depend on any (modified) constants. then, i took the verbatim python implementation of the +algorithm and stripped out every component except the round function `compression_step` that +combines `h'`, `w[i]` and `k[i]`. i applied some basic simplifications to optimize symbolic +execution such as replacing the manual bit rotate operation with the `claripy.RotateRight` builtin +and removing the bitmasking done after addition, since claripy bit vectors can't change size. +otherwise it's the original provided code, except now it's operating on symbolic variables instead +of actual numbers. it looks like this + +```python +def rotate_right(v, n): + # w = (v >> n) | (v << (32 - n)) + # return w & 0xffffffff + return claripy.RotateRight(v, n) + +def compression_step(state, k_i, w_i): + a, b, c, d, e, f, g, h = state + s1 = rotate_right(e, 6) ^ rotate_right(e, 11) ^ rotate_right(e, 25) + ch = (e & f) ^ (~e & g) + tmp1 = (h + s1 + ch + k_i + w_i) + s0 = rotate_right(a, 2) ^ rotate_right(a, 13) ^ rotate_right(a, 22) + maj = (a & b) ^ (a & c) ^ (b & c) + tmp2 = (tmp1 + s0 + maj) + tmp3 = (d + tmp1) + return (tmp2, a, b, c, tmp3, e, f, g) +``` + +i figured this was good enough for z3 to work with and just threw it in + +``` +init_h = [ + # standard constants + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, + 0x1f83d9ab, 0x5be0cd19 +] +init_h = [claripy.BVV(x, 32) for x in init_h] + +secrets = [claripy.BVS(f"rc_{i}", 32) for i in range(8)] + +init_k = secrets + [claripy.BVV(x, 32) for x in [ + ... rest of the standard constants +]] + +def compression(state, w): + round_keys = init_k + for i in range(64): + state = compression_step(state, round_keys[i], w[i]) + return state + +w = [...] # computed w + +real_hash = sys.argv[1] +real_hash = binascii.unhexlify(real_hash) +real_hash = struct.unpack(">8L", real_hash) + +solver = claripy.Solver() + +state = init_h[:] +out_state = compression(state, w) +state = [x + y for x, y in zip(state, out_state)] + +for i in range(8): + solver.add(state[i] == real_hash[i]) + +for i in range(8): + print(solver.eval(secrets[i], 1)[0]) +``` + +this did Not terminate at all. this is a side effect of me never knowing what i'm doing but in +general it seems that z3 does very poorly when you have few constraints that are each very large and +complex. in fact these ones were so large you literally cannot print them to the console + +so i opted to do some handholding. i noticed that since constants 8-63 are the standard ones, the +state from round 8 to the final hash is completely known, and we can backtrack from the final hash +to the round 8 state by solving each round individually. here's the annotated code for this + +```python +# i set up round 63 with the final addition that occurs after all rounds are complete +solver = claripy.Solver() +state = [claripy.BVS(f"state_{i}", 32) for i in range(8)] +out_state = compression_step(state, w[63], init_k[63]) +for i in range(8): + solver.add(out_state[i] + init_h[i] == real_hash[i]) + +# this is a helper function that evaluates the input state of a round that produces a known output +# state and reassigns a new wanted_state for that +def get_wanted_state(): + global solver + global state + wanted_state = [None]*8 + for i in range(8): + res = solver.eval(state[i], 10) + if len(res) != 1: + print("ERROR", i, res) + raise SystemExit() + res = res[0] + wanted_state[i] = res + return wanted_state + +# calculate the input to round 63 +wanted_state = get_wanted_state() + +# emulate 62 downto 8 +# each round, we create a symbolic state array for the input, add constraints for the previously +# calculated output, and evaluate +for round in range(62, 8 - 1, -1): + print("emulating round", round) + + solver = claripy.Solver() + state = [claripy.BVS(f"state_{round}_{i}", 32) for i in range(8)] + out_state = compression_step(state, w[round], init_k[round]) + for i in range(8): + solver.add(out_state[i] == wanted_state[i]) + + wanted_state = get_wanted_state() + +# now we have the exact state that round 8 started with +print("round 8 input state: ", [hex(x) for x in wanted_state]) +``` + +now there should be a lot less work for z3 to guess on, it only needs to work through rounds 0-7 +with the symbolic round constants + +```python +state = init_h[:] +out_state = compression_round_0_7(state, w) +state = [x + y for x, y in zip(state, out_state)] + +for i in range(8): + solver.add(state[i] == wanted_state[i]) + +for i in range(8): + print(solver.eval(secrets[i], 1)[0]) +``` + +this took absolutely ages to complete and when it did the result was `unsat`. yikes + +i'm still not sure what exactly the issue was. it's likely i just made a typo somewhere and didn't +see it. but i deleted this and replaced it with a new strategy, which is to continue going backwards +from round 7 to round 0 instead of going forwards from 0 to 7. what this does is creates a larger +number of constraints -- instead of a single set of 8 constraints on the output we have constraints +for each intermediate state in between each round. but each constraint is individually less complex, +which results in the solving going faster and actually working (!!) + +```python +solver = claripy.Solver() + +# go from round 7 down to 0 +for round in range(7, -1, -1): + print("emulating weird round", round) + # create symbolic variables for this intermediate state + state = [claripy.BVS(f"test_r{round}_{i}", 32) for i in range(8)] + out_state = compression_step(state, w[round], init_k[round]) + for i in range(8): + # constrain the output to the wanted input state for the next round + solver.add(out_state[i] == wanted_state[i]) + + # transfer input state to next round's wanted state + wanted_state = [None]*8 + for i in range(8): + wanted_state[i] = state[i] + +# add the initial state constraints +for i in range(8): + solver.add(state[i] == init_h[i]) + +print("answers") +for i in range(len(secrets)): + res = solver.eval(secrets[i], 1) + sys.stdout.write(hex(res[0])) + if i < len(secrets) - 1: + sys.stdout.write(",") + sys.stdout.flush() + +sys.stdout.write("\n") +sys.stdout.flush() +print("SHONKS") +``` + +this works. for some inputs it takes a long time but most of the time it comes up with the answer +quickly. i was too lazy for an automatic script so i simply pasted the answer into netcat + +``` +$ python solve.py abdc6d366dd37bb452b56335cd8bfbc043e2284001891d358e04a9e76c3b2a49 +... +answers +0x9da77f78,0x24597709,0x4fbc375,0xaa5cd315,0xcd73dc57,0x12dafbf7,0x7e206bb4,0xb9097fba +SHONKS +``` + +the lesson here is if you want z3 to terminate sometime this millenium it's better to provide a +larger number of small constraints than a small number of Giant constraints even if they're +semantically equivalent. i would probably know this if i had ever been formally trained in symbolic +execution, + +--- + +since my team is [BLÅHAJ](https://blahaj.awoo.systems) and this challenge is named `sharky` it was +mandatory to get this flag, so i'm happy i figured it out even though crypto isn't really my strong +suit (yet,) diff --git a/2020/gctf/sharky/challenge.py b/2020/gctf/sharky/challenge.py new file mode 100644 index 0000000..7df4b1d --- /dev/null +++ b/2020/gctf/sharky/challenge.py @@ -0,0 +1,52 @@ +#! /usr/bin/python3 +import binascii +import os +import sha256 + +# Setup msg_secret and flag +FLAG_PATH = 'data/flag.txt' +NUM_KEYS = 8 +MSG = b'Encoded with random keys' + +with open(FLAG_PATH, 'rb') as f: + FLAG = f.read().strip().decode('utf-8') + + +def sha256_with_secret_round_keys(m: bytes, secret_round_keys: dict) -> bytes: + """Computes SHA256 with some secret round keys. + + Args: + m: the message to hash + secret_round_keys: a dictionary where secret_round_keys[i] is the value of + the round key k[i] used in SHA-256 + + Returns: + the digest + """ + sha = sha256.SHA256() + round_keys = sha.k[:] + for i, v in secret_round_keys.items(): + round_keys[i] = v + return sha.sha256(m, round_keys) + + +def generate_random_round_keys(cnt: int): + res = {} + for i in range(cnt): + rk = 0 + for b in os.urandom(4): + rk = rk * 256 + b + res[i] = rk + return res + +if __name__ == '__main__': + secret_round_keys = generate_random_round_keys(NUM_KEYS) + digest = sha256_with_secret_round_keys(MSG, secret_round_keys) + print('MSG Digest: {}'.format(binascii.hexlify(digest).decode())) + GIVEN_KEYS = list(map(lambda s: int(s, 16), input('Enter keys: ').split(','))) + assert len(GIVEN_KEYS) == NUM_KEYS, 'Wrong number of keys provided.' + + if all([GIVEN_KEYS[i] == secret_round_keys[i] for i in range(NUM_KEYS)]): + print('\nGood job, here\'s a flag: {0}'.format(FLAG)) + else: + print('\nSorry, that\'s not right.') diff --git a/2020/gctf/sharky/sha256.py b/2020/gctf/sharky/sha256.py new file mode 100644 index 0000000..a1e64ae --- /dev/null +++ b/2020/gctf/sharky/sha256.py @@ -0,0 +1,78 @@ +#! /usr/bin/python3 +import struct + +class SHA256: + + def __init__(self): + self.h = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, + 0x1f83d9ab, 0x5be0cd19 + ] + + self.k = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, + 0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, + 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, + 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, + 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, + 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, + 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, + 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, + 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2 + ] + + def rotate_right(self, v, n): + w = (v >> n) | (v << (32 - n)) + return w & 0xffffffff + + def compression_step(self, state, k_i, w_i): + a, b, c, d, e, f, g, h = state + s1 = self.rotate_right(e, 6) ^ self.rotate_right(e, 11) ^ self.rotate_right(e, 25) + ch = (e & f) ^ (~e & g) + tmp1 = (h + s1 + ch + k_i + w_i) & 0xffffffff + s0 = self.rotate_right(a, 2) ^ self.rotate_right(a, 13) ^ self.rotate_right(a, 22) + maj = (a & b) ^ (a & c) ^ (b & c) + tmp2 = (tmp1 + s0 + maj) & 0xffffffff + tmp3 = (d + tmp1) & 0xffffffff + return (tmp2, a, b, c, tmp3, e, f, g) + + def compression(self, state, w, round_keys = None): + if round_keys is None: + round_keys = self.k + for i in range(64): + state = self.compression_step(state, round_keys[i], w[i]) + return state + + def compute_w(self, m): + w = list(struct.unpack('>16L', m)) + for _ in range(16, 64): + a, b = w[-15], w[-2] + s0 = self.rotate_right(a, 7) ^ self.rotate_right(a, 18) ^ (a >> 3) + s1 = self.rotate_right(b, 17) ^ self.rotate_right(b, 19) ^ (b >> 10) + s = (w[-16] + w[-7] + s0 + s1) & 0xffffffff + w.append(s) + return w + + def padding(self, m): + lm = len(m) + lpad = struct.pack('>Q', 8 * lm) + lenz = -(lm + 9) % 64 + return m + bytes([0x80]) + bytes(lenz) + lpad + + def sha256_raw(self, m, round_keys = None): + if len(m) % 64 != 0: + raise ValueError('m must be a multiple of 64 bytes') + state = self.h + for i in range(0, len(m), 64): + block = m[i:i + 64] + w = self.compute_w(block) + s = self.compression(state, w, round_keys) + state = [(x + y) & 0xffffffff for x, y in zip(state, s)] + return state + + def sha256(self, m, round_keys = None): + m_padded = self.padding(m) + state = self.sha256_raw(m_padded, round_keys) + return struct.pack('>8L', *state) diff --git a/2020/gctf/sharky/solve.py b/2020/gctf/sharky/solve.py new file mode 100644 index 0000000..fd72d80 --- /dev/null +++ b/2020/gctf/sharky/solve.py @@ -0,0 +1,125 @@ +#! /usr/bin/python3 +import struct +import binascii +import claripy +import sys + +init_h = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, + 0x1f83d9ab, 0x5be0cd19 +] +init_h = [claripy.BVV(x, 32) for x in init_h] + +secrets = [claripy.BVS(f"rc_{i}", 32) for i in range(8)] + +init_k = secrets + [claripy.BVV(x, 32) for x in [ + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, + 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, + 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, + 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, + 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, + 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, + 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, + 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2 +]] + +def rotate_right(v, n): + # w = (v >> n) | (v << (32 - n)) + # return w & 0xffffffff + return claripy.RotateRight(v, n) + +def compression_step(state, k_i, w_i): + a, b, c, d, e, f, g, h = state + s1 = rotate_right(e, 6) ^ rotate_right(e, 11) ^ rotate_right(e, 25) + ch = (e & f) ^ (~e & g) + tmp1 = (h + s1 + ch + k_i + w_i) + s0 = rotate_right(a, 2) ^ rotate_right(a, 13) ^ rotate_right(a, 22) + maj = (a & b) ^ (a & c) ^ (b & c) + tmp2 = (tmp1 + s0 + maj) + tmp3 = (d + tmp1) + return (tmp2, a, b, c, tmp3, e, f, g) + +def compression(state, w): + round_keys = init_k + for i in range(64): + state = compression_step(state, round_keys[i], w[i]) + return state + +w = [1164862319, 1684366368, 2003399784, 544366958, 1685024032, 1801812339, 2147483648, 0, 0, 0, 0, 0, 0, 0, 0, 192, 1522197188, 3891742175, 3836386829, 32341671, 928288908, 2364323079, 1515866404, 649785226, 1435989715, 250124094, 1469326411, 2429553944, 598071608, 1634056085, 4271828083, 4262132921, 2272436470, 39791740, 2337714294, 3555435891, 1519859327, 57013755, 2177157937, 1679613557, 2900649386, 612096658, 172526146, 2214036567, 3330460486, 1490972443, 1925782519, 4215628757, 2379791427, 2058888203, 1834962275, 3917548225, 2375084030, 1546202149, 3188006334, 4280719833, 726047027, 3650106516, 4058756591, 1443098026, 1972312730, 1218108430, 3428722156, 366022263] + +real_hash = sys.argv[1] +real_hash = binascii.unhexlify(real_hash) +real_hash = struct.unpack(">8L", real_hash) + +# emulate round 63 +solver = claripy.Solver() +state = [claripy.BVS(f"state_{i}", 32) for i in range(8)] +out_state = compression_step(state, w[63], init_k[63]) +for i in range(8): + solver.add(out_state[i] + init_h[i] == real_hash[i]) + +def get_wanted_state(): + global solver + global state + wanted_state = [None]*8 + for i in range(8): + res = solver.eval(state[i], 10) + if len(res) != 1: + print("ERROR", i, res) + raise SystemExit() + res = res[0] + wanted_state[i] = res + return wanted_state + +wanted_state = get_wanted_state() + +# emulate 62 downto 8 +for round in range(62, 8 - 1, -1): + print("emulating round", round) + + solver = claripy.Solver() + state = [claripy.BVS(f"state_{round}_{i}", 32) for i in range(8)] + out_state = compression_step(state, w[round], init_k[round]) + for i in range(8): + solver.add(out_state[i] == wanted_state[i]) + + wanted_state = get_wanted_state() + +print("round 8 input state: ", [hex(x) for x in wanted_state]) + +print("emulating rounds 0-7") +from claripy.backends.backend_z3_parallel import BackendZ3Parallel +solver = claripy.Solver(track=True) + +try: + for round in range(7, -1, -1): + print("emulating weird round", round) + state = [claripy.BVS(f"test_r{round}_{i}", 32) for i in range(8)] + out_state = compression_step(state, w[round], init_k[round]) + for i in range(8): + solver.add(out_state[i] == wanted_state[i]) + + wanted_state = [None]*8 + for i in range(8): + wanted_state[i] = state[i] +except claripy.errors.UnsatError: + print("unsat!!!") + print(solver.unsat_core()) + sys.exit() + +for i in range(8): + solver.add(state[i] == init_h[i]) + +print("answers") +for i in range(len(secrets)): + res = solver.eval(secrets[i], 1) + sys.stdout.write(hex(res[0])) + if i < len(secrets) - 1: + sys.stdout.write(",") + sys.stdout.flush() + +sys.stdout.write("\n") +sys.stdout.flush() +print("SHONKS")