Skip to content

Commit

Permalink
Merge branch 'main' of github.com:will-maclean/sb3-burn into main
Browse files Browse the repository at this point in the history
  • Loading branch information
will-maclean committed Jun 23, 2024
2 parents 01b768a + 6275da9 commit 51f25e9
Show file tree
Hide file tree
Showing 11 changed files with 349 additions and 96 deletions.
File renamed without changes.
23 changes: 23 additions & 0 deletions .github/workflows/cov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Coverage

on: [pull_request, push]

jobs:
coverage:
runs-on: ubuntu-latest
env:
CARGO_TERM_COLOR: always
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: rustup update stable
- name: Install cargo-llvm-cov
uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: lcov.info
fail_ci_if_error: true
5 changes: 1 addition & 4 deletions examples/dqn_gridworld.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ fn main() {

let buffer = ReplayBuffer::new(offline_params.memory_size);

let logger = CsvLogger::new(
PathBuf::from("logs/dqn_gridworld/dqn_gridworld.csv"),
false,
);
let logger = CsvLogger::new(PathBuf::from("logs/dqn_gridworld/dqn_gridworld.csv"), false);

match logger.check_can_log(false) {
Ok(_) => {}
Expand Down
39 changes: 39 additions & 0 deletions src/common/spaces.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use burn::tensor::{backend::Backend, Distribution, Tensor};
use dyn_clone::DynClone;
use rand::{rngs::StdRng, Rng, SeedableRng};

Expand Down Expand Up @@ -118,6 +119,44 @@ impl Space<Vec<f32>> for BoxSpace<Vec<f32>> {
}
}

impl<B: Backend, const D: usize> From<(Tensor<B, D>, Tensor<B, D>)> for BoxSpace<Tensor<B, D>> {
fn from(value: (Tensor<B, D>, Tensor<B, D>)) -> Self {
Self {
low: value.0,
high: value.1,
rng: StdRng::from_entropy(),
}
}
}

impl<B: Backend, const D: usize> Space<Tensor<B, D>> for BoxSpace<Tensor<B, D>> {
fn contains(&self, sample: &Tensor<B, D>) -> bool {
if sample.shape() != self.low.shape() {
return false;
}

sample.clone().greater_equal(self.low.clone()).all().into_scalar() &
sample.clone().lower_equal(self.low.clone()).all().into_scalar()
}

fn sample(&mut self) -> Tensor<B, D> {
let shape = self.low.shape();
let sample: Tensor<B, D> = Tensor::random(shape, Distribution::Uniform(0.0, 1.0), &self.low.device());
let range = self.high.clone().sub(self.low.clone());
let sample = sample.mul(range).add(self.low.clone());

sample
}

fn seed(&mut self, seed: [u8; 32]) {
self.rng = StdRng::from_seed(seed);
}

fn shape(&self) -> Tensor<B, D> {
self.low.clone()
}
}

#[cfg(test)]
mod test {
use crate::common::spaces::{BoxSpace, Discrete, Space};
Expand Down
75 changes: 75 additions & 0 deletions src/common/to_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,78 @@ impl ToTensorB<1> for Vec<bool> {
Tensor::<B, 1, Int>::from_data(Data::new(self, Shape::new([n])).convert(), device).bool()
}
}

#[cfg(test)]
mod test {
use burn::{backend::NdArray, tensor::{Bool, Int, Tensor}};

use crate::common::to_tensor::{ToTensorB, ToTensorI};

use super::ToTensorF;

#[test]
fn test_to_tensor_f32(){
let d: f32 = 1.1;
let t: Tensor<NdArray, 1> = d.to_tensor(&Default::default());

assert_eq!(t.shape().dims.len(), 1);
assert_eq!(t.shape().dims, [1]);
assert_eq!(t.into_scalar(), d);
}

#[test]
fn test_to_tensor_vec_f32(){
let d: Vec<f32> = vec![1.1, 2.2];
let t: Tensor<NdArray, 1> = d.to_tensor(&Default::default());

assert_eq!(t.shape().dims.len(), 1);
assert_eq!(t.shape().dims, [2]);
}

#[test]
fn test_to_tensor_vec_vec_f32(){
let d: Vec<Vec<f32>> = vec![vec![1.1, 2.2], vec![3.3, 4.4], vec![1.0, 0.0]];
let t: Tensor<NdArray, 2> = d.to_tensor(&Default::default());

assert_eq!(t.shape().dims.len(), 2);
assert_eq!(t.shape().dims, [3, 2]);
}

#[test]
fn test_to_tensor_usize(){
let d: usize = 1;
let t: Tensor<NdArray, 1, Int> = d.to_tensor(&Default::default());

assert_eq!(t.shape().dims.len(), 1);
assert_eq!(t.shape().dims, [1]);
assert_eq!(t.into_scalar() as usize, d);
}

#[test]
fn test_to_tensor_vec_usize(){
let d: Vec<usize> = vec![1, 2];
let t: Tensor<NdArray, 1, Int> = d.to_tensor(&Default::default());

assert_eq!(t.shape().dims.len(), 1);
assert_eq!(t.shape().dims, [2]);
}

#[test]
fn test_to_tensor_bool(){
let d: bool = true;
let t: Tensor<NdArray, 1, Bool> = d.to_tensor(&Default::default());

assert_eq!(t.shape().dims.len(), 1);
assert_eq!(t.shape().dims, [1]);
assert_eq!(t.into_scalar(), d);
}

#[test]
fn test_to_tensor_vec_bool(){
let d: Vec<bool> = vec![false, true];
let t: Tensor<NdArray, 1, Bool> = d.to_tensor(&Default::default());

assert_eq!(t.shape().dims.len(), 1);
assert_eq!(t.shape().dims, [2]);
}
}
17 changes: 10 additions & 7 deletions src/common/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ pub fn vec_usize_to_one_hot<B: Backend>(
mod test {
use assert_approx_eq::assert_approx_eq;

use burn::{backend::{ndarray::NdArrayDevice, NdArray}, tensor::Tensor};
use burn::{
backend::{ndarray::NdArrayDevice, NdArray},
tensor::Tensor,
};

use crate::common::utils::{mean, linear_decay, generate_random_vector, vec_usize_to_one_hot};
use crate::common::utils::{generate_random_vector, linear_decay, mean, vec_usize_to_one_hot};

#[test]
fn test_mean() {
Expand All @@ -66,29 +69,29 @@ mod test {
}

#[test]
fn test_linear_decay(){
fn test_linear_decay() {
assert_approx_eq!(linear_decay(0.0, 1.0, 0.01, 0.8), 1.0, 1e-3f32);
assert_approx_eq!(linear_decay(0.8, 1.0, 0.01, 0.8), 0.01, 1e-3f32);
assert_approx_eq!(linear_decay(1.0, 1.0, 0.01, 0.8), 0.01, 1e-3f32);
}

#[test]
fn test_gen_rand_vec(){
fn test_gen_rand_vec() {
let sample = generate_random_vector(vec![0.0, 0.0, 0.0], vec![1.0, 1.0, 1.0]);

for s in sample{
for s in sample {
assert!((s >= 0.0) & (s <= 1.0));
}
}

#[should_panic]
#[test]
fn test_gen_rand_vec_bad(){
fn test_gen_rand_vec_bad() {
generate_random_vector(vec![1.0], vec![0.0]);
}

#[test]
fn test_usize_to_one_hot(){
fn test_usize_to_one_hot() {
let ins = vec![0, 1, 2];
let classes = 4;

Expand Down
1 change: 0 additions & 1 deletion src/common/utils/module_update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ mod test {
}

impl<B: Backend> Policy<B> for LinearPolicy<B> {

fn update(&mut self, from: &Self, tau: Option<f32>) {
self.layer = update_linear(&from.layer, self.layer.clone(), tau);
}
Expand Down
142 changes: 63 additions & 79 deletions src/dqn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,82 +184,66 @@ where
}
}

// #[cfg(test)]
// mod test {
// use std::path::PathBuf;

// use burn::{
// backend::{Autodiff, NdArray},
// optim::{Adam, AdamConfig},
// tensor::backend::AutodiffBackend,
// };

// use crate::{
// algorithm::{OfflineAlgParams, OfflineAlgorithm},
// buffer::ReplayBuffer,
// dqn::{DQNAgent, DQNConfig, DQNNet},
// env::{base::Env, gridworld::GridWorldEnv},
// eval::EvalConfig,
// logger::CsvLogger,
// };

// use super::OfflineTrainer;

// #[test]
// fn test_dqn_lightweight() {
// type TrainingBacked = Autodiff<NdArray>;
// let device = Default::default();
// let config_optimizer = AdamConfig::new();
// let optim = config_optimizer.init();
// let offline_params = OfflineAlgParams::new()
// .with_n_steps(10)
// .with_batch_size(2)
// .with_memory_size(5)
// .with_warmup_steps(2);
// let env = GridWorldEnv::default();
// let q = DQNNet::<TrainingBacked>::init(
// &device,
// env.observation_space().clone(),
// env.action_space().clone(),
// 2,
// );
// let agent = DQNAgent::new(
// q.clone(),
// q,
// optim,
// DQNConfig::new(),
// env.observation_space(),
// env.action_space()
// );
// let dqn_alg = OfflineAlgorithm::DQN(agent);
// let buffer = ReplayBuffer::new(offline_params.memory_size);

// // create the logs dir
// let mut log_dir = std::env::current_dir().unwrap();
// log_dir.push("tmp_logs");
// let _ = std::fs::create_dir(&log_dir);

// let logger = CsvLogger::new(
// PathBuf::from("tmp_logs/log.csv"),
// true,
// Some("global_step".to_string()),
// );

// let mut trainer = OfflineTrainer::new(
// offline_params,
// Box::new(env),
// Box::<GridWorldEnv>::default(),
// dqn_alg,
// buffer,
// Box::new(logger),
// None,
// EvalConfig::new(),
// &device,
// &device,
// );

// trainer.train();

// let _ = std::fs::remove_dir_all(log_dir);
// }
// }
#[cfg(test)]
mod test {
use std::path::PathBuf;

use burn::{backend::{Autodiff, NdArray}, optim::{Adam, AdamConfig}};

use crate::{common::{algorithm::{OfflineAlgParams, OfflineTrainer}, buffer::ReplayBuffer, eval::EvalConfig, logger::CsvLogger}, dqn::{module::LinearDQNNet, DQNAgent, DQNConfig}, env::{base::Env, gridworld::GridWorldEnv}};

#[test]
fn test_dqn_lightweight() {
type TrainingBacked = Autodiff<NdArray>;
let device = Default::default();
let config_optimizer = AdamConfig::new();
let optim = config_optimizer.init();
let offline_params = OfflineAlgParams::new()
.with_n_steps(10)
.with_batch_size(2)
.with_memory_size(5)
.with_warmup_steps(2);
let env = GridWorldEnv::default();
let q = LinearDQNNet::<TrainingBacked>::init(
&device,
env.observation_space().shape().len(),
env.action_space().shape(),
2,
);
let agent = DQNAgent::new(
q.clone(),
q,
optim,
DQNConfig::new(),
env.observation_space(),
env.action_space()
);
let buffer = ReplayBuffer::new(offline_params.memory_size);

// create the logs dir
let mut log_dir = std::env::current_dir().unwrap();
log_dir.push("tmp_logs");
let _ = std::fs::create_dir(&log_dir);

let logger = CsvLogger::new(
PathBuf::from("tmp_logs/log.csv"),
false,
);

let mut trainer: OfflineTrainer<_, Adam<NdArray>, _, _, _> = OfflineTrainer::new(
offline_params,
Box::new(env),
Box::<GridWorldEnv>::default(),
agent,
buffer,
Box::new(logger),
None,
EvalConfig::new(),
&device
);

trainer.train();

let _ = std::fs::remove_dir_all(log_dir);
}
}
Loading

0 comments on commit 51f25e9

Please sign in to comment.