better handling of raw ptrs using NonNull

This commit is contained in:
milo 2024-03-09 18:00:03 -05:00
parent 16a1b12a7a
commit 8548784d45
4 changed files with 127 additions and 101 deletions

View File

@ -10,6 +10,7 @@ use mino::matrix::Mat;
use mino::srs::{Piece, Queue}; use mino::srs::{Piece, Queue};
mod node; mod node;
mod raw;
mod trans; mod trans;
use crate::bot::node::{Node, RawNodePtr}; use crate::bot::node::{Node, RawNodePtr};
@ -219,7 +220,7 @@ impl SegmentedAStar {
} }
fn best(&self) -> &Node { fn best(&self) -> &Node {
unsafe { self.best.as_node() } unsafe { self.best.get() }
} }
fn step( fn step(
@ -253,8 +254,9 @@ impl SegmentedAStar {
eval: &Evaluator, eval: &Evaluator,
) -> Result<u32, Option<&'a Node>> { ) -> Result<u32, Option<&'a Node>> {
let open_set = self.open.get_mut(self.depth).ok_or(None)?; let open_set = self.open.get_mut(self.depth).ok_or(None)?;
let cand = open_set.pop().ok_or(None)?; let cand = open_set.pop().ok_or(None)?;
let cand = unsafe { cand.0.as_node() }; let cand: &'a Node = unsafe { cand.0.get() };
if cand.is_terminal() { if cand.is_terminal() {
return Err(Some(cand)); return Err(Some(cand));
@ -294,7 +296,7 @@ impl SegmentedAStar {
for (depth, set) in self.open.iter().enumerate() { for (depth, set) in self.open.iter().enumerate() {
let Some(cand) = set.peek() else { continue }; let Some(cand) = set.peek() else { continue };
let cand = unsafe { cand.0.as_node() }; let cand = unsafe { cand.0.get() };
if best.map_or(true, |best| cand.rating() > best.rating()) { if best.map_or(true, |best| cand.rating() > best.rating()) {
best = Some(cand); best = Some(cand);
self.depth = depth; self.depth = depth;
@ -315,8 +317,8 @@ impl From<&Node> for AStarNode {
impl core::cmp::Ord for AStarNode { impl core::cmp::Ord for AStarNode {
fn cmp(&self, other: &Self) -> core::cmp::Ordering { fn cmp(&self, other: &Self) -> core::cmp::Ordering {
let lhs = unsafe { self.0.as_node() }; let lhs = unsafe { self.0.get() };
let rhs = unsafe { other.0.as_node() }; let rhs = unsafe { other.0.get() };
// TODO: tiebreaker strategy // TODO: tiebreaker strategy
lhs.rating().cmp(&rhs.rating()) lhs.rating().cmp(&rhs.rating())
} }

View File

@ -1,10 +1,11 @@
//! Graph data structures used by `Bot` in its search algorithm. //! Graph data structures used by `Bot` in its search algorithm.
use core::ops::Range; use core::ops::Range;
use core::ptr::NonNull;
use mino::matrix::{Mat, MatBuf}; use mino::matrix::{Mat, MatBuf};
use mino::srs::{Piece, PieceType, Queue}; use mino::srs::{Piece, Queue};
use crate::bot::raw::{RawMatPtr, RawQueue};
use crate::bot::trans::TransTable; use crate::bot::trans::TransTable;
use crate::eval::Rating; use crate::eval::Rating;
use crate::find::find_locations; use crate::find::find_locations;
@ -13,7 +14,7 @@ use crate::Arena;
/// Represents a node in the search tree. A node basically just consists of a board state /// Represents a node in the search tree. A node basically just consists of a board state
/// (incl. queue) and some extra metadata relating it to previous nodes in the tree. /// (incl. queue) and some extra metadata relating it to previous nodes in the tree.
pub struct Node { pub struct Node {
matrix: *const Mat, matrix: RawMatPtr,
queue: RawQueue, queue: RawQueue,
edge: Option<Edge>, edge: Option<Edge>,
rating: Rating, rating: Rating,
@ -22,12 +23,12 @@ pub struct Node {
} }
// Reallocates the matrix into the arena. // Reallocates the matrix into the arena.
fn copy_matrix<'a>(arena: &'a Arena, matrix: &Mat) -> &'a Mat { fn alloc_matrix_copy<'a>(arena: &'a Arena, matrix: &Mat) -> &'a Mat {
Mat::new(arena.alloc_slice_copy(&matrix[..])) Mat::new(arena.alloc_slice_copy(&matrix[..]))
} }
// Reallocates the queue into the arena. // Reallocates the queue into the arena.
fn copy_queue<'a>(arena: &'a Arena, queue: Queue<'_>) -> Queue<'a> { fn alloc_queue_copy<'a>(arena: &'a Arena, queue: Queue<'_>) -> Queue<'a> {
Queue { Queue {
hold: queue.hold, hold: queue.hold,
next: arena.alloc_slice_copy(queue.next), next: arena.alloc_slice_copy(queue.next),
@ -39,8 +40,8 @@ impl Node {
/// matrix and queue are also allocated onto the arena, so you do not need to worry /// matrix and queue are also allocated onto the arena, so you do not need to worry
/// about their lifetimes when managing the lifetime of the root. /// about their lifetimes when managing the lifetime of the root.
pub fn alloc_root<'a>(arena: &'a Arena, matrix: &Mat, queue: Queue<'_>) -> &'a Self { pub fn alloc_root<'a>(arena: &'a Arena, matrix: &Mat, queue: Queue<'_>) -> &'a Self {
let matrix = copy_matrix(arena, matrix); let matrix = alloc_matrix_copy(arena, matrix);
let queue = copy_queue(arena, queue); let queue = alloc_queue_copy(arena, queue);
Node::alloc(arena, matrix, queue, Rating::default(), None) Node::alloc(arena, matrix, queue, Rating::default(), None)
} }
@ -52,8 +53,8 @@ impl Node {
rating: Rating, rating: Rating,
edge: Option<Edge>, edge: Option<Edge>,
) -> &'a Self { ) -> &'a Self {
let matrix = matrix as *const Mat; let matrix = matrix.into();
let queue = RawQueue::from(queue); let queue = queue.into();
arena.alloc_with(|| Self { arena.alloc_with(|| Self {
matrix, matrix,
queue, queue,
@ -63,11 +64,11 @@ impl Node {
} }
pub fn matrix(&self) -> &Mat { pub fn matrix(&self) -> &Mat {
unsafe { &*self.matrix } unsafe { self.matrix.get() }
} }
pub fn queue(&self) -> Queue<'_> { pub fn queue(&self) -> Queue<'_> {
unsafe { self.queue.as_queue() } unsafe { self.queue.get() }
} }
pub fn rating(&self) -> Rating { pub fn rating(&self) -> Rating {
@ -110,6 +111,7 @@ impl Node {
locs.map(move |loc| Piece { ty, loc }) locs.map(move |loc| Piece { ty, loc })
}); });
let parent = RawNodePtr::from(self);
let children = placements.map(move |placement| { let children = placements.map(move |placement| {
// compute new board state from placement // compute new board state from placement
matrix.copy_from(self.matrix()); matrix.copy_from(self.matrix());
@ -121,14 +123,10 @@ impl Node {
trans.get_or_insert(&matrix, queue, || { trans.get_or_insert(&matrix, queue, || {
Node::alloc( Node::alloc(
arena, arena,
copy_matrix(arena, &matrix), alloc_matrix_copy(arena, &matrix),
// `queue` is already allocated on the arena so don't need to copy it
queue, queue,
evaluate(&matrix, queue, &cleared), evaluate(&matrix, queue, &cleared),
Some(Edge { Some(Edge { placement, parent }),
placement,
parent: self.into(),
}),
) )
}) })
}); });
@ -167,56 +165,30 @@ impl core::fmt::Debug for Node {
/// Represents an edge in the graph, pointing from a node to its parent. Particularly, /// Represents an edge in the graph, pointing from a node to its parent. Particularly,
/// contains the placement made in order to arrive at the child from the parent. /// contains the placement made in order to arrive at the child from the parent.
struct Edge { struct Edge {
placement: Piece,
parent: RawNodePtr, parent: RawNodePtr,
placement: Piece,
} }
impl Edge { impl Edge {
fn parent(&self) -> &Node { fn parent(&self) -> &Node {
unsafe { self.parent.as_node() } unsafe { self.parent.get() }
} }
} }
/// Wraps a raw pointer to a `Node`, requiring you to manage the lifetime yourself. /// Wraps a raw pointer to a `Node`, requiring you to manage the lifetime yourself.
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] #[derive(Copy, Clone, Debug)]
#[repr(transparent)] #[repr(transparent)]
pub struct RawNodePtr(*const Node); pub struct RawNodePtr(NonNull<Node>);
impl RawNodePtr { impl RawNodePtr {
pub unsafe fn as_node<'a>(self) -> &'a Node { pub unsafe fn get<'a>(self) -> &'a Node {
&*self.0 self.0.as_ref()
} }
} }
impl From<&Node> for RawNodePtr { impl From<&Node> for RawNodePtr {
fn from(node: &Node) -> Self { fn from(node: &Node) -> Self {
Self(node) Self(NonNull::from(node))
}
}
/// Wraps the raw components of a `Queue`, requiring you to manage the lifetime yourself.
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
struct RawQueue {
hold: Option<PieceType>,
len: u16, // u16 to save space esp. considering padding
next: *const PieceType,
}
impl RawQueue {
pub unsafe fn as_queue<'a>(self) -> Queue<'a> {
let hold = self.hold;
let next = core::slice::from_raw_parts(self.next, self.len as usize);
Queue { hold, next }
}
}
impl From<Queue<'_>> for RawQueue {
fn from(queue: Queue<'_>) -> Self {
Self {
hold: queue.hold,
len: queue.next.len() as u16,
next: queue.next.as_ptr(),
}
} }
} }
@ -224,6 +196,15 @@ impl From<Queue<'_>> for RawQueue {
mod test { mod test {
use super::*; use super::*;
use mino::mat; use mino::mat;
use mino::srs::PieceType;
#[test]
fn test_size_of() {
use core::mem::size_of;
assert_eq!(size_of::<RawNodePtr>(), size_of::<&Node>());
assert_eq!(size_of::<RawNodePtr>(), size_of::<Option<RawNodePtr>>());
assert_eq!(size_of::<Edge>(), size_of::<Option<Edge>>());
}
#[test] #[test]
fn test_copy_matrix() { fn test_copy_matrix() {
@ -232,7 +213,7 @@ mod test {
"..xxx..x.x"; "..xxx..x.x";
"xxxxxx.xxx"; "xxxxxx.xxx";
}; };
let mat1 = copy_matrix(&arena, mat0); let mat1 = alloc_matrix_copy(&arena, mat0);
assert_eq!(mat0, mat1); assert_eq!(mat0, mat1);
} }
@ -241,7 +222,7 @@ mod test {
use PieceType::*; use PieceType::*;
let arena = Arena::new(); let arena = Arena::new();
let q0 = Queue::new(None, &[I, L, J, O]); let q0 = Queue::new(None, &[I, L, J, O]);
let q1 = copy_queue(&arena, q0); let q1 = alloc_queue_copy(&arena, q0);
assert_eq!(q0, q1); assert_eq!(q0, q1);
} }
@ -258,12 +239,12 @@ mod test {
use PieceType::*; use PieceType::*;
let q0 = Queue::new(None, &[I, L, J, O]); let q0 = Queue::new(None, &[I, L, J, O]);
let rq0 = RawQueue::from(q0); let rq0 = RawQueue::from(q0);
let q1 = unsafe { rq0.as_queue() }; let q1 = unsafe { rq0.get() };
assert_eq!(q1, q0); assert_eq!(q1, q0);
let q0 = Queue::new(None, &[]); let q0 = Queue::new(None, &[]);
let rq0 = RawQueue::from(q0); let rq0 = RawQueue::from(q0);
let q1 = unsafe { rq0.as_queue() }; let q1 = unsafe { rq0.get() };
assert_eq!(q1, q0); assert_eq!(q1, q0);
} }
} }

79
fish/src/bot/raw.rs Normal file
View File

@ -0,0 +1,79 @@
use core::ptr::NonNull;
use mino::matrix::Mat;
use mino::srs::{PieceType, Queue};
/// Wraps the raw components of a `Queue`, requiring you to manage the lifetime yourself.
#[derive(Copy, Clone, Debug)]
pub struct RawQueue {
hold: Option<PieceType>,
len: u16, // u16 to save space esp. considering padding
next: *const PieceType,
}
impl RawQueue {
pub unsafe fn get<'a>(self) -> Queue<'a> {
let hold = self.hold;
let next = core::slice::from_raw_parts(self.next, self.len as usize);
Queue { hold, next }
}
}
impl From<Queue<'_>> for RawQueue {
fn from(queue: Queue<'_>) -> Self {
Self {
hold: queue.hold,
len: queue.next.len() as u16,
next: queue.next.as_ptr(),
}
}
}
/// Wraps a raw pointer to a `Mat`, requiring you to manage the lifetime yourself.
#[derive(Copy, Clone, Debug)]
pub struct RawMatPtr(NonNull<Mat>);
impl RawMatPtr {
pub unsafe fn get(&self) -> &Mat {
self.0.as_ref()
}
}
impl From<&Mat> for RawMatPtr {
fn from(mat: &Mat) -> Self {
Self(NonNull::from(mat))
}
}
impl hashbrown::Equivalent<RawMatPtr> for &Mat {
fn equivalent(&self, key: &RawMatPtr) -> bool {
RawMatPtr::from(*self) == *key
}
}
impl core::cmp::PartialEq for RawMatPtr {
fn eq(&self, other: &Self) -> bool {
*unsafe { self.get() } == *unsafe { other.get() }
}
}
impl core::cmp::Eq for RawMatPtr {}
impl core::hash::Hash for RawMatPtr {
fn hash<H: core::hash::Hasher>(&self, h: &mut H) {
unsafe { self.get() }.hash(h)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_size_of() {
use core::mem::size_of;
assert!(size_of::<RawQueue>() <= size_of::<Queue<'_>>());
assert_eq!(size_of::<RawQueue>(), size_of::<Option<RawQueue>>());
assert_eq!(size_of::<RawMatPtr>(), size_of::<&Mat>());
assert_eq!(size_of::<RawMatPtr>(), size_of::<Option<RawMatPtr>>());
}
}

View File

@ -6,6 +6,7 @@ use mino::srs::{PieceType, Queue};
use mino::Mat; use mino::Mat;
use crate::bot::node::{Node, RawNodePtr}; use crate::bot::node::{Node, RawNodePtr};
use crate::bot::raw::RawMatPtr;
use crate::HashMap; use crate::HashMap;
pub struct TransTable { pub struct TransTable {
@ -30,18 +31,12 @@ impl TransTable {
mut alloc: impl FnMut() -> &'a Node, mut alloc: impl FnMut() -> &'a Node,
) -> Result<&'a Node, RawNodePtr> { ) -> Result<&'a Node, RawNodePtr> {
// two-phase lookup: first key off queue info, then matrix info // two-phase lookup: first key off queue info, then matrix info
let map = self let map = self
.lookup .lookup
.entry(QueueKey::new(queue)) .entry(QueueKey::new(queue))
.or_insert_with(HashMap::new); .or_insert_with(HashMap::new);
// SAFETY: IMPORTANT! this ptr is ONLY used for lookup, NOT insert. it is not if let Some(&node_ptr) = map.get(&matrix) {
// necessarily long lived, we are only permitted to insert matrix coming from
// `alloc()`
let mat_ptr = unsafe { RawMatPtr::new(matrix) };
if let Some(&node_ptr) = map.get(&mat_ptr) {
return Err(node_ptr); return Err(node_ptr);
} }
@ -49,11 +44,11 @@ impl TransTable {
debug_assert_eq!(*node.matrix(), *matrix); debug_assert_eq!(*node.matrix(), *matrix);
debug_assert_eq!(node.queue(), queue); debug_assert_eq!(node.queue(), queue);
// SAFETY: as noted above, this matrix ptr is OK to use for insert, since it comes // IMPORTANT: it would be invalid to insert `matrix.into()` since we don't have
// from the newly allocated node. // guaruntees about its lifetime. we have to use a pointer to the matrix in the
let mat_ptr = unsafe { RawMatPtr::new(node.matrix()) }; // newly allocated `node`.
map.insert(node.matrix().into(), node.into());
map.insert(mat_ptr, RawNodePtr::from(node));
Ok(node) Ok(node)
} }
} }
@ -69,34 +64,3 @@ impl QueueKey {
Self(queue.next.len() as u16, queue.hold) Self(queue.next.len() as u16, queue.hold)
} }
} }
// share matrix data with the nodes themselves, since the nodes will outlive the
// transposition table.
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
struct RawMatPtr(*const Mat);
impl RawMatPtr {
// SAFETY: caller must guaruntee that `mat` lives longer than this value.
unsafe fn new(mat: &Mat) -> Self {
Self(mat)
}
fn as_mat(&self) -> &Mat {
unsafe { &*self.0 }
}
}
impl core::cmp::PartialEq for RawMatPtr {
fn eq(&self, other: &Self) -> bool {
*self.as_mat() == *other.as_mat()
}
}
impl core::cmp::Eq for RawMatPtr {}
impl core::hash::Hash for RawMatPtr {
fn hash<H: core::hash::Hasher>(&self, h: &mut H) {
self.as_mat().hash(h)
}
}