create dedicated Rating type to replace awkward (bool,i32) pair

This commit is contained in:
milo 2024-02-22 18:53:37 -05:00
parent a6f893e448
commit c0a06ecd00
5 changed files with 76 additions and 37 deletions

View File

@ -14,9 +14,8 @@ mod trans;
use crate::bot::node::{Node, RawNodePtr};
use crate::bot::trans::TransTable;
use crate::eval::{features, Features, Weights};
use bumpalo::Bump as Arena;
use crate::eval::{features, Features, Rating, Weights};
use crate::Arena;
/// Encompasses an instance of the algorithm.
pub struct Bot {
@ -35,7 +34,7 @@ pub struct Metrics {
pub start_heuristic: i32,
pub end_features: Features,
pub end_heuristic: i32,
pub end_rating: (bool, i32),
pub end_rating: Rating,
pub end_iteration: u32,
// TODO(?) memory usage metrics
// TODO(?) transposition table metrics
@ -127,7 +126,7 @@ impl Bot {
struct Evaluator {
weights: Weights,
root_score: i32,
root_heuristic: i32,
root_queue_len: usize,
}
@ -135,30 +134,26 @@ impl Evaluator {
fn new(weights: &Weights, root: &Node) -> Self {
Self {
weights: *weights,
root_score: features(root.matrix(), 0).evaluate(weights),
root_heuristic: features(root.matrix(), 0).evaluate(weights),
root_queue_len: root.queue().len(),
}
}
fn evaluate(&self, mat: &Mat, queue: Queue<'_>, cleared: &Range<i16>) -> (bool, i32) {
let pcnt = self.root_queue_len.saturating_sub(queue.len());
fn evaluate(&self, mat: &Mat, queue: Queue<'_>, cleared: &Range<i16>) -> Rating {
debug_assert!(queue.len() < self.root_queue_len);
let pcnt = self.root_queue_len - queue.len();
if self.greed() && cleared.contains(&0) {
// cleared the bottom row of the matrix, which must be the last line of cheese
// in the race. piece count is negated so that less pieces is better (larger
// value).
return (true, -(pcnt as i32));
// check if we cleared the bottom row of the matrix, which is assumed to be the
// last line of cheese in the race. if so, then consider this node to be a solve
// and use the piece count as its rating.
//
// TODO: make this condition configurable since its a bit of a hack
if cleared.contains(&0) {
return Rating::Solve(pcnt as u32);
}
let score = features(mat, pcnt).evaluate(&self.weights);
// larger (further below the root score) is better
(false, self.root_score - score)
}
fn greed(&self) -> bool {
// TODO: make this parameter configurable on `Bot` initialization
true
let heuristic = features(mat, pcnt).evaluate(&self.weights);
Rating::Score(self.root_heuristic - heuristic)
}
}
@ -269,8 +264,7 @@ impl SegmentedAStar {
let cand = open_set.pop().ok_or(None)?;
let cand = unsafe { cand.0.as_node() };
if cand.queue().is_empty() || cand.rating().0 {
// terminal node; end search and back up its rating
if cand.is_terminal() {
return Err(Some(cand));
}

View File

@ -6,8 +6,9 @@ use mino::matrix::{Mat, MatBuf};
use mino::srs::{Piece, PieceType, Queue};
use crate::bot::trans::TransTable;
use crate::bot::Arena;
use crate::eval::Rating;
use crate::find::find_locations;
use crate::Arena;
/// Represents a node in the search tree. A node basically just consists of a board state
/// (incl. queue) and some extra metadata relating it to previous nodes in the tree.
@ -15,7 +16,7 @@ pub struct Node {
matrix: *const Mat,
queue: RawQueue,
edge: Option<Edge>,
rating: (bool, i32),
rating: Rating,
// currently there is no need to store a node's children, but maybe this could change
// in the future.
}
@ -40,8 +41,7 @@ impl Node {
pub fn alloc_root<'a>(arena: &'a Arena, matrix: &Mat, queue: Queue<'_>) -> &'a Self {
let matrix = copy_matrix(arena, matrix);
let queue = copy_queue(arena, queue);
let rating = (false, i32::MIN);
Node::alloc(arena, matrix, queue, rating, None)
Node::alloc(arena, matrix, queue, Rating::default(), None)
}
// `matrix` and `queue` must be allocated inside `arena`
@ -49,7 +49,7 @@ impl Node {
arena: &'a Arena,
matrix: &'a Mat,
queue: Queue<'a>,
rating: (bool, i32),
rating: Rating,
edge: Option<Edge>,
) -> &'a Self {
let matrix = matrix as *const Mat;
@ -70,10 +70,14 @@ impl Node {
unsafe { self.queue.as_queue() }
}
pub fn rating(&self) -> (bool, i32) {
pub fn rating(&self) -> Rating {
self.rating
}
pub fn is_terminal(&self) -> bool {
matches!(self.rating, Rating::Solve(_)) || self.queue().is_empty()
}
/// Get the initial placement made after the root node which eventually arrives at
/// this node.
pub fn root_placement(&self) -> Option<Piece> {
@ -97,7 +101,7 @@ impl Node {
mut evaluate: E,
) -> impl Iterator<Item = &'a Node> + 'a
where
E: FnMut(&Mat, Queue<'_>, &Range<i16>) -> (bool, i32) + 'a,
E: FnMut(&Mat, Queue<'_>, &Range<i16>) -> Rating + 'a,
{
let mut matrix = MatBuf::new();
@ -147,9 +151,9 @@ impl core::fmt::Debug for Node {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Node {{ ")?;
match self.rating {
(false, h) => write!(f, "rating: {}", h),
(true, n) => write!(f, "solution: {}", -n),
}?;
Rating::Score(v) => write!(f, "score: {}", v)?,
Rating::Solve(n) => write!(f, "solve: {}", n)?,
}
if let Some(pc) = self.root_placement() {
write!(f, ", root_placement: {:?}", pc)?;
}

View File

@ -65,3 +65,41 @@ impl Features {
score
}
}
/// Solutions can have two types of ratings depending on if they are considered to be a
/// "solve", which means they are rated by how few pieces are required, or they are not a
/// solve, so they are rated by their "score", which is measured as the difference between
/// the evaluation of a terminal node and the evaluation of the root.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Rating {
Score(i32),
Solve(u32),
}
impl core::cmp::Ord for Rating {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
use self::Rating::*;
use core::cmp::Ordering::*;
match (self, other) {
(Score(x), Score(y)) => x.cmp(y), // greater (difference from root) is better
(Solve(x), Solve(y)) => y.cmp(x), // less (piece count) is better
(Solve(_), _) => Greater, // solve always better than non-solve
(_, Solve(_)) => Less,
}
}
}
impl core::cmp::PartialOrd for Rating {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl core::default::Default for Rating {
/// The default rating is the worst possible rating which is ranked worse than any
/// other rating besides itself.
fn default() -> Self {
Rating::Score(i32::MIN)
}
}

View File

@ -9,3 +9,4 @@ pub mod find;
type HashMap<K, V> = hashbrown::HashMap<K, V, HashBuilder>;
type HashSet<T> = hashbrown::HashSet<T, HashBuilder>;
type HashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>;
type Arena = bumpalo::Bump;

View File

@ -1,4 +1,5 @@
use fish::bot::Metrics;
use fish::eval::Rating;
use mino::srs::Piece;
use serde::Serialize;
@ -184,8 +185,9 @@ mod ser {
#[serde(untagged)]
enum RatingVariant {
NotApplicable,
// confusing: 'rating' is called 'score' actually
Rating { rating: i32 },
Solution { solution: i32 },
Solution { solution: u32 },
}
metrics
@ -202,8 +204,8 @@ mod ser {
features: m.end_features.1.to_vec(),
heuristic: m.end_heuristic,
rating: match m.end_rating {
(false, v) => RatingVariant::Rating { rating: v },
(true, v) => RatingVariant::Solution { solution: -v },
Rating::Score(v) => RatingVariant::Rating { rating: v },
Rating::Solve(n) => RatingVariant::Solution { solution: n },
},
iteration: Some(m.end_iteration),
},