Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(gry): refactor reward model #636

Open
wants to merge 63 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
c372c07
refactor network and red reward model
ruoyuGao Apr 5, 2023
6718e4a
create reward model utils
ruoyuGao Apr 5, 2023
be7039a
polish network and reward model utils, provide test for them
ruoyuGao Apr 6, 2023
a4de466
refactor network for two method: learn and forward
ruoyuGao Apr 10, 2023
d615c14
Merge branch 'main' into ruoyugao
ruoyuGao Apr 10, 2023
7a8ec6e
refactor rnd
ruoyuGao Apr 11, 2023
55c7be8
refactor gail
ruoyuGao Apr 11, 2023
ff60716
fix gail for unit test
ruoyuGao Apr 11, 2023
6b80392
refactor icm
ruoyuGao Apr 12, 2023
25d49b5
fix wrong unit test in test_reward_model_utils
ruoyuGao Apr 12, 2023
c081ff0
refactor gcl and pwil
ruoyuGao Apr 13, 2023
f1218cd
refactor pdeil
ruoyuGao Apr 13, 2023
d9060c2
add hidden_size_list to gail
ruoyuGao Apr 13, 2023
179182a
change gail test for new config
ruoyuGao Apr 13, 2023
d067731
refactor trex network
ruoyuGao Apr 14, 2023
29f0d55
fix style and wrong import
ruoyuGao Apr 14, 2023
4ec0bd3
fix style for trex
ruoyuGao Apr 14, 2023
800f090
Merge branch 'main' into ruoyugao
ruoyuGao Apr 14, 2023
c64b5c7
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao Apr 14, 2023
660af32
fix unit test for trex onppo
ruoyuGao Apr 17, 2023
1b0d579
Merge branch 'main' into ruoyugao
ruoyuGao Apr 21, 2023
b4e81dd
refactor ngu and provide cartpole config file
ruoyuGao Apr 21, 2023
eddc80d
change reward entry
ruoyuGao Apr 26, 2023
6e2b867
change trex entry to new entry, combine old trex test to new test
ruoyuGao Apr 26, 2023
e25d265
Merge branch 'main' into ruoyugao
ruoyuGao Apr 28, 2023
97634dc
refactor trex config file
ruoyuGao Apr 28, 2023
f099cac
refactor trex config file
ruoyuGao Apr 28, 2023
0c48c08
refactor trex config file
ruoyuGao Apr 28, 2023
594d619
add gail to new reward entry
ruoyuGao May 3, 2023
58a2bff
remove preferenced based irl entry(used for trex, drex before)
ruoyuGao May 3, 2023
e9db652
Merge branch 'main' into ruoyugao
ruoyuGao May 3, 2023
822d7a4
remove unuse code in gcl
ruoyuGao May 3, 2023
be03aa9
change clear data from pipeline to RM && add ngu to new entry
ruoyuGao May 4, 2023
d3ce3e2
remove ngu old entry
ruoyuGao May 4, 2023
4c19aa3
fix env pool test bug
ruoyuGao May 4, 2023
0cc2149
add drex to new entry
ruoyuGao May 4, 2023
5b4e4cc
fix unit test for trex and gail
ruoyuGao May 4, 2023
ff4de47
fix style
ruoyuGao May 4, 2023
9e63ef1
fix style for drex unittest
ruoyuGao May 5, 2023
ca2e2db
fix drex unittest
ruoyuGao May 5, 2023
8716afe
fix bug in minigrid env
ruoyuGao May 5, 2023
9036141
add explain for rm utils
ruoyuGao May 6, 2023
6b9754a
move RM unittest into one file
ruoyuGao May 6, 2023
a52a1c0
Merge branch 'main' into ruoyugao
ruoyuGao May 6, 2023
a5c7989
add drex config
ruoyuGao May 8, 2023
d631237
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao May 8, 2023
f42d131
fix ngu wrapper bug in minigrid
ruoyuGao May 9, 2023
edff260
fix ngu wrapper bug in minigrid
ruoyuGao May 9, 2023
6ab66e1
Merge branch 'main' into ruoyugao
ruoyuGao May 10, 2023
cb0c627
refactor gcl, add it to reward entry
ruoyuGao May 22, 2023
016fbb3
refactor gcl config and bash format other config
ruoyuGao May 22, 2023
cf50148
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao May 22, 2023
919c01b
fix bug for test, remove wrong comment
ruoyuGao May 22, 2023
a1d0b3a
polish code for ngu, drex, base rm and entry
ruoyuGao Jun 6, 2023
0a0af3c
Merge branch 'main' into ruoyugao
ruoyuGao Jun 6, 2023
e310b4c
polish code for all rm
ruoyuGao Jun 6, 2023
92dc227
fix style for ngu
ruoyuGao Jun 6, 2023
a4f364d
polish comment for config files
ruoyuGao Jun 6, 2023
1f06dec
add gcl unit test
ruoyuGao Jun 9, 2023
a547b3b
polish RM
ruoyuGao Jun 19, 2023
97da5c6
fix style for rnd and icm
ruoyuGao Jun 20, 2023
774b2a4
fix style for rnd and icm
ruoyuGao Jun 20, 2023
b78e36c
fix style for icm
ruoyuGao Jun 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ding/entry/tests/test_serial_entry_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
{
'type': 'red',
'sample_size': 5000,
'input_size': 5,
'hidden_size': 64,
'obs_shape': 4,
'action_shape': 1,
'hidden_size_list': [64, 1],
'update_per_collect': 200,
'batch_size': 128,
},
Expand Down
2 changes: 2 additions & 0 deletions ding/reward_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
from .guided_cost_reward_model import GuidedCostRewardModel
from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel
from .icm_reward_model import ICMRewardModel
from .network import FeatureNetwork, RndNetwork, RedNetwork
from .reword_model_utils import concat_state_action_pairs
53 changes: 53 additions & 0 deletions ding/reward_model/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Union, Tuple, List, Dict
from easydict import EasyDict

