diff --git a/bin/airflow b/bin/airflow index 948456d..07fc84a 100755 --- a/bin/airflow +++ b/bin/airflow @@ -7,7 +7,7 @@ export PYTHONPATH=$(dirname "$SCRIPT_DIR")/shared if [ "$1" == "setup" ] then - python -m venv "$VENV_DIR" + /usr/bin/python3 -m venv "$VENV_DIR" "$VENV_DIR"/bin/pip3 install "apache-airflow[celery]==2.9.1" \ apache-airflow-providers-slack[common.sql] \ apache-airflow-providers-google \ diff --git a/rollout-dashboard/server/src/airflow_client.rs b/rollout-dashboard/server/src/airflow_client.rs index 9d6c798..e660d8a 100644 --- a/rollout-dashboard/server/src/airflow_client.rs +++ b/rollout-dashboard/server/src/airflow_client.rs @@ -5,7 +5,7 @@ use regex::Regex; use reqwest::cookie::Jar; use reqwest::header::{ACCEPT, CONTENT_TYPE, REFERER}; use serde::de::Error; -use serde::{Deserialize, Deserializer}; +use serde::{Deserialize, Deserializer, Serialize}; use std::cmp::min; use std::collections::HashMap; use std::convert::TryFrom; @@ -166,7 +166,7 @@ pub struct XComEntryResponse { pub value: String, } -#[derive(Debug, Deserialize, Clone, PartialEq, Display)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Display)] #[serde(rename_all = "snake_case")] pub enum TaskInstanceState { Success, @@ -216,7 +216,7 @@ where }) } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct TaskInstancesResponseItem { pub task_id: String, #[allow(dead_code)] diff --git a/rollout-dashboard/server/src/frontend_api.rs b/rollout-dashboard/server/src/frontend_api.rs index 1127eb7..ac18d12 100644 --- a/rollout-dashboard/server/src/frontend_api.rs +++ b/rollout-dashboard/server/src/frontend_api.rs @@ -7,7 +7,7 @@ use chrono::{DateTime, Utc}; use futures::future::join_all; use indexmap::IndexMap; use lazy_static::lazy_static; -use log::{debug, trace, warn}; +use log::{debug, info, trace, warn}; use regex::Regex; use rollout_dashboard::types::{ Batch, Rollout, RolloutState, Rollouts, Subnet, SubnetRolloutState, @@ -253,11 +253,12 @@ impl From for RolloutDataGatherError { } } +#[derive(Clone, Serialize)] enum ScheduleCache { Empty, - Invalid, - Valid(String), + Valid(usize, String), } + struct RolloutDataCache { task_instances: HashMap, TaskInstancesResponseItem>>, dispatch_time: DateTime, @@ -266,6 +267,15 @@ struct RolloutDataCache { last_update_time: Option>, } +#[derive(Serialize)] +pub struct RolloutDataCacheResponse { + rollout_id: String, + dispatch_time: DateTime, + schedule: ScheduleCache, + last_update_time: Option>, + linearized_task_instances: Vec, +} + struct RolloutApiCache { /// Map from DAG run ID to task instance ID (with / without index) /// to task instance. @@ -349,6 +359,32 @@ impl RolloutApi { } } + pub async fn get_cache(&self) -> Vec { + let cache = self.cache.lock().await; + let mut result: Vec<_> = cache + .by_dag_run + .iter() + .map(|(k, v)| { + let linearized_tasks = v + .task_instances + .iter() + .flat_map(|(_, tasks)| tasks.iter().map(|(_, task)| task.clone())) + .collect(); + RolloutDataCacheResponse { + rollout_id: k.clone(), + linearized_task_instances: linearized_tasks, + dispatch_time: v.dispatch_time, + last_update_time: v.last_update_time, + schedule: v.schedule.clone(), + } + }) + .collect(); + drop(cache); + result.sort_by_key(|v| v.dispatch_time); + result.reverse(); + result + } + /// Retrieve all rollout data, using a cache to avoid /// re-fetching task instances not updated since last time. /// @@ -524,9 +560,6 @@ impl RolloutApi { // Let's update the cache to incorporate the most up-to-date task instances. for task_instance in all_task_instances.into_iter() { let task_instance_id = task_instance.task_id.clone(); - if task_instance_id == "schedule" { - cache_entry.schedule = ScheduleCache::Invalid; - } let by_name = cache_entry .task_instances @@ -580,6 +613,10 @@ impl RolloutApi { // any non-subnet-related task is running / pending. // * handle tasks corresponding to a batch/subnet in a special way // (commented below in its pertinent section). + debug!( + target: "frontend_api", "Processing task {}.{:?} in state {:?}", + task_instance.task_id, task_instance.map_index, task_instance.state, + ); if task_instance.task_id == "schedule" { match task_instance.state { Some(TaskInstanceState::Skipped) | Some(TaskInstanceState::Removed) => (), @@ -598,9 +635,16 @@ impl RolloutApi { | Some(TaskInstanceState::Scheduled) | None => rollout.state = min(rollout.state, RolloutState::Preparing), Some(TaskInstanceState::Success) => { + if let ScheduleCache::Valid(try_number, _) = cache_entry.schedule { + if try_number != task_instance.try_number { + info!(target: "frontend_api", "{}: resetting schedule cache", dag_run.dag_run_id); + // Another task run of the same task has executed. We must clear the cache entry. + cache_entry.schedule = ScheduleCache::Empty; + } + } let schedule_string = match &cache_entry.schedule { - ScheduleCache::Valid(s) => s, - ScheduleCache::Invalid | ScheduleCache::Empty => { + ScheduleCache::Valid(_, s) => s, + ScheduleCache::Empty => { let value = self .airflow_api .xcom_entry( @@ -613,8 +657,11 @@ impl RolloutApi { .await; let schedule = match value { Ok(schedule) => { - cache_entry.schedule = - ScheduleCache::Valid(schedule.value.clone()); + cache_entry.schedule = ScheduleCache::Valid( + task_instance.try_number, + schedule.value.clone(), + ); + info!(target: "frontend_api", "{}: saving schedule cache", dag_run.dag_run_id); schedule.value } Err(AirflowError::StatusCode( @@ -623,6 +670,7 @@ impl RolloutApi { // There is no schedule to be found. // Or there was no schedule to be found last time // it was queried. + warn!(target: "frontend_api", "{}: no schedule despite schedule task finished", dag_run.dag_run_id); cache_entry.schedule = ScheduleCache::Empty; continue; } diff --git a/rollout-dashboard/server/src/main.rs b/rollout-dashboard/server/src/main.rs index 9a3b322..af456d7 100644 --- a/rollout-dashboard/server/src/main.rs +++ b/rollout-dashboard/server/src/main.rs @@ -5,6 +5,7 @@ use axum::response::Sse; use axum::Json; use axum::{routing::get, Router}; use chrono::{DateTime, Utc}; +use frontend_api::RolloutDataCacheResponse; use futures::stream::Stream; use log::{debug, error, info}; use reqwest::Url; @@ -162,6 +163,13 @@ impl Server { Err(e) => Err(e), } } + + // #[debug_handler] + async fn get_cache(&self) -> Result>, (StatusCode, String)> { + let m = self.rollout_api.get_cache().await; + Ok(Json(m)) + } + fn produce_rollouts_sse_stream(&self) -> Sse>> { debug!(target: "sse", "New client connected."); @@ -251,13 +259,16 @@ async fn main() -> ExitCode { let (finish_loop_tx, mut finish_loop_rx) = watch::channel(()); let server_for_rollouts_handler = server.clone(); + let server_for_cache_handler = server.clone(); let server_for_sse_handler = server.clone(); let rollouts_handler = move || async move { server_for_rollouts_handler.get_rollout_data().await }; + let cached_data_handler = move || async move { server_for_cache_handler.get_cache().await }; let rollouts_sse_handler = move || async move { server_for_sse_handler.produce_rollouts_sse_stream() }; let mut tree = Router::new(); tree = tree.route("/api/v1/rollouts", get(rollouts_handler)); + tree = tree.route("/api/v1/cache", get(cached_data_handler)); tree = tree.route("/api/v1/rollouts/sse", get(rollouts_sse_handler)); tree = tree.nest_service("/", ServeDir::new(frontend_static_dir));