forked from torvalds/linux
-
Notifications
You must be signed in to change notification settings - Fork 435
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: net/tcp: add Rust implementation of CUBIC
CUBIC is the default CCA since 2.6.?. This serves as an example that the abstractions can be used implement real-world CCAs that are based on loss as a feedback mechanism.
- Loading branch information
Valentin Obst
committed
Feb 18, 2024
1 parent
1ecb4d6
commit 8399bea
Showing
3 changed files
with
285 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
//! TCP CUBIC congestion control algorithm. | ||
#![allow(dead_code)] | ||
#![allow(non_snake_case)] | ||
#![allow(unused_variables)] | ||
|
||
use core::cmp::max; | ||
use core::num::NonZeroU32; | ||
use hystart::HystartDetect; | ||
use kernel::net::tcp; | ||
use kernel::net::tcp::cong::{self, hystart}; | ||
use kernel::prelude::*; | ||
use kernel::time; | ||
use kernel::{c_str, module_cca}; | ||
|
||
const BICTCP_BETA_SCALE: u32 = 1024; | ||
|
||
// TODO: Convert to module parameters once they are available. Currently these | ||
// are the defaults from the C implementation. | ||
// TODO: Use NonZeroU32 where appropriate. | ||
/// Whether to use fast-convergence. | ||
const FAST_CONVERGENCE: bool = true; | ||
/// The factor for multiplicative decrease of cwnd upon a loss event. Will be | ||
/// divided by `BICTCP_BETA_SCALE`, approximately 0.7. | ||
const BETA: u32 = 717; | ||
/// The initial value of ssthresh for new connections. Setting this to `None` | ||
/// implies `i32::MAX`. | ||
const INITIAL_SSTHRESH: Option<u32> = None; | ||
/// TODO | ||
const BIC_SCALE: u32 = 41; | ||
/// TODO | ||
const TCP_FRIENDLINESS: bool = true; | ||
/// Whether to use the HyStart slow start algorithm. | ||
const HYSTART: bool = true; | ||
|
||
impl hystart::HyStart for Cubic { | ||
/// Which mechanism to use for deciding when it is time to exit slow start. | ||
const DETECT: HystartDetect = HystartDetect::Both; | ||
/// Lower bound for cwnd during hybrid slow start. | ||
const LOW_WINDOW: u32 = 16; | ||
/// Spacing between ACKs indicating an ACK-train. | ||
/// (Dimension: time. Unit: microseconds). | ||
const ACK_DELTA: time::Usecs32 = 2000; | ||
} | ||
|
||
// TODO: Those are computed based on the module parameters in the init. Even | ||
// with module parameters available this will be a bit tricky to do in Rust. | ||
/// Factor of `8/3 * (1 + beta) / (1 - beta)` that is used in various | ||
/// calculations. (Dimension: none) | ||
const BETA_SCALE: u32 = ((8 * (BICTCP_BETA_SCALE + BETA)) / 3) / (BICTCP_BETA_SCALE - BETA); | ||
/// TODO | ||
const CUBE_RTT_SCALE: u32 = BIC_SCALE * 10; | ||
/// TODO | ||
const CUBE_FACTOR: u64 = (1u64 << 40) / (CUBE_RTT_SCALE as u64); | ||
|
||
module_cca! { | ||
type: Cubic, | ||
name: "tcp_cubic_rust", | ||
author: "Rust for Linux Contributors", | ||
description: "TCP CUBIC congestion control algorithm, Rust implementation", | ||
license: "GPL v2", | ||
} | ||
|
||
struct Cubic {} | ||
|
||
#[vtable] | ||
impl cong::Algorithm for Cubic { | ||
type Data = CubicState; | ||
|
||
const NAME: &'static CStr = c_str!("bic_rust"); | ||
|
||
fn init(sk: &mut cong::Sock<'_, Self>) { | ||
if HYSTART { | ||
<Self as hystart::HyStart>::reset(sk) | ||
} else if let Some(ssthresh) = INITIAL_SSTHRESH { | ||
sk.tcp_sk_mut().set_snd_ssthresh(ssthresh); | ||
} | ||
|
||
// TODO: remove | ||
pr_info!("Socket created: start {}", sk.inet_csk_ca().start_time); | ||
} | ||
|
||
// TODO: remove | ||
fn release(sk: &mut cong::Sock<'_, Self>) { | ||
pr_info!( | ||
"Socket destroyed: start {}, end {}", | ||
sk.inet_csk_ca().start_time, | ||
(time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, | ||
); | ||
} | ||
|
||
fn cwnd_event(sk: &mut cong::Sock<'_, Self>, ev: cong::Event) { | ||
if matches!(ev, cong::Event::TxStart) { | ||
// Here we cannot avoid jiffies as the `lsndtime` field is measured | ||
// in jiffies. | ||
let now = time::jiffies32(); | ||
let delta: time::Jiffies32 = now.wrapping_sub(sk.tcp_sk().lsndtime()); | ||
|
||
if (delta as i32) <= 0 { | ||
return; | ||
} | ||
|
||
let ca = sk.inet_csk_ca_mut(); | ||
// Ok, lets switch to SI time units. | ||
let now = time::jiffies_to_msecs(now as time::Jiffies); | ||
let delta = time::jiffies_to_msecs(delta as time::Jiffies); | ||
if ca.epoch_start != 0 { | ||
ca.epoch_start += delta; | ||
if tcp::after(ca.epoch_start, now) { | ||
ca.epoch_start = now; | ||
} | ||
}; | ||
} | ||
} | ||
|
||
fn set_state(sk: &mut cong::Sock<'_, Self>, new_state: cong::State) { | ||
if matches!(new_state, cong::State::Loss) { | ||
pr_info!( | ||
// TODO: remove | ||
"Retransmission timeout fired: time {}, start {}", | ||
(time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, | ||
sk.inet_csk_ca().start_time | ||
); | ||
sk.inet_csk_ca_mut().reset(); | ||
<Self as hystart::HyStart>::reset(sk); | ||
} | ||
} | ||
|
||
fn pkts_acked(sk: &mut cong::Sock<'_, Self>, sample: &cong::AckSample) { | ||
todo!() | ||
} | ||
|
||
fn ssthresh(sk: &mut cong::Sock<'_, Self>) -> u32 { | ||
let cwnd = sk.tcp_sk().snd_cwnd(); | ||
let ca = sk.inet_csk_ca_mut(); | ||
|
||
pr_info!( | ||
// TODO: remove | ||
"Enter fast retransmit: time {}, start {}", | ||
(time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, | ||
ca.start_time | ||
); | ||
|
||
// Epoch has ended. | ||
ca.epoch_start = 0; | ||
ca.last_max_cwnd = if cwnd < ca.last_max_cwnd && FAST_CONVERGENCE { | ||
(cwnd * (BETA_SCALE + BETA)) / (2 * BETA_SCALE) | ||
} else { | ||
cwnd | ||
}; | ||
|
||
max((cwnd * BETA) / BETA_SCALE, 2) | ||
} | ||
|
||
fn undo_cwnd(sk: &mut cong::Sock<'_, Self>) -> u32 { | ||
pr_info!( | ||
// TODO: remove | ||
"Undo cwnd reduction: time {}, start {}", | ||
(time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, | ||
sk.inet_csk_ca().start_time | ||
); | ||
|
||
cong::reno::undo_cwnd(sk) | ||
} | ||
|
||
fn cong_avoid(sk: &mut cong::Sock<'_, Self>, _ack: u32, mut acked: u32) { | ||
if !sk.tcp_is_cwnd_limited() { | ||
return; | ||
} | ||
|
||
let tp = sk.tcp_sk_mut(); | ||
|
||
if tp.in_slow_start() { | ||
acked = tp.slow_start(acked); | ||
if acked == 0 { | ||
pr_info!( | ||
// TODO: remove | ||
"New cwnd {}, time {}, ssthresh {}, start {}, ss 1", | ||
sk.tcp_sk().snd_cwnd(), | ||
(time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, | ||
sk.tcp_sk().snd_ssthresh(), | ||
sk.inet_csk_ca().start_time | ||
); | ||
return; | ||
} | ||
} | ||
|
||
let cwnd = tp.snd_cwnd(); | ||
let cnt = sk.inet_csk_ca_mut().update(cwnd, acked); | ||
sk.tcp_sk_mut().cong_avoid_ai(cnt, acked); | ||
|
||
pr_info!( | ||
// TODO: remove | ||
"New cwnd {}, time {}, ssthresh {}, start {}, ss 0", | ||
sk.tcp_sk().snd_cwnd(), | ||
(time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, | ||
sk.tcp_sk().snd_ssthresh(), | ||
sk.inet_csk_ca().start_time | ||
); | ||
} | ||
} | ||
|
||
#[allow(non_snake_case)] | ||
struct CubicState { | ||
/// Increase the cwnd by one step after `cnt` ACKs. | ||
cnt: NonZeroU32, | ||
/// W__last_max | ||
last_max_cwnd: u32, | ||
last_cwnd: u32, | ||
/// Time when `last_cwnd` was updated. | ||
last_time: time::Msecs32, | ||
origin_point: u32, | ||
K: time::Msecs32, | ||
/// Time when the current epoch has started. | ||
epoch_start: time::Msecs32, | ||
ack_cnt: u32, | ||
/// Estimate for the cwnd of TCP Reno. | ||
tcp_cwnd: u32, | ||
/// State of the HyStart slow start algorithm. | ||
hystart_state: hystart::HyStartState, | ||
/// Time when the connection was created. | ||
// TODO: remove | ||
start_time: time::Usecs32, | ||
} | ||
|
||
impl hystart::HasHyStartState for CubicState { | ||
fn hy(&self) -> &hystart::HyStartState { | ||
&self.hystart_state | ||
} | ||
|
||
fn hy_mut(&mut self) -> &mut hystart::HyStartState { | ||
&mut self.hystart_state | ||
} | ||
} | ||
|
||
impl Default for CubicState { | ||
fn default() -> Self { | ||
Self { | ||
// NOTE: Initializing this to 1 deviates from the C code. It does | ||
// not change the behavior. | ||
cnt: NonZeroU32::MIN, | ||
last_max_cwnd: 0, | ||
last_cwnd: 0, | ||
last_time: 0, | ||
origin_point: 0, | ||
K: 0, | ||
epoch_start: 0, | ||
ack_cnt: 0, | ||
tcp_cwnd: 0, | ||
hystart_state: hystart::HyStartState::default(), | ||
// TODO: remove | ||
start_time: (time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, | ||
} | ||
} | ||
} | ||
|
||
impl CubicState { | ||
fn update(&mut self, cwnd: u32, acked: u32) -> NonZeroU32 { | ||
todo!() | ||
} | ||
|
||
fn reset(&mut self) { | ||
// TODO: remove | ||
let tmp = self.start_time; | ||
|
||
*self = Self::default(); | ||
|
||
// TODO: remove | ||
self.start_time = tmp; | ||
} | ||
} |