import torch
import torch.nn as nn

from ding.utils import SequenceType, REWARD_MODEL_REGISTRY
from ding.model import FCEncoder, ConvEncoder
from ding.torch_utils.data_helper import to_tensor
import numpy as np


class FeatureNetwork(nn.Module):
ruoyuGao marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType) -> None:
ruoyuGao marked this conversation as resolved.
Show resolved Hide resolved
super(FeatureNetwork, self).__init__()
if isinstance(obs_shape, int) or len(obs_shape) == 1:
self.feature = FCEncoder(obs_shape, hidden_size_list)
elif len(obs_shape) == 3:
self.feature = ConvEncoder(obs_shape, hidden_size_list)
else:
raise KeyError(
"not support obs_shape for pre-defined encoder: {}, please customize your own RND model".
format(obs_shape)
)

def forward(self, obs: torch.Tensor) -> torch.Tensor:
feature_output = self.feature(obs)
return feature_output


class RndNetwork(nn.Module):

def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType) -> None:
super(RndNetwork, self).__init__()
self.target = FeatureNetwork(obs_shape, hidden_size_list)
self.predictor = FeatureNetwork(obs_shape, hidden_size_list)

for param in self.target.parameters():
param.requires_grad = False

def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
predict_feature = self.predictor(obs)
with torch.no_grad():
target_feature = self.target(obs)
return predict_feature, target_feature


class RedNetwork(RndNetwork):

def __init__(self, obs_shape: int, action_shape: int, hidden_size_list: SequenceType) -> None:
# RED network does not support high dimension obs
super().__init__(obs_shape + action_shape, hidden_size_list)
119 changes: 43 additions & 76 deletions ding/reward_model/red_irl_model.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,16 @@
from typing import Dict, List
import pickle
import random
from collections.abc import Iterable

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from ding.utils import REWARD_MODEL_REGISTRY, one_time_warning
from .base_reward_model import BaseRewardModel


class SENet(nn.Module):
"""support estimation network"""

def __init__(self, input_size: int, hidden_size: int, output_dims: int) -> None:
super(SENet, self).__init__()
self.l_1 = nn.Linear(input_size, hidden_size)
self.l_2 = nn.Linear(hidden_size, output_dims)
self.act = nn.Tanh()

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.l_1(x)
out = self.act(out)
out = self.l_2(out)
out = self.act(out)
return out
from .network import RedNetwork
from .reword_model_utils import concat_state_action_pairs


