implement a node transposition table
This commit is contained in:
parent
588d673ad0
commit
0fb79906be
|
@ -8,8 +8,10 @@ use mino::matrix::Mat;
|
|||
use mino::srs::{Piece, Queue};
|
||||
|
||||
mod node;
|
||||
mod trans;
|
||||
|
||||
use self::node::{Node, RawNodePtr};
|
||||
use crate::bot::node::{Node, RawNodePtr};
|
||||
use crate::bot::trans::TransTable;
|
||||
use crate::eval::{features, Weights};
|
||||
|
||||
pub(crate) use bumpalo::Bump as Arena;
|
||||
|
@ -18,6 +20,7 @@ pub(crate) use bumpalo::Bump as Arena;
|
|||
pub struct Bot {
|
||||
iters: u32,
|
||||
evaluator: Evaluator,
|
||||
trans: TransTable,
|
||||
algorithm: SegmentedAStar,
|
||||
// IMPORTANT: `arena` must occur after `algorithm` so that it is dropped last.
|
||||
arena: Arena,
|
||||
|
@ -30,10 +33,12 @@ impl Bot {
|
|||
let arena = bumpalo::Bump::new();
|
||||
let root = Node::alloc_root(&arena, matrix, queue);
|
||||
let evaluator = Evaluator::new(weights, root);
|
||||
let trans = TransTable::new();
|
||||
let algorithm = SegmentedAStar::new(root);
|
||||
Self {
|
||||
iters: 0,
|
||||
evaluator,
|
||||
trans,
|
||||
algorithm,
|
||||
arena,
|
||||
}
|
||||
|
@ -54,9 +59,12 @@ impl Bot {
|
|||
// same as "bot.think_for(2500); bot.think_for(5000 - bot.iterations());"
|
||||
let max_iters = self.iters + gas;
|
||||
while self.iters < max_iters {
|
||||
let did_update = self
|
||||
.algorithm
|
||||
.step(&self.arena, &self.evaluator, &mut self.iters);
|
||||
let did_update = self.algorithm.step(
|
||||
&self.arena,
|
||||
&mut self.trans,
|
||||
&self.evaluator,
|
||||
&mut self.iters,
|
||||
);
|
||||
if did_update {
|
||||
tracing::debug!(
|
||||
"new suggestion @ {}: {:?}",
|
||||
|
@ -173,9 +181,15 @@ impl SegmentedAStar {
|
|||
self.best.map(|node| unsafe { node.as_node() })
|
||||
}
|
||||
|
||||
fn step(&mut self, arena: &Arena, eval: &Evaluator, iters: &mut u32) -> bool {
|
||||
fn step(
|
||||
&mut self,
|
||||
arena: &Arena,
|
||||
trans: &mut TransTable,
|
||||
eval: &Evaluator,
|
||||
iters: &mut u32,
|
||||
) -> bool {
|
||||
*iters += 1;
|
||||
match self.expand(arena, eval) {
|
||||
match self.expand(arena, trans, eval) {
|
||||
Ok(work) => {
|
||||
*iters += work;
|
||||
false
|
||||
|
@ -191,7 +205,12 @@ impl SegmentedAStar {
|
|||
}
|
||||
}
|
||||
|
||||
fn expand<'a>(&mut self, arena: &'a Arena, eval: &Evaluator) -> Result<u32, Option<&'a Node>> {
|
||||
fn expand<'a>(
|
||||
&mut self,
|
||||
arena: &'a Arena,
|
||||
trans: &mut TransTable,
|
||||
eval: &Evaluator,
|
||||
) -> Result<u32, Option<&'a Node>> {
|
||||
let open_set = self.open.get_mut(self.depth);
|
||||
let cand = open_set.map_or(None, |set| set.pop()).ok_or(None)?;
|
||||
let cand = unsafe { cand.0.as_node() };
|
||||
|
@ -212,7 +231,7 @@ impl SegmentedAStar {
|
|||
eval.evaluate(mat, queue)
|
||||
};
|
||||
|
||||
for suc in cand.expand(arena, evaluate) {
|
||||
for suc in cand.expand(arena, trans, evaluate) {
|
||||
self.open[self.depth].push(suc.into());
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
use mino::matrix::{Mat, MatBuf};
|
||||
use mino::srs::{Piece, PieceType, Queue};
|
||||
|
||||
use crate::bot::trans::TransTable;
|
||||
use crate::bot::Arena;
|
||||
use crate::find::find_locations;
|
||||
|
||||
|
@ -94,6 +95,7 @@ impl Node {
|
|||
pub fn expand<'a, E>(
|
||||
&'a self,
|
||||
arena: &'a Arena,
|
||||
trans: &'a mut TransTable,
|
||||
mut evaluate: E,
|
||||
) -> impl Iterator<Item = &'a Node> + 'a
|
||||
where
|
||||
|
@ -106,24 +108,33 @@ impl Node {
|
|||
|
||||
let mut matrix = MatBuf::new();
|
||||
|
||||
placements.map(move |placement| {
|
||||
placements.filter_map(move |placement| {
|
||||
matrix.copy_from(self.matrix());
|
||||
placement.cells().fill(&mut matrix);
|
||||
matrix.clear_lines();
|
||||
// TODO: the above call returns useful information about if this placement is
|
||||
// a combo, does it clear the bottom row of garbage. this should be used for
|
||||
// prioritizing nodes
|
||||
// a combo & does it clear the bottom row of garbage
|
||||
let queue = self.queue().remove(placement.ty);
|
||||
|
||||
let parent = RawNodePtr::from(self);
|
||||
let edge = Edge { placement, parent };
|
||||
let suc_matrix = copy_matrix(arena, &matrix);
|
||||
let suc_queue = self.queue().remove(placement.ty);
|
||||
trans
|
||||
.get_or_insert(&matrix, queue, || {
|
||||
Node::alloc(
|
||||
arena,
|
||||
copy_matrix(arena, &matrix),
|
||||
queue,
|
||||
evaluate(&matrix, queue),
|
||||
Some(Edge {
|
||||
placement,
|
||||
parent: self.into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
// ignore duplicates found in the transposition table
|
||||
.map(|n| tracing::trace!("duplicate: {n:?}"))
|
||||
.err()
|
||||
|
||||
// TODO: transposition table
|
||||
|
||||
let rating = evaluate(suc_matrix, suc_queue);
|
||||
|
||||
Node::alloc(arena, suc_matrix, suc_queue, rating, Some(edge))
|
||||
// TODO: when a duplicate node is encountered, reparent it so that it has
|
||||
// a more preferable initial path.
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
//! Transposition table for deduplicating identical states.
|
||||
|
||||
// trans rights are human rights :D
|
||||
|
||||
use mino::srs::{PieceType, Queue};
|
||||
use mino::Mat;
|
||||
|
||||
use crate::bot::node::{Node, RawNodePtr};
|
||||
|
||||
pub(crate) struct TransTable {
|
||||
lookup: HashMap<QueueKey, HashMap<RawMatPtr, RawNodePtr>>,
|
||||
}
|
||||
|
||||
type HashMap<K, V> = hashbrown::HashMap<K, V, HashBuilder>;
|
||||
type HashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>;
|
||||
|
||||
impl TransTable {
|
||||
/// Constructs a new empty transposition table.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
lookup: HashMap::with_capacity(128),
|
||||
}
|
||||
}
|
||||
|
||||
/// Looks up a node in the transposition table matching the given state. If one is
|
||||
/// found, returns `Ok(existing_node_ptr)`. If not found, a node is created and
|
||||
/// inserted by calling the given closure, and returns `Err(newly_inserted_node)`.
|
||||
pub fn get_or_insert<'a>(
|
||||
&mut self,
|
||||
matrix: &Mat,
|
||||
queue: Queue<'_>,
|
||||
mut alloc: impl FnMut() -> &'a Node,
|
||||
) -> Result<RawNodePtr, &'a Node> {
|
||||
// two-phase lookup: first key off queue info, then matrix info
|
||||
|
||||
let map = self
|
||||
.lookup
|
||||
.entry(QueueKey::new(queue))
|
||||
.or_insert_with(HashMap::new);
|
||||
|
||||
// SAFETY: IMPORTANT! this ptr is ONLY used for lookup, NOT insert. it is not
|
||||
// necessarily long lived, we are only permitted to insert matrix coming from
|
||||
// `alloc()`
|
||||
let mat_ptr = unsafe { RawMatPtr::new(matrix) };
|
||||
|
||||
if let Some(&node_ptr) = map.get(&mat_ptr) {
|
||||
return Ok(node_ptr);
|
||||
}
|
||||
|
||||
let node = alloc();
|
||||
debug_assert!(*node.matrix() == *matrix);
|
||||
debug_assert!(node.queue() == queue);
|
||||
|
||||
// SAFETY: as noted above, this matrix ptr is OK to use for insert, since it comes
|
||||
// from the newly allocated node.
|
||||
let mat_ptr = unsafe { RawMatPtr::new(node.matrix()) };
|
||||
|
||||
map.insert(mat_ptr, RawNodePtr::from(node));
|
||||
Err(node)
|
||||
}
|
||||
}
|
||||
|
||||
// Rather than storing the actual next pieces of each queue, only store the number of
|
||||
// next pieces. This is because the suffixes will always be the same if they have the
|
||||
// same number of pieces, so its more efficient to only store & compare the length.
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
|
||||
struct QueueKey(u16, Option<PieceType>);
|
||||
|
||||
impl QueueKey {
|
||||
fn new(queue: Queue<'_>) -> Self {
|
||||
Self(queue.next.len() as u16, queue.hold)
|
||||
}
|
||||
}
|
||||
|
||||
// share matrix data with the nodes themselves, since the nodes will outlive the
|
||||
// transposition table.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[repr(transparent)]
|
||||
struct RawMatPtr(*const Mat);
|
||||
|
||||
impl RawMatPtr {
|
||||
// SAFETY: caller must guaruntee that `mat` lives longer than this value.
|
||||
unsafe fn new(mat: &Mat) -> Self {
|
||||
Self(mat)
|
||||
}
|
||||
|
||||
fn as_mat(&self) -> &Mat {
|
||||
unsafe { &*self.0 }
|
||||
}
|
||||
}
|
||||
|
||||
impl core::cmp::PartialEq for RawMatPtr {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
*self.as_mat() == *other.as_mat()
|
||||
}
|
||||
}
|
||||
|
||||
impl core::cmp::Eq for RawMatPtr {}
|
||||
|
||||
impl core::hash::Hash for RawMatPtr {
|
||||
fn hash<H: core::hash::Hasher>(&self, h: &mut H) {
|
||||
self.as_mat().hash(h)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue