Skip to content

Commit

Permalink
Merge pull request #253 from lichess-org/chunking
Browse files Browse the repository at this point in the history
Chunking
  • Loading branch information
niklasf authored Jan 4, 2024
2 parents 87b6830 + 986b9e0 commit 92ae186
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 152 deletions.
7 changes: 5 additions & 2 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ impl Work {
}
}

pub fn timeout(&self) -> Duration {
pub fn timeout_per_position(&self) -> Duration {
match *self {
Work::Analysis { timeout, .. } => timeout,
Work::Move { .. } => Duration::from_secs(2),
Expand Down Expand Up @@ -288,6 +288,9 @@ impl From<Centis> for Duration {
}
}

#[derive(Debug, Copy, Clone, Eq, PartialEq, Deserialize)]
pub struct PositionIndex(pub usize);

#[serde_as]
#[derive(Debug, Deserialize)]
pub struct AcquireResponseBody {
Expand All @@ -303,7 +306,7 @@ pub struct AcquireResponseBody {
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, Uci>")]
pub moves: Vec<Uci>,
#[serde(rename = "skipPositions", default)]
pub skip_positions: Vec<usize>,
pub skip_positions: Vec<PositionIndex>,
}

impl AcquireResponseBody {
Expand Down
45 changes: 30 additions & 15 deletions src/ipc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,45 @@ use tokio::sync::oneshot;
use url::Url;

use crate::{
api::{AnalysisPart, BatchId, Score, Work},
api::{AnalysisPart, BatchId, PositionIndex, Score, Work},
assets::EngineFlavor,
};

/// Uniquely identifies a position within a batch.
#[derive(Debug, Copy, Clone)]
pub struct PositionId(pub usize);
#[derive(Debug)]
pub struct Chunk {
pub work: Work,
pub variant: Variant,
pub flavor: EngineFlavor,
pub positions: Vec<Position>,
}

impl Chunk {
pub const MAX_POSITIONS: usize = 6;

pub fn timeout(&self) -> Duration {
self.positions
.iter()
.filter(|pos| pos.position_index.is_some())
.count() as u32
* self.work.timeout_per_position()
}
}

#[derive(Debug, Clone)]
pub struct Position {
pub work: Work,
pub position_id: PositionId,
pub flavor: EngineFlavor,
pub position_index: Option<PositionIndex>,
pub url: Option<Url>,
pub skip: bool,

pub variant: Variant,
pub root_fen: Fen,
pub moves: Vec<Uci>,
}

#[derive(Debug, Clone)]
pub struct PositionResponse {
pub work: Work,
pub position_id: PositionId,
pub position_index: Option<PositionIndex>,
pub url: Option<Url>,

pub scores: Matrix<Score>,
Expand Down Expand Up @@ -87,29 +102,29 @@ impl<T> Matrix<T> {

pub fn best(&self) -> Option<&T> {
self.matrix
.get(0)
.first()
.and_then(|row| row.last().and_then(|v| v.as_ref()))
}
}

#[derive(Debug)]
pub struct PositionFailed {
pub struct ChunkFailed {
pub batch_id: BatchId,
}

#[derive(Debug)]
pub struct Pull {
pub response: Option<Result<PositionResponse, PositionFailed>>,
pub callback: oneshot::Sender<Position>,
pub responses: Result<Vec<PositionResponse>, ChunkFailed>,
pub callback: oneshot::Sender<Chunk>,
}

impl Pull {
pub fn split(
self,
) -> (
Option<Result<PositionResponse, PositionFailed>>,
oneshot::Sender<Position>,
Result<Vec<PositionResponse>, ChunkFailed>,
oneshot::Sender<Chunk>,
) {
(self.response, self.callback)
(self.responses, self.callback)
}
}
24 changes: 17 additions & 7 deletions src/logger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ use shakmaty::variant::Variant;
use url::Url;

use crate::{
api::BatchId,
api::{BatchId, PositionIndex},
configure::Verbose,
ipc::{Position, PositionId, PositionResponse},
ipc::{Chunk, Position, PositionResponse},
util::NevermindExt as _,
};

Expand Down Expand Up @@ -108,33 +108,43 @@ impl Logger {
pub struct ProgressAt {
pub batch_id: BatchId,
pub batch_url: Option<Url>,
pub position_id: Option<PositionId>,
pub position_index: Option<PositionIndex>,
}

impl fmt::Display for ProgressAt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(ref batch_url) = self.batch_url {
let mut url = batch_url.clone();
if let Some(PositionId(positon_id)) = self.position_id {
if let Some(PositionIndex(positon_id)) = self.position_index {
url.set_fragment(Some(&positon_id.to_string()));
}
fmt::Display::fmt(&url, f)
} else {
write!(f, "{}", self.batch_id)?;
if let Some(PositionId(positon_id)) = self.position_id {
if let Some(PositionIndex(positon_id)) = self.position_index {
write!(f, "#{positon_id}")?;
}
Ok(())
}
}
}

impl From<&Chunk> for ProgressAt {
fn from(chunk: &Chunk) -> ProgressAt {
ProgressAt {
batch_id: chunk.work.id(),
batch_url: chunk.positions.last().and_then(|pos| pos.url.clone()),
position_index: chunk.positions.last().and_then(|pos| pos.position_index),
}
}
}

impl From<&Position> for ProgressAt {
fn from(pos: &Position) -> ProgressAt {
ProgressAt {
batch_id: pos.work.id(),
batch_url: pos.url.clone(),
position_id: Some(pos.position_id),
position_index: pos.position_index,
}
}
}
Expand All @@ -144,7 +154,7 @@ impl From<&PositionResponse> for ProgressAt {
ProgressAt {
batch_id: pos.work.id(),
batch_url: pos.url.clone(),
position_id: Some(pos.position_id),
position_index: pos.position_index,
}
}
}
Expand Down
37 changes: 22 additions & 15 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use tokio::{
use crate::{
assets::{Assets, ByEngineFlavor, Cpu, EngineFlavor},
configure::{Command, Cores, CpuPriority, Opt},
ipc::{Position, PositionFailed, Pull},
ipc::{Chunk, ChunkFailed, Pull},
logger::{Logger, ProgressAt},
util::RandomizedBackoff,
};
Expand Down Expand Up @@ -249,7 +249,7 @@ async fn run(opt: Opt, logger: &Logger) {
}
}

// Shutdown queue to abort remaining jobs.
// Shutdown queue to abort remaining chunks.
queue.shutdown().await;

// Wait for all workers.
Expand All @@ -266,21 +266,21 @@ async fn run(opt: Opt, logger: &Logger) {
async fn worker(i: usize, assets: Arc<Assets>, tx: mpsc::Sender<Pull>, logger: Logger) {
logger.debug(&format!("Started worker {i}."));

let mut job: Option<Position> = None;
let mut chunk: Option<Chunk> = None;
let mut engine = ByEngineFlavor {
official: None,
multi_variant: None,
};
let mut engine_backoff = RandomizedBackoff::default();

let default_budget = Duration::from_secs(60);
let default_budget = Duration::from_secs(30);
let mut budget = default_budget;

loop {
let response = if let Some(job) = job.take() {
let responses = if let Some(chunk) = chunk.take() {
// Ensure engine process is ready.
let flavor = job.flavor;
let context = ProgressAt::from(&job);
let flavor = chunk.flavor;
let context = ProgressAt::from(&chunk);
let (mut sf, join_handle) =
if let Some((sf, join_handle)) = engine.get_mut(flavor).take() {
(sf, join_handle)
Expand Down Expand Up @@ -310,19 +310,19 @@ async fn worker(i: usize, assets: Arc<Assets>, tx: mpsc::Sender<Pull>, logger: L
};

// Provide time budget.
budget = min(default_budget, budget) + job.work.timeout();
budget = min(default_budget, budget) + chunk.timeout();

// Analyse or play.
let timer = Instant::now();
let batch_id = job.work.id();
let batch_id = chunk.work.id();
let res = tokio::select! {
_ = tx.closed() => {
logger.debug(&format!("Worker {i} shutting down engine early"));
drop(sf);
join_handle.await.expect("join");
break;
}
res = sf.go(job) => {
res = sf.go_multiple(chunk) => {
match res {
Ok(res) => {
*engine.get_mut(flavor) = Some((sf, join_handle));
Expand All @@ -344,7 +344,7 @@ async fn worker(i: usize, assets: Arc<Assets>, tx: mpsc::Sender<Pull>, logger: L
});
drop(sf);
join_handle.await.expect("join");
Err(PositionFailed { batch_id })
Err(ChunkFailed { batch_id })
}
};

Expand All @@ -354,14 +354,21 @@ async fn worker(i: usize, assets: Arc<Assets>, tx: mpsc::Sender<Pull>, logger: L
logger.debug(&format!("Low engine timeout budget: {budget:?}"));
}

Some(res)
res
} else {
None
Ok(Vec::new())
};

let (callback, waiter) = oneshot::channel();

if tx.send(Pull { response, callback }).await.is_err() {
if tx
.send(Pull {
responses,
callback,
})
.await
.is_err()
{
logger.debug(&format!(
"Worker {i} was about to send result, but shutting down"
));
Expand All @@ -372,7 +379,7 @@ async fn worker(i: usize, assets: Arc<Assets>, tx: mpsc::Sender<Pull>, logger: L
_ = tx.closed() => break,
res = waiter => {
match res {
Ok(next_job) => job = Some(next_job),
Ok(next_chunk) => chunk = Some(next_chunk),
Err(_) => break,
}
}
Expand Down
Loading

0 comments on commit 92ae186

Please sign in to comment.