@REWARD_MODEL_REGISTRY.register('red')
Expand All @@ -35,38 +22,39 @@ class RedRewardModel(BaseRewardModel):
``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \
``__init__``, ``_train``
Config:
== ================== ===== ============= ======================================= =======================
ID Symbol Type Default Value Description Other(Shape)
== ================== ===== ============= ======================================= =======================
1 ``type`` str red | Reward model register name, refer |
| to registry ``REWARD_MODEL_REGISTRY`` |
2 | ``expert_data_`` str expert_data | Path to the expert dataset | Should be a '.pkl'
| ``path`` .pkl | | file
3 | ``sample_size`` int 1000 | sample data from expert dataset |
| with fixed size |
4 | ``sigma`` int 5 | hyperparameter of r(s,a) | r(s,a) = exp(
== ================== ====== ============= ======================================= =======================
ID Symbol Type Default Value Description Other(Shape)
== ================== ====== ============= ======================================= =======================
1 ``type`` str red | Reward model register name, refer |
| to registry ``REWARD_MODEL_REGISTRY`` |
2 | ``expert_data_`` str expert_data | Path to the expert dataset | Should be a '.pkl'
| ``path`` .pkl | | file
3 | ``sample_size`` int 1000 | sample data from expert dataset |
| with fixed size |
4 | ``sigma`` int 5 | hyperparameter of r(s,a) | r(s,a) = exp(
| -sigma* L(s,a))
5 | ``batch_size`` int 64 | Training batch size |
6 | ``hidden_size`` int 128 | Linear model hidden size |
7 | ``update_per_`` int 100 | Number of updates per collect |
| ``collect`` | |
8 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay
5 | ``batch_size`` int 64 | Training batch size |
6 | ``hidden`` list [64, 64, | Sequence of ``hidden_size`` |
| ``_size_list`` (int) 128] | of reward network |
7 | ``update_per_`` int 100 | Number of updates per collect |
| ``collect`` | |
8 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay
``_per_iters`` | buffer's data count
| isn't too few.
| (code work in entry)
== ================== ===== ============= ======================================= =======================
Properties:
- online_net (:obj: `SENet`): The reward model, in default initialized once as the training begins.
== ================== ====== ============= ======================================= =======================
"""
config = dict(
# (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
type='red',
# (int) Linear model input size.
# input_size=4,
# (int) observation shape
# obs_shape=4,
# (int) action shape
# action_shape=1,
# (int) Sample data from expert dataset with fixed size.
sample_size=1000,
# (int) Linear model hidden size.
hidden_size=128,
# (list(int)) Sequence of ``hidden_size`` of reward network.
hidden_size_list=[128, 1],
# (float) The step size of gradient descent.
learning_rate=1e-3,
# (int) How many updates(iterations) to train after collector's one collection.
Expand Down Expand Up @@ -99,11 +87,9 @@ def __init__(self, config: Dict, device: str, tb_logger: 'SummaryWriter') -> Non
self.device = device
assert device in ["cpu", "cuda"] or "cuda" in device
self.tb_logger = tb_logger
self.target_net: SENet = SENet(config.input_size, config.hidden_size, 1)
self.online_net: SENet = SENet(config.input_size, config.hidden_size, 1)
self.target_net.to(device)
self.online_net.to(device)
self.opt: optim.Adam = optim.Adam(self.online_net.parameters(), config.learning_rate)
self.reward_model = RedNetwork(config.obs_shape, config.action_shape, config.hidden_size_list)
self.reward_model.to(self.device)
self.opt = optim.Adam(self.reward_model.predictor.parameters(), config.learning_rate)
self.train_once_flag = False

self.load_expert_data()
Expand All @@ -121,19 +107,18 @@ def load_expert_data(self) -> None:
self.expert_data = random.sample(self.expert_data, sample_size)
print('the expert data size is:', len(self.expert_data))

def _train(self, batch_data: torch.Tensor) -> float:
def _train(self) -> float:
"""
Overview:
Helper function for ``train`` which caclulates loss for train data and expert data.
Arguments:
- batch_data (:obj:`torch.Tensor`): Data used for training
Returns:
- Combined loss calculated of reward model from using ``batch_data`` in both target and reward models.
- Combined loss calculated of reward model from using ``states_actions_tensor``.
"""
with torch.no_grad():
target = self.target_net(batch_data)
hat: torch.Tensor = self.online_net(batch_data)
loss: torch.Tensor = ((hat - target) ** 2).mean()
sample_batch = random.sample(self.expert_data, self.cfg.batch_size)
states_actions_tensor = concat_state_action_pairs(sample_batch)
states_actions_tensor = states_actions_tensor.to(self.device)
predict_feature, target_feature = self.reward_model(states_actions_tensor)
loss = F.mse_loss(predict_feature, target_feature.detach())
self.opt.zero_grad()
loss.backward()
self.opt.step()
Expand All @@ -150,17 +135,7 @@ def train(self) -> None:
one_time_warning('RED model should be trained once, we do not train it anymore')
else:
for i in range(self.cfg.update_per_collect):
sample_batch = random.sample(self.expert_data, self.cfg.batch_size)
states_data = []
actions_data = []
for item in sample_batch:
states_data.append(item['obs'])
actions_data.append(item['action'])
states_tensor: torch.Tensor = torch.stack(states_data).float()
actions_tensor: torch.Tensor = torch.stack(actions_data).float()
states_actions_tensor: torch.Tensor = torch.cat([states_tensor, actions_tensor], dim=1)
states_actions_tensor = states_actions_tensor.to(self.device)
loss = self._train(states_actions_tensor)
loss = self._train()
self.tb_logger.add_scalar('reward_model/red_loss', loss, i)
self.train_once_flag = True

Expand All @@ -177,20 +152,12 @@ def estimate(self, data: list) -> List[Dict]:
# NOTE: deepcopy reward part of data is very important,
# otherwise the reward of data in the replay buffer will be incorrectly modified.
train_data_augmented = self.reward_deepcopy(data)
states_data = []
actions_data = []
for item in train_data_augmented:
states_data.append(item['obs'])
actions_data.append(item['action'])
states_tensor = torch.stack(states_data).float()
actions_tensor = torch.stack(actions_data).float()
states_actions_tensor = torch.cat([states_tensor, actions_tensor], dim=1)
states_actions_tensor = concat_state_action_pairs(train_data_augmented)
states_actions_tensor = states_actions_tensor.to(self.device)
with torch.no_grad():
hat_1 = self.online_net(states_actions_tensor)
hat_2 = self.target_net(states_actions_tensor)
c = ((hat_1 - hat_2) ** 2).mean(dim=1)
r = torch.exp(-self.cfg.sigma * c)
predict_feature, target_feature = self.reward_model(states_actions_tensor)
mse = F.mse_loss(predict_feature, target_feature, reduction='none').mean(dim=1)
r = torch.exp(-self.cfg.sigma * mse)
for item, rew in zip(train_data_augmented, r):
item['reward'] = rew
return train_data_augmented
Expand Down
37 changes: 37 additions & 0 deletions ding/reward_model/reword_model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Union, Optional, List, Any, Tuple
from collections.abc import Iterable

import torch
import torch.optim as optim
import torch.nn.functional as F


def concat_state_action_pairs(
data: list, action_size: Optional[int] = None, one_hot: Optional[bool] = False
) -> torch.Tensor:
"""
Overview:
Concatenate state and action pairs from input.
Arguments:
- data (:obj:`List`): List with at least ``obs`` and ``action`` keys.
Returns:
- state_actions_tensor (:obj:`Torch.tensor`): State and action pairs.
"""
states_data = []
actions_data = []
#check data(dict) has key obs and action
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

空格 使用 bash format.sh ding 格式化代码

assert isinstance(data, Iterable)
assert "obs" in data[0] and "action" in data[0]
ruoyuGao marked this conversation as resolved.
Show resolved Hide resolved
for item in data:
states_data.append(item['obs'].flatten()) # to allow 3d obs and actions concatenation
if one_hot and action_size:
action = torch.Tensor([int(i == item['action']) for i in range(action_size)])
ruoyuGao marked this conversation as resolved.
Show resolved Hide resolved
actions_data.append(action)
else:
actions_data.append(item['action'])

states_tensor: torch.Tensor = torch.stack(states_data).float()
actions_tensor: torch.Tensor = torch.stack(actions_data).float()
states_actions_tensor: torch.Tensor = torch.cat([states_tensor, actions_tensor], dim=1)

return states_actions_tensor