diff --git a/.clang-tidy b/.clang-tidy index fa80f780..c2d1c2eb 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -33,6 +33,7 @@ CheckOptions: - { key: readability-identifier-naming.GlobalConstantCase, value: UPPER_CASE } - { key: readability-identifier-naming.MemberCase, value: lower_case } - { key: readability-identifier-naming.MemberSuffix, value: _ } + - { key: readability-identifier-naming.PublicMemberCase, value: lower_case } - { key: readability-identifier-naming.TypeAliasCase, value: CamelCase } - { key: readability-identifier-naming.ConstantMemberCase, value: CamelCase } - { key: readability-identifier-naming.ConstantMemberPrefix, value: k } diff --git a/.gitignore b/.gitignore index 6d79f4bd..a6678d8b 100644 --- a/.gitignore +++ b/.gitignore @@ -146,3 +146,4 @@ pyflakes* *.swp train_dir log +_vizdoom* diff --git a/examples/benchmark.py b/benchmark/benchmark.py similarity index 99% rename from examples/benchmark.py rename to benchmark/benchmark.py index 3ceb21a5..5f0365eb 100644 --- a/examples/benchmark.py +++ b/benchmark/benchmark.py @@ -58,7 +58,7 @@ parser.add_argument("--total-iter", type=int, default=50000) args = parser.parse_args() env = envpool.make_gym( - "Pong-v5", + args.task, num_envs=args.num_envs, batch_size=args.batch_size, num_threads=args.num_threads, diff --git a/examples/numa_test.sh b/benchmark/numa_test.sh similarity index 100% rename from examples/numa_test.sh rename to benchmark/numa_test.sh diff --git a/docs/pages/build.rst b/docs/pages/build.rst index bb8c2724..4b59fccf 100644 --- a/docs/pages/build.rst +++ b/docs/pages/build.rst @@ -60,6 +60,8 @@ or `golang `_ with version >= 1.16: # check if successfully installed bazel + See `Issue #87 `_. + Install Other Dependencies -------------------------- @@ -113,6 +115,8 @@ This creates a wheel under ``bazel-bin/setup.runfiles/envpool/dist``. export HTTPS_PROXY=http://... # then run the command to build + See `Issue #87 `_. + Use Shortcut ------------ @@ -154,3 +158,5 @@ system via ``/host``. .. code-block:: bash make docker-dev-cn + + See `Issue #87 `_. diff --git a/docs/pages/env.rst b/docs/pages/env.rst index 661357d5..22f461f1 100644 --- a/docs/pages/env.rst +++ b/docs/pages/env.rst @@ -254,7 +254,7 @@ class for convenience, which follows the definition of ``CartPoleEnvSpec``. The following functions are required to override: - constructor, in this case it is ``CartPoleEnv(const Spec& spec, int env_id)``; - you can use ``spec.config_["max_episode_steps"_]`` to extract the value from + you can use ``spec.config["max_episode_steps"_]`` to extract the value from config; - ``bool IsDone()``: return a boolean that indicate whether the current episode is finished or not; diff --git a/envpool/atari/atari_env.h b/envpool/atari/atari_env.h index d47bf9c3..331235c9 100644 --- a/envpool/atari/atari_env.h +++ b/envpool/atari/atari_env.h @@ -103,29 +103,27 @@ class AtariEnv : public Env { AtariEnv(const Spec& spec, int env_id) : Env(spec, env_id), env_(new ale::ALEInterface()), - max_episode_steps_(spec.config_["max_episode_steps"_]), + max_episode_steps_(spec.config["max_episode_steps"_]), elapsed_step_(max_episode_steps_ + 1), - stack_num_(spec.config_["stack_num"_]), - frame_skip_(spec.config_["frame_skip"_]), + stack_num_(spec.config["stack_num"_]), + frame_skip_(spec.config["frame_skip"_]), fire_reset_(false), - reward_clip_(spec.config_["reward_clip"_]), - zero_discount_on_life_loss_( - spec.config_["zero_discount_on_life_loss"_]), - gray_scale_(spec.config_["gray_scale"_]), - episodic_life_(spec.config_["episodic_life"_]), - use_inter_area_resize_(spec.config_["use_inter_area_resize"_]), + reward_clip_(spec.config["reward_clip"_]), + zero_discount_on_life_loss_(spec.config["zero_discount_on_life_loss"_]), + gray_scale_(spec.config["gray_scale"_]), + episodic_life_(spec.config["episodic_life"_]), + use_inter_area_resize_(spec.config["use_inter_area_resize"_]), done_(true), raw_spec_({kRawHeight, kRawWidth, gray_scale_ ? 1 : 3}), - resize_spec_({spec.config_["img_height"_], spec.config_["img_width"_], + resize_spec_({spec.config["img_height"_], spec.config["img_width"_], gray_scale_ ? 1 : 3}), - transpose_spec_({gray_scale_ ? 1 : 3, spec.config_["img_height"_], - spec.config_["img_width"_]}), + transpose_spec_({gray_scale_ ? 1 : 3, spec.config["img_height"_], + spec.config["img_width"_]}), resize_img_(resize_spec_), - dist_noop_(0, spec.config_["noop_max"_] - 1), - rom_path_( - GetRomPath(spec.config_["base_path"_], spec.config_["task"_])) { + dist_noop_(0, spec.config["noop_max"_] - 1), + rom_path_(GetRomPath(spec.config["base_path"_], spec.config["task"_])) { env_->setFloat("repeat_action_probability", - spec.config_["repeat_action_probability"_]); + spec.config["repeat_action_probability"_]); env_->setInt("random_seed", seed_); env_->loadROM(rom_path_); action_set_ = env_->getMinimalActionSet(); @@ -258,7 +256,7 @@ class AtariEnv : public Env { auto* ptr = static_cast(maxpool_buf_[0].Data()); if (maxpool) { auto* ptr1 = static_cast(maxpool_buf_[1].Data()); - for (std::size_t i = 0; i < maxpool_buf_[0].size_; ++i) { + for (std::size_t i = 0; i < maxpool_buf_[0].size; ++i) { ptr[i] = std::max(ptr[i], ptr1[i]); } } @@ -282,7 +280,7 @@ class AtariEnv : public Env { } } } - std::size_t size = tgt.size_; + std::size_t size = tgt.size; stack_buf_.push_back(std::move(tgt)); if (push_all) { for (auto& s : stack_buf_) { diff --git a/envpool/classic_control/acrobot.h b/envpool/classic_control/acrobot.h index 45de07ea..24571e10 100644 --- a/envpool/classic_control/acrobot.h +++ b/envpool/classic_control/acrobot.h @@ -49,16 +49,15 @@ using AcrobotEnvSpec = EnvSpec; class AcrobotEnv : public Env { struct V5 { - double s0_{0}, s1_{0}, s2_{0}, s3_{0}, s4_{0}; + double s0{0}, s1{0}, s2{0}, s3{0}, s4{0}; V5() = default; V5(double s0, double s1, double s2, double s3, double s4) - : s0_(s0), s1_(s1), s2_(s2), s3_(s3), s4_(s4) {} + : s0(s0), s1(s1), s2(s2), s3(s3), s4(s4) {} V5 operator+(const V5& v) const { - return V5(s0_ + v.s0_, s1_ + v.s1_, s2_ + v.s2_, s3_ + v.s3_, - s4_ + v.s4_); + return V5(s0 + v.s0, s1 + v.s1, s2 + v.s2, s3 + v.s3, s4 + v.s4); } V5 operator*(double v) const { - return V5(s0_ * v, s1_ * v, s2_ * v, s3_ * v, s4_ * v); + return V5(s0 * v, s1 * v, s2 * v, s3 * v, s4 * v); } }; @@ -82,7 +81,7 @@ class AcrobotEnv : public Env { public: AcrobotEnv(const Spec& spec, int env_id) : Env(spec, env_id), - max_episode_steps_(spec.config_["max_episode_steps"_]), + max_episode_steps_(spec.config["max_episode_steps"_]), elapsed_step_(max_episode_steps_ + 1), dist_(-kInitRange, kInitRange), done_(true) {} @@ -90,11 +89,11 @@ class AcrobotEnv : public Env { bool IsDone() override { return done_; } void Reset() override { - s_.s0_ = dist_(gen_); - s_.s1_ = dist_(gen_); - s_.s2_ = dist_(gen_); - s_.s3_ = dist_(gen_); - s_.s4_ = 0; + s_.s0 = dist_(gen_); + s_.s1 = dist_(gen_); + s_.s2 = dist_(gen_); + s_.s3 = dist_(gen_); + s_.s4 = 0; done_ = false; elapsed_step_ = 0; WriteState(0.0); @@ -105,33 +104,33 @@ class AcrobotEnv : public Env { int act = action["action"_]; float reward = -1.0; - s_.s4_ = act - 1; + s_.s4 = act - 1; s_ = Rk4(s_); - while (s_.s0_ < -kPi) { - s_.s0_ += kPi * 2; + while (s_.s0 < -kPi) { + s_.s0 += kPi * 2; } - while (s_.s1_ < -kPi) { - s_.s1_ += kPi * 2; + while (s_.s1 < -kPi) { + s_.s1 += kPi * 2; } - while (s_.s0_ >= kPi) { - s_.s0_ -= kPi * 2; + while (s_.s0 >= kPi) { + s_.s0 -= kPi * 2; } - while (s_.s1_ >= kPi) { - s_.s1_ -= kPi * 2; + while (s_.s1 >= kPi) { + s_.s1 -= kPi * 2; } - if (s_.s2_ < -kMaxVel1) { - s_.s2_ = -kMaxVel1; + if (s_.s2 < -kMaxVel1) { + s_.s2 = -kMaxVel1; } - if (s_.s3_ < -kMaxVel2) { - s_.s3_ = -kMaxVel2; + if (s_.s3 < -kMaxVel2) { + s_.s3 = -kMaxVel2; } - if (s_.s2_ > kMaxVel1) { - s_.s2_ = kMaxVel1; + if (s_.s2 > kMaxVel1) { + s_.s2 = kMaxVel1; } - if (s_.s3_ > kMaxVel2) { - s_.s3_ = kMaxVel2; + if (s_.s3 > kMaxVel2) { + s_.s3 = kMaxVel2; } - if (-std::cos(s_.s0_) - std::cos(s_.s0_ + s_.s1_) > 1) { + if (-std::cos(s_.s0) - std::cos(s_.s0 + s_.s1) > 1) { done_ = true; reward = 0.0; } @@ -149,11 +148,11 @@ class AcrobotEnv : public Env { } [[nodiscard]] V5 Derivs(V5 s, double t) const { - double theta1 = s.s0_; - double theta2 = s.s1_; - double dtheta1 = s.s2_; - double dtheta2 = s.s3_; - double a = s.s4_; + double theta1 = s.s0; + double theta2 = s.s1; + double dtheta1 = s.s2; + double dtheta2 = s.s3; + double a = s.s4; double d1 = kM * kLC * kLC + kM * (kL * kL + kLC * kLC + 2 * kL * kLC * std::cos(theta2)) + kI * 2; @@ -172,14 +171,14 @@ class AcrobotEnv : public Env { void WriteState(float reward) { State state = Allocate(); - state["obs"_][0] = static_cast(std::cos(s_.s0_)); - state["obs"_][1] = static_cast(std::sin(s_.s0_)); - state["obs"_][2] = static_cast(std::cos(s_.s1_)); - state["obs"_][3] = static_cast(std::sin(s_.s1_)); - state["obs"_][4] = static_cast(s_.s2_); - state["obs"_][5] = static_cast(s_.s3_); - state["info:state"_][0] = static_cast(s_.s0_); - state["info:state"_][1] = static_cast(s_.s1_); + state["obs"_][0] = static_cast(std::cos(s_.s0)); + state["obs"_][1] = static_cast(std::sin(s_.s0)); + state["obs"_][2] = static_cast(std::cos(s_.s1)); + state["obs"_][3] = static_cast(std::sin(s_.s1)); + state["obs"_][4] = static_cast(s_.s2); + state["obs"_][5] = static_cast(s_.s3); + state["info:state"_][0] = static_cast(s_.s0); + state["info:state"_][1] = static_cast(s_.s1); state["reward"_] = reward; } }; diff --git a/envpool/classic_control/cartpole.h b/envpool/classic_control/cartpole.h index d0711b35..9a5da20e 100644 --- a/envpool/classic_control/cartpole.h +++ b/envpool/classic_control/cartpole.h @@ -70,7 +70,7 @@ class CartPoleEnv : public Env { public: CartPoleEnv(const Spec& spec, int env_id) : Env(spec, env_id), - max_episode_steps_(spec.config_["max_episode_steps"_]), + max_episode_steps_(spec.config["max_episode_steps"_]), elapsed_step_(max_episode_steps_ + 1), dist_(-kInitRange, kInitRange), done_(true) {} diff --git a/envpool/classic_control/mountain_car.h b/envpool/classic_control/mountain_car.h index b705b593..4c02d4ca 100644 --- a/envpool/classic_control/mountain_car.h +++ b/envpool/classic_control/mountain_car.h @@ -62,7 +62,7 @@ class MountainCarEnv : public Env { public: MountainCarEnv(const Spec& spec, int env_id) : Env(spec, env_id), - max_episode_steps_(spec.config_["max_episode_steps"_]), + max_episode_steps_(spec.config["max_episode_steps"_]), elapsed_step_(max_episode_steps_ + 1), dist_(-0.6, -0.4), done_(true) {} diff --git a/envpool/classic_control/mountain_car_continuous.h b/envpool/classic_control/mountain_car_continuous.h index b7db207a..81c1c2d9 100644 --- a/envpool/classic_control/mountain_car_continuous.h +++ b/envpool/classic_control/mountain_car_continuous.h @@ -62,7 +62,7 @@ class MountainCarContinuousEnv : public Env { public: MountainCarContinuousEnv(const Spec& spec, int env_id) : Env(spec, env_id), - max_episode_steps_(spec.config_["max_episode_steps"_]), + max_episode_steps_(spec.config["max_episode_steps"_]), elapsed_step_(max_episode_steps_ + 1), dist_(-0.6, -0.4), done_(true) {} diff --git a/envpool/classic_control/pendulum.h b/envpool/classic_control/pendulum.h index 843d1580..2dd98e50 100644 --- a/envpool/classic_control/pendulum.h +++ b/envpool/classic_control/pendulum.h @@ -60,7 +60,7 @@ class PendulumEnv : public Env { public: PendulumEnv(const Spec& spec, int env_id) : Env(spec, env_id), - max_episode_steps_(spec.config_["max_episode_steps"_]), + max_episode_steps_(spec.config["max_episode_steps"_]), elapsed_step_(max_episode_steps_ + 1), dist_(-kPi, kPi), dist_dot_(-1, 1), diff --git a/envpool/core/action_buffer_queue.h b/envpool/core/action_buffer_queue.h index 19eb98ad..4aacfdee 100644 --- a/envpool/core/action_buffer_queue.h +++ b/envpool/core/action_buffer_queue.h @@ -35,9 +35,9 @@ class ActionBufferQueue { public: struct ActionSlice { - int env_id_; - int order_; - bool force_reset_; + int env_id; + int order; + bool force_reset; }; protected: diff --git a/envpool/core/action_buffer_queue_test.cc b/envpool/core/action_buffer_queue_test.cc index 7f26c479..ed4a12e9 100644 --- a/envpool/core/action_buffer_queue_test.cc +++ b/envpool/core/action_buffer_queue_test.cc @@ -37,7 +37,7 @@ TEST(ActionBufferQueueTest, Concurrent) { // enqueue all envs for (std::size_t i = 0; i < num_envs; ++i) { actions.push_back(ActionSlice{ - .env_id_ = static_cast(i), .order_ = -1, .force_reset_ = false}); + .env_id = static_cast(i), .order = -1, .force_reset = false}); } queue.EnqueueBulk(actions); std::vector> flag(mul); @@ -53,9 +53,8 @@ TEST(ActionBufferQueueTest, Concurrent) { } actions.clear(); for (std::size_t i = 0; i < env_num[m]; ++i) { - actions.push_back(ActionSlice{.env_id_ = static_cast(i), - .order_ = -1, - .force_reset_ = false}); + actions.push_back(ActionSlice{ + .env_id = static_cast(i), .order = -1, .force_reset = false}); } queue.EnqueueBulk(actions); } diff --git a/envpool/core/array.h b/envpool/core/array.h index 04e6b6c6..07a0bfd1 100644 --- a/envpool/core/array.h +++ b/envpool/core/array.h @@ -29,9 +29,9 @@ class Array { public: - std::size_t size_; - std::size_t ndim_; - std::size_t element_size_; + std::size_t size; + std::size_t ndim; + std::size_t element_size; protected: std::vector shape_; @@ -40,17 +40,17 @@ class Array { template Array(char* ptr, Shape&& shape, std::size_t element_size, // NOLINT Deleter&& deleter) - : size_(Prod(shape.data(), shape.size())), - ndim_(shape.size()), - element_size_(element_size), + : size(Prod(shape.data(), shape.size())), + ndim(shape.size()), + element_size(element_size), shape_(std::forward(shape)), ptr_(ptr, std::forward(deleter)) {} template Array(std::shared_ptr ptr, Shape&& shape, std::size_t element_size) - : size_(Prod(shape.data(), shape.size())), - ndim_(shape.size()), - element_size_(element_size), + : size(Prod(shape.data(), shape.size())), + ndim(shape.size()), + element_size(element_size), shape_(std::forward(shape)), ptr_(std::move(ptr)) {} @@ -64,12 +64,11 @@ class Array { */ template Array(const ShapeSpec& spec, char* data, Deleter&& deleter) // NOLINT - : Array(data, spec.Shape(), spec.element_size_, + : Array(data, spec.Shape(), spec.element_size, std::forward(deleter)) {} Array(const ShapeSpec& spec, char* data) - : Array(data, spec.Shape(), spec.element_size_, [](char* /*unused*/) {}) { - } + : Array(data, spec.Shape(), spec.element_size, [](char* /*unused*/) {}) {} /** * Constructor an `Array` of shape defined by `spec`. This constructor @@ -77,7 +76,7 @@ class Array { */ explicit Array(const ShapeSpec& spec) : Array(spec, nullptr, [](char* /*unused*/) {}) { - ptr_.reset(new char[size_ * element_size_](), + ptr_.reset(new char[size * element_size](), [](const char* p) { delete[] p; }); } @@ -87,16 +86,16 @@ class Array { template inline Array operator()(Index... index) const { constexpr std::size_t num_index = sizeof...(Index); - DCHECK_GE(ndim_, num_index); + DCHECK_GE(ndim, num_index); std::size_t offset = 0; std::size_t i = 0; - for (((offset = offset * shape_[i++] + index), ...); i < ndim_; ++i) { + for (((offset = offset * shape_[i++] + index), ...); i < ndim; ++i) { offset *= shape_[i]; } return Array( - ptr_.get() + offset * element_size_, + ptr_.get() + offset * element_size, std::vector(shape_.begin() + num_index, shape_.end()), - element_size_, [](char* /*unused*/) {}); + element_size, [](char* /*unused*/) {}); } /** @@ -108,27 +107,27 @@ class Array { * Take a slice at the first axis of the Array. */ [[nodiscard]] Array Slice(std::size_t start, std::size_t end) const { - DCHECK_GT(ndim_, (std::size_t)0); + DCHECK_GT(ndim, (std::size_t)0); CHECK_GE(shape_[0], end); CHECK_GE(end, start); std::vector new_shape(shape_); new_shape[0] = end - start; std::size_t offset = 0; if (shape_[0] > 0) { - offset = start * size_ / shape_[0]; + offset = start * size / shape_[0]; } - return Array(ptr_.get() + offset * element_size_, std::move(new_shape), - element_size_, [](char* p) {}); + return Array(ptr_.get() + offset * element_size, std::move(new_shape), + element_size, [](char* p) {}); } /** * Copy the content of another Array to this Array. */ void Assign(const Array& value) const { - DCHECK_EQ(element_size_, value.element_size_) + DCHECK_EQ(element_size, value.element_size) << " element size doesn't match"; - DCHECK_EQ(size_, value.size_) << " ndim doesn't match"; - std::memcpy(ptr_.get(), value.ptr_.get(), size_ * element_size_); + DCHECK_EQ(size, value.size) << " ndim doesn't match"; + std::memcpy(ptr_.get(), value.ptr_.get(), size * element_size); } /** @@ -137,8 +136,8 @@ class Array { */ template void operator=(const T& value) const { - DCHECK_EQ(element_size_, sizeof(T)) << " element size doesn't match"; - DCHECK_EQ(size_, (std::size_t)1) << " assigning scalar to non-scalar array"; + DCHECK_EQ(element_size, sizeof(T)) << " element size doesn't match"; + DCHECK_EQ(size, (std::size_t)1) << " assigning scalar to non-scalar array"; *reinterpret_cast(ptr_.get()) = value; } @@ -148,8 +147,8 @@ class Array { */ template void Assign(const T* buff, std::size_t sz) const { - DCHECK_EQ(sz, size_) << " assignment size mismatch"; - DCHECK_EQ(sizeof(T), element_size_) << " element size mismatch"; + DCHECK_EQ(sz, size) << " assignment size mismatch"; + DCHECK_EQ(sizeof(T), element_size) << " element size mismatch"; std::memcpy(ptr_.get(), buff, sz * sizeof(T)); } @@ -159,8 +158,8 @@ class Array { */ template operator T() const { // NOLINT - DCHECK_EQ(element_size_, sizeof(T)) << " there could be a type mismatch"; - DCHECK_EQ(size_, (std::size_t)1) + DCHECK_EQ(element_size, sizeof(T)) << " there could be a type mismatch"; + DCHECK_EQ(size, (std::size_t)1) << " Array with a shape can't be used as a scalar"; return *reinterpret_cast(ptr_.get()); } @@ -191,11 +190,11 @@ class Array { [[nodiscard]] Array Truncate(std::size_t end) const { auto new_shape = std::vector(shape_); new_shape[0] = end; - Array ret(ptr_, std::move(new_shape), element_size_); + Array ret(ptr_, std::move(new_shape), element_size); return ret; } - void Zero() const { std::memset(ptr_.get(), 0, size_ * element_size_); } + void Zero() const { std::memset(ptr_.get(), 0, size * element_size); } [[nodiscard]] std::shared_ptr SharedPtr() const { return ptr_; } }; diff --git a/envpool/core/async_envpool.h b/envpool/core/async_envpool.h index 2c2e9b42..6da48849 100644 --- a/envpool/core/async_envpool.h +++ b/envpool/core/async_envpool.h @@ -63,18 +63,18 @@ class AsyncEnvPool : public EnvPool { explicit AsyncEnvPool(const Spec& spec) : EnvPool(spec), - num_envs_(spec.config_["num_envs"_]), - batch_(spec.config_["batch_size"_] <= 0 ? num_envs_ - : spec.config_["batch_size"_]), - max_num_players_(spec.config_["max_num_players"_]), - num_threads_(spec.config_["num_threads"_]), + num_envs_(spec.config["num_envs"_]), + batch_(spec.config["batch_size"_] <= 0 ? num_envs_ + : spec.config["batch_size"_]), + max_num_players_(spec.config["max_num_players"_]), + num_threads_(spec.config["num_threads"_]), is_sync_(batch_ == num_envs_ && max_num_players_ == 1), stop_(0), stepping_env_num_(0), action_buffer_queue_(new ActionBufferQueue(num_envs_)), state_buffer_queue_(new StateBufferQueue( batch_, num_envs_, max_num_players_, - spec.state_spec_.template AllValues())), + spec.state_spec.template AllValues())), envs_(num_envs_) { std::size_t processor_count = std::thread::hardware_concurrency(); ThreadPool init_pool(std::min(processor_count, num_envs_)); @@ -96,15 +96,14 @@ class AsyncEnvPool : public EnvPool { if (stop_ == 1) { break; } - int env_id = raw_action.env_id_; - int order = raw_action.order_; - bool reset = raw_action.force_reset_ || envs_[env_id]->IsDone(); + int env_id = raw_action.env_id; + int order = raw_action.order; + bool reset = raw_action.force_reset || envs_[env_id]->IsDone(); envs_[env_id]->EnvStep(state_buffer_queue_.get(), order, reset); } }); } - std::size_t thread_affinity_offset = - spec.config_["thread_affinity_offset"_]; + std::size_t thread_affinity_offset = spec.config["thread_affinity_offset"_]; if (thread_affinity_offset >= 0) { for (std::size_t tid = 0; tid < num_threads_; ++tid) { cpu_set_t cpuset; @@ -139,9 +138,9 @@ class AsyncEnvPool : public EnvPool { int eid = env_id[i]; envs_[eid]->SetAction(action_batch, i); actions.emplace_back(ActionSlice{ - .env_id_ = eid, - .order_ = is_sync_ ? i : -1, - .force_reset_ = false, + .env_id = eid, + .order = is_sync_ ? i : -1, + .force_reset = false, }); } if (is_sync_) { @@ -171,9 +170,9 @@ class AsyncEnvPool : public EnvPool { int shared_offset = env_ids.Shape(0); std::vector actions(shared_offset); for (int i = 0; i < shared_offset; ++i) { - actions[i].force_reset_ = true; - actions[i].env_id_ = env_ids[i]; - actions[i].order_ = is_sync_ ? i : -1; + actions[i].force_reset = true; + actions[i].env_id = env_ids[i]; + actions[i].order = is_sync_ ? i : -1; } if (is_sync_) { stepping_env_num_ += shared_offset; diff --git a/envpool/core/dict.h b/envpool/core/dict.h index 6ae2227c..a9c56224 100644 --- a/envpool/core/dict.h +++ b/envpool/core/dict.h @@ -37,8 +37,8 @@ class Value { public: using Key = K; using Type = D; - explicit Value(Type&& v) : v_(v) {} - Type v_; + explicit Value(Type&& v) : v(v) {} + Type v; }; template @@ -237,7 +237,7 @@ class Dict : public std::decay_t { template decltype(auto) MakeDict(Value... v) { return Dict(std::make_tuple(typename Value::Key()...), - std::make_tuple(v.v_...)); + std::make_tuple(v.v...)); } template < diff --git a/envpool/core/env.h b/envpool/core/env.h index 5b6dc84c..d5ffb7a6 100644 --- a/envpool/core/env.h +++ b/envpool/core/env.h @@ -54,18 +54,18 @@ class Env { using Action = NamedVector>; Env(const EnvSpec& spec, int env_id) - : max_num_players_(spec.config_["max_num_players"_]), + : max_num_players_(spec.config["max_num_players"_]), spec_(spec), env_id_(env_id), - seed_(spec.config_["seed"_] + env_id), + seed_(spec.config["seed"_] + env_id), gen_(seed_), current_step_(-1), is_single_player_(max_num_players_ == 1), - action_specs_(spec.action_spec_.template AllValues()), + action_specs_(spec.action_spec.template AllValues()), is_player_action_(Transform(action_specs_, [](const ShapeSpec& s) { - return (!s.shape_.empty() && s.shape_[0] == -1); + return (!s.shape.empty() && s.shape[0] == -1); })) { - slice_.done_write_ = [] { LOG(INFO) << "Use `Allocate` to write state."; }; + slice_.done_write = [] { LOG(INFO) << "Use `Allocate` to write state."; }; } void SetAction(std::shared_ptr> action_batch, @@ -109,7 +109,7 @@ class Env { if (continuous) { raw_action_.emplace_back((*action_batch_)[i].Slice(start, end)); } else { - action_specs_[i].shape_[0] = player_num; + action_specs_[i].shape[0] = player_num; Array arr(action_specs_[i]); for (int j = 0; j < player_num; ++j) { int player_index = env_player_index[j]; @@ -153,13 +153,13 @@ class Env { } void PostProcess() { - slice_.done_write_(); + slice_.done_write(); // action_batch_.reset(); } State Allocate(int player_num = 1) { slice_ = sbq_->Allocate(player_num, order_); - State state(&slice_.arr_); + State state(&slice_.arr); state["done"_] = IsDone(); state["info:env_id"_] = env_id_; state["elapsed_step"_] = current_step_; diff --git a/envpool/core/env_spec.h b/envpool/core/env_spec.h index e78e275c..afb1b087 100644 --- a/envpool/core/env_spec.h +++ b/envpool/core/env_spec.h @@ -55,23 +55,23 @@ class EnvSpec { using ActionKeys = typename ActionSpec::Keys; // For C++ - Config config_; - StateSpec state_spec_; - ActionSpec action_spec_; + Config config; + StateSpec state_spec; + ActionSpec action_spec; static inline const Config DEFAULT_CONFIG = ConcatDict(common_config, EnvFns::DefaultConfig()); EnvSpec() : EnvSpec(DEFAULT_CONFIG) {} explicit EnvSpec(const ConfigValues& conf) - : config_(conf), - state_spec_(ConcatDict(common_state_spec, EnvFns::StateSpec(config_))), - action_spec_( - ConcatDict(common_action_spec, EnvFns::ActionSpec(config_))) { - if (config_["batch_size"_] > config_["num_envs"_]) { + : config(conf), + state_spec(ConcatDict(common_state_spec, EnvFns::StateSpec(config))), + action_spec( + ConcatDict(common_action_spec, EnvFns::ActionSpec(config))) { + if (config["batch_size"_] > config["num_envs"_]) { throw std::invalid_argument( "It is required that batch_size <= num_envs, got num_envs = " + - std::to_string(config_["num_envs"_]) + - ", batch_size = " + std::to_string(config_["batch_size"_])); + std::to_string(config["num_envs"_]) + + ", batch_size = " + std::to_string(config["batch_size"_])); } } }; diff --git a/envpool/core/py_envpool.h b/envpool/core/py_envpool.h index 31e8cb03..2882a5ba 100644 --- a/envpool/core/py_envpool.h +++ b/envpool/core/py_envpool.h @@ -71,8 +71,8 @@ decltype(auto) ExportSpecs(const std::tuple& specs) { return std::apply( [&](auto&&... spec) { return std::make_tuple( - std::make_tuple(py::dtype::of(), spec.shape_, - spec.bounds_, spec.elementwise_bounds_)...); + std::make_tuple(py::dtype::of(), spec.shape, + spec.bounds, spec.elementwise_bounds)...); }, specs); } @@ -85,9 +85,9 @@ class PyEnvSpec : public EnvSpec { using ActionSpecT = decltype(ExportSpecs(std::declval())); - StateSpecT py_state_spec_; - ActionSpecT py_action_spec_; - typename EnvSpec::ConfigValues py_config_values_; + StateSpecT py_state_spec; + ActionSpecT py_action_spec; + typename EnvSpec::ConfigValues py_config_values; static std::vector py_config_keys; static std::vector py_state_keys; static std::vector py_action_keys; @@ -95,9 +95,9 @@ class PyEnvSpec : public EnvSpec { explicit PyEnvSpec(const typename EnvSpec::ConfigValues& conf) : EnvSpec(conf), - py_state_spec_(ExportSpecs(EnvSpec::state_spec_)), - py_action_spec_(ExportSpecs(EnvSpec::action_spec_)), - py_config_values_(EnvSpec::config_.AllValues()) {} + py_state_spec(ExportSpecs(EnvSpec::state_spec)), + py_action_spec(ExportSpecs(EnvSpec::action_spec)), + py_config_values(EnvSpec::config.AllValues()) {} }; template std::vector PyEnvSpec::py_config_keys = @@ -149,12 +149,12 @@ class PyEnvPool : public EnvPool { public: using PySpec = PyEnvSpec; - PySpec py_spec_; + PySpec py_spec; static std::vector py_state_keys; static std::vector py_action_keys; explicit PyEnvPool(const PySpec& py_spec) - : EnvPool(py_spec), py_spec_(py_spec) {} + : EnvPool(py_spec), py_spec(py_spec) {} /** * py api @@ -162,7 +162,7 @@ class PyEnvPool : public EnvPool { void PySend(const std::vector& action) { std::vector arr; arr.reserve(action.size()); - ToArray(action, py_spec_.action_spec_, &arr); + ToArray(action, py_spec.action_spec, &arr); py::gil_scoped_release release; EnvPool::Send(arr); // delegate to the c++ api } @@ -179,7 +179,7 @@ class PyEnvPool : public EnvPool { } std::vector ret; ret.reserve(EnvPool::State::SIZE); - ToNumpy(arr, py_spec_.state_spec_, &ret); + ToNumpy(arr, py_spec.state_spec, &ret); return ret; } @@ -211,9 +211,9 @@ py::object abc_meta = py::module::import("abc").attr("ABCMeta"); #define REGISTER(MODULE, SPEC, ENVPOOL) \ py::class_(MODULE, "_" #SPEC, py::metaclass(abc_meta)) \ .def(py::init()) \ - .def_readonly("_config_values", &SPEC::py_config_values_) \ - .def_readonly("_state_spec", &SPEC::py_state_spec_) \ - .def_readonly("_action_spec", &SPEC::py_action_spec_) \ + .def_readonly("_config_values", &SPEC::py_config_values) \ + .def_readonly("_state_spec", &SPEC::py_state_spec) \ + .def_readonly("_action_spec", &SPEC::py_action_spec) \ .def_readonly_static("_state_keys", &SPEC::py_state_keys) \ .def_readonly_static("_action_keys", &SPEC::py_action_keys) \ .def_readonly_static("_config_keys", &SPEC::py_config_keys) \ @@ -221,7 +221,7 @@ py::object abc_meta = py::module::import("abc").attr("ABCMeta"); &SPEC::py_default_config_values); \ py::class_(MODULE, "_" #ENVPOOL, py::metaclass(abc_meta)) \ .def(py::init()) \ - .def_readonly("_spec", &ENVPOOL::py_spec_) \ + .def_readonly("_spec", &ENVPOOL::py_spec) \ .def("_recv", &ENVPOOL::PyRecv) \ .def("_send", &ENVPOOL::PySend) \ .def("_reset", &ENVPOOL::PyReset) \ diff --git a/envpool/core/spec.h b/envpool/core/spec.h index ee8dfe39..eb52e141 100644 --- a/envpool/core/spec.h +++ b/envpool/core/spec.h @@ -36,20 +36,20 @@ std::size_t Prod(const std::size_t* shape, std::size_t ndim) { class ShapeSpec { public: - int element_size_; - std::vector shape_; + int element_size; + std::vector shape; ShapeSpec() = default; ShapeSpec(int element_size, std::vector shape_vec) - : element_size_(element_size), shape_(std::move(shape_vec)) {} + : element_size(element_size), shape(std::move(shape_vec)) {} [[nodiscard]] ShapeSpec Batch(int batch_size) const { std::vector new_shape = {batch_size}; - new_shape.insert(new_shape.end(), shape_.begin(), shape_.end()); - return ShapeSpec(element_size_, std::move(new_shape)); + new_shape.insert(new_shape.end(), shape.begin(), shape.end()); + return ShapeSpec(element_size, std::move(new_shape)); } [[nodiscard]] std::vector Shape() const { - auto s = std::vector(shape_.size()); - for (std::size_t i = 0; i < shape_.size(); ++i) { - s[i] = shape_[i]; + auto s = std::vector(shape.size()); + for (std::size_t i = 0; i < shape.size(); ++i) { + s[i] = shape[i]; } return s; } @@ -59,9 +59,9 @@ template class Spec : public ShapeSpec { public: using dtype = D; // NOLINT - std::tuple bounds_ = {std::numeric_limits::min(), - std::numeric_limits::max()}; - std::tuple, std::vector> elementwise_bounds_; + std::tuple bounds = {std::numeric_limits::min(), + std::numeric_limits::max()}; + std::tuple, std::vector> elementwise_bounds; explicit Spec(std::vector&& shape) : ShapeSpec(sizeof(dtype), std::move(shape)) {} explicit Spec(const std::vector& shape) @@ -69,25 +69,24 @@ class Spec : public ShapeSpec { /* init with constant bounds */ Spec(std::vector&& shape, std::tuple&& bounds) - : ShapeSpec(sizeof(dtype), std::move(shape)), - bounds_(std::move(bounds)) {} + : ShapeSpec(sizeof(dtype), std::move(shape)), bounds(std::move(bounds)) {} Spec(const std::vector& shape, const std::tuple& bounds) - : ShapeSpec(sizeof(dtype), shape), bounds_(bounds) {} + : ShapeSpec(sizeof(dtype), shape), bounds(bounds) {} /* init with elementwise bounds */ Spec(std::vector&& shape, std::tuple, std::vector>&& elementwise_bounds) : ShapeSpec(sizeof(dtype), std::move(shape)), - elementwise_bounds_(std::move(elementwise_bounds)) {} + elementwise_bounds(std::move(elementwise_bounds)) {} Spec(const std::vector& shape, const std::tuple, std::vector>& elementwise_bounds) : ShapeSpec(sizeof(dtype), shape), - elementwise_bounds_(elementwise_bounds) {} + elementwise_bounds(elementwise_bounds) {} [[nodiscard]] Spec Batch(int batch_size) const { std::vector new_shape = {batch_size}; - new_shape.insert(new_shape.end(), shape_.begin(), shape_.end()); + new_shape.insert(new_shape.end(), shape.begin(), shape.end()); return Spec(std::move(new_shape)); } }; diff --git a/envpool/core/state_buffer.h b/envpool/core/state_buffer.h index ac661cf7..4e3385b1 100644 --- a/envpool/core/state_buffer.h +++ b/envpool/core/state_buffer.h @@ -57,8 +57,8 @@ class StateBuffer { * invoke done write. */ struct WritableSlice { - std::vector arr_; - std::function done_write_; + std::vector arr; + std::function done_write; }; /** @@ -106,8 +106,8 @@ class StateBuffer { state.emplace_back(a[shared_offset]); } } - return WritableSlice{.arr_ = std::move(state), - .done_write_ = [this]() { Done(); }}; + return WritableSlice{.arr = std::move(state), + .done_write = [this]() { Done(); }}; } DLOG(INFO) << "Allocation failed, continue to the next block of memory"; throw std::out_of_range("StateBuffer out of storage"); diff --git a/envpool/core/state_buffer_queue.h b/envpool/core/state_buffer_queue.h index 509e9327..56f946c3 100644 --- a/envpool/core/state_buffer_queue.h +++ b/envpool/core/state_buffer_queue.h @@ -54,14 +54,14 @@ class StateBufferQueue { max_num_players_(max_num_players), is_player_state_(Transform(specs, [](const ShapeSpec& s) { - return (!s.shape_.empty() && - s.shape_[0] == -1); + return (!s.shape.empty() && + s.shape[0] == -1); })), specs_(Transform(specs, [=](ShapeSpec s) { - if (!s.shape_.empty() && s.shape_[0] == -1) { + if (!s.shape.empty() && s.shape[0] == -1) { // If first dim is num_players - s.shape_[0] = batch_ * max_num_players_; + s.shape[0] = batch_ * max_num_players_; return s; } return s.Batch(batch_); diff --git a/envpool/core/state_buffer_queue_test.cc b/envpool/core/state_buffer_queue_test.cc index 9ac5d052..caac0394 100644 --- a/envpool/core/state_buffer_queue_test.cc +++ b/envpool/core/state_buffer_queue_test.cc @@ -33,9 +33,9 @@ TEST(StateBufferQueueTest, Basic) { for (std::size_t i = 0; i < batch; ++i) { std::size_t num_players = 1; auto slice = queue.Allocate(num_players); - slice.done_write_(); - EXPECT_EQ(slice.arr_[0].Shape(0), 10); - EXPECT_EQ(slice.arr_[1].Shape(0), 1); + slice.done_write(); + EXPECT_EQ(slice.arr[0].Shape(0), 10); + EXPECT_EQ(slice.arr[1].Shape(0), 1); size += num_players; } std::vector out = queue.Wait(); @@ -63,9 +63,9 @@ TEST(StateBufferQueueTest, SinglePlayerSync) { std::shuffle(order.begin(), order.end(), gen); for (std::size_t i = 0; i < batch; ++i) { auto slice = queue.Allocate(1, order[i]); - EXPECT_EQ(slice.arr_[0].Shape(0), 1); - slice.arr_[0] = static_cast(i); - slice.done_write_(); + EXPECT_EQ(slice.arr[0].Shape(0), 1); + slice.arr[0] = static_cast(i); + slice.done_write(); } std::vector out = queue.Wait(); EXPECT_EQ(out[0].Shape(0), batch); @@ -80,8 +80,8 @@ TEST(StateBufferQueueTest, SinglePlayerSync) { env_id.pop_back(); for (std::size_t i = 0; i < env_id.size(); ++i) { auto slice = queue.Allocate(1, i); - slice.arr_[0] = env_id[i]; - slice.done_write_(); + slice.arr[0] = env_id[i]; + slice.done_write(); } std::vector out = queue.Wait(batch - env_id.size()); EXPECT_EQ(out[0].Shape(0), env_id.size()); @@ -104,9 +104,9 @@ TEST(StateBufferQueueTest, NumPlayers) { for (std::size_t i = 0; i < batch; ++i) { std::size_t num_players = 1 + std::rand() % max_num_players; auto slice = queue.Allocate(num_players); - slice.done_write_(); - EXPECT_EQ(slice.arr_[0].Shape(0), num_players); - EXPECT_EQ(slice.arr_[1].Shape(0), 1); + slice.done_write(); + EXPECT_EQ(slice.arr[0].Shape(0), num_players); + EXPECT_EQ(slice.arr[1].Shape(0), 1); size += num_players; } std::vector out = queue.Wait(); @@ -128,9 +128,9 @@ TEST(StateBufferQueueTest, MultipleTimes) { for (std::size_t i = 0; i < batch; ++i) { std::size_t num_players = 1 + std::rand() % max_num_players; auto slice = queue.Allocate(num_players); - slice.done_write_(); - EXPECT_EQ(slice.arr_[0].Shape(0), num_players); - EXPECT_EQ(slice.arr_[1].Shape(0), 1); + slice.done_write(); + EXPECT_EQ(slice.arr[0].Shape(0), num_players); + EXPECT_EQ(slice.arr[1].Shape(0), 1); size += num_players; } std::vector out = queue.Wait(); @@ -152,7 +152,7 @@ TEST(StateBufferQueueTest, ConcurrentSinglePlayer) { for (std::size_t i = 0; i < num_envs; ++i) { pool.enqueue([&] { auto slice = queue.Allocate(1); - slice.done_write_(); + slice.done_write(); }); } std::size_t total = 10000; @@ -164,7 +164,7 @@ TEST(StateBufferQueueTest, ConcurrentSinglePlayer) { auto slice = queue.Allocate(1); std::this_thread::sleep_for( std::chrono::nanoseconds(std::rand() % 1000 + 1)); - slice.done_write_(); + slice.done_write(); }); } } @@ -184,7 +184,7 @@ TEST(StateBufferQueueTest, ConcurrentMultiPlayer) { pool.enqueue([&] { std::size_t num_players = 1 + std::rand() % max_num_players; auto slice = queue.Allocate(num_players); - slice.done_write_(); + slice.done_write(); }); } std::size_t total = 1000; @@ -197,7 +197,7 @@ TEST(StateBufferQueueTest, ConcurrentMultiPlayer) { auto slice = queue.Allocate(num_players); std::this_thread::sleep_for( std::chrono::nanoseconds(std::rand() % 1000 + 1)); - slice.done_write_(); + slice.done_write(); }); } } diff --git a/envpool/core/state_buffer_test.cc b/envpool/core/state_buffer_test.cc index a2d653c0..95240f85 100644 --- a/envpool/core/state_buffer_test.cc +++ b/envpool/core/state_buffer_test.cc @@ -37,7 +37,7 @@ TEST(StateBufferTest, Basic) { auto r = buffer.Allocate(num); offset = buffer.Offsets(); EXPECT_EQ(std::get<0>(offset), std::get<1>(offset)); - r.done_write_(); + r.done_write(); } auto bs = buffer.Wait(); EXPECT_EQ(bs[0].Shape(0), total); @@ -60,10 +60,10 @@ TEST(StateBufferTest, SinglePlayerSync) { auto r = buffer.Allocate(num, batch - 1 - i); offset = buffer.Offsets(); EXPECT_EQ(std::get<0>(offset), std::get<1>(offset)); - EXPECT_EQ(r.arr_[0].Shape(), std::vector({10, 2, 2})); - EXPECT_EQ(r.arr_[1].Shape(), std::vector({1, 2, 2})); - r.arr_[1] = i; // only the first element is modified - r.done_write_(); + EXPECT_EQ(r.arr[0].Shape(), std::vector({10, 2, 2})); + EXPECT_EQ(r.arr[1].Shape(), std::vector({1, 2, 2})); + r.arr[1] = i; // only the first element is modified + r.done_write(); } auto bs = buffer.Wait(); EXPECT_EQ(bs[0].Shape(0), total); @@ -82,7 +82,7 @@ TEST(StateBufferTest, Truncate) { StateBuffer buffer(batch, max_num_players, specs, std::vector({false, true})); auto r = buffer.Allocate(player_num); - r.done_write_(); + r.done_write(); buffer.Done(batch - 1); auto bs = buffer.Wait(); EXPECT_EQ(bs[0].Shape(), std::vector({1, 10, 2, 2})); @@ -105,10 +105,10 @@ TEST(StateBufferTest, MultiPlayers) { total += num; auto r = buffer.Allocate(num); offset = buffer.Offsets(); - EXPECT_EQ(num, r.arr_[0].Shape()[0]); + EXPECT_EQ(num, r.arr[0].Shape()[0]); EXPECT_EQ(std::get<0>(offset), total); EXPECT_EQ(std::get<1>(offset), i + 1); - r.done_write_(); + r.done_write(); } auto bs = buffer.Wait(); EXPECT_EQ(bs[0].Shape(0), total); diff --git a/envpool/mujoco/ant.h b/envpool/mujoco/ant.h index d0176a75..b0b8f783 100644 --- a/envpool/mujoco/ant.h +++ b/envpool/mujoco/ant.h @@ -83,22 +83,22 @@ class AntEnv : public Env, public MujocoEnv { public: AntEnv(const Spec& spec, int env_id) : Env(spec, env_id), - MujocoEnv(spec.config_["base_path"_] + "/mujoco/assets/ant.xml", - spec.config_["frame_skip"_], spec.config_["post_constraint"_], - spec.config_["max_episode_steps"_]), - terminate_when_unhealthy_(spec.config_["terminate_when_unhealthy"_]), - no_pos_(spec.config_["exclude_current_positions_from_observation"_]), - ctrl_cost_weight_(spec.config_["ctrl_cost_weight"_]), - contact_cost_weight_(spec.config_["contact_cost_weight"_]), - forward_reward_weight_(spec.config_["forward_reward_weight"_]), - healthy_reward_(spec.config_["healthy_reward"_]), - healthy_z_min_(spec.config_["healthy_z_min"_]), - healthy_z_max_(spec.config_["healthy_z_max"_]), - contact_force_min_(spec.config_["contact_force_min"_]), - contact_force_max_(spec.config_["contact_force_max"_]), - dist_qpos_(-spec.config_["reset_noise_scale"_], - spec.config_["reset_noise_scale"_]), - dist_qvel_(0, spec.config_["reset_noise_scale"_]) {} + MujocoEnv(spec.config["base_path"_] + "/mujoco/assets/ant.xml", + spec.config["frame_skip"_], spec.config["post_constraint"_], + spec.config["max_episode_steps"_]), + terminate_when_unhealthy_(spec.config["terminate_when_unhealthy"_]), + no_pos_(spec.config["exclude_current_positions_from_observation"_]), + ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), + contact_cost_weight_(spec.config["contact_cost_weight"_]), + forward_reward_weight_(spec.config["forward_reward_weight"_]), + healthy_reward_(spec.config["healthy_reward"_]), + healthy_z_min_(spec.config["healthy_z_min"_]), + healthy_z_max_(spec.config["healthy_z_max"_]), + contact_force_min_(spec.config["contact_force_min"_]), + contact_force_max_(spec.config["contact_force_max"_]), + dist_qpos_(-spec.config["reset_noise_scale"_], + spec.config["reset_noise_scale"_]), + dist_qvel_(0, spec.config["reset_noise_scale"_]) {} void MujocoResetModel() override { for (int i = 0; i < model_->nq; ++i) { diff --git a/envpool/mujoco/half_cheetah.h b/envpool/mujoco/half_cheetah.h index da29c7ab..2ff24e2c 100644 --- a/envpool/mujoco/half_cheetah.h +++ b/envpool/mujoco/half_cheetah.h @@ -70,16 +70,15 @@ class HalfCheetahEnv : public Env, public MujocoEnv { public: HalfCheetahEnv(const Spec& spec, int env_id) : Env(spec, env_id), - MujocoEnv( - spec.config_["base_path"_] + "/mujoco/assets/half_cheetah.xml", - spec.config_["frame_skip"_], spec.config_["post_constraint"_], - spec.config_["max_episode_steps"_]), - no_pos_(spec.config_["exclude_current_positions_from_observation"_]), - ctrl_cost_weight_(spec.config_["ctrl_cost_weight"_]), - forward_reward_weight_(spec.config_["forward_reward_weight"_]), - dist_qpos_(-spec.config_["reset_noise_scale"_], - spec.config_["reset_noise_scale"_]), - dist_qvel_(0, spec.config_["reset_noise_scale"_]) {} + MujocoEnv(spec.config["base_path"_] + "/mujoco/assets/half_cheetah.xml", + spec.config["frame_skip"_], spec.config["post_constraint"_], + spec.config["max_episode_steps"_]), + no_pos_(spec.config["exclude_current_positions_from_observation"_]), + ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), + forward_reward_weight_(spec.config["forward_reward_weight"_]), + dist_qpos_(-spec.config["reset_noise_scale"_], + spec.config["reset_noise_scale"_]), + dist_qvel_(0, spec.config["reset_noise_scale"_]) {} void MujocoResetModel() override { for (int i = 0; i < model_->nq; ++i) { diff --git a/envpool/mujoco/hopper.h b/envpool/mujoco/hopper.h index 697705de..77aa9e7d 100644 --- a/envpool/mujoco/hopper.h +++ b/envpool/mujoco/hopper.h @@ -76,23 +76,23 @@ class HopperEnv : public Env, public MujocoEnv { public: HopperEnv(const Spec& spec, int env_id) : Env(spec, env_id), - MujocoEnv(spec.config_["base_path"_] + "/mujoco/assets/hopper.xml", - spec.config_["frame_skip"_], spec.config_["post_constraint"_], - spec.config_["max_episode_steps"_]), - terminate_when_unhealthy_(spec.config_["terminate_when_unhealthy"_]), - no_pos_(spec.config_["exclude_current_positions_from_observation"_]), - ctrl_cost_weight_(spec.config_["ctrl_cost_weight"_]), - forward_reward_weight_(spec.config_["forward_reward_weight"_]), - healthy_reward_(spec.config_["healthy_reward"_]), - healthy_z_min_(spec.config_["healthy_z_min"_]), - velocity_min_(spec.config_["velocity_min"_]), - velocity_max_(spec.config_["velocity_max"_]), - healthy_state_min_(spec.config_["healthy_state_min"_]), - healthy_state_max_(spec.config_["healthy_state_max"_]), - healthy_angle_min_(spec.config_["healthy_angle_min"_]), - healthy_angle_max_(spec.config_["healthy_angle_max"_]), - dist_(-spec.config_["reset_noise_scale"_], - spec.config_["reset_noise_scale"_]) {} + MujocoEnv(spec.config["base_path"_] + "/mujoco/assets/hopper.xml", + spec.config["frame_skip"_], spec.config["post_constraint"_], + spec.config["max_episode_steps"_]), + terminate_when_unhealthy_(spec.config["terminate_when_unhealthy"_]), + no_pos_(spec.config["exclude_current_positions_from_observation"_]), + ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), + forward_reward_weight_(spec.config["forward_reward_weight"_]), + healthy_reward_(spec.config["healthy_reward"_]), + healthy_z_min_(spec.config["healthy_z_min"_]), + velocity_min_(spec.config["velocity_min"_]), + velocity_max_(spec.config["velocity_max"_]), + healthy_state_min_(spec.config["healthy_state_min"_]), + healthy_state_max_(spec.config["healthy_state_max"_]), + healthy_angle_min_(spec.config["healthy_angle_min"_]), + healthy_angle_max_(spec.config["healthy_angle_max"_]), + dist_(-spec.config["reset_noise_scale"_], + spec.config["reset_noise_scale"_]) {} void MujocoResetModel() override { for (int i = 0; i < model_->nq; ++i) { diff --git a/envpool/mujoco/humanoid.h b/envpool/mujoco/humanoid.h index 82667a35..8260b401 100644 --- a/envpool/mujoco/humanoid.h +++ b/envpool/mujoco/humanoid.h @@ -81,22 +81,22 @@ class HumanoidEnv : public Env, public MujocoEnv { public: HumanoidEnv(const Spec& spec, int env_id) : Env(spec, env_id), - MujocoEnv(spec.config_["base_path"_] + "/mujoco/assets/humanoid.xml", - spec.config_["frame_skip"_], spec.config_["post_constraint"_], - spec.config_["max_episode_steps"_]), - terminate_when_unhealthy_(spec.config_["terminate_when_unhealthy"_]), - no_pos_(spec.config_["exclude_current_positions_from_observation"_]), - ctrl_cost_weight_(spec.config_["ctrl_cost_weight"_]), - contact_cost_weight_(spec.config_["contact_cost_weight"_]), - contact_cost_max_(spec.config_["contact_cost_max"_]), - forward_reward_weight_(spec.config_["forward_reward_weight"_]), - healthy_reward_(spec.config_["healthy_reward"_]), - healthy_z_min_(spec.config_["healthy_z_min"_]), - healthy_z_max_(spec.config_["healthy_z_max"_]), + MujocoEnv(spec.config["base_path"_] + "/mujoco/assets/humanoid.xml", + spec.config["frame_skip"_], spec.config["post_constraint"_], + spec.config["max_episode_steps"_]), + terminate_when_unhealthy_(spec.config["terminate_when_unhealthy"_]), + no_pos_(spec.config["exclude_current_positions_from_observation"_]), + ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), + contact_cost_weight_(spec.config["contact_cost_weight"_]), + contact_cost_max_(spec.config["contact_cost_max"_]), + forward_reward_weight_(spec.config["forward_reward_weight"_]), + healthy_reward_(spec.config["healthy_reward"_]), + healthy_z_min_(spec.config["healthy_z_min"_]), + healthy_z_max_(spec.config["healthy_z_max"_]), mass_x_(0), mass_y_(0), - dist_(-spec.config_["reset_noise_scale"_], - spec.config_["reset_noise_scale"_]) {} + dist_(-spec.config["reset_noise_scale"_], + spec.config["reset_noise_scale"_]) {} void MujocoResetModel() override { for (int i = 0; i < model_->nq; ++i) { diff --git a/envpool/mujoco/humanoid_standup.h b/envpool/mujoco/humanoid_standup.h index c9a9eee2..fa056a54 100644 --- a/envpool/mujoco/humanoid_standup.h +++ b/envpool/mujoco/humanoid_standup.h @@ -74,17 +74,17 @@ class HumanoidStandupEnv : public Env, HumanoidStandupEnv(const Spec& spec, int env_id) : Env(spec, env_id), MujocoEnv( - spec.config_["base_path"_] + "/mujoco/assets/humanoidstandup.xml", - spec.config_["frame_skip"_], spec.config_["post_constraint"_], - spec.config_["max_episode_steps"_]), - no_pos_(spec.config_["exclude_current_positions_from_observation"_]), - ctrl_cost_weight_(spec.config_["ctrl_cost_weight"_]), - contact_cost_weight_(spec.config_["contact_cost_weight"_]), - contact_cost_max_(spec.config_["contact_cost_max"_]), - forward_reward_weight_(spec.config_["forward_reward_weight"_]), - healthy_reward_(spec.config_["healthy_reward"_]), - dist_(-spec.config_["reset_noise_scale"_], - spec.config_["reset_noise_scale"_]) {} + spec.config["base_path"_] + "/mujoco/assets/humanoidstandup.xml", + spec.config["frame_skip"_], spec.config["post_constraint"_], + spec.config["max_episode_steps"_]), + no_pos_(spec.config["exclude_current_positions_from_observation"_]), + ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), + contact_cost_weight_(spec.config["contact_cost_weight"_]), + contact_cost_max_(spec.config["contact_cost_max"_]), + forward_reward_weight_(spec.config["forward_reward_weight"_]), + healthy_reward_(spec.config["healthy_reward"_]), + dist_(-spec.config["reset_noise_scale"_], + spec.config["reset_noise_scale"_]) {} void MujocoResetModel() override { for (int i = 0; i < model_->nq; ++i) { diff --git a/envpool/mujoco/inverted_double_pendulum.h b/envpool/mujoco/inverted_double_pendulum.h index b6db4721..70f4b7d0 100644 --- a/envpool/mujoco/inverted_double_pendulum.h +++ b/envpool/mujoco/inverted_double_pendulum.h @@ -68,17 +68,17 @@ class InvertedDoublePendulumEnv : public Env, public: InvertedDoublePendulumEnv(const Spec& spec, int env_id) : Env(spec, env_id), - MujocoEnv(spec.config_["base_path"_] + + MujocoEnv(spec.config["base_path"_] + "/mujoco/assets/inverted_double_pendulum.xml", - spec.config_["frame_skip"_], spec.config_["post_constraint"_], - spec.config_["max_episode_steps"_]), - healthy_reward_(spec.config_["healthy_reward"_]), - healthy_z_max_(spec.config_["healthy_z_max"_]), - observation_min_(spec.config_["observation_min"_]), - observation_max_(spec.config_["observation_max"_]), - dist_qpos_(-spec.config_["reset_noise_scale"_], - spec.config_["reset_noise_scale"_]), - dist_qvel_(0, spec.config_["reset_noise_scale"_]) {} + spec.config["frame_skip"_], spec.config["post_constraint"_], + spec.config["max_episode_steps"_]), + healthy_reward_(spec.config["healthy_reward"_]), + healthy_z_max_(spec.config["healthy_z_max"_]), + observation_min_(spec.config["observation_min"_]), + observation_max_(spec.config["observation_max"_]), + dist_qpos_(-spec.config["reset_noise_scale"_], + spec.config["reset_noise_scale"_]), + dist_qvel_(0, spec.config["reset_noise_scale"_]) {} void MujocoResetModel() override { for (int i = 0; i < model_->nq; ++i) { diff --git a/envpool/mujoco/inverted_pendulum.h b/envpool/mujoco/inverted_pendulum.h index 96d7eaea..aad39a53 100644 --- a/envpool/mujoco/inverted_pendulum.h +++ b/envpool/mujoco/inverted_pendulum.h @@ -66,14 +66,14 @@ class InvertedPendulumEnv : public Env, InvertedPendulumEnv(const Spec& spec, int env_id) : Env(spec, env_id), MujocoEnv( - spec.config_["base_path"_] + "/mujoco/assets/inverted_pendulum.xml", - spec.config_["frame_skip"_], spec.config_["post_constraint"_], - spec.config_["max_episode_steps"_]), - healthy_reward_(spec.config_["healthy_reward"_]), - healthy_z_min_(spec.config_["healthy_z_min"_]), - healthy_z_max_(spec.config_["healthy_z_max"_]), - dist_(-spec.config_["reset_noise_scale"_], - spec.config_["reset_noise_scale"_]) {} + spec.config["base_path"_] + "/mujoco/assets/inverted_pendulum.xml", + spec.config["frame_skip"_], spec.config["post_constraint"_], + spec.config["max_episode_steps"_]), + healthy_reward_(spec.config["healthy_reward"_]), + healthy_z_min_(spec.config["healthy_z_min"_]), + healthy_z_max_(spec.config["healthy_z_max"_]), + dist_(-spec.config["reset_noise_scale"_], + spec.config["reset_noise_scale"_]) {} void MujocoResetModel() override { for (int i = 0; i < model_->nq; ++i) { diff --git a/envpool/mujoco/pusher.h b/envpool/mujoco/pusher.h index 6f90fc2e..8bc0aea8 100644 --- a/envpool/mujoco/pusher.h +++ b/envpool/mujoco/pusher.h @@ -68,19 +68,19 @@ class PusherEnv : public Env, public MujocoEnv { public: PusherEnv(const Spec& spec, int env_id) : Env(spec, env_id), - MujocoEnv(spec.config_["base_path"_] + "/mujoco/assets/pusher.xml", - spec.config_["frame_skip"_], spec.config_["post_constraint"_], - spec.config_["max_episode_steps"_]), - ctrl_cost_weight_(spec.config_["ctrl_cost_weight"_]), - dist_cost_weight_(spec.config_["dist_cost_weight"_]), - near_cost_weight_(spec.config_["near_cost_weight"_]), - cylinder_dist_min_(spec.config_["cylinder_dist_min"_]), - dist_qpos_x_(spec.config_["cylinder_x_min"_], - spec.config_["cylinder_x_max"_]), - dist_qpos_y_(spec.config_["cylinder_y_min"_], - spec.config_["cylinder_y_max"_]), - dist_qvel_(-spec.config_["reset_qvel_scale"_], - spec.config_["reset_qvel_scale"_]) {} + MujocoEnv(spec.config["base_path"_] + "/mujoco/assets/pusher.xml", + spec.config["frame_skip"_], spec.config["post_constraint"_], + spec.config["max_episode_steps"_]), + ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), + dist_cost_weight_(spec.config["dist_cost_weight"_]), + near_cost_weight_(spec.config["near_cost_weight"_]), + cylinder_dist_min_(spec.config["cylinder_dist_min"_]), + dist_qpos_x_(spec.config["cylinder_x_min"_], + spec.config["cylinder_x_max"_]), + dist_qpos_y_(spec.config["cylinder_y_min"_], + spec.config["cylinder_y_max"_]), + dist_qvel_(-spec.config["reset_qvel_scale"_], + spec.config["reset_qvel_scale"_]) {} void MujocoResetModel() override { for (int i = 0; i < model_->nq - 4; ++i) { diff --git a/envpool/mujoco/reacher.h b/envpool/mujoco/reacher.h index a6bd8c6d..e58ac825 100644 --- a/envpool/mujoco/reacher.h +++ b/envpool/mujoco/reacher.h @@ -66,18 +66,18 @@ class ReacherEnv : public Env, public MujocoEnv { public: ReacherEnv(const Spec& spec, int env_id) : Env(spec, env_id), - MujocoEnv(spec.config_["base_path"_] + "/mujoco/assets/reacher.xml", - spec.config_["frame_skip"_], spec.config_["post_constraint"_], - spec.config_["max_episode_steps"_]), - ctrl_cost_weight_(spec.config_["ctrl_cost_weight"_]), - dist_cost_weight_(spec.config_["dist_cost_weight"_]), - reset_goal_scale_(spec.config_["reset_goal_scale"_]), - dist_qpos_(-spec.config_["reset_qpos_scale"_], - spec.config_["reset_qpos_scale"_]), - dist_qvel_(-spec.config_["reset_qvel_scale"_], - spec.config_["reset_qvel_scale"_]), - dist_goal_(-spec.config_["reset_goal_scale"_], - spec.config_["reset_goal_scale"_]) {} + MujocoEnv(spec.config["base_path"_] + "/mujoco/assets/reacher.xml", + spec.config["frame_skip"_], spec.config["post_constraint"_], + spec.config["max_episode_steps"_]), + ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), + dist_cost_weight_(spec.config["dist_cost_weight"_]), + reset_goal_scale_(spec.config["reset_goal_scale"_]), + dist_qpos_(-spec.config["reset_qpos_scale"_], + spec.config["reset_qpos_scale"_]), + dist_qvel_(-spec.config["reset_qvel_scale"_], + spec.config["reset_qvel_scale"_]), + dist_goal_(-spec.config["reset_goal_scale"_], + spec.config["reset_goal_scale"_]) {} void MujocoResetModel() override { for (int i = 0; i < model_->nq - 2; ++i) { diff --git a/envpool/mujoco/swimmer.h b/envpool/mujoco/swimmer.h index e026fe43..09a3eeb7 100644 --- a/envpool/mujoco/swimmer.h +++ b/envpool/mujoco/swimmer.h @@ -72,14 +72,14 @@ class SwimmerEnv : public Env, public MujocoEnv { public: SwimmerEnv(const Spec& spec, int env_id) : Env(spec, env_id), - MujocoEnv(spec.config_["base_path"_] + "/mujoco/assets/swimmer.xml", - spec.config_["frame_skip"_], spec.config_["post_constraint"_], - spec.config_["max_episode_steps"_]), - no_pos_(spec.config_["exclude_current_positions_from_observation"_]), - ctrl_cost_weight_(spec.config_["ctrl_cost_weight"_]), - forward_reward_weight_(spec.config_["forward_reward_weight"_]), - dist_(-spec.config_["reset_noise_scale"_], - spec.config_["reset_noise_scale"_]) {} + MujocoEnv(spec.config["base_path"_] + "/mujoco/assets/swimmer.xml", + spec.config["frame_skip"_], spec.config["post_constraint"_], + spec.config["max_episode_steps"_]), + no_pos_(spec.config["exclude_current_positions_from_observation"_]), + ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), + forward_reward_weight_(spec.config["forward_reward_weight"_]), + dist_(-spec.config["reset_noise_scale"_], + spec.config["reset_noise_scale"_]) {} void MujocoResetModel() override { for (int i = 0; i < model_->nq; ++i) { diff --git a/envpool/mujoco/walker2d.h b/envpool/mujoco/walker2d.h index 9ab0b39d..6b841646 100644 --- a/envpool/mujoco/walker2d.h +++ b/envpool/mujoco/walker2d.h @@ -74,22 +74,22 @@ class Walker2dEnv : public Env, public MujocoEnv { public: Walker2dEnv(const Spec& spec, int env_id) : Env(spec, env_id), - MujocoEnv(spec.config_["base_path"_] + "/mujoco/assets/walker2d.xml", - spec.config_["frame_skip"_], spec.config_["post_constraint"_], - spec.config_["max_episode_steps"_]), - terminate_when_unhealthy_(spec.config_["terminate_when_unhealthy"_]), - no_pos_(spec.config_["exclude_current_positions_from_observation"_]), - ctrl_cost_weight_(spec.config_["ctrl_cost_weight"_]), - forward_reward_weight_(spec.config_["forward_reward_weight"_]), - healthy_reward_(spec.config_["healthy_reward"_]), - healthy_z_min_(spec.config_["healthy_z_min"_]), - healthy_z_max_(spec.config_["healthy_z_max"_]), - healthy_angle_min_(spec.config_["healthy_angle_min"_]), - healthy_angle_max_(spec.config_["healthy_angle_max"_]), - velocity_min_(spec.config_["velocity_min"_]), - velocity_max_(spec.config_["velocity_max"_]), - dist_(-spec.config_["reset_noise_scale"_], - spec.config_["reset_noise_scale"_]) {} + MujocoEnv(spec.config["base_path"_] + "/mujoco/assets/walker2d.xml", + spec.config["frame_skip"_], spec.config["post_constraint"_], + spec.config["max_episode_steps"_]), + terminate_when_unhealthy_(spec.config["terminate_when_unhealthy"_]), + no_pos_(spec.config["exclude_current_positions_from_observation"_]), + ctrl_cost_weight_(spec.config["ctrl_cost_weight"_]), + forward_reward_weight_(spec.config["forward_reward_weight"_]), + healthy_reward_(spec.config["healthy_reward"_]), + healthy_z_min_(spec.config["healthy_z_min"_]), + healthy_z_max_(spec.config["healthy_z_max"_]), + healthy_angle_min_(spec.config["healthy_angle_min"_]), + healthy_angle_max_(spec.config["healthy_angle_max"_]), + velocity_min_(spec.config["velocity_min"_]), + velocity_max_(spec.config["velocity_max"_]), + dist_(-spec.config["reset_noise_scale"_], + spec.config["reset_noise_scale"_]) {} void MujocoResetModel() override { for (int i = 0; i < model_->nq; ++i) { diff --git a/envpool/toy_text/blackjack.h b/envpool/toy_text/blackjack.h index b0095809..f2a1b3f0 100644 --- a/envpool/toy_text/blackjack.h +++ b/envpool/toy_text/blackjack.h @@ -56,8 +56,8 @@ class BlackjackEnv : public Env { public: BlackjackEnv(const Spec& spec, int env_id) : Env(spec, env_id), - natural_(spec.config_["natural"_]), - sab_(spec.config_["sab"_]), + natural_(spec.config["natural"_]), + sab_(spec.config["sab"_]), dist_(1, 13), done_(true) {} diff --git a/envpool/toy_text/catch.h b/envpool/toy_text/catch.h index 65f2ac42..7783ad60 100644 --- a/envpool/toy_text/catch.h +++ b/envpool/toy_text/catch.h @@ -53,8 +53,8 @@ class CatchEnv : public Env { public: CatchEnv(const Spec& spec, int env_id) : Env(spec, env_id), - height_(spec.config_["height"_]), - width_(spec.config_["width"_]), + height_(spec.config["height"_]), + width_(spec.config["width"_]), dist_(0, width_ - 1), done_(true) {} diff --git a/envpool/toy_text/frozen_lake.h b/envpool/toy_text/frozen_lake.h index 2be3a328..4472a930 100644 --- a/envpool/toy_text/frozen_lake.h +++ b/envpool/toy_text/frozen_lake.h @@ -58,8 +58,8 @@ class FrozenLakeEnv : public Env { public: FrozenLakeEnv(const Spec& spec, int env_id) : Env(spec, env_id), - size_(spec.config_["size"_]), - max_episode_steps_(spec.config_["max_episode_steps"_]), + size_(spec.config["size"_]), + max_episode_steps_(spec.config["max_episode_steps"_]), dist_(-1, 1), done_(true) { if (size_ != 8) { diff --git a/envpool/toy_text/nchain.h b/envpool/toy_text/nchain.h index b6b99811..28df186a 100644 --- a/envpool/toy_text/nchain.h +++ b/envpool/toy_text/nchain.h @@ -55,7 +55,7 @@ class NChainEnv : public Env { public: NChainEnv(const Spec& spec, int env_id) : Env(spec, env_id), - max_episode_steps_(spec.config_["max_episode_steps"_]), + max_episode_steps_(spec.config["max_episode_steps"_]), dist_(0, 1), done_(true) {} diff --git a/envpool/toy_text/taxi.h b/envpool/toy_text/taxi.h index d025c9aa..5005c4c0 100644 --- a/envpool/toy_text/taxi.h +++ b/envpool/toy_text/taxi.h @@ -58,7 +58,7 @@ class TaxiEnv : public Env { public: TaxiEnv(const Spec& spec, int env_id) : Env(spec, env_id), - max_episode_steps_(spec.config_["max_episode_steps"_]), + max_episode_steps_(spec.config["max_episode_steps"_]), dist_car_(0, 3), dist_loc_(0, 4), done_(true), diff --git a/envpool/vizdoom/vizdoom_env.h b/envpool/vizdoom/vizdoom_env.h index 995516f4..0b17f96a 100644 --- a/envpool/vizdoom/vizdoom_env.h +++ b/envpool/vizdoom/vizdoom_env.h @@ -160,48 +160,47 @@ class VizdoomEnv : public Env { : Env(spec, env_id), info_index_({19, 20, 21, 22, 23, 24, 10, 7, 4, 3, 9, 5, 0, 15, 16, 73}), dg_(new DoomGame()), - lmp_dir_(spec.config_["lmp_save_dir"_]), + lmp_dir_(spec.config["lmp_save_dir"_]), save_lmp_(lmp_dir_.length() > 0), - episodic_life_(spec.config_["episodic_life"_]), - use_combined_action_(spec.config_["use_combined_action"_]), - use_inter_area_resize_(spec.config_["use_inter_area_resize"_]), + episodic_life_(spec.config["episodic_life"_]), + use_combined_action_(spec.config["use_combined_action"_]), + use_inter_area_resize_(spec.config["use_inter_area_resize"_]), done_(true), - max_episode_steps_(spec.config_["max_episode_steps"_]), + max_episode_steps_(spec.config["max_episode_steps"_]), elapsed_step_(max_episode_steps_ + 1), - stack_num_(spec.config_["stack_num"_]), - frame_skip_(spec.config_["frame_skip"_]), + stack_num_(spec.config["stack_num"_]), + frame_skip_(spec.config["frame_skip"_]), episode_count_(0), last_deathcount_(0), last_hitcount_(0), last_damagecount_(0), - weapon_duration_(spec.config_["weapon_duration"_]), + weapon_duration_(spec.config["weapon_duration"_]), weapon_reward_(10) { if (save_lmp_) { - lmp_dir_ = spec.config_["lmp_save_dir"_] + "/env_" + - std::to_string(env_id) + "_"; + lmp_dir_ = + spec.config["lmp_save_dir"_] + "/env_" + std::to_string(env_id) + "_"; } dg_->setViZDoomPath( - MergePath(spec.config_["base_path"_], spec.config_["vzd_path"_])); + MergePath(spec.config["base_path"_], spec.config["vzd_path"_])); dg_->setDoomGamePath( - MergePath(spec.config_["base_path"_], spec.config_["iwad_path"_])); - dg_->loadConfig(spec.config_["cfg_path"_]); + MergePath(spec.config["base_path"_], spec.config["iwad_path"_])); + dg_->loadConfig(spec.config["cfg_path"_]); dg_->setWindowVisible(false); - dg_->addGameArgs(spec.config_["game_args"_]); + dg_->addGameArgs(spec.config["game_args"_]); dg_->setMode(PLAYER); dg_->setEpisodeTimeout((max_episode_steps_ + 1) * frame_skip_); - if (!spec.config_["wad_path"_].empty()) { - dg_->setDoomScenarioPath(spec.config_["wad_path"_]); + if (!spec.config["wad_path"_].empty()) { + dg_->setDoomScenarioPath(spec.config["wad_path"_]); } - dg_->setSeed(spec.config_["seed"_]); - dg_->setDoomMap(spec.config_["map_id"_]); + dg_->setSeed(spec.config["seed"_]); + dg_->setDoomMap(spec.config["map_id"_]); channel_ = dg_->getScreenChannels(); raw_buf_ = Array(FrameSpec({dg_->getScreenHeight(), dg_->getScreenWidth(), 1})); for (int i = 0; i < stack_num_; ++i) { - stack_buf_.emplace_back( - Array(FrameSpec({channel_, spec.config_["img_height"_], - spec.config_["img_width"_]}))); + stack_buf_.emplace_back(Array(FrameSpec( + {channel_, spec.config["img_height"_], spec.config["img_width"_]}))); } for (auto i : info_index_) { dg_->addAvailableGameVariable(static_cast(i)); @@ -239,19 +238,19 @@ class VizdoomEnv : public Env { button_list_ = dg_->getAvailableButtons(); std::vector> delta_config( button_string_list.size()); - for (const auto& i : spec.config_["delta_button_config"_]) { + for (const auto& i : spec.config["delta_button_config"_]) { int button_index = Str2Button(i.first); if (button_index != -1) { delta_config[button_index] = i.second; } } - action_set_ = BuildActionSet(button_list_, spec.config_["force_speed"_], - delta_config); + action_set_ = + BuildActionSet(button_list_, spec.config["force_speed"_], delta_config); // reward config pos_reward_.resize(gv_list_.size(), 0.0); neg_reward_.resize(gv_list_.size(), 0.0); - for (const auto& i : spec.config_["reward_config"_]) { + for (const auto& i : spec.config["reward_config"_]) { int gv_index = Str2GV(i.first); if (gv_index == -1) { continue; @@ -265,7 +264,7 @@ class VizdoomEnv : public Env { neg_reward_[index] = std::get<1>(i.second); } // weapon reward config - const auto& weapon_config = spec.config_["selected_weapon_reward_config"_]; + const auto& weapon_config = spec.config["selected_weapon_reward_config"_]; for (int i = 0; i < 8; ++i) { auto it = weapon_config.find(i); if (it != weapon_config.end()) { @@ -406,14 +405,14 @@ class VizdoomEnv : public Env { // get screen auto* raw_ptr = static_cast(raw_buf_.Data()); - std::size_t size = raw_buf_.size_; + std::size_t size = raw_buf_.size; for (int c = 0; c < channel_; ++c) { // gamestate->screenBuffer is channel-first image memcpy(raw_ptr, gamestate->screenBuffer->data() + c * size, size); auto slice = tgt[c]; Resize(raw_buf_, &slice, use_inter_area_resize_); } - size = tgt.size_; + size = tgt.size; stack_buf_.emplace_back(tgt); if (is_reset) { for (auto& s : stack_buf_) { diff --git a/examples/env_step.py b/examples/env_step.py index 5889769f..383d8b5c 100644 --- a/examples/env_step.py +++ b/examples/env_step.py @@ -62,6 +62,7 @@ def dm_sync_step() -> None: def async_step() -> None: num_envs = 8 batch_size = 4 + # Create an envpool that each step only 4 of 8 result will be out, # and left other "slow step" envs execute at background. env = envpool.make_dm("Pong-v5", num_envs=num_envs, batch_size=batch_size) @@ -73,13 +74,27 @@ def async_step() -> None: # generate action with len(action) == len(env_id) action = np.random.randint(action_num, size=batch_size) ts = env.step(action, env_id) + # Same as gym - env = envpool.make_gym("Pong-v5", num_envs=num_envs, batch_size=batch_size) - # But gym's reset() API cannot return env_id - obs = env.reset() + env = envpool.make_gym( + "Pong-v5", + num_envs=num_envs, + batch_size=batch_size, + gym_reset_return_info=True, + ) + # If you want gym's reset() API return env_id, + # just set gym_reset_return_info=True + obs, info = env.reset() assert obs.shape == (batch_size, 4, 84, 84) - # But we cannot get this observation's corresponding env_id, - # therefore we use a low-level API + env_id = info["env_id"] + for _ in range(1000): + action = np.random.randint(action_num, size=batch_size) + obs, rew, done, info = env.step(action, env_id) + env_id = info["env_id"] + assert len(env_id) == batch_size + assert obs.shape == (batch_size, 4, 84, 84) + + # We can also use a low-level API (send and recv) env = envpool.make_gym("Pong-v5", num_envs=num_envs, batch_size=batch_size) env.async_reset() # no return, just send `reset` signal to all envs for _ in range(1000):