writeups/2020/gctf/sharky/solve.py

126 lines
4.4 KiB
Python

#! /usr/bin/python3
import struct
import binascii
import claripy
import sys
init_h = [
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c,
0x1f83d9ab, 0x5be0cd19
]
init_h = [claripy.BVV(x, 32) for x in init_h]
secrets = [claripy.BVS(f"rc_{i}", 32) for i in range(8)]
init_k = secrets + [claripy.BVV(x, 32) for x in [
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786,
0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147,
0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b,
0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a,
0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
]]
def rotate_right(v, n):
# w = (v >> n) | (v << (32 - n))
# return w & 0xffffffff
return claripy.RotateRight(v, n)
def compression_step(state, k_i, w_i):
a, b, c, d, e, f, g, h = state
s1 = rotate_right(e, 6) ^ rotate_right(e, 11) ^ rotate_right(e, 25)
ch = (e & f) ^ (~e & g)
tmp1 = (h + s1 + ch + k_i + w_i)
s0 = rotate_right(a, 2) ^ rotate_right(a, 13) ^ rotate_right(a, 22)
maj = (a & b) ^ (a & c) ^ (b & c)
tmp2 = (tmp1 + s0 + maj)
tmp3 = (d + tmp1)
return (tmp2, a, b, c, tmp3, e, f, g)
def compression(state, w):
round_keys = init_k
for i in range(64):
state = compression_step(state, round_keys[i], w[i])
return state
w = [1164862319, 1684366368, 2003399784, 544366958, 1685024032, 1801812339, 2147483648, 0, 0, 0, 0, 0, 0, 0, 0, 192, 1522197188, 3891742175, 3836386829, 32341671, 928288908, 2364323079, 1515866404, 649785226, 1435989715, 250124094, 1469326411, 2429553944, 598071608, 1634056085, 4271828083, 4262132921, 2272436470, 39791740, 2337714294, 3555435891, 1519859327, 57013755, 2177157937, 1679613557, 2900649386, 612096658, 172526146, 2214036567, 3330460486, 1490972443, 1925782519, 4215628757, 2379791427, 2058888203, 1834962275, 3917548225, 2375084030, 1546202149, 3188006334, 4280719833, 726047027, 3650106516, 4058756591, 1443098026, 1972312730, 1218108430, 3428722156, 366022263]
real_hash = sys.argv[1]
real_hash = binascii.unhexlify(real_hash)
real_hash = struct.unpack(">8L", real_hash)
# emulate round 63
solver = claripy.Solver()
state = [claripy.BVS(f"state_{i}", 32) for i in range(8)]
out_state = compression_step(state, w[63], init_k[63])
for i in range(8):
solver.add(out_state[i] + init_h[i] == real_hash[i])
def get_wanted_state():
global solver
global state
wanted_state = [None]*8
for i in range(8):
res = solver.eval(state[i], 10)
if len(res) != 1:
print("ERROR", i, res)
raise SystemExit()
res = res[0]
wanted_state[i] = res
return wanted_state
wanted_state = get_wanted_state()
# emulate 62 downto 8
for round in range(62, 8 - 1, -1):
print("emulating round", round)
solver = claripy.Solver()
state = [claripy.BVS(f"state_{round}_{i}", 32) for i in range(8)]
out_state = compression_step(state, w[round], init_k[round])
for i in range(8):
solver.add(out_state[i] == wanted_state[i])
wanted_state = get_wanted_state()
print("round 8 input state: ", [hex(x) for x in wanted_state])
print("emulating rounds 0-7")
from claripy.backends.backend_z3_parallel import BackendZ3Parallel
solver = claripy.Solver(track=True)
try:
for round in range(7, -1, -1):
print("emulating weird round", round)
state = [claripy.BVS(f"test_r{round}_{i}", 32) for i in range(8)]
out_state = compression_step(state, w[round], init_k[round])
for i in range(8):
solver.add(out_state[i] == wanted_state[i])
wanted_state = [None]*8
for i in range(8):
wanted_state[i] = state[i]
except claripy.errors.UnsatError:
print("unsat!!!")
print(solver.unsat_core())
sys.exit()
for i in range(8):
solver.add(state[i] == init_h[i])
print("answers")
for i in range(len(secrets)):
res = solver.eval(secrets[i], 1)
sys.stdout.write(hex(res[0]))
if i < len(secrets) - 1:
sys.stdout.write(",")
sys.stdout.flush()
sys.stdout.write("\n")
sys.stdout.flush()
print("SHONKS")