From 6275da963deef172dc0e96f528a0b37f99e7bcce Mon Sep 17 00:00:00 2001 From: will-maclean <41996719+will-maclean@users.noreply.github.com> Date: Sun, 23 Jun 2024 23:06:03 +1000 Subject: [PATCH] testing cov (#43) adding coverage CI and tests --- .github/{workflows => _deactivated}/rust.yml | 0 .github/workflows/cov.yml | 23 +++ examples/dqn_gridworld.rs | 5 +- src/common/spaces.rs | 39 +++++ src/common/to_tensor.rs | 75 ++++++++++ src/common/utils/mod.rs | 17 ++- src/common/utils/module_update.rs | 1 - src/dqn/mod.rs | 142 ++++++++----------- src/dqn/module.rs | 62 +++++++- src/env/probe.rs | 62 ++++++++ src/env/wrappers.rs | 19 ++- 11 files changed, 349 insertions(+), 96 deletions(-) rename .github/{workflows => _deactivated}/rust.yml (100%) create mode 100644 .github/workflows/cov.yml diff --git a/.github/workflows/rust.yml b/.github/_deactivated/rust.yml similarity index 100% rename from .github/workflows/rust.yml rename to .github/_deactivated/rust.yml diff --git a/.github/workflows/cov.yml b/.github/workflows/cov.yml new file mode 100644 index 0000000..d0838b2 --- /dev/null +++ b/.github/workflows/cov.yml @@ -0,0 +1,23 @@ +name: Coverage + +on: [pull_request, push] + +jobs: + coverage: + runs-on: ubuntu-latest + env: + CARGO_TERM_COLOR: always + steps: + - uses: actions/checkout@v4 + - name: Install Rust + run: rustup update stable + - name: Install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov + - name: Generate code coverage + run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: lcov.info + fail_ci_if_error: true \ No newline at end of file diff --git a/examples/dqn_gridworld.rs b/examples/dqn_gridworld.rs index 6a5bbac..4a9cba7 100644 --- a/examples/dqn_gridworld.rs +++ b/examples/dqn_gridworld.rs @@ -55,10 +55,7 @@ fn main() { let buffer = ReplayBuffer::new(offline_params.memory_size); - let logger = CsvLogger::new( - PathBuf::from("logs/dqn_gridworld/dqn_gridworld.csv"), - false, - ); + let logger = CsvLogger::new(PathBuf::from("logs/dqn_gridworld/dqn_gridworld.csv"), false); match logger.check_can_log(false) { Ok(_) => {} diff --git a/src/common/spaces.rs b/src/common/spaces.rs index 7a8d4c1..657c5e1 100644 --- a/src/common/spaces.rs +++ b/src/common/spaces.rs @@ -1,3 +1,4 @@ +use burn::tensor::{backend::Backend, Distribution, Tensor}; use dyn_clone::DynClone; use rand::{rngs::StdRng, Rng, SeedableRng}; @@ -118,6 +119,44 @@ impl Space> for BoxSpace> { } } +impl From<(Tensor, Tensor)> for BoxSpace> { + fn from(value: (Tensor, Tensor)) -> Self { + Self { + low: value.0, + high: value.1, + rng: StdRng::from_entropy(), + } + } +} + +impl Space> for BoxSpace> { + fn contains(&self, sample: &Tensor) -> bool { + if sample.shape() != self.low.shape() { + return false; + } + + sample.clone().greater_equal(self.low.clone()).all().into_scalar() & + sample.clone().lower_equal(self.low.clone()).all().into_scalar() + } + + fn sample(&mut self) -> Tensor { + let shape = self.low.shape(); + let sample: Tensor = Tensor::random(shape, Distribution::Uniform(0.0, 1.0), &self.low.device()); + let range = self.high.clone().sub(self.low.clone()); + let sample = sample.mul(range).add(self.low.clone()); + + sample + } + + fn seed(&mut self, seed: [u8; 32]) { + self.rng = StdRng::from_seed(seed); + } + + fn shape(&self) -> Tensor { + self.low.clone() + } +} + #[cfg(test)] mod test { use crate::common::spaces::{BoxSpace, Discrete, Space}; diff --git a/src/common/to_tensor.rs b/src/common/to_tensor.rs index 6d258e0..a08775e 100644 --- a/src/common/to_tensor.rs +++ b/src/common/to_tensor.rs @@ -68,3 +68,78 @@ impl ToTensorB<1> for Vec { Tensor::::from_data(Data::new(self, Shape::new([n])).convert(), device).bool() } } + +#[cfg(test)] +mod test { + use burn::{backend::NdArray, tensor::{Bool, Int, Tensor}}; + + use crate::common::to_tensor::{ToTensorB, ToTensorI}; + + use super::ToTensorF; + + #[test] + fn test_to_tensor_f32(){ + let d: f32 = 1.1; + let t: Tensor = d.to_tensor(&Default::default()); + + assert_eq!(t.shape().dims.len(), 1); + assert_eq!(t.shape().dims, [1]); + assert_eq!(t.into_scalar(), d); + } + + #[test] + fn test_to_tensor_vec_f32(){ + let d: Vec = vec![1.1, 2.2]; + let t: Tensor = d.to_tensor(&Default::default()); + + assert_eq!(t.shape().dims.len(), 1); + assert_eq!(t.shape().dims, [2]); + } + + #[test] + fn test_to_tensor_vec_vec_f32(){ + let d: Vec> = vec![vec![1.1, 2.2], vec![3.3, 4.4], vec![1.0, 0.0]]; + let t: Tensor = d.to_tensor(&Default::default()); + + assert_eq!(t.shape().dims.len(), 2); + assert_eq!(t.shape().dims, [3, 2]); + } + + #[test] + fn test_to_tensor_usize(){ + let d: usize = 1; + let t: Tensor = d.to_tensor(&Default::default()); + + assert_eq!(t.shape().dims.len(), 1); + assert_eq!(t.shape().dims, [1]); + assert_eq!(t.into_scalar() as usize, d); + } + + #[test] + fn test_to_tensor_vec_usize(){ + let d: Vec = vec![1, 2]; + let t: Tensor = d.to_tensor(&Default::default()); + + assert_eq!(t.shape().dims.len(), 1); + assert_eq!(t.shape().dims, [2]); + } + + #[test] + fn test_to_tensor_bool(){ + let d: bool = true; + let t: Tensor = d.to_tensor(&Default::default()); + + assert_eq!(t.shape().dims.len(), 1); + assert_eq!(t.shape().dims, [1]); + assert_eq!(t.into_scalar(), d); + } + + #[test] + fn test_to_tensor_vec_bool(){ + let d: Vec = vec![false, true]; + let t: Tensor = d.to_tensor(&Default::default()); + + assert_eq!(t.shape().dims.len(), 1); + assert_eq!(t.shape().dims, [2]); + } +} \ No newline at end of file diff --git a/src/common/utils/mod.rs b/src/common/utils/mod.rs index 38cb2d6..7732e07 100644 --- a/src/common/utils/mod.rs +++ b/src/common/utils/mod.rs @@ -50,9 +50,12 @@ pub fn vec_usize_to_one_hot( mod test { use assert_approx_eq::assert_approx_eq; - use burn::{backend::{ndarray::NdArrayDevice, NdArray}, tensor::Tensor}; + use burn::{ + backend::{ndarray::NdArrayDevice, NdArray}, + tensor::Tensor, + }; - use crate::common::utils::{mean, linear_decay, generate_random_vector, vec_usize_to_one_hot}; + use crate::common::utils::{generate_random_vector, linear_decay, mean, vec_usize_to_one_hot}; #[test] fn test_mean() { @@ -66,29 +69,29 @@ mod test { } #[test] - fn test_linear_decay(){ + fn test_linear_decay() { assert_approx_eq!(linear_decay(0.0, 1.0, 0.01, 0.8), 1.0, 1e-3f32); assert_approx_eq!(linear_decay(0.8, 1.0, 0.01, 0.8), 0.01, 1e-3f32); assert_approx_eq!(linear_decay(1.0, 1.0, 0.01, 0.8), 0.01, 1e-3f32); } #[test] - fn test_gen_rand_vec(){ + fn test_gen_rand_vec() { let sample = generate_random_vector(vec![0.0, 0.0, 0.0], vec![1.0, 1.0, 1.0]); - for s in sample{ + for s in sample { assert!((s >= 0.0) & (s <= 1.0)); } } #[should_panic] #[test] - fn test_gen_rand_vec_bad(){ + fn test_gen_rand_vec_bad() { generate_random_vector(vec![1.0], vec![0.0]); } #[test] - fn test_usize_to_one_hot(){ + fn test_usize_to_one_hot() { let ins = vec![0, 1, 2]; let classes = 4; diff --git a/src/common/utils/module_update.rs b/src/common/utils/module_update.rs index 7be81e4..b990d43 100644 --- a/src/common/utils/module_update.rs +++ b/src/common/utils/module_update.rs @@ -84,7 +84,6 @@ mod test { } impl Policy for LinearPolicy { - fn update(&mut self, from: &Self, tau: Option) { self.layer = update_linear(&from.layer, self.layer.clone(), tau); } diff --git a/src/dqn/mod.rs b/src/dqn/mod.rs index 60cdec8..ef57ede 100644 --- a/src/dqn/mod.rs +++ b/src/dqn/mod.rs @@ -184,82 +184,66 @@ where } } -// #[cfg(test)] -// mod test { -// use std::path::PathBuf; - -// use burn::{ -// backend::{Autodiff, NdArray}, -// optim::{Adam, AdamConfig}, -// tensor::backend::AutodiffBackend, -// }; - -// use crate::{ -// algorithm::{OfflineAlgParams, OfflineAlgorithm}, -// buffer::ReplayBuffer, -// dqn::{DQNAgent, DQNConfig, DQNNet}, -// env::{base::Env, gridworld::GridWorldEnv}, -// eval::EvalConfig, -// logger::CsvLogger, -// }; - -// use super::OfflineTrainer; - -// #[test] -// fn test_dqn_lightweight() { -// type TrainingBacked = Autodiff; -// let device = Default::default(); -// let config_optimizer = AdamConfig::new(); -// let optim = config_optimizer.init(); -// let offline_params = OfflineAlgParams::new() -// .with_n_steps(10) -// .with_batch_size(2) -// .with_memory_size(5) -// .with_warmup_steps(2); -// let env = GridWorldEnv::default(); -// let q = DQNNet::::init( -// &device, -// env.observation_space().clone(), -// env.action_space().clone(), -// 2, -// ); -// let agent = DQNAgent::new( -// q.clone(), -// q, -// optim, -// DQNConfig::new(), -// env.observation_space(), -// env.action_space() -// ); -// let dqn_alg = OfflineAlgorithm::DQN(agent); -// let buffer = ReplayBuffer::new(offline_params.memory_size); - -// // create the logs dir -// let mut log_dir = std::env::current_dir().unwrap(); -// log_dir.push("tmp_logs"); -// let _ = std::fs::create_dir(&log_dir); - -// let logger = CsvLogger::new( -// PathBuf::from("tmp_logs/log.csv"), -// true, -// Some("global_step".to_string()), -// ); - -// let mut trainer = OfflineTrainer::new( -// offline_params, -// Box::new(env), -// Box::::default(), -// dqn_alg, -// buffer, -// Box::new(logger), -// None, -// EvalConfig::new(), -// &device, -// &device, -// ); - -// trainer.train(); - -// let _ = std::fs::remove_dir_all(log_dir); -// } -// } +#[cfg(test)] +mod test { + use std::path::PathBuf; + + use burn::{backend::{Autodiff, NdArray}, optim::{Adam, AdamConfig}}; + + use crate::{common::{algorithm::{OfflineAlgParams, OfflineTrainer}, buffer::ReplayBuffer, eval::EvalConfig, logger::CsvLogger}, dqn::{module::LinearDQNNet, DQNAgent, DQNConfig}, env::{base::Env, gridworld::GridWorldEnv}}; + + #[test] + fn test_dqn_lightweight() { + type TrainingBacked = Autodiff; + let device = Default::default(); + let config_optimizer = AdamConfig::new(); + let optim = config_optimizer.init(); + let offline_params = OfflineAlgParams::new() + .with_n_steps(10) + .with_batch_size(2) + .with_memory_size(5) + .with_warmup_steps(2); + let env = GridWorldEnv::default(); + let q = LinearDQNNet::::init( + &device, + env.observation_space().shape().len(), + env.action_space().shape(), + 2, + ); + let agent = DQNAgent::new( + q.clone(), + q, + optim, + DQNConfig::new(), + env.observation_space(), + env.action_space() + ); + let buffer = ReplayBuffer::new(offline_params.memory_size); + + // create the logs dir + let mut log_dir = std::env::current_dir().unwrap(); + log_dir.push("tmp_logs"); + let _ = std::fs::create_dir(&log_dir); + + let logger = CsvLogger::new( + PathBuf::from("tmp_logs/log.csv"), + false, + ); + + let mut trainer: OfflineTrainer<_, Adam, _, _, _> = OfflineTrainer::new( + offline_params, + Box::new(env), + Box::::default(), + agent, + buffer, + Box::new(logger), + None, + EvalConfig::new(), + &device + ); + + trainer.train(); + + let _ = std::fs::remove_dir_all(log_dir); + } +} diff --git a/src/dqn/module.rs b/src/dqn/module.rs index 7a89c50..30c684c 100644 --- a/src/dqn/module.rs +++ b/src/dqn/module.rs @@ -169,10 +169,12 @@ impl ConvDQNNet { act_size: usize, hidden_size: usize, ) -> Self { + //TODO: take in channels, calculate linear input size + Self { - c1: nn::conv::Conv2dConfig::new([obs_shape.shape().dims[0], 4], [3, 3]).init(device), - c2: nn::conv::Conv2dConfig::new([4, 8], [3, 3]).init(device), - l1: nn::LinearConfig::new(64, hidden_size).init(device), + c1: nn::conv::Conv2dConfig::new([obs_shape.shape().dims[0], 2], [3, 3]).init(device), + c2: nn::conv::Conv2dConfig::new([2, 3], [3, 3]).init(device), + l1: nn::LinearConfig::new(432, hidden_size).init(device), l2: nn::LinearConfig::new(hidden_size, act_size).init(device), } } @@ -205,7 +207,11 @@ impl Policy for ConvDQNNet { #[cfg(test)] mod test { - use burn::{backend::NdArray, tensor::Tensor}; + use burn::{backend::NdArray, tensor::{Shape, Tensor}}; + + use crate::common::spaces::{BoxSpace, Discrete, Space}; + + use super::{ConvDQNNet, DQNNet, LinearAdvDQNNet, LinearDQNNet}; #[test] fn test_broadcast_sanity() { @@ -225,4 +231,52 @@ mod test { assert_eq!(c, vec![-1.0, 0.0, 1.0]); } + + #[test] + fn test_linear_net_usize_usize(){ + let mut obs_space = Discrete::from(3); + + let dqn: LinearDQNNet = LinearDQNNet::init(&Default::default(), obs_space.shape(), 1, 2); + + dqn.forward(vec![obs_space.sample()], Box::new(obs_space), &Default::default()); + } + + #[test] + fn test_linear_net_vecf32_usize(){ + let mut obs_space = BoxSpace::from((vec![0.0, 1.0], vec![0.5, 8.1])); + + let dqn: LinearDQNNet = LinearDQNNet::init(&Default::default(), obs_space.shape().len(), 1, 2); + + dqn.forward(vec![obs_space.sample()], Box::new(obs_space), &Default::default()); + } + + #[test] + fn test_linear_adv_net_usize_usize(){ + let mut obs_space = Discrete::from(3); + + let dqn: LinearAdvDQNNet = LinearAdvDQNNet::init(&Default::default(), obs_space.shape(), 1, 2); + + dqn.forward(vec![obs_space.sample()], Box::new(obs_space), &Default::default()); + } + + #[test] + fn test_linear_adv_net_vecf32_usize(){ + let mut obs_space = BoxSpace::from((vec![0.0, 1.0], vec![0.5, 8.1])); + + let dqn: LinearAdvDQNNet = LinearAdvDQNNet::init(&Default::default(), obs_space.shape().len(), 1, 2); + + dqn.forward(vec![obs_space.sample()], Box::new(obs_space), &Default::default()); + } + + #[test] + fn test_conv_usize(){ + let shape = Shape::new([3, 16, 16]); + let low: Tensor = Tensor::zeros(shape.clone(), &Default::default()); + let high: Tensor = Tensor::zeros(shape, &Default::default()); + let mut obs_space = BoxSpace::from((low, high)); + + let dqn: ConvDQNNet = ConvDQNNet::init(&Default::default(), obs_space.shape(), 1, 2); + + dqn.forward(vec![obs_space.sample()], Box::new(obs_space), &Default::default()); + } } diff --git a/src/env/probe.rs b/src/env/probe.rs index ab37876..cea384c 100644 --- a/src/env/probe.rs +++ b/src/env/probe.rs @@ -316,3 +316,65 @@ impl Env for ProbeEnvStateActionTest { self } } + +#[cfg(test)] +mod test { + use crate::env::base::Env; + + use super::{ProbeEnvActionTest, ProbeEnvBackpropTest, ProbeEnvDiscountingTest, ProbeEnvStateActionTest, ProbeEnvValueTest}; + + #[test] + fn test_probe_env_value_test(){ + let mut env = ProbeEnvValueTest::default(); + + let mut done = false; + while !done { + let res = env.step(&env.action_space().sample()); + done = res.truncated | res.terminated; + } + } + + #[test] + fn test_probe_env_backprop_test(){ + let mut env = ProbeEnvBackpropTest::default(); + + let mut done = false; + while !done { + let res = env.step(&env.action_space().sample()); + done = res.truncated | res.terminated; + } + } + + #[test] + fn test_probe_env_action_test(){ + let mut env = ProbeEnvActionTest::default(); + + let mut done = false; + while !done { + let res = env.step(&env.action_space().sample()); + done = res.truncated | res.terminated; + } + } + + #[test] + fn test_probe_env_state_action_test(){ + let mut env = ProbeEnvStateActionTest::default(); + + let mut done = false; + while !done { + let res = env.step(&env.action_space().sample()); + done = res.truncated | res.terminated; + } + } + + #[test] + fn test_probe_env_discounting_test(){ + let mut env = ProbeEnvDiscountingTest::default(); + + let mut done = false; + while !done { + let res = env.step(&env.action_space().sample()); + done = res.truncated | res.terminated; + } + } +} \ No newline at end of file diff --git a/src/env/wrappers.rs b/src/env/wrappers.rs index 4c01cba..46a76c0 100644 --- a/src/env/wrappers.rs +++ b/src/env/wrappers.rs @@ -144,7 +144,7 @@ impl Env for AutoResetWrapper { mod test { use crate::env::{base::Env, classic_control::cartpole::CartpoleEnv}; - use super::TimeLimitWrapper; + use super::{AutoResetWrapper, TimeLimitWrapper}; #[test] fn test_time_limit_wrapper() { @@ -164,4 +164,21 @@ mod test { assert_eq!(truncate_steps, ep_len); } + + #[test] + fn test_auto_reset(){ + let env = CartpoleEnv::default(); + let mut env = AutoResetWrapper::new(Box::new(env)); + + let mut done = false; + env.reset(None, None); + + while !done { + let res = env.step(&env.action_space().sample()); + done = res.terminated | res.truncated; + } + + // wouldn't be able to do this without the wrapper + env.step(&env.action_space().sample()); + } }