diff --git a/src/dqn/module.rs b/src/dqn/module.rs index 270dbfc..30c684c 100644 --- a/src/dqn/module.rs +++ b/src/dqn/module.rs @@ -172,9 +172,9 @@ impl ConvDQNNet { //TODO: take in channels, calculate linear input size Self { - c1: nn::conv::Conv2dConfig::new([obs_shape.shape().dims[0], 4], [3, 3]).init(device), - c2: nn::conv::Conv2dConfig::new([4, 8], [3, 3]).init(device), - l1: nn::LinearConfig::new(64, hidden_size).init(device), + c1: nn::conv::Conv2dConfig::new([obs_shape.shape().dims[0], 2], [3, 3]).init(device), + c2: nn::conv::Conv2dConfig::new([2, 3], [3, 3]).init(device), + l1: nn::LinearConfig::new(432, hidden_size).init(device), l2: nn::LinearConfig::new(hidden_size, act_size).init(device), } }