implement a node transposition table
This commit is contained in:
parent
588d673ad0
commit
0fb79906be
|
@ -8,8 +8,10 @@ use mino::matrix::Mat;
|
||||||
use mino::srs::{Piece, Queue};
|
use mino::srs::{Piece, Queue};
|
||||||
|
|
||||||
mod node;
|
mod node;
|
||||||
|
mod trans;
|
||||||
|
|
||||||
use self::node::{Node, RawNodePtr};
|
use crate::bot::node::{Node, RawNodePtr};
|
||||||
|
use crate::bot::trans::TransTable;
|
||||||
use crate::eval::{features, Weights};
|
use crate::eval::{features, Weights};
|
||||||
|
|
||||||
pub(crate) use bumpalo::Bump as Arena;
|
pub(crate) use bumpalo::Bump as Arena;
|
||||||
|
@ -18,6 +20,7 @@ pub(crate) use bumpalo::Bump as Arena;
|
||||||
pub struct Bot {
|
pub struct Bot {
|
||||||
iters: u32,
|
iters: u32,
|
||||||
evaluator: Evaluator,
|
evaluator: Evaluator,
|
||||||
|
trans: TransTable,
|
||||||
algorithm: SegmentedAStar,
|
algorithm: SegmentedAStar,
|
||||||
// IMPORTANT: `arena` must occur after `algorithm` so that it is dropped last.
|
// IMPORTANT: `arena` must occur after `algorithm` so that it is dropped last.
|
||||||
arena: Arena,
|
arena: Arena,
|
||||||
|
@ -30,10 +33,12 @@ impl Bot {
|
||||||
let arena = bumpalo::Bump::new();
|
let arena = bumpalo::Bump::new();
|
||||||
let root = Node::alloc_root(&arena, matrix, queue);
|
let root = Node::alloc_root(&arena, matrix, queue);
|
||||||
let evaluator = Evaluator::new(weights, root);
|
let evaluator = Evaluator::new(weights, root);
|
||||||
|
let trans = TransTable::new();
|
||||||
let algorithm = SegmentedAStar::new(root);
|
let algorithm = SegmentedAStar::new(root);
|
||||||
Self {
|
Self {
|
||||||
iters: 0,
|
iters: 0,
|
||||||
evaluator,
|
evaluator,
|
||||||
|
trans,
|
||||||
algorithm,
|
algorithm,
|
||||||
arena,
|
arena,
|
||||||
}
|
}
|
||||||
|
@ -54,9 +59,12 @@ impl Bot {
|
||||||
// same as "bot.think_for(2500); bot.think_for(5000 - bot.iterations());"
|
// same as "bot.think_for(2500); bot.think_for(5000 - bot.iterations());"
|
||||||
let max_iters = self.iters + gas;
|
let max_iters = self.iters + gas;
|
||||||
while self.iters < max_iters {
|
while self.iters < max_iters {
|
||||||
let did_update = self
|
let did_update = self.algorithm.step(
|
||||||
.algorithm
|
&self.arena,
|
||||||
.step(&self.arena, &self.evaluator, &mut self.iters);
|
&mut self.trans,
|
||||||
|
&self.evaluator,
|
||||||
|
&mut self.iters,
|
||||||
|
);
|
||||||
if did_update {
|
if did_update {
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
"new suggestion @ {}: {:?}",
|
"new suggestion @ {}: {:?}",
|
||||||
|
@ -173,9 +181,15 @@ impl SegmentedAStar {
|
||||||
self.best.map(|node| unsafe { node.as_node() })
|
self.best.map(|node| unsafe { node.as_node() })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn step(&mut self, arena: &Arena, eval: &Evaluator, iters: &mut u32) -> bool {
|
fn step(
|
||||||
|
&mut self,
|
||||||
|
arena: &Arena,
|
||||||
|
trans: &mut TransTable,
|
||||||
|
eval: &Evaluator,
|
||||||
|
iters: &mut u32,
|
||||||
|
) -> bool {
|
||||||
*iters += 1;
|
*iters += 1;
|
||||||
match self.expand(arena, eval) {
|
match self.expand(arena, trans, eval) {
|
||||||
Ok(work) => {
|
Ok(work) => {
|
||||||
*iters += work;
|
*iters += work;
|
||||||
false
|
false
|
||||||
|
@ -191,7 +205,12 @@ impl SegmentedAStar {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn expand<'a>(&mut self, arena: &'a Arena, eval: &Evaluator) -> Result<u32, Option<&'a Node>> {
|
fn expand<'a>(
|
||||||
|
&mut self,
|
||||||
|
arena: &'a Arena,
|
||||||
|
trans: &mut TransTable,
|
||||||
|
eval: &Evaluator,
|
||||||
|
) -> Result<u32, Option<&'a Node>> {
|
||||||
let open_set = self.open.get_mut(self.depth);
|
let open_set = self.open.get_mut(self.depth);
|
||||||
let cand = open_set.map_or(None, |set| set.pop()).ok_or(None)?;
|
let cand = open_set.map_or(None, |set| set.pop()).ok_or(None)?;
|
||||||
let cand = unsafe { cand.0.as_node() };
|
let cand = unsafe { cand.0.as_node() };
|
||||||
|
@ -212,7 +231,7 @@ impl SegmentedAStar {
|
||||||
eval.evaluate(mat, queue)
|
eval.evaluate(mat, queue)
|
||||||
};
|
};
|
||||||
|
|
||||||
for suc in cand.expand(arena, evaluate) {
|
for suc in cand.expand(arena, trans, evaluate) {
|
||||||
self.open[self.depth].push(suc.into());
|
self.open[self.depth].push(suc.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
use mino::matrix::{Mat, MatBuf};
|
use mino::matrix::{Mat, MatBuf};
|
||||||
use mino::srs::{Piece, PieceType, Queue};
|
use mino::srs::{Piece, PieceType, Queue};
|
||||||
|
|
||||||
|
use crate::bot::trans::TransTable;
|
||||||
use crate::bot::Arena;
|
use crate::bot::Arena;
|
||||||
use crate::find::find_locations;
|
use crate::find::find_locations;
|
||||||
|
|
||||||
|
@ -94,6 +95,7 @@ impl Node {
|
||||||
pub fn expand<'a, E>(
|
pub fn expand<'a, E>(
|
||||||
&'a self,
|
&'a self,
|
||||||
arena: &'a Arena,
|
arena: &'a Arena,
|
||||||
|
trans: &'a mut TransTable,
|
||||||
mut evaluate: E,
|
mut evaluate: E,
|
||||||
) -> impl Iterator<Item = &'a Node> + 'a
|
) -> impl Iterator<Item = &'a Node> + 'a
|
||||||
where
|
where
|
||||||
|
@ -106,24 +108,33 @@ impl Node {
|
||||||
|
|
||||||
let mut matrix = MatBuf::new();
|
let mut matrix = MatBuf::new();
|
||||||
|
|
||||||
placements.map(move |placement| {
|
placements.filter_map(move |placement| {
|
||||||
matrix.copy_from(self.matrix());
|
matrix.copy_from(self.matrix());
|
||||||
placement.cells().fill(&mut matrix);
|
placement.cells().fill(&mut matrix);
|
||||||
matrix.clear_lines();
|
matrix.clear_lines();
|
||||||
// TODO: the above call returns useful information about if this placement is
|
// TODO: the above call returns useful information about if this placement is
|
||||||
// a combo, does it clear the bottom row of garbage. this should be used for
|
// a combo & does it clear the bottom row of garbage
|
||||||
// prioritizing nodes
|
let queue = self.queue().remove(placement.ty);
|
||||||
|
|
||||||
let parent = RawNodePtr::from(self);
|
trans
|
||||||
let edge = Edge { placement, parent };
|
.get_or_insert(&matrix, queue, || {
|
||||||
let suc_matrix = copy_matrix(arena, &matrix);
|
Node::alloc(
|
||||||
let suc_queue = self.queue().remove(placement.ty);
|
arena,
|
||||||
|
copy_matrix(arena, &matrix),
|
||||||
|
queue,
|
||||||
|
evaluate(&matrix, queue),
|
||||||
|
Some(Edge {
|
||||||
|
placement,
|
||||||
|
parent: self.into(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
// ignore duplicates found in the transposition table
|
||||||
|
.map(|n| tracing::trace!("duplicate: {n:?}"))
|
||||||
|
.err()
|
||||||
|
|
||||||
// TODO: transposition table
|
// TODO: when a duplicate node is encountered, reparent it so that it has
|
||||||
|
// a more preferable initial path.
|
||||||
let rating = evaluate(suc_matrix, suc_queue);
|
|
||||||
|
|
||||||
Node::alloc(arena, suc_matrix, suc_queue, rating, Some(edge))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,104 @@
|
||||||
|
//! Transposition table for deduplicating identical states.
|
||||||
|
|
||||||
|
// trans rights are human rights :D
|
||||||
|
|
||||||
|
use mino::srs::{PieceType, Queue};
|
||||||
|
use mino::Mat;
|
||||||
|
|
||||||
|
use crate::bot::node::{Node, RawNodePtr};
|
||||||
|
|
||||||
|
pub(crate) struct TransTable {
|
||||||
|
lookup: HashMap<QueueKey, HashMap<RawMatPtr, RawNodePtr>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
type HashMap<K, V> = hashbrown::HashMap<K, V, HashBuilder>;
|
||||||
|
type HashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>;
|
||||||
|
|
||||||
|
impl TransTable {
|
||||||
|
/// Constructs a new empty transposition table.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
lookup: HashMap::with_capacity(128),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Looks up a node in the transposition table matching the given state. If one is
|
||||||
|
/// found, returns `Ok(existing_node_ptr)`. If not found, a node is created and
|
||||||
|
/// inserted by calling the given closure, and returns `Err(newly_inserted_node)`.
|
||||||
|
pub fn get_or_insert<'a>(
|
||||||
|
&mut self,
|
||||||
|
matrix: &Mat,
|
||||||
|
queue: Queue<'_>,
|
||||||
|
mut alloc: impl FnMut() -> &'a Node,
|
||||||
|
) -> Result<RawNodePtr, &'a Node> {
|
||||||
|
// two-phase lookup: first key off queue info, then matrix info
|
||||||
|
|
||||||
|
let map = self
|
||||||
|
.lookup
|
||||||
|
.entry(QueueKey::new(queue))
|
||||||
|
.or_insert_with(HashMap::new);
|
||||||
|
|
||||||
|
// SAFETY: IMPORTANT! this ptr is ONLY used for lookup, NOT insert. it is not
|
||||||
|
// necessarily long lived, we are only permitted to insert matrix coming from
|
||||||
|
// `alloc()`
|
||||||
|
let mat_ptr = unsafe { RawMatPtr::new(matrix) };
|
||||||
|
|
||||||
|
if let Some(&node_ptr) = map.get(&mat_ptr) {
|
||||||
|
return Ok(node_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
let node = alloc();
|
||||||
|
debug_assert!(*node.matrix() == *matrix);
|
||||||
|
debug_assert!(node.queue() == queue);
|
||||||
|
|
||||||
|
// SAFETY: as noted above, this matrix ptr is OK to use for insert, since it comes
|
||||||
|
// from the newly allocated node.
|
||||||
|
let mat_ptr = unsafe { RawMatPtr::new(node.matrix()) };
|
||||||
|
|
||||||
|
map.insert(mat_ptr, RawNodePtr::from(node));
|
||||||
|
Err(node)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rather than storing the actual next pieces of each queue, only store the number of
|
||||||
|
// next pieces. This is because the suffixes will always be the same if they have the
|
||||||
|
// same number of pieces, so its more efficient to only store & compare the length.
|
||||||
|
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
|
||||||
|
struct QueueKey(u16, Option<PieceType>);
|
||||||
|
|
||||||
|
impl QueueKey {
|
||||||
|
fn new(queue: Queue<'_>) -> Self {
|
||||||
|
Self(queue.next.len() as u16, queue.hold)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// share matrix data with the nodes themselves, since the nodes will outlive the
|
||||||
|
// transposition table.
|
||||||
|
#[derive(Copy, Clone, Debug)]
|
||||||
|
#[repr(transparent)]
|
||||||
|
struct RawMatPtr(*const Mat);
|
||||||
|
|
||||||
|
impl RawMatPtr {
|
||||||
|
// SAFETY: caller must guaruntee that `mat` lives longer than this value.
|
||||||
|
unsafe fn new(mat: &Mat) -> Self {
|
||||||
|
Self(mat)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_mat(&self) -> &Mat {
|
||||||
|
unsafe { &*self.0 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl core::cmp::PartialEq for RawMatPtr {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
*self.as_mat() == *other.as_mat()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl core::cmp::Eq for RawMatPtr {}
|
||||||
|
|
||||||
|
impl core::hash::Hash for RawMatPtr {
|
||||||
|
fn hash<H: core::hash::Hasher>(&self, h: &mut H) {
|
||||||
|
self.as_mat().hash(h)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue