From d924594b20c663d3349fc81eed6d3bd0b937878b Mon Sep 17 00:00:00 2001 From: Will Maclean Date: Sun, 1 Sep 2024 20:56:11 +1000 Subject: [PATCH] one bug at a time --- src/common/distributions/action_distribution.rs | 6 ++++-- src/env/classic_control/pendulum.rs | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/common/distributions/action_distribution.rs b/src/common/distributions/action_distribution.rs index 94bbb9b..286f1f3 100644 --- a/src/common/distributions/action_distribution.rs +++ b/src/common/distributions/action_distribution.rs @@ -80,7 +80,7 @@ impl DiagGaussianDistribution { log_std: Param::from_tensor( Tensor::ones(Shape::new([action_dim]), device).mul_scalar(log_std_init), ), - dist: dist.no_grad(), + dist, } } } @@ -116,10 +116,12 @@ impl ActionDistribution for DiagGaussianDistribution { let loc = self.means.forward(obs); + self.dist = Normal::new(loc.clone(), scale); + if deterministic { loc } else { - Normal::new(loc, scale).sample() + self.dist.sample() } } } diff --git a/src/env/classic_control/pendulum.rs b/src/env/classic_control/pendulum.rs index 20d39a4..76fa427 100644 --- a/src/env/classic_control/pendulum.rs +++ b/src/env/classic_control/pendulum.rs @@ -142,9 +142,9 @@ impl Env, Vec> for PendulumEnv { pub fn make_pendulum(max_steps: Option) -> Box, Vec>> { let env = make_pendulum_eval(max_steps); - let env = ScaleRewardWrapper::new(env, 0.01); + // let env = ScaleRewardWrapper::new(env, 0.01); - Box::new(env) + env } pub fn make_pendulum_eval(max_steps: Option) -> Box, Vec>> {