add weights configuration to tidepool config

This commit is contained in:
tali 2023-04-15 18:30:26 -04:00
parent 3445f9e44a
commit 6a08775a83
1 changed files with 30 additions and 7 deletions

View File

@ -1,8 +1,9 @@
//! Config file schema (serde) for tidepool. See `example-configs/` for usage samples.
use fish::eval::Weights;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
#[derive(Deserialize, PartialEq, Debug, Clone)]
pub struct Config {
pub game: GameConfig,
#[serde(default)]
@ -37,17 +38,21 @@ impl GameRulesConfig {
};
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
#[derive(Deserialize, PartialEq, Debug, Clone)]
pub struct BotConfig {
// TODO: weights
// TODO: algorithm
// TODO: capabililties
#[serde(default = "defaults::iters")]
pub iters: usize,
#[serde(default = "defaults::weights", deserialize_with = "de::weights")]
pub weights: Weights,
}
impl BotConfig {
pub const DEFAULT: Self = Self { iters: 10_000 };
pub const DEFAULT: Self = Self {
iters: 10_000,
weights: Weights::DEFAULT,
};
}
impl Default for BotConfig {
@ -78,6 +83,10 @@ mod defaults {
pub const fn iters() -> usize {
BotConfig::DEFAULT.iters
}
pub const fn weights() -> Weights {
BotConfig::DEFAULT.weights
}
}
mod de {
@ -105,6 +114,13 @@ mod de {
Variant::Custom(c) => c,
})
}
pub fn weights<'de, D>(deserializer: D) -> Result<Weights, D::Error>
where
D: serde::Deserializer<'de>,
{
<[i32; 4] as Deserialize>::deserialize(deserializer).map(Weights)
}
}
#[cfg(test)]
@ -121,6 +137,7 @@ goal = 100
[bot]
iters = 5_000
weights = [1, 2, 3, 4]
"
)
.unwrap(),
@ -129,7 +146,10 @@ iters = 5_000
goal: 100,
rules: GameRulesConfig::JSTRIS,
},
bot: BotConfig { iters: 5_000 }
bot: BotConfig {
iters: 5_000,
weights: Weights([1, 2, 3, 4]),
}
}
);
}
@ -141,7 +161,7 @@ iters = 5_000
"
game.goal = 100
game.rules = { min = 4, previews = 8 }
bot = {}
bot.iters = 7000
"
)
.unwrap(),
@ -154,7 +174,10 @@ bot = {}
previews: 8,
},
},
bot: BotConfig::default(),
bot: BotConfig {
iters: 7000,
weights: Weights::DEFAULT,
},
}
);
}