Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 30, 2024
1 parent 718434b commit 1706175
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
12 changes: 12 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2101,19 +2101,31 @@ def test_auto_num_threads(self):
RandomPolicy(ContinuousActionVecMockEnv().full_action_spec),
)
for _ in collector:
print("checking torch.get_num_threads()", torch.get_num_threads(), "expecting", init_threads - 1)
assert torch.get_num_threads() == init_threads - 1
break
collector.shutdown()
assert torch.get_num_threads() == init_threads
del collector
import gc
gc.collect()
finally:
torch.set_num_threads(init_threads)

try:
collector = MultiSyncDataCollector(
[ParallelEnv(2, ContinuousActionVecMockEnv)],
RandomPolicy(ContinuousActionVecMockEnv().full_action_spec.expand(2)),
)
for _ in collector:
print("checking torch.get_num_threads()", torch.get_num_threads(), "expecting", init_threads - 2)
assert torch.get_num_threads() == init_threads - 2
break
collector.shutdown()
assert torch.get_num_threads() == init_threads
del collector
import gc
gc.collect()
finally:
torch.set_num_threads(init_threads)

Expand Down
5 changes: 5 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import gc
import argparse
import os.path
import re
Expand Down Expand Up @@ -2375,10 +2376,14 @@ def test_auto_num_threads(self):
assert torch.get_num_threads() == max(1, init_threads - 5)

env2.close()
del env2
gc.collect()

assert torch.get_num_threads() == max(1, init_threads - 3)

env3.close()
del env3
gc.collect()

assert torch.get_num_threads() == init_threads
finally:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,7 +1437,7 @@ def _shutdown_main(self) -> None:
)
print(
"collectors.py:1439 torch.set_num_threads(torchrl._THREAD_POOL)",
torchrl._THREAD_POOL
torchrl._THREAD_POOL, self._total_workers_from_env(self.create_env_fn)
)
torch.set_num_threads(torchrl._THREAD_POOL)

Expand Down

0 comments on commit 1706175

Please sign in to comment.