Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
will-maclean committed Sep 1, 2024
1 parent c88ec55 commit 21eff40
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 100 deletions.
4 changes: 1 addition & 3 deletions examples/sac_pendulum.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::path::PathBuf;

use burn::{
backend::{
ndarray::NdArrayDevice, Autodiff, NdArray
},
backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
grad_clipping::GradientClippingConfig,
optim::{Adam, AdamConfig},
};
Expand Down
10 changes: 2 additions & 8 deletions examples/sac_probe.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use std::path::PathBuf;

use burn::{
backend::{
libtorch::LibTorchDevice,
Autodiff, LibTorch,
},
backend::{libtorch::LibTorchDevice, Autodiff, LibTorch},
grad_clipping::GradientClippingConfig,
optim::{Adam, AdamConfig},
};
Expand Down Expand Up @@ -81,10 +78,7 @@ fn main() {

let buffer = ReplayBuffer::new(offline_params.memory_size);

let logger = CsvLogger::new(
PathBuf::from("logs/sac_probe/log_sac_probe.csv"),
false,
);
let logger = CsvLogger::new(PathBuf::from("logs/sac_probe/log_sac_probe.csv"), false);

match logger.check_can_log(false) {
Ok(_) => {}
Expand Down
26 changes: 21 additions & 5 deletions src/common/distributions/action_distribution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use burn::{
};
use serde::de;

Check warning on line 6 in src/common/distributions/action_distribution.rs

View workflow job for this annotation

GitHub Actions / build_and_test

unused import: `serde::de`

Check warning on line 6 in src/common/distributions/action_distribution.rs

View workflow job for this annotation

GitHub Actions / build_and_test

unused import: `serde::de`

use crate::common::{agent::Policy, utils::{disp_tensorf, module_update::update_linear}};
use crate::common::{
agent::Policy,
utils::{disp_tensorf, module_update::update_linear},
};

use super::{distribution::BaseDistribution, normal::Normal};

Expand Down Expand Up @@ -113,7 +116,7 @@ impl<B: Backend> ActionDistribution<B> for DiagGaussianDistribution<B> {
.repeat_dim(0, obs.shape().dims[0]);

let loc = self.means.forward(obs);

if deterministic {
loc
} else {
Expand All @@ -137,7 +140,13 @@ pub struct SquashedDiagGaussianDistribution<B: Backend> {
}

impl<B: Backend> SquashedDiagGaussianDistribution<B> {
pub fn new(latent_dim: usize, action_dim: usize, log_std_init: f32, device: &B::Device, epsilon: f32) -> Self{
pub fn new(
latent_dim: usize,
action_dim: usize,
log_std_init: f32,
device: &B::Device,
epsilon: f32,
) -> Self {
Self {
diag_gaus_dist: DiagGaussianDistribution::new(
latent_dim,
Expand Down Expand Up @@ -165,7 +174,13 @@ impl<B: Backend> ActionDistribution<B> for SquashedDiagGaussianDistribution<B> {

// Squash correction (from original SAC implementation)
// this comes from the fact that tanh is bijective and differentiable
let out = log_prob - sample.powi_scalar(2).mul_scalar(-1).add_scalar(1.0 + self.epsilon).log().sum_dim(1);
let out = log_prob
- sample
.powi_scalar(2)
.mul_scalar(-1)
.add_scalar(1.0 + self.epsilon)
.log()
.sum_dim(1);

disp_tensorf("second log prob", &out);

Expand Down Expand Up @@ -256,7 +271,8 @@ mod test {
Shape::new([latent_size]),
Distribution::Normal(0.0, 1.0),
&Default::default(),
).unsqueeze_dim(0);
)
.unsqueeze_dim(0);

let action_sample = dist.actions_from_obs(dummy_obs, false);
let log_prob = dist.log_prob(action_sample);
Expand Down
11 changes: 5 additions & 6 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@ mod test {
// #[test]
// fn mean_can_debug_wgpu(){
// let t: Tensor<Wgpu, 1> = Tensor::from_floats([0.0, 1.0, 2.0], &Default::default());

// println!("{}", t);
// println!("{}", t.mean());



// let t: Tensor<Autodiff<Wgpu>, 1> = Tensor::from_floats([0.0, 1.0, 2.0], &Default::default());

// println!("{t}");
// println!("{}", t.mean());
// }

//TODO: seems to be a burn bug. Disabling for now so tests pass
// #[test]
// fn mean_can_debug_libtorch(){
Expand All @@ -44,4 +43,4 @@ mod test {
// // println!("{t}");
// // println!("{}", t.mean());
// }
}
}
5 changes: 1 addition & 4 deletions src/common/to_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ pub trait ToTensorB<const D: usize>: Clone {

impl ToTensorB<1> for bool {
fn to_tensor<B: Backend>(self, device: &<B as Backend>::Device) -> Tensor<B, 1, Bool> {
Tensor::<B, 1, Bool>::from_bool(
TensorData::from([self]),
device,
)
Tensor::<B, 1, Bool>::from_bool(TensorData::from([self]), device)
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/common/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ pub fn angle_normalise(f: f32) -> f32 {
(f + PI) % (2.0 * PI) - PI
}

pub fn disp_tensorf<B: Backend, const D: usize>(name: &str, t: &Tensor<B, D>){
pub fn disp_tensorf<B: Backend, const D: usize>(name: &str, t: &Tensor<B, D>) {

Check warning on line 60 in src/common/utils/mod.rs

View workflow job for this annotation

GitHub Actions / build_and_test

unused variable: `name`

Check warning on line 60 in src/common/utils/mod.rs

View workflow job for this annotation

GitHub Actions / build_and_test

unused variable: `t`

Check warning on line 60 in src/common/utils/mod.rs

View workflow job for this annotation

GitHub Actions / build_and_test

unused variable: `name`

Check warning on line 60 in src/common/utils/mod.rs

View workflow job for this annotation

GitHub Actions / build_and_test

unused variable: `t`
// println!("{name}. {t}\n");
}

pub fn disp_tensorb<B: Backend, const D: usize>(name: &str, t: &Tensor<B, D, Bool>){
pub fn disp_tensorb<B: Backend, const D: usize>(name: &str, t: &Tensor<B, D, Bool>) {

Check warning on line 64 in src/common/utils/mod.rs

View workflow job for this annotation

GitHub Actions / build_and_test

unused variable: `name`

Check warning on line 64 in src/common/utils/mod.rs

View workflow job for this annotation

GitHub Actions / build_and_test

unused variable: `t`

Check warning on line 64 in src/common/utils/mod.rs

View workflow job for this annotation

GitHub Actions / build_and_test

unused variable: `name`

Check warning on line 64 in src/common/utils/mod.rs

View workflow job for this annotation

GitHub Actions / build_and_test

unused variable: `t`
// println!("{name}. {t}\n");
}

Expand Down
35 changes: 11 additions & 24 deletions src/env/probe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,19 +318,20 @@ impl Env<usize, usize> for ProbeEnvStateActionTest {
}

#[derive(Debug, Clone)]
pub struct ProbeEnvContinuousActions{
pub struct ProbeEnvContinuousActions {
state: f32,
rng: ThreadRng,
}

impl Default for ProbeEnvContinuousActions {
fn default() -> Self {
Self { state: 0.0, rng: Default::default() }
Self {
state: 0.0,
rng: Default::default(),
}
}
}



impl Env<Vec<f32>, Vec<f32>> for ProbeEnvContinuousActions {
fn step(&mut self, action: &Vec<f32>) -> EnvObservation<Vec<f32>> {
assert!(action.len() == 1);
Expand All @@ -339,7 +340,7 @@ impl Env<Vec<f32>, Vec<f32>> for ProbeEnvContinuousActions {

let reward = 1.0 - (a - self.state).abs();

EnvObservation{
EnvObservation {
obs: [0.0].to_vec(),
reward,
terminated: true,
Expand All @@ -355,41 +356,27 @@ impl Env<Vec<f32>, Vec<f32>> for ProbeEnvContinuousActions {
}

fn action_space(&self) -> Box<dyn Space<Vec<f32>>> {
Box::new(
BoxSpace::from((
[0.0].to_vec(),
[1.0].to_vec()
))
)
Box::new(BoxSpace::from(([0.0].to_vec(), [1.0].to_vec())))
}

fn observation_space(&self) -> Box<dyn Space<Vec<f32>>> {
Box::new(
BoxSpace::from((
[0.0].to_vec(),
[1.0].to_vec()
))
)
Box::new(BoxSpace::from(([0.0].to_vec(), [1.0].to_vec())))
}

fn reward_range(&self) -> RewardRange {
RewardRange{
RewardRange {
low: 0.0,
high: 1.0,
}
}

fn render(&self) {

}
fn render(&self) {}

fn renderable(&self) -> bool {
false
}

fn close(&mut self) {

}
fn close(&mut self) {}

fn unwrapped(&self) -> &dyn Env<Vec<f32>, Vec<f32>> {
self
Expand Down
19 changes: 8 additions & 11 deletions src/env/wrappers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ use crate::common::spaces::Space;

use super::base::{Env, EnvObservation, InfoData, RewardRange};


pub struct ScaleRewardWrapper<O, A>
where
where
O: Clone + Debug,
A: Clone + Debug,
{
Expand All @@ -15,17 +14,17 @@ where
}

impl<O, A> ScaleRewardWrapper<O, A>
where
where
O: Clone + Debug,
A: Clone + Debug,
{
pub fn new(env: Box<dyn Env<O, A>>, scaling: f32) -> Self {
pub fn new(env: Box<dyn Env<O, A>>, scaling: f32) -> Self {
Self { env, scaling }
}
}

impl<O, A> Env<O, A> for ScaleRewardWrapper<O, A>
where
where
O: Clone + Debug,
A: Clone + Debug,
{
Expand Down Expand Up @@ -70,27 +69,26 @@ where
}
}


pub struct SignRewardWrapper<O, A>
where
where
O: Clone + Debug,
A: Clone + Debug,
{
env: Box<dyn Env<O, A>>,
}

impl<O, A> SignRewardWrapper<O, A>
where
where
O: Clone + Debug,
A: Clone + Debug,
{
pub fn new(env: Box<dyn Env<O, A>>) -> Self {
pub fn new(env: Box<dyn Env<O, A>>) -> Self {
Self { env }
}
}

impl<O, A> Env<O, A> for SignRewardWrapper<O, A>
where
where
O: Clone + Debug,
A: Clone + Debug,
{
Expand Down Expand Up @@ -135,7 +133,6 @@ where
}
}


pub struct TimeLimitWrapper<O: Clone + Debug, A: Clone + Debug> {
env: Box<dyn Env<O, A>>,
max_steps: usize,
Expand Down
Loading

0 comments on commit 21eff40

Please sign in to comment.