create dedicated Rating type to replace awkward (bool,i32) pair

This commit is contained in:
milo 2024-02-22 18:53:37 -05:00
parent a6f893e448
commit c0a06ecd00
5 changed files with 76 additions and 37 deletions

View File

@ -14,9 +14,8 @@ mod trans;
use crate::bot::node::{Node, RawNodePtr}; use crate::bot::node::{Node, RawNodePtr};
use crate::bot::trans::TransTable; use crate::bot::trans::TransTable;
use crate::eval::{features, Features, Weights}; use crate::eval::{features, Features, Rating, Weights};
use crate::Arena;
use bumpalo::Bump as Arena;
/// Encompasses an instance of the algorithm. /// Encompasses an instance of the algorithm.
pub struct Bot { pub struct Bot {
@ -35,7 +34,7 @@ pub struct Metrics {
pub start_heuristic: i32, pub start_heuristic: i32,
pub end_features: Features, pub end_features: Features,
pub end_heuristic: i32, pub end_heuristic: i32,
pub end_rating: (bool, i32), pub end_rating: Rating,
pub end_iteration: u32, pub end_iteration: u32,
// TODO(?) memory usage metrics // TODO(?) memory usage metrics
// TODO(?) transposition table metrics // TODO(?) transposition table metrics
@ -127,7 +126,7 @@ impl Bot {
struct Evaluator { struct Evaluator {
weights: Weights, weights: Weights,
root_score: i32, root_heuristic: i32,
root_queue_len: usize, root_queue_len: usize,
} }
@ -135,30 +134,26 @@ impl Evaluator {
fn new(weights: &Weights, root: &Node) -> Self { fn new(weights: &Weights, root: &Node) -> Self {
Self { Self {
weights: *weights, weights: *weights,
root_score: features(root.matrix(), 0).evaluate(weights), root_heuristic: features(root.matrix(), 0).evaluate(weights),
root_queue_len: root.queue().len(), root_queue_len: root.queue().len(),
} }
} }
fn evaluate(&self, mat: &Mat, queue: Queue<'_>, cleared: &Range<i16>) -> (bool, i32) { fn evaluate(&self, mat: &Mat, queue: Queue<'_>, cleared: &Range<i16>) -> Rating {
let pcnt = self.root_queue_len.saturating_sub(queue.len()); debug_assert!(queue.len() < self.root_queue_len);
let pcnt = self.root_queue_len - queue.len();
if self.greed() && cleared.contains(&0) { // check if we cleared the bottom row of the matrix, which is assumed to be the
// cleared the bottom row of the matrix, which must be the last line of cheese // last line of cheese in the race. if so, then consider this node to be a solve
// in the race. piece count is negated so that less pieces is better (larger // and use the piece count as its rating.
// value). //
return (true, -(pcnt as i32)); // TODO: make this condition configurable since its a bit of a hack
if cleared.contains(&0) {
return Rating::Solve(pcnt as u32);
} }
let score = features(mat, pcnt).evaluate(&self.weights); let heuristic = features(mat, pcnt).evaluate(&self.weights);
Rating::Score(self.root_heuristic - heuristic)
// larger (further below the root score) is better
(false, self.root_score - score)
}
fn greed(&self) -> bool {
// TODO: make this parameter configurable on `Bot` initialization
true
} }
} }
@ -269,8 +264,7 @@ impl SegmentedAStar {
let cand = open_set.pop().ok_or(None)?; let cand = open_set.pop().ok_or(None)?;
let cand = unsafe { cand.0.as_node() }; let cand = unsafe { cand.0.as_node() };
if cand.queue().is_empty() || cand.rating().0 { if cand.is_terminal() {
// terminal node; end search and back up its rating
return Err(Some(cand)); return Err(Some(cand));
} }

View File

@ -6,8 +6,9 @@ use mino::matrix::{Mat, MatBuf};
use mino::srs::{Piece, PieceType, Queue}; use mino::srs::{Piece, PieceType, Queue};
use crate::bot::trans::TransTable; use crate::bot::trans::TransTable;
use crate::bot::Arena; use crate::eval::Rating;
use crate::find::find_locations; use crate::find::find_locations;
use crate::Arena;
/// Represents a node in the search tree. A node basically just consists of a board state /// Represents a node in the search tree. A node basically just consists of a board state
/// (incl. queue) and some extra metadata relating it to previous nodes in the tree. /// (incl. queue) and some extra metadata relating it to previous nodes in the tree.
@ -15,7 +16,7 @@ pub struct Node {
matrix: *const Mat, matrix: *const Mat,
queue: RawQueue, queue: RawQueue,
edge: Option<Edge>, edge: Option<Edge>,
rating: (bool, i32), rating: Rating,
// currently there is no need to store a node's children, but maybe this could change // currently there is no need to store a node's children, but maybe this could change
// in the future. // in the future.
} }
@ -40,8 +41,7 @@ impl Node {
pub fn alloc_root<'a>(arena: &'a Arena, matrix: &Mat, queue: Queue<'_>) -> &'a Self { pub fn alloc_root<'a>(arena: &'a Arena, matrix: &Mat, queue: Queue<'_>) -> &'a Self {
let matrix = copy_matrix(arena, matrix); let matrix = copy_matrix(arena, matrix);
let queue = copy_queue(arena, queue); let queue = copy_queue(arena, queue);
let rating = (false, i32::MIN); Node::alloc(arena, matrix, queue, Rating::default(), None)
Node::alloc(arena, matrix, queue, rating, None)
} }
// `matrix` and `queue` must be allocated inside `arena` // `matrix` and `queue` must be allocated inside `arena`
@ -49,7 +49,7 @@ impl Node {
arena: &'a Arena, arena: &'a Arena,
matrix: &'a Mat, matrix: &'a Mat,
queue: Queue<'a>, queue: Queue<'a>,
rating: (bool, i32), rating: Rating,
edge: Option<Edge>, edge: Option<Edge>,
) -> &'a Self { ) -> &'a Self {
let matrix = matrix as *const Mat; let matrix = matrix as *const Mat;
@ -70,10 +70,14 @@ impl Node {
unsafe { self.queue.as_queue() } unsafe { self.queue.as_queue() }
} }
pub fn rating(&self) -> (bool, i32) { pub fn rating(&self) -> Rating {
self.rating self.rating
} }
pub fn is_terminal(&self) -> bool {
matches!(self.rating, Rating::Solve(_)) || self.queue().is_empty()
}
/// Get the initial placement made after the root node which eventually arrives at /// Get the initial placement made after the root node which eventually arrives at
/// this node. /// this node.
pub fn root_placement(&self) -> Option<Piece> { pub fn root_placement(&self) -> Option<Piece> {
@ -97,7 +101,7 @@ impl Node {
mut evaluate: E, mut evaluate: E,
) -> impl Iterator<Item = &'a Node> + 'a ) -> impl Iterator<Item = &'a Node> + 'a
where where
E: FnMut(&Mat, Queue<'_>, &Range<i16>) -> (bool, i32) + 'a, E: FnMut(&Mat, Queue<'_>, &Range<i16>) -> Rating + 'a,
{ {
let mut matrix = MatBuf::new(); let mut matrix = MatBuf::new();
@ -147,9 +151,9 @@ impl core::fmt::Debug for Node {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Node {{ ")?; write!(f, "Node {{ ")?;
match self.rating { match self.rating {
(false, h) => write!(f, "rating: {}", h), Rating::Score(v) => write!(f, "score: {}", v)?,
(true, n) => write!(f, "solution: {}", -n), Rating::Solve(n) => write!(f, "solve: {}", n)?,
}?; }
if let Some(pc) = self.root_placement() { if let Some(pc) = self.root_placement() {
write!(f, ", root_placement: {:?}", pc)?; write!(f, ", root_placement: {:?}", pc)?;
} }

View File

@ -65,3 +65,41 @@ impl Features {
score score
} }
} }
/// Solutions can have two types of ratings depending on if they are considered to be a
/// "solve", which means they are rated by how few pieces are required, or they are not a
/// solve, so they are rated by their "score", which is measured as the difference between
/// the evaluation of a terminal node and the evaluation of the root.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Rating {
Score(i32),
Solve(u32),
}
impl core::cmp::Ord for Rating {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
use self::Rating::*;
use core::cmp::Ordering::*;
match (self, other) {
(Score(x), Score(y)) => x.cmp(y), // greater (difference from root) is better
(Solve(x), Solve(y)) => y.cmp(x), // less (piece count) is better
(Solve(_), _) => Greater, // solve always better than non-solve
(_, Solve(_)) => Less,
}
}
}
impl core::cmp::PartialOrd for Rating {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl core::default::Default for Rating {
/// The default rating is the worst possible rating which is ranked worse than any
/// other rating besides itself.
fn default() -> Self {
Rating::Score(i32::MIN)
}
}

View File

@ -9,3 +9,4 @@ pub mod find;
type HashMap<K, V> = hashbrown::HashMap<K, V, HashBuilder>; type HashMap<K, V> = hashbrown::HashMap<K, V, HashBuilder>;
type HashSet<T> = hashbrown::HashSet<T, HashBuilder>; type HashSet<T> = hashbrown::HashSet<T, HashBuilder>;
type HashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>; type HashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>;
type Arena = bumpalo::Bump;

View File

@ -1,4 +1,5 @@
use fish::bot::Metrics; use fish::bot::Metrics;
use fish::eval::Rating;
use mino::srs::Piece; use mino::srs::Piece;
use serde::Serialize; use serde::Serialize;
@ -184,8 +185,9 @@ mod ser {
#[serde(untagged)] #[serde(untagged)]
enum RatingVariant { enum RatingVariant {
NotApplicable, NotApplicable,
// confusing: 'rating' is called 'score' actually
Rating { rating: i32 }, Rating { rating: i32 },
Solution { solution: i32 }, Solution { solution: u32 },
} }
metrics metrics
@ -202,8 +204,8 @@ mod ser {
features: m.end_features.1.to_vec(), features: m.end_features.1.to_vec(),
heuristic: m.end_heuristic, heuristic: m.end_heuristic,
rating: match m.end_rating { rating: match m.end_rating {
(false, v) => RatingVariant::Rating { rating: v }, Rating::Score(v) => RatingVariant::Rating { rating: v },
(true, v) => RatingVariant::Solution { solution: -v }, Rating::Solve(n) => RatingVariant::Solution { solution: n },
}, },
iteration: Some(m.end_iteration), iteration: Some(m.end_iteration),
}, },