WIP: modified A* with bumpalo managed nodes

This commit is contained in:
tali 2023-03-07 19:30:31 -05:00
parent d52e4da215
commit 15d69fe128
5 changed files with 314 additions and 93 deletions

7
Cargo.lock generated
View File

@ -26,6 +26,12 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bumpalo"
version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535"
[[package]]
name = "cc"
version = "1.0.78"
@ -102,6 +108,7 @@ version = "0.1.0"
dependencies = [
"ahash",
"anyhow",
"bumpalo",
"clap",
"hashbrown",
"mino",

View File

@ -20,6 +20,7 @@ mino = { path = "../mino" }
ahash = "0.8"
anyhow = { version = "1.0", optional = true }
bumpalo = "3.12"
clap = { version = "4.0", features = ["derive"], optional = true }
hashbrown = "0.13"
serde = { version = "1.0", features = ["derive"], optional = true }

View File

@ -1,117 +1,250 @@
//! AI engine.
use crate::{eval, find};
use alloc::vec::Vec;
use core::{future::Future, pin::Pin, task::Poll};
use core::cell::Cell;
use core::ops::Deref;
use core::pin::Pin;
use alloc::boxed::Box;
use mino::srs::{Piece, PieceType};
use mino::{Mat, MatBuf};
#[derive(Debug)]
#[non_exhaustive]
use alloc::vec::Vec;
use bumpalo::Bump;
use crate::eval::evaluate;
use crate::find::cap::All;
use crate::find::{FindLocations, FindLocationsBuffers};
use self::search::ModifiedAStar;
mod search;
pub struct Ai {
root_matrix: MatBuf,
root_previews: Vec<PieceType>,
path: Vec<Piece>,
cur_mat: MatBuf,
search: search::ModifiedAStar<Graph>,
best: Option<Node>,
_arena: Pin<Box<Bump>>,
}
fn evaluate(mat: &Mat, depth: usize) -> i32 {
let w_height = 5;
let w_ideps = 10;
let w_mdse = 10;
let w_pc = 10;
let mut rating = 0;
rating += eval::max_height(mat) * w_height;
rating += eval::i_deps(mat) * w_ideps;
rating += eval::mystery_mdse(mat) * w_mdse;
rating += (depth as i32) * w_pc;
rating
}
pub struct Exhausted;
impl Ai {
// TODO: personality config
pub fn new() -> Self {
pub fn new(init_mat: &Mat, init_previews: &[PieceType], init_hold: Option<PieceType>) -> Self {
let arena = Box::pin(Bump::new());
let init_queue = Queue::alloc(&*arena, init_previews, init_hold);
let graph = Graph::new(&*arena, init_mat, init_queue);
let search = ModifiedAStar::new(graph);
Self {
root_matrix: MatBuf::new(),
root_previews: Vec::with_capacity(8),
path: Vec::with_capacity(8),
cur_mat: MatBuf::new(),
best: None,
search,
_arena: arena,
}
}
pub fn start(
&mut self,
init_mat: &Mat,
init_previews: &[PieceType],
init_hold: Option<PieceType>,
) {
// init root node
self.root_matrix.copy_from(init_mat);
self.root_previews.clear();
self.root_previews.extend_from_slice(init_previews);
self.root_previews.extend(init_hold); // TODO: actual hold logic
// init search state
self.path.clear();
self.cur_mat.copy_from(&self.root_matrix);
pub fn think(&mut self) -> Result<(), Exhausted> {
let node = self.search.step().ok_or(Exhausted)?;
if self.best.map_or(true, |best| node > best) {
tracing::debug!("new best: {node:?} ({})", node.rating);
self.best = Some(node);
}
Ok(())
}
pub fn think_1_cycle(&mut self) -> bool {
let depth = self.path.len();
if depth >= self.root_previews.len() {
return true;
pub fn suggestion(&self) -> impl Iterator<Item = Piece> + '_ {
self.best.iter().flat_map(|n| n.trace())
}
}
struct Graph {
arena: *const Bump,
root: Node,
find_buf: Option<FindLocationsBuffers>,
children_buf: Vec<Node>,
}
impl Graph {
fn new(arena: *const Bump, root_mat: &Mat, root_queue: Queue) -> Self {
let root = Node::new_root(arena, root_mat, root_queue);
Self {
arena,
root,
find_buf: None,
children_buf: Vec::new(),
}
}
}
impl search::Graph for Graph {
type Node = Node;
fn root(&mut self) -> Self::Node {
self.root.clone()
}
fn expand(&mut self, node: Self::Node) -> &[Self::Node] {
self.children_buf.clear();
for ty in node.queue.current() {
let find_buf = self.find_buf.take().unwrap_or_default();
let mut locs = FindLocations::with_buffers(&node.mat, ty, All, find_buf);
for loc in &mut locs {
let piece = Piece { ty, loc };
let node = node.succ(self.arena, piece);
self.children_buf.push(node);
}
self.find_buf = Some(locs.into_buffers());
}
let ty = self.root_previews[depth];
let locs = find::find_locations(&self.cur_mat, ty, find::cap::All);
let pcs = locs.map(|loc| Piece { ty, loc });
tracing::trace!("expanded to create {} children", self.children_buf.len());
&self.children_buf
}
}
let mut rate_mat: MatBuf = MatBuf::new();
let mut best_mat: MatBuf = MatBuf::new();
let mut best: Option<(i32, Piece)> = None;
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
struct Node(*const NodeData);
for pc in pcs {
rate_mat.copy_from(&self.cur_mat);
pc.cells().fill(&mut rate_mat);
rate_mat.clear_lines();
let rating = evaluate(&rate_mat, depth);
struct NodeData {
mat: MatBuf,
queue: Queue,
pcnt: usize,
rating: i32,
back_edge: Cell<Option<Edge>>,
}
let best_rating = best.clone().map_or(i32::MAX, |(r, _)| r);
if rating < best_rating {
best = Some((rating, pc));
best_mat.copy_from(&rate_mat);
#[derive(Copy, Clone)]
struct Edge {
piece: Piece,
pred: Node,
}
impl search::Node for Node {
fn is_terminal(&self) -> bool {
self.queue.is_empty()
}
}
impl Node {
fn new_root(arena: *const Bump, mat: &Mat, queue: Queue) -> Self {
Self::new(arena, mat, queue, 0)
}
fn succ(self, arena: *const Bump, piece: Piece) -> Self {
let mut mat: MatBuf = MatBuf::new();
mat.copy_from(&self.mat);
piece.cells().fill(&mut mat);
mat.clear_lines();
let queue = self.queue.succ(piece.ty);
let pcnt = self.pcnt + 1;
let succ = Self::new(arena, &mat, queue, pcnt);
succ.back_edge.set(Some(Edge {
piece,
pred: self.clone(),
}));
succ
}
fn new(arena: *const Bump, mat: &Mat, queue: Queue, pcnt: usize) -> Self {
let arena = unsafe { &*arena };
let rating = evaluate(mat, pcnt);
let node_data = NodeData::alloc(arena, mat, queue, pcnt, rating);
Self(node_data)
}
fn trace(self) -> Vec<Piece> {
let mut pieces = Vec::with_capacity(self.pcnt);
let mut parent = Some(self);
while let Some(node) = parent.take() {
if let Some(edge) = node.back_edge.get() {
pieces.push(edge.piece);
parent = Some(edge.pred);
}
}
let pc = match best {
Some((_, pc)) => pc,
None => return true, // no locations; game over
};
self.path.push(pc);
self.cur_mat.copy_from(&best_mat);
false
}
pub fn suggestion(&self) -> Vec<Piece> {
self.path.clone()
pieces.reverse();
pieces
}
}
pub struct Think<'a>(&'a mut Ai);
impl core::cmp::Ord for Node {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
other.rating.cmp(&self.rating)
}
}
impl Future for Think<'_> {
type Output = ();
impl core::cmp::Eq for Node {}
fn poll(mut self: Pin<&mut Self>, _cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
let ai: &mut Ai = &mut self.0;
impl core::cmp::PartialOrd for Node {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
if ai.think_1_cycle() {
// TODO: if <limits reached> then return Poll::Ready)
Poll::Pending
impl core::cmp::PartialEq for Node {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl Deref for Node {
type Target = NodeData;
fn deref(&self) -> &Self::Target {
unsafe { &*self.0 }
}
}
impl NodeData {
fn alloc<'a>(arena: &'a Bump, mat: &Mat, queue: Queue, pcnt: usize, rating: i32) -> &'a Self {
let node = arena.alloc_with(|| NodeData {
mat: MatBuf::new(),
rating,
pcnt,
queue,
back_edge: Cell::new(None),
});
node.mat.copy_from(mat);
node
}
}
struct Queue {
next: *const [PieceType],
held: Option<PieceType>,
}
impl Queue {
fn alloc(arena: &Bump, previews: &[PieceType], hold: Option<PieceType>) -> Self {
Queue {
next: arena.alloc_slice_copy(previews),
held: hold,
}
}
fn next(&self) -> &[PieceType] {
unsafe { &*self.next }
}
fn current(&self) -> impl Iterator<Item = PieceType> {
[self.next().first().copied(), self.held]
.into_iter()
.flatten()
}
fn is_empty(&self) -> bool {
self.next().is_empty() && self.held.is_none()
}
fn succ(&self, ty: PieceType) -> Self {
let (hd, tl) = match self.next() {
[hd, tl @ ..] => (Some(*hd), tl),
[] => (None, &[][..]),
};
if self.held == Some(ty) {
Self { next: tl, held: hd }
} else {
Poll::Ready(())
debug_assert_eq!(hd, Some(ty));
Self {
next: tl,
held: self.held,
}
}
}
}

83
fish/src/ai/search.rs Normal file
View File

@ -0,0 +1,83 @@
use alloc::{collections::BinaryHeap, vec::Vec};
pub trait Graph {
type Node: Node;
fn root(&mut self) -> Self::Node;
fn expand(&mut self, node: Self::Node) -> &[Self::Node];
}
pub trait Node: Clone + core::fmt::Debug + Ord {
fn is_terminal(&self) -> bool;
}
pub struct ModifiedAStar<G: Graph> {
graph: G,
fringe: Vec<BinaryHeap<G::Node>>,
depth: usize,
}
struct NoneAvailable;
impl<G: Graph> ModifiedAStar<G> {
pub fn new(mut graph: G) -> Self {
Self {
fringe: Vec::from_iter([BinaryHeap::from_iter([graph.root()])]),
depth: 0,
graph,
}
}
pub fn step(&mut self) -> Option<G::Node> {
loop {
match self.expand() {
Ok(Some(term_node)) => break Some(term_node),
Ok(None) => continue,
Err(NoneAvailable) => match self.select() {
Ok(()) => continue,
Err(NoneAvailable) => break None,
},
}
}
}
fn expand(&mut self) -> Result<Option<G::Node>, NoneAvailable> {
tracing::trace!("expand depth = {}", self.depth);
let set = self.fringe.get_mut(self.depth);
self.depth += 1;
let node = set.and_then(|s| s.pop()).ok_or(NoneAvailable)?;
if node.is_terminal() {
tracing::trace!("found terminal node {node:?}");
return Ok(Some(node));
}
let children = self.graph.expand(node).iter().cloned();
self.fringe.resize_with(self.depth + 1, BinaryHeap::new);
self.fringe[self.depth].extend(children);
Ok(None)
}
fn select(&mut self) -> Result<(), NoneAvailable> {
let mut best = None;
for (depth, set) in self.fringe.iter().enumerate() {
if let Some(node) = set.peek() {
if best.as_ref().map_or(true, |best| node > best) {
best = Some(node.clone());
self.depth = depth;
}
}
}
if let Some(best) = best {
tracing::trace!("selected depth = {}, best = {:?}", self.depth, best);
Ok(())
} else {
tracing::trace!("fringe exhausted; no nodes remaining");
Err(NoneAvailable)
}
}
}

View File

@ -112,24 +112,21 @@ fn print_best_move(_settings: Settings) -> anyhow::Result<()> {
let input: fish::io::InputState =
serde_json::from_reader(std::io::stdin()).context("error parsing input state")?;
let mut mat = input.matrix.to_mat();
mat.clear_lines();
let mat = input.matrix.to_mat();
// mat.clear_lines();
// TODO: ai init config, e.g. personality
let mut ai = fish::Ai::new();
// TODO: hold
// TODO: attack state
ai.start(&mat, &input.queue.previews, input.queue.hold);
// TODO: attack state (combo/b2b)
let mut ai = fish::Ai::new(&mat, &input.queue.previews, input.queue.hold);
// TODO: resource limits (cycles,nodes,time)
let mut cycles = 0;
loop {
if ai.think_1_cycle() {
tracing::trace!("thinking... ({cycles})");
if matches!(ai.think(), Err(fish::ai::Exhausted)) || cycles > 100_000 {
break;
}
cycles += 1;
tracing::trace!("thinking ({cycles})...");
}
// print suggestions trace