394 lines
13 KiB
Python
394 lines
13 KiB
Python
from typing import Tuple, Any, List, Set, Optional, Dict
|
|
from collections import defaultdict, Counter
|
|
from enum import IntEnum, auto
|
|
from dataclasses import dataclass, field
|
|
import copy
|
|
|
|
import networkx
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class CodeLoc:
|
|
bbl_addr: int
|
|
stmt_idx: int
|
|
ins_addr: int
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class Atom:
|
|
loc: CodeLoc
|
|
size: int
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class RegisterAtom(Atom):
|
|
name: str
|
|
slot_name: str
|
|
|
|
def __repr__(self):
|
|
return f'{self.name} @ {self.loc.ins_addr:#x}'
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class MemoryAtom(Atom):
|
|
endness: str
|
|
|
|
def __repr__(self):
|
|
return f'MEM @ {self.loc.ins_addr:#x}'
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class TmpAtom(Atom):
|
|
tmp: int
|
|
|
|
def __repr__(self):
|
|
return f'TMP#{self.tmp} @ {self.loc.ins_addr:#x}'
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ConstAtom(Atom):
|
|
value: int
|
|
|
|
def __repr__(self):
|
|
return f'CONST#{self.value:#x} @ {self.loc.ins_addr:#x}'
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class Op:
|
|
def invert(self) -> 'Op':
|
|
raise NotImplementedError
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ConstOffsetOp(Op):
|
|
const: int
|
|
|
|
def invert(self):
|
|
return ConstOffsetOp(-self.const)
|
|
|
|
def __init__(self, const: int):
|
|
# ummmmmm missing size
|
|
while const < -2**63:
|
|
const += 2**64
|
|
while const > 2**63:
|
|
const -= 2**64
|
|
object.__setattr__(self, 'const', const)
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class StrideOffsetOp(Op):
|
|
stride: int
|
|
|
|
def invert(self):
|
|
return self
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class NegOp(Op):
|
|
def invert(self):
|
|
return self
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class VarOffsetOp(Op):
|
|
var: Any
|
|
|
|
def invert(self):
|
|
# TODO ????
|
|
return self
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class DerefOp(Op):
|
|
size: int
|
|
|
|
def invert(self):
|
|
return RefOp(self.size)
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class RefOp(Op):
|
|
size: int
|
|
|
|
def invert(self):
|
|
return DerefOp(self.size)
|
|
|
|
#@dataclass(frozen=True, slots=True)
|
|
#class OtherOp(Op):
|
|
# def invert(self) -> 'Op':
|
|
# return self
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class OpSequence:
|
|
ops: Tuple[Op, ...] = ()
|
|
|
|
def __add__(self, other: 'OpSequence') -> 'OpSequence':
|
|
seq = list(self.ops)
|
|
seq.extend(other.ops)
|
|
simplify_op_sequence(seq)
|
|
return OpSequence(tuple(seq))
|
|
|
|
def appended(self, *op: Op) -> 'OpSequence':
|
|
seq = list(self.ops)
|
|
seq.extend(op)
|
|
simplify_op_sequence(seq)
|
|
return OpSequence(tuple(seq))
|
|
|
|
@staticmethod
|
|
def concat(*sequences: 'OpSequence') -> 'OpSequence':
|
|
seq: List[Op] = []
|
|
for s in sequences:
|
|
seq.extend(s.ops)
|
|
simplify_op_sequence(seq)
|
|
return OpSequence(tuple(seq))
|
|
|
|
def invert(self) -> 'OpSequence':
|
|
return OpSequence(tuple(x.invert() for x in reversed(self.ops)))
|
|
|
|
def compute_unifications(self) -> List[Tuple[int, int]]:
|
|
base_offset = 0
|
|
strides = []
|
|
for op in self.ops:
|
|
if isinstance(op, ConstOffsetOp):
|
|
base_offset += op.const
|
|
elif isinstance(op, StrideOffsetOp):
|
|
strides.append(op.stride)
|
|
else:
|
|
base_offset = 0
|
|
strides = []
|
|
|
|
return [(base_offset, base_offset + stride) for stride in strides]
|
|
|
|
def simplify_op_sequence(seq: List[Op]):
|
|
i = 0
|
|
while i < len(seq):
|
|
cur = seq[i]
|
|
if isinstance(cur, ConstOffsetOp) and cur.const == 0:
|
|
seq.pop(i)
|
|
if i > 0:
|
|
i -= 1
|
|
continue
|
|
nex = seq[i + 1] if i + 1 < len(seq) else None
|
|
if isinstance(cur, ConstOffsetOp) and isinstance(nex, ConstOffsetOp):
|
|
seq[i] = ConstOffsetOp(cur.const + nex.const)
|
|
seq.pop(i + 1)
|
|
if i > 0:
|
|
i -= 1
|
|
continue
|
|
if isinstance(cur, RefOp) and isinstance(nex, DerefOp) and cur.size == nex.size:
|
|
seq.pop(i)
|
|
seq.pop(i)
|
|
if i > 0:
|
|
i -= 1
|
|
continue
|
|
if isinstance(cur, DerefOp) and isinstance(nex, RefOp) and cur.size == nex.size:
|
|
seq.pop(i)
|
|
seq.pop(i)
|
|
if i > 0:
|
|
i -= 1
|
|
continue
|
|
if isinstance(cur, NegOp) and isinstance(nex, NegOp):
|
|
seq.pop(i)
|
|
seq.pop(i)
|
|
if i > 0:
|
|
i -= 1
|
|
continue
|
|
if isinstance(cur, StrideOffsetOp) and isinstance(nex, StrideOffsetOp) and cur.stride == nex.stride:
|
|
seq.pop(i)
|
|
if i > 0:
|
|
i -= 1
|
|
continue
|
|
|
|
i += 1
|
|
|
|
|
|
# noinspection PyArgumentList
|
|
class DataKind(IntEnum):
|
|
GenericData = auto()
|
|
Int = auto()
|
|
Float = auto()
|
|
Pointer = auto()
|
|
|
|
@dataclass(slots=True)
|
|
class Prop:
|
|
self_data: Counter[DataKind] = field(default_factory=Counter)
|
|
struct_data: Counter[Tuple[int, int, DataKind]] = field(default_factory=Counter)
|
|
unifications: Counter[Tuple[int, int]] = field(default_factory=Counter)
|
|
|
|
def update(self, other: 'Prop'):
|
|
self.self_data.update(other.self_data)
|
|
self.struct_data.update(other.struct_data)
|
|
self.unifications.update(other.unifications)
|
|
|
|
def subtract(self, other: 'Prop'):
|
|
self.self_data.subtract(other.self_data)
|
|
self.struct_data.subtract(other.struct_data)
|
|
self.unifications.subtract(other.unifications)
|
|
|
|
def maximize(self, other: 'Prop'):
|
|
for key, val in other.self_data.items():
|
|
self.self_data[key] = max(self.self_data[key], val)
|
|
for key, val in other.struct_data.items():
|
|
self.struct_data[key] = max(self.struct_data[key], val)
|
|
for key, val in other.unifications.items():
|
|
self.unifications[key] = max(self.unifications[key], val)
|
|
|
|
def __or__(self, other: 'Prop'):
|
|
result = Prop()
|
|
result.maximize(self)
|
|
result.maximize(other)
|
|
return result
|
|
|
|
def transform(self, ops: OpSequence):
|
|
result = copy.deepcopy(self)
|
|
for op in ops.ops:
|
|
if isinstance(op, RefOp):
|
|
result.struct_data = Counter({(0, op.size, k): v for k, v in result.self_data.items()})
|
|
result.self_data.clear()
|
|
result.unifications.clear()
|
|
elif isinstance(op, DerefOp):
|
|
result.self_data = Counter({k[2]: v for k, v in result.struct_data.items() if k[0] == 0 and k[1] == op.size})
|
|
result.struct_data.clear()
|
|
result.unifications.clear()
|
|
elif isinstance(op, ConstOffsetOp):
|
|
items = list(result.struct_data.items())
|
|
result.struct_data.clear()
|
|
for (offset, size, kind), v in items:
|
|
result.struct_data[(offset - op.const, size, kind)] = v # there is some JANK shit going on with this sign
|
|
saved = result.self_data.get(DataKind.Pointer, None)
|
|
result.self_data.clear()
|
|
if saved:
|
|
result.self_data[DataKind.Pointer] = saved
|
|
result.unifications = Counter((x - op.const, y - op.const) for x, y in result.unifications)
|
|
elif isinstance(op, StrideOffsetOp):
|
|
result.self_data.clear()
|
|
elif isinstance(op, VarOffsetOp):
|
|
saved = result.self_data.get(DataKind.Pointer, None)
|
|
result = Prop()
|
|
if saved:
|
|
result.self_data[DataKind.Pointer] = saved
|
|
else:
|
|
result = Prop()
|
|
return result
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class LiveData:
|
|
"""
|
|
The in-flight data representation for the analysis. All sizes are in bytes
|
|
"""
|
|
loc: CodeLoc
|
|
sources: Tuple[Tuple[Atom, OpSequence], ...]
|
|
size: int
|
|
# if this is non-empty it means the data is characterized SOLELY by the sum of a0*x + a1*y + a2*z + ...
|
|
strides: Tuple[Tuple[Optional['LiveData'], int], ...]
|
|
|
|
@property
|
|
def const(self):
|
|
if len(self.strides) == 1 and self.strides[0][0] is None:
|
|
return self.strides[0][1]
|
|
return None
|
|
|
|
@classmethod
|
|
def new_null(cls, loc: CodeLoc, size: int, strides: Tuple[Tuple[Optional['LiveData'], int], ...]=()):
|
|
return cls(loc, (), size, strides)
|
|
|
|
@classmethod
|
|
def new_atom(cls, loc: CodeLoc, atom: Atom) -> 'LiveData':
|
|
return cls(loc, ((atom, OpSequence()),), atom.size, ())
|
|
|
|
@classmethod
|
|
def new_const(cls, loc: CodeLoc, value: int, size: int) -> 'LiveData':
|
|
return cls(loc, ((ConstAtom(loc, size, value), OpSequence()),), size, ((None, value),))
|
|
|
|
def appended(self, loc: CodeLoc, op: Op, size: int, strides: Optional[Tuple[Tuple[Optional['LiveData'], int], ...]]=None) -> 'LiveData':
|
|
return LiveData(
|
|
loc,
|
|
tuple((atom, seq.appended(op)) for atom, seq in self.sources),
|
|
size,
|
|
self.strides if strides is None else strides,
|
|
)
|
|
|
|
def unioned(
|
|
self,
|
|
loc: CodeLoc,
|
|
other: 'LiveData',
|
|
size: int,
|
|
strides: Tuple[Tuple[Optional['LiveData'], int], ...]=(),
|
|
) -> 'LiveData':
|
|
return LiveData(loc, self.sources + other.sources, size, strides)
|
|
|
|
def commit(self, target: Atom, graph: networkx.DiGraph):
|
|
prop = Prop()
|
|
for src, seq in self.sources:
|
|
for start, end in seq.compute_unifications():
|
|
prop.unifications[(start, end)] += 1
|
|
graph.add_edge(src, target, ops=seq, cf=[])
|
|
self.atom_prop(target, prop, graph)
|
|
|
|
def prop(self, prop: Prop, graph: networkx.DiGraph):
|
|
for atom, ops in self.sources:
|
|
tprop = prop.transform(ops.invert())
|
|
self.atom_prop(atom, tprop, graph)
|
|
|
|
@staticmethod
|
|
def atom_prop(atom, tprop: Prop, graph: networkx.DiGraph):
|
|
try:
|
|
eprop: Prop = graph.nodes[atom].get('prop')
|
|
except KeyError:
|
|
graph.add_node(atom, prop=tprop)
|
|
else:
|
|
if eprop:
|
|
eprop.update(tprop)
|
|
else:
|
|
graph.nodes[atom]['prop'] = tprop
|
|
|
|
def prop_self(self, kind: DataKind, graph: networkx.DiGraph):
|
|
prop = Prop()
|
|
prop.self_data[kind] += 1
|
|
self.prop(prop, graph)
|
|
|
|
def prop_union(self, offset1: int, offset2: int, graph: networkx.DiGraph):
|
|
prop = Prop()
|
|
prop.unifications[(offset1, offset2)] += 1
|
|
self.prop(prop, graph)
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class RegisterInputInfo:
|
|
atom: RegisterAtom
|
|
callsites: Tuple[int, ...]
|
|
# when we go back through a ret, we push the callsite onto this stack. we may then only go back through calls if
|
|
# they match the top of the stack, at which point they are popped off
|
|
reverse_callsites: Tuple[int, ...]
|
|
# when we go back through a call and there is nothing on the callstack, an entry is pushed onto this stack.
|
|
# not sure what this indicates yet
|
|
|
|
def step(self, pred: int, succ: int, jumpkind: str, callsite: Optional[int]) -> 'Optional[RegisterInputInfo]':
|
|
if jumpkind == 'Ijk_Ret':
|
|
if callsite is None:
|
|
raise TypeError("Must specify callsite if jumpkind is Ret")
|
|
return RegisterInputInfo(atom=self.atom, callsites=self.callsites + (callsite,), reverse_callsites=self.reverse_callsites)
|
|
elif jumpkind == 'Ijk_Call':
|
|
if not self.callsites:
|
|
return RegisterInputInfo(atom=self.atom, callsites=(), reverse_callsites=self.reverse_callsites + (pred,))
|
|
elif self.callsites[-1] == pred:
|
|
return RegisterInputInfo(atom=self.atom, callsites=self.callsites[:-1], reverse_callsites=self.reverse_callsites)
|
|
else:
|
|
return None
|
|
else:
|
|
return RegisterInputInfo(atom=self.atom, callsites=self.callsites, reverse_callsites=self.reverse_callsites)
|
|
|
|
def commit(self, graph: networkx.DiGraph, source: RegisterAtom):
|
|
actions: List[ControlFlowAction] = [ControlFlowActionPop(i) for i in self.callsites]
|
|
actions += [ControlFlowActionPush(i) for i in self.reverse_callsites]
|
|
graph.add_edge(source, self.atom, ops=OpSequence(), cf=actions)
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ControlFlowAction:
|
|
pass
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ControlFlowActionPush(ControlFlowAction):
|
|
callsite: int
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ControlFlowActionPop(ControlFlowAction):
|
|
callsite: int
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class BlockInfo:
|
|
outputs: Dict[str, RegisterAtom] = field(default_factory=dict) # slot names
|
|
inputs: Dict[str, RegisterAtom] = field(default_factory=dict) # alias names
|
|
atoms: List[Atom] = field(default_factory=list)
|
|
ready_inputs: Set[str] = field(default_factory=set) # alias names
|