forked from PKU-MARL/DexterousHands
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhatrpo_policy.py
136 lines (116 loc) · 7.61 KB
/
hatrpo_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
from algorithms.marl.actor_critic import Actor, Critic
from utils.util import update_linear_schedule
class HATRPO_Policy:
"""
HATRPO Policy class. Wraps actor and critic networks to compute actions and value function predictions.
:param args: (argparse.Namespace) arguments containing relevant model and policy information.
:param obs_space: (gym.Space) observation space.
:param cent_obs_space: (gym.Space) value function input space .
:param action_space: (gym.Space) action space.
:param device: (torch.device) specifies the device to run on (cpu/gpu).
"""
def __init__(self, config, obs_space, cent_obs_space, act_space, device=torch.device("cpu")):
self.args=config
self.device = device
self.lr = config["lr"]
self.critic_lr = config["critic_lr"]
self.opti_eps = config["opti_eps"]
self.weight_decay = config["weight_decay"]
self.obs_space = obs_space
self.share_obs_space = cent_obs_space
self.act_space = act_space
self.actor = Actor(config, self.obs_space, self.act_space, self.device)
######################################Please Note#########################################
##### We create one critic for each agent, but they are trained with same data #####
##### and using same update setting. Therefore they have the same parameter, #####
##### you can regard them as the same critic. #####
##########################################################################################
self.critic = Critic(config, self.share_obs_space, self.device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
lr=self.lr, eps=self.opti_eps,
weight_decay=self.weight_decay)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
lr=self.critic_lr,
eps=self.opti_eps,
weight_decay=self.weight_decay)
def lr_decay(self, episode, episodes):
"""
Decay the actor and critic learning rates.
:param episode: (int) current training episode.
:param episodes: (int) total number of training episodes.
"""
update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr)
update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr)
def get_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, masks, available_actions=None,
deterministic=False):
"""
Compute actions and value function predictions for the given inputs.
:param cent_obs (np.ndarray): centralized input to the critic.
:param obs (np.ndarray): local agent inputs to the actor.
:param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
:param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
:param masks: (np.ndarray) denotes points at which RNN states should be reset.
:param available_actions: (np.ndarray) denotes which actions are available to agent
(if None, all actions available)
:param deterministic: (bool) whether the action should be mode of distribution or should be sampled.
:return values: (torch.Tensor) value function predictions.
:return actions: (torch.Tensor) actions to take.
:return action_log_probs: (torch.Tensor) log probabilities of chosen actions.
:return rnn_states_actor: (torch.Tensor) updated actor network RNN states.
:return rnn_states_critic: (torch.Tensor) updated critic network RNN states.
"""
actions, action_log_probs, rnn_states_actor = self.actor(obs,
rnn_states_actor,
masks,
available_actions,
deterministic)
values, rnn_states_critic = self.critic(cent_obs, rnn_states_critic, masks)
return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic
def get_values(self, cent_obs, rnn_states_critic, masks):
"""
Get value function predictions.
:param cent_obs (np.ndarray): centralized input to the critic.
:param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
:param masks: (np.ndarray) denotes points at which RNN states should be reset.
:return values: (torch.Tensor) value function predictions.
"""
values, _ = self.critic(cent_obs, rnn_states_critic, masks)
return values
def evaluate_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, action, masks,
available_actions=None, active_masks=None):
"""
Get action logprobs / entropy and value function predictions for actor update.
:param cent_obs (np.ndarray): centralized input to the critic.
:param obs (np.ndarray): local agent inputs to the actor.
:param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
:param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
:param action: (np.ndarray) actions whose log probabilites and entropy to compute.
:param masks: (np.ndarray) denotes points at which RNN states should be reset.
:param available_actions: (np.ndarray) denotes which actions are available to agent
(if None, all actions available)
:param active_masks: (torch.Tensor) denotes whether an agent is active or dead.
:return values: (torch.Tensor) value function predictions.
:return action_log_probs: (torch.Tensor) log probabilities of the input actions.
:return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
"""
action_log_probs, dist_entropy , action_mu, action_std, all_probs= self.actor.evaluate_actions(obs,
rnn_states_actor,
action,
masks,
available_actions,
active_masks)
values, _ = self.critic(cent_obs, rnn_states_critic, masks)
return values, action_log_probs, dist_entropy, action_mu, action_std, all_probs
def act(self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False):
"""
Compute actions using the given inputs.
:param obs (np.ndarray): local agent inputs to the actor.
:param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
:param masks: (np.ndarray) denotes points at which RNN states should be reset.
:param available_actions: (np.ndarray) denotes which actions are available to agent
(if None, all actions available)
:param deterministic: (bool) whether the action should be mode of distribution or should be sampled.
"""
actions, _, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic)
return actions, rnn_states_actor