better handling of raw ptrs using NonNull

This commit is contained in:
milo 2024-03-09 18:00:03 -05:00
parent 16a1b12a7a
commit 8548784d45
4 changed files with 127 additions and 101 deletions

View File

@ -10,6 +10,7 @@ use mino::matrix::Mat;
use mino::srs::{Piece, Queue};
mod node;
mod raw;
mod trans;
use crate::bot::node::{Node, RawNodePtr};
@ -219,7 +220,7 @@ impl SegmentedAStar {
}
fn best(&self) -> &Node {
unsafe { self.best.as_node() }
unsafe { self.best.get() }
}
fn step(
@ -253,8 +254,9 @@ impl SegmentedAStar {
eval: &Evaluator,
) -> Result<u32, Option<&'a Node>> {
let open_set = self.open.get_mut(self.depth).ok_or(None)?;
let cand = open_set.pop().ok_or(None)?;
let cand = unsafe { cand.0.as_node() };
let cand: &'a Node = unsafe { cand.0.get() };
if cand.is_terminal() {
return Err(Some(cand));
@ -294,7 +296,7 @@ impl SegmentedAStar {
for (depth, set) in self.open.iter().enumerate() {
let Some(cand) = set.peek() else { continue };
let cand = unsafe { cand.0.as_node() };
let cand = unsafe { cand.0.get() };
if best.map_or(true, |best| cand.rating() > best.rating()) {
best = Some(cand);
self.depth = depth;
@ -315,8 +317,8 @@ impl From<&Node> for AStarNode {
impl core::cmp::Ord for AStarNode {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
let lhs = unsafe { self.0.as_node() };
let rhs = unsafe { other.0.as_node() };
let lhs = unsafe { self.0.get() };
let rhs = unsafe { other.0.get() };
// TODO: tiebreaker strategy
lhs.rating().cmp(&rhs.rating())
}

View File

@ -1,10 +1,11 @@
//! Graph data structures used by `Bot` in its search algorithm.
use core::ops::Range;
use core::ptr::NonNull;
use mino::matrix::{Mat, MatBuf};
use mino::srs::{Piece, PieceType, Queue};
use mino::srs::{Piece, Queue};
use crate::bot::raw::{RawMatPtr, RawQueue};
use crate::bot::trans::TransTable;
use crate::eval::Rating;
use crate::find::find_locations;
@ -13,7 +14,7 @@ use crate::Arena;
/// Represents a node in the search tree. A node basically just consists of a board state
/// (incl. queue) and some extra metadata relating it to previous nodes in the tree.
pub struct Node {
matrix: *const Mat,
matrix: RawMatPtr,
queue: RawQueue,
edge: Option<Edge>,
rating: Rating,
@ -22,12 +23,12 @@ pub struct Node {
}
// Reallocates the matrix into the arena.
fn copy_matrix<'a>(arena: &'a Arena, matrix: &Mat) -> &'a Mat {
fn alloc_matrix_copy<'a>(arena: &'a Arena, matrix: &Mat) -> &'a Mat {
Mat::new(arena.alloc_slice_copy(&matrix[..]))
}
// Reallocates the queue into the arena.
fn copy_queue<'a>(arena: &'a Arena, queue: Queue<'_>) -> Queue<'a> {
fn alloc_queue_copy<'a>(arena: &'a Arena, queue: Queue<'_>) -> Queue<'a> {
Queue {
hold: queue.hold,
next: arena.alloc_slice_copy(queue.next),
@ -39,8 +40,8 @@ impl Node {
/// matrix and queue are also allocated onto the arena, so you do not need to worry
/// about their lifetimes when managing the lifetime of the root.
pub fn alloc_root<'a>(arena: &'a Arena, matrix: &Mat, queue: Queue<'_>) -> &'a Self {
let matrix = copy_matrix(arena, matrix);
let queue = copy_queue(arena, queue);
let matrix = alloc_matrix_copy(arena, matrix);
let queue = alloc_queue_copy(arena, queue);
Node::alloc(arena, matrix, queue, Rating::default(), None)
}
@ -52,8 +53,8 @@ impl Node {
rating: Rating,
edge: Option<Edge>,
) -> &'a Self {
let matrix = matrix as *const Mat;
let queue = RawQueue::from(queue);
let matrix = matrix.into();
let queue = queue.into();
arena.alloc_with(|| Self {
matrix,
queue,
@ -63,11 +64,11 @@ impl Node {
}
pub fn matrix(&self) -> &Mat {
unsafe { &*self.matrix }
unsafe { self.matrix.get() }
}
pub fn queue(&self) -> Queue<'_> {
unsafe { self.queue.as_queue() }
unsafe { self.queue.get() }
}
pub fn rating(&self) -> Rating {
@ -110,6 +111,7 @@ impl Node {
locs.map(move |loc| Piece { ty, loc })
});
let parent = RawNodePtr::from(self);
let children = placements.map(move |placement| {
// compute new board state from placement
matrix.copy_from(self.matrix());
@ -121,14 +123,10 @@ impl Node {
trans.get_or_insert(&matrix, queue, || {
Node::alloc(
arena,
copy_matrix(arena, &matrix),
// `queue` is already allocated on the arena so don't need to copy it
alloc_matrix_copy(arena, &matrix),
queue,
evaluate(&matrix, queue, &cleared),
Some(Edge {
placement,
parent: self.into(),
}),
Some(Edge { placement, parent }),
)
})
});
@ -167,56 +165,30 @@ impl core::fmt::Debug for Node {
/// Represents an edge in the graph, pointing from a node to its parent. Particularly,
/// contains the placement made in order to arrive at the child from the parent.
struct Edge {
placement: Piece,
parent: RawNodePtr,
placement: Piece,
}
impl Edge {
fn parent(&self) -> &Node {
unsafe { self.parent.as_node() }
unsafe { self.parent.get() }
}
}
/// Wraps a raw pointer to a `Node`, requiring you to manage the lifetime yourself.
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
pub struct RawNodePtr(*const Node);
pub struct RawNodePtr(NonNull<Node>);
impl RawNodePtr {
pub unsafe fn as_node<'a>(self) -> &'a Node {
&*self.0
pub unsafe fn get<'a>(self) -> &'a Node {
self.0.as_ref()
}
}
impl From<&Node> for RawNodePtr {
fn from(node: &Node) -> Self {
Self(node)
}
}
/// Wraps the raw components of a `Queue`, requiring you to manage the lifetime yourself.
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
struct RawQueue {
hold: Option<PieceType>,
len: u16, // u16 to save space esp. considering padding
next: *const PieceType,
}
impl RawQueue {
pub unsafe fn as_queue<'a>(self) -> Queue<'a> {
let hold = self.hold;
let next = core::slice::from_raw_parts(self.next, self.len as usize);
Queue { hold, next }
}
}
impl From<Queue<'_>> for RawQueue {
fn from(queue: Queue<'_>) -> Self {
Self {
hold: queue.hold,
len: queue.next.len() as u16,
next: queue.next.as_ptr(),
}
Self(NonNull::from(node))
}
}
@ -224,6 +196,15 @@ impl From<Queue<'_>> for RawQueue {
mod test {
use super::*;
use mino::mat;
use mino::srs::PieceType;
#[test]
fn test_size_of() {
use core::mem::size_of;
assert_eq!(size_of::<RawNodePtr>(), size_of::<&Node>());
assert_eq!(size_of::<RawNodePtr>(), size_of::<Option<RawNodePtr>>());
assert_eq!(size_of::<Edge>(), size_of::<Option<Edge>>());
}
#[test]
fn test_copy_matrix() {
@ -232,7 +213,7 @@ mod test {
"..xxx..x.x";
"xxxxxx.xxx";
};
let mat1 = copy_matrix(&arena, mat0);
let mat1 = alloc_matrix_copy(&arena, mat0);
assert_eq!(mat0, mat1);
}
@ -241,7 +222,7 @@ mod test {
use PieceType::*;
let arena = Arena::new();
let q0 = Queue::new(None, &[I, L, J, O]);
let q1 = copy_queue(&arena, q0);
let q1 = alloc_queue_copy(&arena, q0);
assert_eq!(q0, q1);
}
@ -258,12 +239,12 @@ mod test {
use PieceType::*;
let q0 = Queue::new(None, &[I, L, J, O]);
let rq0 = RawQueue::from(q0);
let q1 = unsafe { rq0.as_queue() };
let q1 = unsafe { rq0.get() };
assert_eq!(q1, q0);
let q0 = Queue::new(None, &[]);
let rq0 = RawQueue::from(q0);
let q1 = unsafe { rq0.as_queue() };
let q1 = unsafe { rq0.get() };
assert_eq!(q1, q0);
}
}

79
fish/src/bot/raw.rs Normal file
View File

@ -0,0 +1,79 @@
use core::ptr::NonNull;
use mino::matrix::Mat;
use mino::srs::{PieceType, Queue};
/// Wraps the raw components of a `Queue`, requiring you to manage the lifetime yourself.
#[derive(Copy, Clone, Debug)]
pub struct RawQueue {
hold: Option<PieceType>,
len: u16, // u16 to save space esp. considering padding
next: *const PieceType,
}
impl RawQueue {
pub unsafe fn get<'a>(self) -> Queue<'a> {
let hold = self.hold;
let next = core::slice::from_raw_parts(self.next, self.len as usize);
Queue { hold, next }
}
}
impl From<Queue<'_>> for RawQueue {
fn from(queue: Queue<'_>) -> Self {
Self {
hold: queue.hold,
len: queue.next.len() as u16,
next: queue.next.as_ptr(),
}
}
}
/// Wraps a raw pointer to a `Mat`, requiring you to manage the lifetime yourself.
#[derive(Copy, Clone, Debug)]
pub struct RawMatPtr(NonNull<Mat>);
impl RawMatPtr {
pub unsafe fn get(&self) -> &Mat {
self.0.as_ref()
}
}
impl From<&Mat> for RawMatPtr {
fn from(mat: &Mat) -> Self {
Self(NonNull::from(mat))
}
}
impl hashbrown::Equivalent<RawMatPtr> for &Mat {
fn equivalent(&self, key: &RawMatPtr) -> bool {
RawMatPtr::from(*self) == *key
}
}
impl core::cmp::PartialEq for RawMatPtr {
fn eq(&self, other: &Self) -> bool {
*unsafe { self.get() } == *unsafe { other.get() }
}
}
impl core::cmp::Eq for RawMatPtr {}
impl core::hash::Hash for RawMatPtr {
fn hash<H: core::hash::Hasher>(&self, h: &mut H) {
unsafe { self.get() }.hash(h)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_size_of() {
use core::mem::size_of;
assert!(size_of::<RawQueue>() <= size_of::<Queue<'_>>());
assert_eq!(size_of::<RawQueue>(), size_of::<Option<RawQueue>>());
assert_eq!(size_of::<RawMatPtr>(), size_of::<&Mat>());
assert_eq!(size_of::<RawMatPtr>(), size_of::<Option<RawMatPtr>>());
}
}

View File

@ -6,6 +6,7 @@ use mino::srs::{PieceType, Queue};
use mino::Mat;
use crate::bot::node::{Node, RawNodePtr};
use crate::bot::raw::RawMatPtr;
use crate::HashMap;
pub struct TransTable {
@ -30,18 +31,12 @@ impl TransTable {
mut alloc: impl FnMut() -> &'a Node,
) -> Result<&'a Node, RawNodePtr> {
// two-phase lookup: first key off queue info, then matrix info
let map = self
.lookup
.entry(QueueKey::new(queue))
.or_insert_with(HashMap::new);
// SAFETY: IMPORTANT! this ptr is ONLY used for lookup, NOT insert. it is not
// necessarily long lived, we are only permitted to insert matrix coming from
// `alloc()`
let mat_ptr = unsafe { RawMatPtr::new(matrix) };
if let Some(&node_ptr) = map.get(&mat_ptr) {
if let Some(&node_ptr) = map.get(&matrix) {
return Err(node_ptr);
}
@ -49,11 +44,11 @@ impl TransTable {
debug_assert_eq!(*node.matrix(), *matrix);
debug_assert_eq!(node.queue(), queue);
// SAFETY: as noted above, this matrix ptr is OK to use for insert, since it comes
// from the newly allocated node.
let mat_ptr = unsafe { RawMatPtr::new(node.matrix()) };
// IMPORTANT: it would be invalid to insert `matrix.into()` since we don't have
// guaruntees about its lifetime. we have to use a pointer to the matrix in the
// newly allocated `node`.
map.insert(node.matrix().into(), node.into());
map.insert(mat_ptr, RawNodePtr::from(node));
Ok(node)
}
}
@ -69,34 +64,3 @@ impl QueueKey {
Self(queue.next.len() as u16, queue.hold)
}
}
// share matrix data with the nodes themselves, since the nodes will outlive the
// transposition table.
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
struct RawMatPtr(*const Mat);
impl RawMatPtr {
// SAFETY: caller must guaruntee that `mat` lives longer than this value.
unsafe fn new(mat: &Mat) -> Self {
Self(mat)
}
fn as_mat(&self) -> &Mat {
unsafe { &*self.0 }
}
}
impl core::cmp::PartialEq for RawMatPtr {
fn eq(&self, other: &Self) -> bool {
*self.as_mat() == *other.as_mat()
}
}
impl core::cmp::Eq for RawMatPtr {}
impl core::hash::Hash for RawMatPtr {
fn hash<H: core::hash::Hasher>(&self, h: &mut H) {
self.as_mat().hash(h)
}
}