implement a node transposition table

This commit is contained in:
tali 2023-04-16 15:06:33 -04:00
parent 588d673ad0
commit 0fb79906be
3 changed files with 154 additions and 20 deletions

View File

@ -8,8 +8,10 @@ use mino::matrix::Mat;
use mino::srs::{Piece, Queue}; use mino::srs::{Piece, Queue};
mod node; mod node;
mod trans;
use self::node::{Node, RawNodePtr}; use crate::bot::node::{Node, RawNodePtr};
use crate::bot::trans::TransTable;
use crate::eval::{features, Weights}; use crate::eval::{features, Weights};
pub(crate) use bumpalo::Bump as Arena; pub(crate) use bumpalo::Bump as Arena;
@ -18,6 +20,7 @@ pub(crate) use bumpalo::Bump as Arena;
pub struct Bot { pub struct Bot {
iters: u32, iters: u32,
evaluator: Evaluator, evaluator: Evaluator,
trans: TransTable,
algorithm: SegmentedAStar, algorithm: SegmentedAStar,
// IMPORTANT: `arena` must occur after `algorithm` so that it is dropped last. // IMPORTANT: `arena` must occur after `algorithm` so that it is dropped last.
arena: Arena, arena: Arena,
@ -30,10 +33,12 @@ impl Bot {
let arena = bumpalo::Bump::new(); let arena = bumpalo::Bump::new();
let root = Node::alloc_root(&arena, matrix, queue); let root = Node::alloc_root(&arena, matrix, queue);
let evaluator = Evaluator::new(weights, root); let evaluator = Evaluator::new(weights, root);
let trans = TransTable::new();
let algorithm = SegmentedAStar::new(root); let algorithm = SegmentedAStar::new(root);
Self { Self {
iters: 0, iters: 0,
evaluator, evaluator,
trans,
algorithm, algorithm,
arena, arena,
} }
@ -54,9 +59,12 @@ impl Bot {
// same as "bot.think_for(2500); bot.think_for(5000 - bot.iterations());" // same as "bot.think_for(2500); bot.think_for(5000 - bot.iterations());"
let max_iters = self.iters + gas; let max_iters = self.iters + gas;
while self.iters < max_iters { while self.iters < max_iters {
let did_update = self let did_update = self.algorithm.step(
.algorithm &self.arena,
.step(&self.arena, &self.evaluator, &mut self.iters); &mut self.trans,
&self.evaluator,
&mut self.iters,
);
if did_update { if did_update {
tracing::debug!( tracing::debug!(
"new suggestion @ {}: {:?}", "new suggestion @ {}: {:?}",
@ -173,9 +181,15 @@ impl SegmentedAStar {
self.best.map(|node| unsafe { node.as_node() }) self.best.map(|node| unsafe { node.as_node() })
} }
fn step(&mut self, arena: &Arena, eval: &Evaluator, iters: &mut u32) -> bool { fn step(
&mut self,
arena: &Arena,
trans: &mut TransTable,
eval: &Evaluator,
iters: &mut u32,
) -> bool {
*iters += 1; *iters += 1;
match self.expand(arena, eval) { match self.expand(arena, trans, eval) {
Ok(work) => { Ok(work) => {
*iters += work; *iters += work;
false false
@ -191,7 +205,12 @@ impl SegmentedAStar {
} }
} }
fn expand<'a>(&mut self, arena: &'a Arena, eval: &Evaluator) -> Result<u32, Option<&'a Node>> { fn expand<'a>(
&mut self,
arena: &'a Arena,
trans: &mut TransTable,
eval: &Evaluator,
) -> Result<u32, Option<&'a Node>> {
let open_set = self.open.get_mut(self.depth); let open_set = self.open.get_mut(self.depth);
let cand = open_set.map_or(None, |set| set.pop()).ok_or(None)?; let cand = open_set.map_or(None, |set| set.pop()).ok_or(None)?;
let cand = unsafe { cand.0.as_node() }; let cand = unsafe { cand.0.as_node() };
@ -212,7 +231,7 @@ impl SegmentedAStar {
eval.evaluate(mat, queue) eval.evaluate(mat, queue)
}; };
for suc in cand.expand(arena, evaluate) { for suc in cand.expand(arena, trans, evaluate) {
self.open[self.depth].push(suc.into()); self.open[self.depth].push(suc.into());
} }

View File

@ -3,6 +3,7 @@
use mino::matrix::{Mat, MatBuf}; use mino::matrix::{Mat, MatBuf};
use mino::srs::{Piece, PieceType, Queue}; use mino::srs::{Piece, PieceType, Queue};
use crate::bot::trans::TransTable;
use crate::bot::Arena; use crate::bot::Arena;
use crate::find::find_locations; use crate::find::find_locations;
@ -94,6 +95,7 @@ impl Node {
pub fn expand<'a, E>( pub fn expand<'a, E>(
&'a self, &'a self,
arena: &'a Arena, arena: &'a Arena,
trans: &'a mut TransTable,
mut evaluate: E, mut evaluate: E,
) -> impl Iterator<Item = &'a Node> + 'a ) -> impl Iterator<Item = &'a Node> + 'a
where where
@ -106,24 +108,33 @@ impl Node {
let mut matrix = MatBuf::new(); let mut matrix = MatBuf::new();
placements.map(move |placement| { placements.filter_map(move |placement| {
matrix.copy_from(self.matrix()); matrix.copy_from(self.matrix());
placement.cells().fill(&mut matrix); placement.cells().fill(&mut matrix);
matrix.clear_lines(); matrix.clear_lines();
// TODO: the above call returns useful information about if this placement is // TODO: the above call returns useful information about if this placement is
// a combo, does it clear the bottom row of garbage. this should be used for // a combo & does it clear the bottom row of garbage
// prioritizing nodes let queue = self.queue().remove(placement.ty);
let parent = RawNodePtr::from(self); trans
let edge = Edge { placement, parent }; .get_or_insert(&matrix, queue, || {
let suc_matrix = copy_matrix(arena, &matrix); Node::alloc(
let suc_queue = self.queue().remove(placement.ty); arena,
copy_matrix(arena, &matrix),
queue,
evaluate(&matrix, queue),
Some(Edge {
placement,
parent: self.into(),
}),
)
})
// ignore duplicates found in the transposition table
.map(|n| tracing::trace!("duplicate: {n:?}"))
.err()
// TODO: transposition table // TODO: when a duplicate node is encountered, reparent it so that it has
// a more preferable initial path.
let rating = evaluate(suc_matrix, suc_queue);
Node::alloc(arena, suc_matrix, suc_queue, rating, Some(edge))
}) })
} }
} }

104
fish/src/bot/trans.rs Normal file
View File

@ -0,0 +1,104 @@
//! Transposition table for deduplicating identical states.
// trans rights are human rights :D
use mino::srs::{PieceType, Queue};
use mino::Mat;
use crate::bot::node::{Node, RawNodePtr};
pub(crate) struct TransTable {
lookup: HashMap<QueueKey, HashMap<RawMatPtr, RawNodePtr>>,
}
type HashMap<K, V> = hashbrown::HashMap<K, V, HashBuilder>;
type HashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>;
impl TransTable {
/// Constructs a new empty transposition table.
pub fn new() -> Self {
Self {
lookup: HashMap::with_capacity(128),
}
}
/// Looks up a node in the transposition table matching the given state. If one is
/// found, returns `Ok(existing_node_ptr)`. If not found, a node is created and
/// inserted by calling the given closure, and returns `Err(newly_inserted_node)`.
pub fn get_or_insert<'a>(
&mut self,
matrix: &Mat,
queue: Queue<'_>,
mut alloc: impl FnMut() -> &'a Node,
) -> Result<RawNodePtr, &'a Node> {
// two-phase lookup: first key off queue info, then matrix info
let map = self
.lookup
.entry(QueueKey::new(queue))
.or_insert_with(HashMap::new);
// SAFETY: IMPORTANT! this ptr is ONLY used for lookup, NOT insert. it is not
// necessarily long lived, we are only permitted to insert matrix coming from
// `alloc()`
let mat_ptr = unsafe { RawMatPtr::new(matrix) };
if let Some(&node_ptr) = map.get(&mat_ptr) {
return Ok(node_ptr);
}
let node = alloc();
debug_assert!(*node.matrix() == *matrix);
debug_assert!(node.queue() == queue);
// SAFETY: as noted above, this matrix ptr is OK to use for insert, since it comes
// from the newly allocated node.
let mat_ptr = unsafe { RawMatPtr::new(node.matrix()) };
map.insert(mat_ptr, RawNodePtr::from(node));
Err(node)
}
}
// Rather than storing the actual next pieces of each queue, only store the number of
// next pieces. This is because the suffixes will always be the same if they have the
// same number of pieces, so its more efficient to only store & compare the length.
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
struct QueueKey(u16, Option<PieceType>);
impl QueueKey {
fn new(queue: Queue<'_>) -> Self {
Self(queue.next.len() as u16, queue.hold)
}
}
// share matrix data with the nodes themselves, since the nodes will outlive the
// transposition table.
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
struct RawMatPtr(*const Mat);
impl RawMatPtr {
// SAFETY: caller must guaruntee that `mat` lives longer than this value.
unsafe fn new(mat: &Mat) -> Self {
Self(mat)
}
fn as_mat(&self) -> &Mat {
unsafe { &*self.0 }
}
}
impl core::cmp::PartialEq for RawMatPtr {
fn eq(&self, other: &Self) -> bool {
*self.as_mat() == *other.as_mat()
}
}
impl core::cmp::Eq for RawMatPtr {}
impl core::hash::Hash for RawMatPtr {
fn hash<H: core::hash::Hasher>(&self, h: &mut H) {
self.as_mat().hash(h)
}
}