diff --git a/fish/src/bot.rs b/fish/src/bot.rs index 81f1704..c1bb3aa 100644 --- a/fish/src/bot.rs +++ b/fish/src/bot.rs @@ -10,7 +10,7 @@ use mino::srs::{Piece, Queue}; mod node; use self::node::{Node, RawNodePtr}; -use crate::eval::evaluate; +use crate::eval::{features, Weights}; pub(crate) use bumpalo::Bump as Arena; @@ -25,10 +25,10 @@ pub struct Bot { impl Bot { /// Constructs a new bot from the given initial state (matrix and queue). // TODO: specify weights - pub fn new(matrix: &Mat, queue: Queue<'_>) -> Self { + pub fn new(weights: &Weights, matrix: &Mat, queue: Queue<'_>) -> Self { let arena = bumpalo::Bump::new(); let root = Node::alloc_root(&arena, matrix, queue); - let evaluator = Evaluator::new(root); + let evaluator = Evaluator::new(weights, root); let algorithm = SegmentedAStar::new(root); Self { evaluator, @@ -56,21 +56,21 @@ impl Bot { struct Evaluator { // TODO: weights + weights: Weights, root_score: i32, root_queue_len: usize, } impl Evaluator { - fn new(root: &Node) -> Self { + fn new(weights: &Weights, root: &Node) -> Self { Self { - root_score: evaluate(root.matrix(), 0), + weights: *weights, + root_score: features(root.matrix(), 0).evaluate(weights), root_queue_len: root.queue().len(), } } fn evaluate(&self, mat: &Mat, queue: Queue<'_>) -> i32 { - let pcnt = self.root_queue_len.saturating_sub(queue.len()); - // FIXME: the old blockfish has two special edge cases for rating nodes that is // not done here. // @@ -93,8 +93,11 @@ impl Evaluator { // avoided by holding the last piece. i think this improves the performance only // slightly, but it is also a bit of a hack that deserves further consideration. + let pcnt = self.root_queue_len.saturating_sub(queue.len()); + let score = features(mat, pcnt).evaluate(&self.weights); + // larger (i.e., further below the root score) is better - self.root_score - evaluate(mat, pcnt) + self.root_score - score } } diff --git a/fish/src/eval.rs b/fish/src/eval.rs index a9e99e3..eaa0791 100644 --- a/fish/src/eval.rs +++ b/fish/src/eval.rs @@ -1,226 +1,64 @@ -use mino::matrix::{COLUMNS, EMPTY_ROW, FULL_ROW}; use mino::Mat; +mod basic; mod downstacking; -pub use downstacking::mystery_mdse; +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub struct Features(pub usize, pub [i32; 4]); -pub fn evaluate(mat: &Mat, pcnt: usize) -> i32 { - // TODO: public interface for weights etc. - - struct Weights { - height: i32, - i_deps: i32, - mdse: i32, - pcnt: i32, - } - - const W: Weights = Weights { - height: 5, - i_deps: 10, - mdse: 10, - pcnt: 10, - }; - - let mut rating = 0; - rating += max_height(mat) * W.height; - rating += i_deps(mat) * W.i_deps; - rating += mystery_mdse(mat) * W.mdse; - rating += (pcnt as i32) * W.pcnt; - rating +/// Computes the feature vector for the given matrix. These are the underlying values +/// used to calculate the heuristic rating. +pub fn features(mat: &Mat, pcnt: usize) -> Features { + Features( + pcnt, + [ + basic::max_height(mat), + basic::i_deps(mat), + basic::row_trans(mat), + downstacking::mystery_mdse(mat), + ], + ) } -pub fn max_height(matrix: &Mat) -> i32 { - matrix.rows() as i32 +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub struct Weights(pub [i32; 4]); + +impl Weights { + /// Default weights (determined experimentally). + pub const DEFAULT: Self = Self([ + 512, // max_height + 1024, // i_deps + 0, // FIXME: row_trans + 1024, // mdse + ]); + + /// Constant penalty given to the number of pieces placed. Each element of the weight + /// vector is effectively a ratio to this constant, i.e. if this were to be doubled, + /// then all of the weights should be doubled as well to produce an identical + /// heuristic. + pub const PER_PIECE: i32 = 1024; } -pub fn i_deps(matrix: &Mat) -> i32 { - let mut depth = [0u8; COLUMNS as usize]; - let mut count = 0; - for y in 0..matrix.rows() { - // 012345689xxxx - // ↑ - // _______xxx___ - let mut mask = 0b111 << (COLUMNS - 2); - // _______x.x___ - let mut test = 0b101 << (COLUMNS - 2); - for x in (0..COLUMNS).rev() { - if matrix.test_row(y, mask, test) { - depth[x as usize] += 1; - depth[x as usize] %= 4; - if depth[x as usize] == 3 { - count += 1; - } - } else { - depth[x as usize] = 0; - } - mask >>= 1; - test >>= 1; - } - } - count -} - -pub fn row_trans(matrix: &Mat) -> i32 { - let mut prev_row = FULL_ROW; - let mut count = 0; - for &curr_row in &matrix[..] { - count += (curr_row ^ prev_row).count_ones(); - prev_row = curr_row; - } - count += (prev_row ^ EMPTY_ROW).count_ones(); - count as i32 -} - -#[cfg(test)] -mod test { - use super::*; - use mino::mat; - - #[test] - fn test_i_deps() { - assert_eq!(i_deps(Mat::EMPTY), 0); - assert_eq!( - i_deps(mat! { - "xxxxx...xx"; - "xxxxx..xxx"; - "xxxxx.xx.x"; - "xxxxx.xxxx"; - }), - 0 - ); - assert_eq!( - i_deps(mat! { - "xxx.....xx"; - "xxxx..xxxx"; - "xxxxx.xxxx"; - "x.xxx.xxxx"; - }), - 0 - ); - assert_eq!( - i_deps(mat! { - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - }), - 1 - ); - assert_eq!( - i_deps(mat! { - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - }), - 1 - ); - assert_eq!( - i_deps(mat! { - "xxxxxxxx.x"; - "xxxxxxxx.x"; - "xxxxxxxx.x"; - "xxxxxxxx.x"; - "x.xxxxxxxx"; - "x.xxxxxxxx"; - "x.xxxxxxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - }), - 3, - ); - assert_eq!( - i_deps(mat! { - // 6 rows - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - }), - 1 - ); - assert_eq!( - i_deps(mat! { - // 7 rows - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - }), - 2 - ); - assert_eq!( - i_deps(mat! { - "x.x......x"; - "x.xxx.xxxx"; - "x.xxx.xxxx"; - "x.xxx.xxxx"; - }), - 2 - ); - } - - #[test] - fn test_row_trans() { - assert_eq!(row_trans(Mat::EMPTY), 10); - assert_eq!( - row_trans(mat! { - "x........."; - "xx........"; - "xxx......."; - "xxxx......"; - "xxxxx....."; - "xxxxxx...."; - "xxxxxxx..."; - "xxxxxxxx.."; - "xxxxxxxxx."; - "xxxxxxxxxx"; - }), - 10, - ); - assert_eq!( - row_trans(mat! { - "xxxxx.xxxx"; - }), - 9 + 1, - ); - assert_eq!( - row_trans(mat! { - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - "xxxxx.xxxx"; - }), - 9 + 1, - ); - assert_eq!( - row_trans(mat! { - "xxxxx.xxxx"; - "xxxx.Xxxxx"; - }), - 9 + 2 + 1 - ); - assert_eq!( - row_trans(mat! { - "xxxxx.xxx."; - "xx....xxxx"; - }), - 8 + 4 + 4 - ); - assert_eq!( - row_trans(mat! { - "xxxx..xxxx"; - "xxxx...xxx"; - "xxxxx.xxxx"; - }), - 8 + 1 + 2 + 1, - ); +impl Default for Weights { + fn default() -> Self { + Self::DEFAULT + } +} + +impl Features { + /// Applies the given weights vector to the features. Returns the heuristic score. + #[inline] + pub fn evaluate(&self, weights: &Weights) -> i32 { + let pcnt = self.0 as i32; + let ft = &self.1; + let wt = &weights.0; + + let mut score = Weights::PER_PIECE * pcnt; + score += wt[0] * ft[0]; + score += wt[1] * ft[1]; + score += wt[2] * ft[2]; + score += wt[3] * ft[3]; + // TODO(?) quadratic weights, e.g. score += wt[i] * f[j]^2 + score } } diff --git a/fish/src/eval/basic.rs b/fish/src/eval/basic.rs new file mode 100644 index 0000000..e4de7ba --- /dev/null +++ b/fish/src/eval/basic.rs @@ -0,0 +1,196 @@ +use mino::matrix::{Mat, COLUMNS, EMPTY_ROW, FULL_ROW}; + +pub fn max_height(matrix: &Mat) -> i32 { + matrix.rows() as i32 +} + +pub fn i_deps(matrix: &Mat) -> i32 { + let mut depth = [0u8; COLUMNS as usize]; + let mut count = 0; + for y in 0..matrix.rows() { + // 012345689xxxx + // ↑ + // _______xxx___ + let mut mask = 0b111 << (COLUMNS - 2); + // _______x.x___ + let mut test = 0b101 << (COLUMNS - 2); + for x in (0..COLUMNS).rev() { + if matrix.test_row(y, mask, test) { + depth[x as usize] += 1; + depth[x as usize] %= 4; + if depth[x as usize] == 3 { + count += 1; + } + } else { + depth[x as usize] = 0; + } + mask >>= 1; + test >>= 1; + } + } + count +} + +pub fn row_trans(matrix: &Mat) -> i32 { + let mut prev_row = FULL_ROW; + let mut count = 0; + for &curr_row in &matrix[..] { + count += (curr_row ^ prev_row).count_ones(); + prev_row = curr_row; + } + count += (prev_row ^ EMPTY_ROW).count_ones(); + count as i32 +} + +#[cfg(test)] +mod test { + use super::*; + use mino::mat; + + #[test] + fn test_i_deps() { + assert_eq!(i_deps(Mat::EMPTY), 0); + assert_eq!( + i_deps(mat! { + "xxxxx...xx"; + "xxxxx..xxx"; + "xxxxx.xx.x"; + "xxxxx.xxxx"; + }), + 0 + ); + assert_eq!( + i_deps(mat! { + "xxx.....xx"; + "xxxx..xxxx"; + "xxxxx.xxxx"; + "x.xxx.xxxx"; + }), + 0 + ); + assert_eq!( + i_deps(mat! { + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + }), + 1 + ); + assert_eq!( + i_deps(mat! { + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + }), + 1 + ); + assert_eq!( + i_deps(mat! { + "xxxxxxxx.x"; + "xxxxxxxx.x"; + "xxxxxxxx.x"; + "xxxxxxxx.x"; + "x.xxxxxxxx"; + "x.xxxxxxxx"; + "x.xxxxxxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + }), + 3, + ); + assert_eq!( + i_deps(mat! { + // 6 rows + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + }), + 1 + ); + assert_eq!( + i_deps(mat! { + // 7 rows + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + }), + 2 + ); + assert_eq!( + i_deps(mat! { + "x.x......x"; + "x.xxx.xxxx"; + "x.xxx.xxxx"; + "x.xxx.xxxx"; + }), + 2 + ); + } + + #[test] + fn test_row_trans() { + assert_eq!(row_trans(Mat::EMPTY), 10); + assert_eq!( + row_trans(mat! { + "x........."; + "xx........"; + "xxx......."; + "xxxx......"; + "xxxxx....."; + "xxxxxx...."; + "xxxxxxx..."; + "xxxxxxxx.."; + "xxxxxxxxx."; + "xxxxxxxxxx"; + }), + 10, + ); + assert_eq!( + row_trans(mat! { + "xxxxx.xxxx"; + }), + 9 + 1, + ); + assert_eq!( + row_trans(mat! { + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + "xxxxx.xxxx"; + }), + 9 + 1, + ); + assert_eq!( + row_trans(mat! { + "xxxxx.xxxx"; + "xxxx.Xxxxx"; + }), + 9 + 2 + 1 + ); + assert_eq!( + row_trans(mat! { + "xxxxx.xxx."; + "xx....xxxx"; + }), + 8 + 4 + 4 + ); + assert_eq!( + row_trans(mat! { + "xxxx..xxxx"; + "xxxx...xxx"; + "xxxxx.xxxx"; + }), + 8 + 1 + 2 + 1, + ); + } +} diff --git a/tidepool/src/main.rs b/tidepool/src/main.rs index 70ce583..379469c 100644 --- a/tidepool/src/main.rs +++ b/tidepool/src/main.rs @@ -5,9 +5,9 @@ use std::io::{Read, Write}; use std::path::Path; use std::sync::{atomic, mpsc, Arc}; use std::time::{Duration, Instant}; -use tidepool::output::SummaryStats; -use fish::bot; +use fish::{bot, eval}; +use tidepool::output::SummaryStats; use tidepool::{cli, config, output, sim}; fn main() -> Result<()> { @@ -338,6 +338,7 @@ fn run_simulation( ) { let mut moves = data_args.list_moves.then(Vec::new); let seed = seed.unwrap_or_else(|| rand::thread_rng().next_u64()); + let weights = eval::Weights::default(); let sim_opts = sim::Options { goal: config.game.goal, @@ -354,6 +355,7 @@ fn run_simulation( } let mut bot = bot::Bot::new( + &weights, sim.matrix(), mino::Queue::new(sim.queue().hold(), sim.queue().next()), );