Skip to content

Commit

Permalink
feature(zjow): add Implicit Q-Learning (#821)
Browse files Browse the repository at this point in the history
* Add IQL algo

* Polish IQL Algorithm

* polish iql
  • Loading branch information
zjowowen authored Jan 27, 2025
1 parent bf258f8 commit dae7673
Show file tree
Hide file tree
Showing 14 changed files with 1,549 additions and 0 deletions.
361 changes: 361 additions & 0 deletions ding/model/template/qvac.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

from .d4pg import D4PGPolicy
from .cql import CQLPolicy, DiscreteCQLPolicy
from .iql import IQLPolicy
from .dt import DTPolicy
from .pdqn import PDQNPolicy
from .madqn import MADQNPolicy
Expand Down Expand Up @@ -322,6 +323,11 @@ class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('iql_command')
class IQLCommandModePolicy(IQLPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('discrete_cql_command')
class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy):
pass
Expand Down
646 changes: 646 additions & 0 deletions ding/policy/iql.py

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions ding/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,38 @@ def __init__(self, cfg: dict) -> None:
except (KeyError, AttributeError):
# do not normalize
pass
if hasattr(cfg.env, "reward_norm"):
if cfg.env.reward_norm == "normalize":
dataset['rewards'] = (dataset['rewards'] - dataset['rewards'].mean()) / dataset['rewards'].std()
elif cfg.env.reward_norm == "iql_antmaze":
dataset['rewards'] = dataset['rewards'] - 1.0
elif cfg.env.reward_norm == "iql_locomotion":

def return_range(dataset, max_episode_steps):
returns, lengths = [], []
ep_ret, ep_len = 0.0, 0
for r, d in zip(dataset["rewards"], dataset["terminals"]):
ep_ret += float(r)
ep_len += 1
if d or ep_len == max_episode_steps:
returns.append(ep_ret)
lengths.append(ep_len)
ep_ret, ep_len = 0.0, 0
# returns.append(ep_ret) # incomplete trajectory
lengths.append(ep_len) # but still keep track of number of steps
assert sum(lengths) == len(dataset["rewards"])
return min(returns), max(returns)

min_ret, max_ret = return_range(dataset, 1000)
dataset['rewards'] /= max_ret - min_ret
dataset['rewards'] *= 1000
elif cfg.env.reward_norm == "cql_antmaze":
dataset['rewards'] = (dataset['rewards'] - 0.5) * 4.0
elif cfg.env.reward_norm == "antmaze":
dataset['rewards'] = (dataset['rewards'] - 0.25) * 2.0
else:
raise NotImplementedError

self._data = []
self._load_d4rl(dataset)

Expand Down
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="halfcheetah_medium_expert_iql_seed0",
env=dict(
env_id='halfcheetah-medium-expert-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=17,
action_shape=6,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/halfcheetah_medium_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="halfcheetah_medium_iql_seed0",
env=dict(
env_id='halfcheetah-medium-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=17,
action_shape=6,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="halfcheetah_medium_replay_iql_seed0",
env=dict(
env_id='halfcheetah-medium-replay-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=17,
action_shape=6,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/hopper_medium_expert_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="hopper_medium_expert_iql_seed0",
env=dict(
env_id='hopper-medium-expert-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=11,
action_shape=3,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/hopper_medium_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="hopper_medium_iql_seed0",
env=dict(
env_id='hopper-medium-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=11,
action_shape=3,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/hopper_medium_replay_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="hopper_medium_replay_iql_seed0",
env=dict(
env_id='hopper-medium-replay-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=11,
action_shape=3,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
Loading

0 comments on commit dae7673

Please sign in to comment.