typetapper/typetapper/data.py

394 lines
13 KiB
Python

from typing import Tuple, Any, List, Set, Optional, Dict
from collections import defaultdict, Counter
from enum import IntEnum, auto
from dataclasses import dataclass, field
import copy
import networkx
@dataclass(frozen=True, slots=True)
class CodeLoc:
bbl_addr: int
stmt_idx: int
ins_addr: int
@dataclass(frozen=True, slots=True)
class Atom:
loc: CodeLoc
size: int
@dataclass(frozen=True, slots=True)
class RegisterAtom(Atom):
name: str
slot_name: str
def __repr__(self):
return f'{self.name} @ {self.loc.ins_addr:#x}'
@dataclass(frozen=True, slots=True)
class MemoryAtom(Atom):
endness: str
def __repr__(self):
return f'MEM @ {self.loc.ins_addr:#x}'
@dataclass(frozen=True, slots=True)
class TmpAtom(Atom):
tmp: int
def __repr__(self):
return f'TMP#{self.tmp} @ {self.loc.ins_addr:#x}'
@dataclass(frozen=True, slots=True)
class ConstAtom(Atom):
value: int
def __repr__(self):
return f'CONST#{self.value:#x} @ {self.loc.ins_addr:#x}'
@dataclass(frozen=True, slots=True)
class Op:
def invert(self) -> 'Op':
raise NotImplementedError
@dataclass(frozen=True, slots=True)
class ConstOffsetOp(Op):
const: int
def invert(self):
return ConstOffsetOp(-self.const)
def __init__(self, const: int):
# ummmmmm missing size
while const < -2**63:
const += 2**64
while const > 2**63:
const -= 2**64
object.__setattr__(self, 'const', const)
@dataclass(frozen=True, slots=True)
class StrideOffsetOp(Op):
stride: int
def invert(self):
return self
@dataclass(frozen=True, slots=True)
class NegOp(Op):
def invert(self):
return self
@dataclass(frozen=True, slots=True)
class VarOffsetOp(Op):
var: Any
def invert(self):
# TODO ????
return self
@dataclass(frozen=True, slots=True)
class DerefOp(Op):
size: int
def invert(self):
return RefOp(self.size)
@dataclass(frozen=True, slots=True)
class RefOp(Op):
size: int
def invert(self):
return DerefOp(self.size)
#@dataclass(frozen=True, slots=True)
#class OtherOp(Op):
# def invert(self) -> 'Op':
# return self
@dataclass(frozen=True, slots=True)
class OpSequence:
ops: Tuple[Op, ...] = ()
def __add__(self, other: 'OpSequence') -> 'OpSequence':
seq = list(self.ops)
seq.extend(other.ops)
simplify_op_sequence(seq)
return OpSequence(tuple(seq))
def appended(self, *op: Op) -> 'OpSequence':
seq = list(self.ops)
seq.extend(op)
simplify_op_sequence(seq)
return OpSequence(tuple(seq))
@staticmethod
def concat(*sequences: 'OpSequence') -> 'OpSequence':
seq: List[Op] = []
for s in sequences:
seq.extend(s.ops)
simplify_op_sequence(seq)
return OpSequence(tuple(seq))
def invert(self) -> 'OpSequence':
return OpSequence(tuple(x.invert() for x in reversed(self.ops)))
def compute_unifications(self) -> List[Tuple[int, int]]:
base_offset = 0
strides = []
for op in self.ops:
if isinstance(op, ConstOffsetOp):
base_offset += op.const
elif isinstance(op, StrideOffsetOp):
strides.append(op.stride)
else:
base_offset = 0
strides = []
return [(base_offset, base_offset + stride) for stride in strides]
def simplify_op_sequence(seq: List[Op]):
i = 0
while i < len(seq):
cur = seq[i]
if isinstance(cur, ConstOffsetOp) and cur.const == 0:
seq.pop(i)
if i > 0:
i -= 1
continue
nex = seq[i + 1] if i + 1 < len(seq) else None
if isinstance(cur, ConstOffsetOp) and isinstance(nex, ConstOffsetOp):
seq[i] = ConstOffsetOp(cur.const + nex.const)
seq.pop(i + 1)
if i > 0:
i -= 1
continue
if isinstance(cur, RefOp) and isinstance(nex, DerefOp) and cur.size == nex.size:
seq.pop(i)
seq.pop(i)
if i > 0:
i -= 1
continue
if isinstance(cur, DerefOp) and isinstance(nex, RefOp) and cur.size == nex.size:
seq.pop(i)
seq.pop(i)
if i > 0:
i -= 1
continue
if isinstance(cur, NegOp) and isinstance(nex, NegOp):
seq.pop(i)
seq.pop(i)
if i > 0:
i -= 1
continue
if isinstance(cur, StrideOffsetOp) and isinstance(nex, StrideOffsetOp) and cur.stride == nex.stride:
seq.pop(i)
if i > 0:
i -= 1
continue
i += 1
# noinspection PyArgumentList
class DataKind(IntEnum):
GenericData = auto()
Int = auto()
Float = auto()
Pointer = auto()
@dataclass(slots=True)
class Prop:
self_data: Counter[DataKind] = field(default_factory=Counter)
struct_data: Counter[Tuple[int, int, DataKind]] = field(default_factory=Counter)
unifications: Counter[Tuple[int, int]] = field(default_factory=Counter)
def update(self, other: 'Prop'):
self.self_data.update(other.self_data)
self.struct_data.update(other.struct_data)
self.unifications.update(other.unifications)
def subtract(self, other: 'Prop'):
self.self_data.subtract(other.self_data)
self.struct_data.subtract(other.struct_data)
self.unifications.subtract(other.unifications)
def maximize(self, other: 'Prop'):
for key, val in other.self_data.items():
self.self_data[key] = max(self.self_data[key], val)
for key, val in other.struct_data.items():
self.struct_data[key] = max(self.struct_data[key], val)
for key, val in other.unifications.items():
self.unifications[key] = max(self.unifications[key], val)
def __or__(self, other: 'Prop'):
result = Prop()
result.maximize(self)
result.maximize(other)
return result
def transform(self, ops: OpSequence):
result = copy.deepcopy(self)
for op in ops.ops:
if isinstance(op, RefOp):
result.struct_data = Counter({(0, op.size, k): v for k, v in result.self_data.items()})
result.self_data.clear()
result.unifications.clear()
elif isinstance(op, DerefOp):
result.self_data = Counter({k[2]: v for k, v in result.struct_data.items() if k[0] == 0 and k[1] == op.size})
result.struct_data.clear()
result.unifications.clear()
elif isinstance(op, ConstOffsetOp):
items = list(result.struct_data.items())
result.struct_data.clear()
for (offset, size, kind), v in items:
result.struct_data[(offset - op.const, size, kind)] = v # there is some JANK shit going on with this sign
saved = result.self_data.get(DataKind.Pointer, None)
result.self_data.clear()
if saved:
result.self_data[DataKind.Pointer] = saved
result.unifications = Counter((x - op.const, y - op.const) for x, y in result.unifications)
elif isinstance(op, StrideOffsetOp):
result.self_data.clear()
elif isinstance(op, VarOffsetOp):
saved = result.self_data.get(DataKind.Pointer, None)
result = Prop()
if saved:
result.self_data[DataKind.Pointer] = saved
else:
result = Prop()
return result
@dataclass(frozen=True, slots=True)
class LiveData:
"""
The in-flight data representation for the analysis. All sizes are in bytes
"""
loc: CodeLoc
sources: Tuple[Tuple[Atom, OpSequence], ...]
size: int
# if this is non-empty it means the data is characterized SOLELY by the sum of a0*x + a1*y + a2*z + ...
strides: Tuple[Tuple[Optional['LiveData'], int], ...]
@property
def const(self):
if len(self.strides) == 1 and self.strides[0][0] is None:
return self.strides[0][1]
return None
@classmethod
def new_null(cls, loc: CodeLoc, size: int, strides: Tuple[Tuple[Optional['LiveData'], int], ...]=()):
return cls(loc, (), size, strides)
@classmethod
def new_atom(cls, loc: CodeLoc, atom: Atom) -> 'LiveData':
return cls(loc, ((atom, OpSequence()),), atom.size, ())
@classmethod
def new_const(cls, loc: CodeLoc, value: int, size: int) -> 'LiveData':
return cls(loc, ((ConstAtom(loc, size, value), OpSequence()),), size, ((None, value),))
def appended(self, loc: CodeLoc, op: Op, size: int, strides: Optional[Tuple[Tuple[Optional['LiveData'], int], ...]]=None) -> 'LiveData':
return LiveData(
loc,
tuple((atom, seq.appended(op)) for atom, seq in self.sources),
size,
self.strides if strides is None else strides,
)
def unioned(
self,
loc: CodeLoc,
other: 'LiveData',
size: int,
strides: Tuple[Tuple[Optional['LiveData'], int], ...]=(),
) -> 'LiveData':
return LiveData(loc, self.sources + other.sources, size, strides)
def commit(self, target: Atom, graph: networkx.DiGraph):
prop = Prop()
for src, seq in self.sources:
for start, end in seq.compute_unifications():
prop.unifications[(start, end)] += 1
graph.add_edge(src, target, ops=seq, cf=[])
self.atom_prop(target, prop, graph)
def prop(self, prop: Prop, graph: networkx.DiGraph):
for atom, ops in self.sources:
tprop = prop.transform(ops.invert())
self.atom_prop(atom, tprop, graph)
@staticmethod
def atom_prop(atom, tprop: Prop, graph: networkx.DiGraph):
try:
eprop: Prop = graph.nodes[atom].get('prop')
except KeyError:
graph.add_node(atom, prop=tprop)
else:
if eprop:
eprop.update(tprop)
else:
graph.nodes[atom]['prop'] = tprop
def prop_self(self, kind: DataKind, graph: networkx.DiGraph):
prop = Prop()
prop.self_data[kind] += 1
self.prop(prop, graph)
def prop_union(self, offset1: int, offset2: int, graph: networkx.DiGraph):
prop = Prop()
prop.unifications[(offset1, offset2)] += 1
self.prop(prop, graph)
@dataclass(frozen=True, slots=True)
class RegisterInputInfo:
atom: RegisterAtom
callsites: Tuple[int, ...]
# when we go back through a ret, we push the callsite onto this stack. we may then only go back through calls if
# they match the top of the stack, at which point they are popped off
reverse_callsites: Tuple[int, ...]
# when we go back through a call and there is nothing on the callstack, an entry is pushed onto this stack.
# not sure what this indicates yet
def step(self, pred: int, succ: int, jumpkind: str, callsite: Optional[int]) -> 'Optional[RegisterInputInfo]':
if jumpkind == 'Ijk_Ret':
if callsite is None:
raise TypeError("Must specify callsite if jumpkind is Ret")
return RegisterInputInfo(atom=self.atom, callsites=self.callsites + (callsite,), reverse_callsites=self.reverse_callsites)
elif jumpkind == 'Ijk_Call':
if not self.callsites:
return RegisterInputInfo(atom=self.atom, callsites=(), reverse_callsites=self.reverse_callsites + (pred,))
elif self.callsites[-1] == pred:
return RegisterInputInfo(atom=self.atom, callsites=self.callsites[:-1], reverse_callsites=self.reverse_callsites)
else:
return None
else:
return RegisterInputInfo(atom=self.atom, callsites=self.callsites, reverse_callsites=self.reverse_callsites)
def commit(self, graph: networkx.DiGraph, source: RegisterAtom):
actions: List[ControlFlowAction] = [ControlFlowActionPop(i) for i in self.callsites]
actions += [ControlFlowActionPush(i) for i in self.reverse_callsites]
graph.add_edge(source, self.atom, ops=OpSequence(), cf=actions)
@dataclass(frozen=True, slots=True)
class ControlFlowAction:
pass
@dataclass(frozen=True, slots=True)
class ControlFlowActionPush(ControlFlowAction):
callsite: int
@dataclass(frozen=True, slots=True)
class ControlFlowActionPop(ControlFlowAction):
callsite: int
@dataclass(slots=True)
class BlockInfo:
outputs: Dict[str, RegisterAtom] = field(default_factory=dict) # slot names
inputs: Dict[str, RegisterAtom] = field(default_factory=dict) # alias names
atoms: List[Atom] = field(default_factory=list)
ready_inputs: Set[str] = field(default_factory=set) # alias names