Skip to content

Commit

Permalink
Fix action ref counting error (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 authored Jun 15, 2022
1 parent d2e127d commit 5e21217
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/env/dm_control.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ BallInCupCatch-v1
- ``max_episode_steps``: 1000;


CartpoleBalance-v1, CartpoleBalanceSparse-v1, CarpoletSwingup-v1, CartpoleSwingupSparse-v1, CartpoleTwoPoles-v1, CartPoleThreePoles-v1
CartpoleBalance-v1, CartpoleBalanceSparse-v1, CarpoletSwingup-v1, CartpoleSwingupSparse-v1, CartpoleTwoPoles-v1, CartpoleThreePoles-v1
--------------------------------------------------------------------------------------------------------------------------------------

`dm_control suite cartpole source code
Expand Down
2 changes: 1 addition & 1 deletion envpool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
register,
)

__version__ = "0.6.1.post1"
__version__ = "0.6.2"
__all__ = [
"register",
"make",
Expand Down
4 changes: 2 additions & 2 deletions envpool/core/py_envpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ void ToArray(const std::vector<py::array>& py_arrs,
std::apply(
[&](auto&&... spec) {
(ret->emplace_back(
NumpyToArray<typename Spec::dtype>(py_arrs[index++])),
NumpyToArrayIncRef<typename Spec::dtype>(py_arrs[index++])),
...);
},
specs);
Expand Down Expand Up @@ -268,7 +268,7 @@ class PyEnvPool : public EnvPool {
*/
void PyReset(const py::array& env_ids) {
// PyArray arr = PyArray::From<int>(env_ids);
auto arr = NumpyToArray<int>(env_ids);
auto arr = NumpyToArrayIncRef<int>(env_ids);
py::gil_scoped_release release;
EnvPool::Reset(arr);
}
Expand Down
19 changes: 14 additions & 5 deletions envpool/dummy/dummy_envpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ class DummyEnvFns {
*/
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
return MakeDict("players.action"_.Bind(Spec<int>({-1})),
return MakeDict("list_action"_.Bind(Spec<double>({6})),
"players.action"_.Bind(Spec<int>({-1})),
"players.id"_.Bind(Spec<int>({-1})));
}
};
Expand Down Expand Up @@ -175,10 +176,6 @@ class DummyEnv : public Env<DummyEnvSpec> {
int num_players =
max_num_players_ <= 1 ? 1 : state_ % (max_num_players_ - 1) + 1;

// Ask envpool to allocate a piece of memory where we can write the state
// after reset.
auto state = Allocate(num_players);

// Parse the action, and execute the env (dummy env has nothing to do)
int action_num = action["players.env_id"_].Shape(0);
for (int i = 0; i < action_num; ++i) {
Expand All @@ -187,6 +184,18 @@ class DummyEnv : public Env<DummyEnvSpec> {
}
}

// Check if actions can successfully pass into envpool
double x = action["list_action"_][0];

for (int i = 0; i < 6; ++i) {
double y = action["list_action"_][i];
CHECK_EQ(x, y);
}

// Ask envpool to allocate a piece of memory where we can write the state
// after reset.
auto state = Allocate(num_players);

// write the information of the next state into the state.
for (int i = 0; i < num_players; ++i) {
state["info:players.id"_][i] = i;
Expand Down
13 changes: 12 additions & 1 deletion envpool/dummy/dummy_envpool_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,14 @@ TEST(DummyEnvPoolTest, SplitZeroAction) {
auto state_vec = envpool.Recv();
// construct action
std::vector<Array> raw_action({Array(Spec<int>({4})), Array(Spec<int>({8})),
Array(Spec<double>({4, 6})),
Array(Spec<int>({8})), Array(Spec<int>({8}))});
DummyAction action(&raw_action);
for (int i = 0; i < 4; ++i) {
action["env_id"_][i] = i;
for (int j = 0; j < 6; ++j) {
action["list_action"_][i][j] = 3.0 + i;
}
}
std::vector<int> player_env_id({1, 2, 0, 2, 0, 1, 1, 2});
for (int i = 0; i < 8; ++i) {
Expand Down Expand Up @@ -131,6 +135,12 @@ void Runner(int num_envs, int batch, int seed, int total_iter, int num_threads,
all_env_ids[i] = i;
}
envpool.Reset(all_env_ids);
auto list_action = Array(Spec<double>({num_envs, 6}));
for (int i = 0; i < num_envs; ++i) {
for (int j = 0; j < 6; ++j) {
list_action[i][j] = 5.0 + i;
}
}
auto start = std::chrono::system_clock::now();
for (int i = 0; i < total_iter; ++i) {
// recv
Expand Down Expand Up @@ -195,10 +205,11 @@ void Runner(int num_envs, int batch, int seed, int total_iter, int num_threads,
}
}
// construct action
std::vector<Array> raw_action(4);
std::vector<Array> raw_action(5);
DummyAction action(&raw_action);
action["env_id"_] = env_id;
action["players.env_id"_] = player_env_id;
action["list_action"_] = list_action;
action["players.action"_] = player_id;
action["players.id"_] = player_id;
// send
Expand Down
1 change: 1 addition & 0 deletions envpool/dummy/dummy_py_envpool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def test_envpool(self) -> None:
action = {
"env_id": state["info:env_id"],
"players.env_id": state["info:players.env_id"],
"list_action": np.zeros((batch, 6), dtype=np.float64),
"players.id": state["info:players.id"],
"players.action": state["info:players.id"],
}
Expand Down
10 changes: 10 additions & 0 deletions envpool/mujoco/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,16 @@ py_library(
deps = ["//envpool/python:api"],
)

cc_test(
name = "mujoco_envpool_test",
size = "enormous",
srcs = ["gym/mujoco_gym_envpool_test.cc"],
deps = [
":mujoco_gym_env",
"@com_google_googletest//:gtest_main",
],
)

py_test(
name = "mujoco_gym_deterministic_test",
size = "enormous",
Expand Down
53 changes: 53 additions & 0 deletions envpool/mujoco/gym/mujoco_gym_envpool_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2022 Garena Online Private Limited
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <glog/logging.h>
#include <gtest/gtest.h>

#include <random>
#include <vector>

#include "envpool/mujoco/gym/half_cheetah.h"

using MjcAction = typename mujoco_gym::HalfCheetahEnv::Action;
using MjcState = typename mujoco_gym::HalfCheetahEnv::State;

TEST(MjcEnvPoolTest, CheckAction) {
auto config = mujoco_gym::HalfCheetahEnvSpec::kDefaultConfig;
int num_envs = 128;
config["num_envs"_] = num_envs;
mujoco_gym::HalfCheetahEnvSpec spec(config);
mujoco_gym::HalfCheetahEnvPool envpool(spec);
Array all_env_ids(Spec<int>({num_envs}));
for (int i = 0; i < num_envs; ++i) {
all_env_ids[i] = i;
}
envpool.Reset(all_env_ids);
auto state_vec = envpool.Recv();
// construct action
std::vector<Array> raw_action({Array(Spec<int>({num_envs})),
Array(Spec<int>({num_envs})),
Array(Spec<double>({num_envs, 6}))});
MjcAction action(&raw_action);
for (int i = 0; i < num_envs; ++i) {
action["env_id"_][i] = i;
action["players.env_id"_][i] = i;
for (int j = 0; j < 6; ++j) {
action["action"_][i][j] = (i + j + 1) / 100.0;
}
}
// send
envpool.Send(action);
state_vec = envpool.Recv();
}
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = envpool
version = 0.6.1.post1
version = 0.6.2
author = "EnvPool Contributors"
author_email = "[email protected]"
description = "C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments."
Expand Down

0 comments on commit 5e21217

Please sign in to comment.