xilinx/bram_decode_tool/decoder.py

268 lines
9.2 KiB
Python
Executable File

#!/usr/bin/env python3
import re
import struct
import os
import argparse
from difflib import SequenceMatcher
WORD_BE = struct.Struct(">I")
def _unpack_word(word: bytes) -> int:
return WORD_BE.unpack(word)[0]
WORD_LE = struct.Struct("<I")
def _unpack_word_le(word: bytes) -> int:
return WORD_LE.unpack(word)[0]
# a section of the microblaze code that's always the same (libc init stuff)
# KNOWN_CODE_BOUNDS = (0x50, 0x2f4)
# KNOWN_CODE_BOUNDS = (0x50, 0x120)
# KNOWN_CODE_BOUNDS = (0x50, 0x100)
KNOWN_CODE_BOUNDS = (0, 0x900)
BRAM_DB_FNAME = "bram.db"
XRAY_DB_FMT = re.compile(r"(.*?) ([0-9]+)_([0-9]+)")
XRAY_BRAM_WORD_OFFSETS = [0, 10, 20, 30, 40, 51, 61, 71, 81, 91]
def load_db(name):
db = {}
with open(name, "r") as f:
for line in f:
m = XRAY_DB_FMT.search(line)
if m is not None:
initstring, framenum, bitnum = [m.group(x) for x in range(1,4)]
db[(int(framenum), int(bitnum))] = line.strip().split(" ")[0]
return db
def make_far_addr(bottom_top, row, col, minor):
addr = 0x00800000 | (bottom_top << 22)
addr |= (row << 17)
addr |= (col << 7)
addr |= minor
return addr
def all_bram_cols():
for bottom_top in [0, 1]:
for row in [0]:
for col in range(5):
yield (bottom_top, row, col)
def load_init_bits(inputfile):
with open(inputfile, "rb") as f:
buf = f.read()
# old dump
bottom = b"\x00" * 404 + buf[0:101*4*128*5 - 404]
# skip 2 pad frames in between
# top = buf[101*4*(128*5 + 2):]
top = buf[101*4*(128*5 + 1):101*4*(128*5 + 1)+404*128*5]
# end old dump
# adjusted for new dump
# offs = 404
# bottom = buf[offs:offs+404*128*5]
# offs = offs + 404*128*5 + 808
# top = buf[offs:offs+404*128*5]
# end new dump
buf = bottom + top
assert len(buf) == 1280 * 4 * 101
init_bits = set()
addrs = set()
frame_idx = 0
for (bottom_top, row, col) in all_bram_cols():
for minor in range(128):
frame = buf[frame_idx*101*4:(frame_idx+1)*101*4]
assert len(frame) == 404
frame_idx += 1
words = [_unpack_word(frame[j*4:j*4+4]) for j in range(101)]
assert len(words) == 101
addr = make_far_addr(bottom_top, row, col, minor)
num_bits = 0
for i, word in enumerate(words):
for bit in range(32):
if (word >> bit) & 1 == 1:
num_bits += 1
bit_repr = (addr, i, bit)
init_bits.add(bit_repr)
if num_bits > 101*32/30:
addrs.add(addr)
addrs = sorted(addrs)
start = None
last = None
in_range = False
print("addrs:")
for i in range(len(addrs)):
if not in_range:
start = addrs[i]
last = addrs[i]
in_range = True
elif addrs[i] - 1 != last:
print("Addrs: ", hex(start), "-", hex(last))
start = addrs[i]
last = addrs[i]
else:
last = addrs[i]
if in_range:
print("Addrs: ", hex(start), "-", hex(last))
# init_bits = set()
# for (bottom_top, row, col) in all_bram_cols():
# for minor in range(128):
# s = "bram_orig/initmem_initialized.bit-frames-_bram_"
# # s = "bram_test/test_out.bit-frames-_bram_"
# # s = "bram_sweep/test_out_sweep.bit-frames-_bram_"
# # s = "bram_onelane/test_out_onelane.bit-frames-_bram_"
# # s = "bram_incr/test_out_incr.bit-frames-_bram_"
# s += f"{bottom_top}_{row}_{col}_{minor}.dat"
# with open(s, "rb") as f:
# frame = f.read()
# assert len(frame) == 404
# words = [_unpack_word(frame[j*4:j*4+4]) for j in range(101)]
# assert len(words) == 101
#
# addr = make_far_addr(bottom_top, row, col, minor)
# for i, word in enumerate(words):
# for bit in range(32):
# if (word >> bit) & 1 == 1:
# # in the test data, we appear to have dumped it wrong oops
# # offset addr by 1 to fix
# bit_repr = (addr+1, i, bit)
# init_bits.add(bit_repr)
return init_bits
def preview_bitlane(bitlane, fmt_hex=True):
s = ""
for i in range(0, 256//2, 4):
chars = "".join([str(x) for x in bitlane[i:i+4]])
s += hex(int(chars, 2))[2:] if fmt_hex else chars
return s
def decode_bitlanes(init_bits, db):
bitlanes = []
for (bottom_top, row, col) in all_bram_cols():
for word_offs in XRAY_BRAM_WORD_OFFSETS:
# print("processing", bottom_top, row, col, word_offs)
initstrings = set()
for minor in range(128):
for word in range(10):
for bit in range(32):
addr = make_far_addr(bottom_top, row, col, minor)
if (addr, word_offs + word, bit) in init_bits:
# lookup the init string, it's the frame in 0-127, word number
# (within the 10-word segment)
initline = db.get((minor, word*32 + bit), None)
# if initline is None:
# print("[!] warn: unexpected bit", addr, word_offs, minor,
# word*32 + bit)
initstrings.add(initline)
if len(initstrings) == 0:
continue
# print(len(initstrings))
bitlane = bytearray(b"\x00" * 32768)
# cnt = 0
for i in range(len(bitlane)):
position = i//2
y_index = i%2
initstring = f"{position//256:02x}".upper()
initbit = f"{position%256:03d}"
if f"BRAM_L.RAMB18_Y{y_index}.INIT_{initstring}[{initbit}]" in initstrings:
# cnt += 1
# print(i, position, y_index, initstring, initbit)
bitlane[i] = 1
# print(cnt)
# print(preview_bitlane(bitlane))
# raise SystemExit()
bitlanes.append(bitlane)
return bitlanes
def get_similarity(a, b):
l = len(a)
assert len(a) == len(b)
d = 0
for i in range(l):
if a[i] != b[i]:
d += 1
return d/l
# matcher = SequenceMatcher(None, a, b)
# return (l - sum([b.size for b in matcher.get_matching_blocks()])) / l
def main(known_code, inputfile, outputfile):
print("[+] initializing database")
db = load_db(BRAM_DB_FNAME)
print("[+] loading file")
init_bits = load_init_bits(inputfile)
print("[+] decoding bitlanes")
bitlanes = decode_bitlanes(init_bits, db)
print("[+] brute forcing order")
# for bn, bitlane in enumerate(bitlanes):
# print(bn, preview_bitlane(bitlane))
mapping = {}
with open(known_code, "rb") as f:
bounds = KNOWN_CODE_BOUNDS
bounds = (0, 0x100)
data = f.read()[0:bounds[1]*4]
words = [_unpack_word(data[j*4:j*4+4]) for j in range(0, bounds[1])]
# hax
# bounds = (0xfad0//4, 0xfb60//4)
# data = b"\x00" * bounds[0]*4 + f.read()[0x11ba0 + 4:]
# words = [_unpack_word(data[j*4:j*4+4]) for j in range(0, len(data)//4)]
# mapping[7] = (15, -1)
# mapping[15] = (59, -1)
# mapping[23] = (4, -1)
# mapping[31] = (0, -1)
# end hax
for i in range(32):
bitlane = bytearray(b"\x00"*bounds[1])
for j in range(*bounds):
bitlane[j] = (words[j] >> i) & 1
for bn, rb_lane in enumerate(bitlanes):
a = rb_lane[bounds[0]:bounds[1]]
b = bitlane[bounds[0]:bounds[1]]
similarity = get_similarity(a, b)
if similarity < 0.2:
if mapping.get(i, None) is not None:
if similarity < mapping[i][1]:
mapping[i] = (bn, similarity)
else:
mapping[i] = (bn, similarity)
if mapping.get(i, None) is None:
print("[!] ERROR: no candidate bitlane found for bit", i)
for k,v in mapping.items():
print(f"{k:02d} -> {v[0]:02d} (similarity {v[1]})")
for bn in set(range(len(bitlanes))) - set([v[0] for v in mapping.values()]):
if bitlanes[bn].count(1) > 200:
print("candidate left", bn)
if len(mapping) < 32:
print("[!] brute force failure!")
print("[*] bitlanes are shown below")
return
print("[+] brute force success, writing bin")
words = [0] * 32768
for bit in range(32):
bitlane = bitlanes[mapping[bit][0]]
for word_num in range(len(words)):
words[word_num] |= (bitlane[word_num]&1) << bit
with open(outputfile, "wb") as f:
for word in words:
f.write(WORD_BE.pack(word))
if __name__ == "__main__":
parse = argparse.ArgumentParser()
parse.add_argument("knownfile")
parse.add_argument("inputfile")
parse.add_argument("outputfile")
args = parse.parse_args()
main(args.knownfile, args.inputfile, args.outputfile)