From c0a06ecd005bf4928869c9eec67f6966dc70ec8c Mon Sep 17 00:00:00 2001 From: milo Date: Thu, 22 Feb 2024 18:53:37 -0500 Subject: [PATCH] create dedicated Rating type to replace awkward (bool,i32) pair --- fish/src/bot.rs | 42 ++++++++++++++++++------------------------ fish/src/bot/node.rs | 24 ++++++++++++++---------- fish/src/eval.rs | 38 ++++++++++++++++++++++++++++++++++++++ fish/src/lib.rs | 1 + tidepool/src/output.rs | 8 +++++--- 5 files changed, 76 insertions(+), 37 deletions(-) diff --git a/fish/src/bot.rs b/fish/src/bot.rs index 61f8d2b..784bb64 100644 --- a/fish/src/bot.rs +++ b/fish/src/bot.rs @@ -14,9 +14,8 @@ mod trans; use crate::bot::node::{Node, RawNodePtr}; use crate::bot::trans::TransTable; -use crate::eval::{features, Features, Weights}; - -use bumpalo::Bump as Arena; +use crate::eval::{features, Features, Rating, Weights}; +use crate::Arena; /// Encompasses an instance of the algorithm. pub struct Bot { @@ -35,7 +34,7 @@ pub struct Metrics { pub start_heuristic: i32, pub end_features: Features, pub end_heuristic: i32, - pub end_rating: (bool, i32), + pub end_rating: Rating, pub end_iteration: u32, // TODO(?) memory usage metrics // TODO(?) transposition table metrics @@ -127,7 +126,7 @@ impl Bot { struct Evaluator { weights: Weights, - root_score: i32, + root_heuristic: i32, root_queue_len: usize, } @@ -135,30 +134,26 @@ impl Evaluator { fn new(weights: &Weights, root: &Node) -> Self { Self { weights: *weights, - root_score: features(root.matrix(), 0).evaluate(weights), + root_heuristic: features(root.matrix(), 0).evaluate(weights), root_queue_len: root.queue().len(), } } - fn evaluate(&self, mat: &Mat, queue: Queue<'_>, cleared: &Range) -> (bool, i32) { - let pcnt = self.root_queue_len.saturating_sub(queue.len()); + fn evaluate(&self, mat: &Mat, queue: Queue<'_>, cleared: &Range) -> Rating { + debug_assert!(queue.len() < self.root_queue_len); + let pcnt = self.root_queue_len - queue.len(); - if self.greed() && cleared.contains(&0) { - // cleared the bottom row of the matrix, which must be the last line of cheese - // in the race. piece count is negated so that less pieces is better (larger - // value). - return (true, -(pcnt as i32)); + // check if we cleared the bottom row of the matrix, which is assumed to be the + // last line of cheese in the race. if so, then consider this node to be a solve + // and use the piece count as its rating. + // + // TODO: make this condition configurable since its a bit of a hack + if cleared.contains(&0) { + return Rating::Solve(pcnt as u32); } - let score = features(mat, pcnt).evaluate(&self.weights); - - // larger (further below the root score) is better - (false, self.root_score - score) - } - - fn greed(&self) -> bool { - // TODO: make this parameter configurable on `Bot` initialization - true + let heuristic = features(mat, pcnt).evaluate(&self.weights); + Rating::Score(self.root_heuristic - heuristic) } } @@ -269,8 +264,7 @@ impl SegmentedAStar { let cand = open_set.pop().ok_or(None)?; let cand = unsafe { cand.0.as_node() }; - if cand.queue().is_empty() || cand.rating().0 { - // terminal node; end search and back up its rating + if cand.is_terminal() { return Err(Some(cand)); } diff --git a/fish/src/bot/node.rs b/fish/src/bot/node.rs index d6d764c..9b19a21 100644 --- a/fish/src/bot/node.rs +++ b/fish/src/bot/node.rs @@ -6,8 +6,9 @@ use mino::matrix::{Mat, MatBuf}; use mino::srs::{Piece, PieceType, Queue}; use crate::bot::trans::TransTable; -use crate::bot::Arena; +use crate::eval::Rating; use crate::find::find_locations; +use crate::Arena; /// Represents a node in the search tree. A node basically just consists of a board state /// (incl. queue) and some extra metadata relating it to previous nodes in the tree. @@ -15,7 +16,7 @@ pub struct Node { matrix: *const Mat, queue: RawQueue, edge: Option, - rating: (bool, i32), + rating: Rating, // currently there is no need to store a node's children, but maybe this could change // in the future. } @@ -40,8 +41,7 @@ impl Node { pub fn alloc_root<'a>(arena: &'a Arena, matrix: &Mat, queue: Queue<'_>) -> &'a Self { let matrix = copy_matrix(arena, matrix); let queue = copy_queue(arena, queue); - let rating = (false, i32::MIN); - Node::alloc(arena, matrix, queue, rating, None) + Node::alloc(arena, matrix, queue, Rating::default(), None) } // `matrix` and `queue` must be allocated inside `arena` @@ -49,7 +49,7 @@ impl Node { arena: &'a Arena, matrix: &'a Mat, queue: Queue<'a>, - rating: (bool, i32), + rating: Rating, edge: Option, ) -> &'a Self { let matrix = matrix as *const Mat; @@ -70,10 +70,14 @@ impl Node { unsafe { self.queue.as_queue() } } - pub fn rating(&self) -> (bool, i32) { + pub fn rating(&self) -> Rating { self.rating } + pub fn is_terminal(&self) -> bool { + matches!(self.rating, Rating::Solve(_)) || self.queue().is_empty() + } + /// Get the initial placement made after the root node which eventually arrives at /// this node. pub fn root_placement(&self) -> Option { @@ -97,7 +101,7 @@ impl Node { mut evaluate: E, ) -> impl Iterator + 'a where - E: FnMut(&Mat, Queue<'_>, &Range) -> (bool, i32) + 'a, + E: FnMut(&Mat, Queue<'_>, &Range) -> Rating + 'a, { let mut matrix = MatBuf::new(); @@ -147,9 +151,9 @@ impl core::fmt::Debug for Node { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "Node {{ ")?; match self.rating { - (false, h) => write!(f, "rating: {}", h), - (true, n) => write!(f, "solution: {}", -n), - }?; + Rating::Score(v) => write!(f, "score: {}", v)?, + Rating::Solve(n) => write!(f, "solve: {}", n)?, + } if let Some(pc) = self.root_placement() { write!(f, ", root_placement: {:?}", pc)?; } diff --git a/fish/src/eval.rs b/fish/src/eval.rs index 7e13cc6..e6f3b60 100644 --- a/fish/src/eval.rs +++ b/fish/src/eval.rs @@ -65,3 +65,41 @@ impl Features { score } } + +/// Solutions can have two types of ratings depending on if they are considered to be a +/// "solve", which means they are rated by how few pieces are required, or they are not a +/// solve, so they are rated by their "score", which is measured as the difference between +/// the evaluation of a terminal node and the evaluation of the root. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum Rating { + Score(i32), + Solve(u32), +} + +impl core::cmp::Ord for Rating { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + use self::Rating::*; + use core::cmp::Ordering::*; + + match (self, other) { + (Score(x), Score(y)) => x.cmp(y), // greater (difference from root) is better + (Solve(x), Solve(y)) => y.cmp(x), // less (piece count) is better + (Solve(_), _) => Greater, // solve always better than non-solve + (_, Solve(_)) => Less, + } + } +} + +impl core::cmp::PartialOrd for Rating { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl core::default::Default for Rating { + /// The default rating is the worst possible rating which is ranked worse than any + /// other rating besides itself. + fn default() -> Self { + Rating::Score(i32::MIN) + } +} diff --git a/fish/src/lib.rs b/fish/src/lib.rs index 8d7310b..97273c7 100644 --- a/fish/src/lib.rs +++ b/fish/src/lib.rs @@ -9,3 +9,4 @@ pub mod find; type HashMap = hashbrown::HashMap; type HashSet = hashbrown::HashSet; type HashBuilder = core::hash::BuildHasherDefault; +type Arena = bumpalo::Bump; diff --git a/tidepool/src/output.rs b/tidepool/src/output.rs index 289d946..45f5076 100644 --- a/tidepool/src/output.rs +++ b/tidepool/src/output.rs @@ -1,4 +1,5 @@ use fish::bot::Metrics; +use fish::eval::Rating; use mino::srs::Piece; use serde::Serialize; @@ -184,8 +185,9 @@ mod ser { #[serde(untagged)] enum RatingVariant { NotApplicable, + // confusing: 'rating' is called 'score' actually Rating { rating: i32 }, - Solution { solution: i32 }, + Solution { solution: u32 }, } metrics @@ -202,8 +204,8 @@ mod ser { features: m.end_features.1.to_vec(), heuristic: m.end_heuristic, rating: match m.end_rating { - (false, v) => RatingVariant::Rating { rating: v }, - (true, v) => RatingVariant::Solution { solution: -v }, + Rating::Score(v) => RatingVariant::Rating { rating: v }, + Rating::Solve(n) => RatingVariant::Solution { solution: n }, }, iteration: Some(m.end_iteration), },