diff --git a/fish/src/bot.rs b/fish/src/bot.rs index af66cb1..4d32f24 100644 --- a/fish/src/bot.rs +++ b/fish/src/bot.rs @@ -8,8 +8,10 @@ use mino::matrix::Mat; use mino::srs::{Piece, Queue}; mod node; +mod trans; -use self::node::{Node, RawNodePtr}; +use crate::bot::node::{Node, RawNodePtr}; +use crate::bot::trans::TransTable; use crate::eval::{features, Weights}; pub(crate) use bumpalo::Bump as Arena; @@ -18,6 +20,7 @@ pub(crate) use bumpalo::Bump as Arena; pub struct Bot { iters: u32, evaluator: Evaluator, + trans: TransTable, algorithm: SegmentedAStar, // IMPORTANT: `arena` must occur after `algorithm` so that it is dropped last. arena: Arena, @@ -30,10 +33,12 @@ impl Bot { let arena = bumpalo::Bump::new(); let root = Node::alloc_root(&arena, matrix, queue); let evaluator = Evaluator::new(weights, root); + let trans = TransTable::new(); let algorithm = SegmentedAStar::new(root); Self { iters: 0, evaluator, + trans, algorithm, arena, } @@ -54,9 +59,12 @@ impl Bot { // same as "bot.think_for(2500); bot.think_for(5000 - bot.iterations());" let max_iters = self.iters + gas; while self.iters < max_iters { - let did_update = self - .algorithm - .step(&self.arena, &self.evaluator, &mut self.iters); + let did_update = self.algorithm.step( + &self.arena, + &mut self.trans, + &self.evaluator, + &mut self.iters, + ); if did_update { tracing::debug!( "new suggestion @ {}: {:?}", @@ -173,9 +181,15 @@ impl SegmentedAStar { self.best.map(|node| unsafe { node.as_node() }) } - fn step(&mut self, arena: &Arena, eval: &Evaluator, iters: &mut u32) -> bool { + fn step( + &mut self, + arena: &Arena, + trans: &mut TransTable, + eval: &Evaluator, + iters: &mut u32, + ) -> bool { *iters += 1; - match self.expand(arena, eval) { + match self.expand(arena, trans, eval) { Ok(work) => { *iters += work; false @@ -191,7 +205,12 @@ impl SegmentedAStar { } } - fn expand<'a>(&mut self, arena: &'a Arena, eval: &Evaluator) -> Result> { + fn expand<'a>( + &mut self, + arena: &'a Arena, + trans: &mut TransTable, + eval: &Evaluator, + ) -> Result> { let open_set = self.open.get_mut(self.depth); let cand = open_set.map_or(None, |set| set.pop()).ok_or(None)?; let cand = unsafe { cand.0.as_node() }; @@ -212,7 +231,7 @@ impl SegmentedAStar { eval.evaluate(mat, queue) }; - for suc in cand.expand(arena, evaluate) { + for suc in cand.expand(arena, trans, evaluate) { self.open[self.depth].push(suc.into()); } diff --git a/fish/src/bot/node.rs b/fish/src/bot/node.rs index 8cfbb16..67a96fc 100644 --- a/fish/src/bot/node.rs +++ b/fish/src/bot/node.rs @@ -3,6 +3,7 @@ use mino::matrix::{Mat, MatBuf}; use mino::srs::{Piece, PieceType, Queue}; +use crate::bot::trans::TransTable; use crate::bot::Arena; use crate::find::find_locations; @@ -94,6 +95,7 @@ impl Node { pub fn expand<'a, E>( &'a self, arena: &'a Arena, + trans: &'a mut TransTable, mut evaluate: E, ) -> impl Iterator + 'a where @@ -106,24 +108,33 @@ impl Node { let mut matrix = MatBuf::new(); - placements.map(move |placement| { + placements.filter_map(move |placement| { matrix.copy_from(self.matrix()); placement.cells().fill(&mut matrix); matrix.clear_lines(); // TODO: the above call returns useful information about if this placement is - // a combo, does it clear the bottom row of garbage. this should be used for - // prioritizing nodes + // a combo & does it clear the bottom row of garbage + let queue = self.queue().remove(placement.ty); - let parent = RawNodePtr::from(self); - let edge = Edge { placement, parent }; - let suc_matrix = copy_matrix(arena, &matrix); - let suc_queue = self.queue().remove(placement.ty); + trans + .get_or_insert(&matrix, queue, || { + Node::alloc( + arena, + copy_matrix(arena, &matrix), + queue, + evaluate(&matrix, queue), + Some(Edge { + placement, + parent: self.into(), + }), + ) + }) + // ignore duplicates found in the transposition table + .map(|n| tracing::trace!("duplicate: {n:?}")) + .err() - // TODO: transposition table - - let rating = evaluate(suc_matrix, suc_queue); - - Node::alloc(arena, suc_matrix, suc_queue, rating, Some(edge)) + // TODO: when a duplicate node is encountered, reparent it so that it has + // a more preferable initial path. }) } } diff --git a/fish/src/bot/trans.rs b/fish/src/bot/trans.rs new file mode 100644 index 0000000..ab0c1dd --- /dev/null +++ b/fish/src/bot/trans.rs @@ -0,0 +1,104 @@ +//! Transposition table for deduplicating identical states. + +// trans rights are human rights :D + +use mino::srs::{PieceType, Queue}; +use mino::Mat; + +use crate::bot::node::{Node, RawNodePtr}; + +pub(crate) struct TransTable { + lookup: HashMap>, +} + +type HashMap = hashbrown::HashMap; +type HashBuilder = core::hash::BuildHasherDefault; + +impl TransTable { + /// Constructs a new empty transposition table. + pub fn new() -> Self { + Self { + lookup: HashMap::with_capacity(128), + } + } + + /// Looks up a node in the transposition table matching the given state. If one is + /// found, returns `Ok(existing_node_ptr)`. If not found, a node is created and + /// inserted by calling the given closure, and returns `Err(newly_inserted_node)`. + pub fn get_or_insert<'a>( + &mut self, + matrix: &Mat, + queue: Queue<'_>, + mut alloc: impl FnMut() -> &'a Node, + ) -> Result { + // two-phase lookup: first key off queue info, then matrix info + + let map = self + .lookup + .entry(QueueKey::new(queue)) + .or_insert_with(HashMap::new); + + // SAFETY: IMPORTANT! this ptr is ONLY used for lookup, NOT insert. it is not + // necessarily long lived, we are only permitted to insert matrix coming from + // `alloc()` + let mat_ptr = unsafe { RawMatPtr::new(matrix) }; + + if let Some(&node_ptr) = map.get(&mat_ptr) { + return Ok(node_ptr); + } + + let node = alloc(); + debug_assert!(*node.matrix() == *matrix); + debug_assert!(node.queue() == queue); + + // SAFETY: as noted above, this matrix ptr is OK to use for insert, since it comes + // from the newly allocated node. + let mat_ptr = unsafe { RawMatPtr::new(node.matrix()) }; + + map.insert(mat_ptr, RawNodePtr::from(node)); + Err(node) + } +} + +// Rather than storing the actual next pieces of each queue, only store the number of +// next pieces. This is because the suffixes will always be the same if they have the +// same number of pieces, so its more efficient to only store & compare the length. +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] +struct QueueKey(u16, Option); + +impl QueueKey { + fn new(queue: Queue<'_>) -> Self { + Self(queue.next.len() as u16, queue.hold) + } +} + +// share matrix data with the nodes themselves, since the nodes will outlive the +// transposition table. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +struct RawMatPtr(*const Mat); + +impl RawMatPtr { + // SAFETY: caller must guaruntee that `mat` lives longer than this value. + unsafe fn new(mat: &Mat) -> Self { + Self(mat) + } + + fn as_mat(&self) -> &Mat { + unsafe { &*self.0 } + } +} + +impl core::cmp::PartialEq for RawMatPtr { + fn eq(&self, other: &Self) -> bool { + *self.as_mat() == *other.as_mat() + } +} + +impl core::cmp::Eq for RawMatPtr {} + +impl core::hash::Hash for RawMatPtr { + fn hash(&self, h: &mut H) { + self.as_mat().hash(h) + } +}