Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
htdt committed Jan 20, 2022
1 parent ecffc7c commit 2f34bed
Show file tree
Hide file tree
Showing 18 changed files with 427 additions and 594 deletions.
39 changes: 39 additions & 0 deletions cfg/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
embedding:
pretrain:
steps: 1e5
epochs: 2500
optimizer:
lr: 2e-4
size: 32
epochs: 1
batch_size: 2048
tau: 0.1
temporal_shift: 4
spatial_shift: 4
rollouts_in_batch: 5

model:
num_obs: 16
obs_hidden: 4
history_fc: 128
instant_fc: 512

agent:
optimizer:
lr: 2e-4
clip_grad: 1
pi_clip: 0.1
gamma: 0.99
epochs: 3
batch_size: 256
ent_k: 0.01
val_loss_k: 1
gae_lambda: 0.95

train:
max_ep_steps: 108000
clip_rewards: True
total_steps: 1e7
rollout_size: 128
num_env: 8
log_every: 10
39 changes: 39 additions & 0 deletions cfg/history.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
embedding:
pretrain:
steps: 1e5
epochs: 2500
optimizer:
lr: 2e-4
size: 32
epochs: 1
batch_size: 2048
tau: 0.1
temporal_shift: 4
spatial_shift: 4
rollouts_in_batch: 5

model:
num_obs: 16
obs_hidden: 4
history_fc: 512
instant_fc: 0

agent:
optimizer:
lr: 2e-4
clip_grad: 1
pi_clip: 0.1
gamma: 0.99
epochs: 3
batch_size: 256
ent_k: 0.01
val_loss_k: 1
gae_lambda: 0.95

train:
max_ep_steps: 108000
clip_rewards: True
total_steps: 1e7
rollout_size: 128
num_env: 8
log_every: 10
30 changes: 0 additions & 30 deletions common/logger.py

This file was deleted.

53 changes: 39 additions & 14 deletions common/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,38 @@
import torch
import gym
from gym.spaces.box import Box
import cv2

from baselines import bench
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
from baselines.common import atari_wrappers
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.shmem_vec_env import ShmemVecEnv
from baselines.common.vec_env import VecEnvWrapper


def make_vec_envs(name, num, seed=0):
def make_vec_envs(
name, num, nstack, seed=0, clip_rewards=False, downsample=True, max_ep_steps=10000
):
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py

def make_env(rank):
def _thunk():
env = gym.make(name)
is_atari = hasattr(gym.envs, 'atari') and isinstance(
env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
if is_atari:
env = make_atari(name, max_episode_steps=10000)

env = atari_wrappers.make_atari(name, max_episode_steps=max_ep_steps)
env.seed(seed + rank)
env = bench.Monitor(env, None)
if is_atari:
env = wrap_deepmind(env, frame_stack=False)
# env = wrap_deepmind(env, frame_stack=False, clip_rewards=clip_rewards)

env = atari_wrappers.EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
env = atari_wrappers.FireResetEnv(env)
if downsample:
env = atari_wrappers.WarpFrame(env)
else:
env = Grayscale(env)
if clip_rewards:
env = atari_wrappers.ClipRewardEnv(env)
return env

return _thunk

random.seed(seed)
Expand All @@ -33,17 +43,18 @@ def _thunk():
torch.cuda.manual_seed_all(seed)

envs = [make_env(i) for i in range(num)]
envs = DummyVecEnv(envs) if num == 1 else ShmemVecEnv(envs, context='fork')
envs = FrameStack(VecPyTorch(envs))
envs = DummyVecEnv(envs) if num == 1 else ShmemVecEnv(envs, context="fork")
envs = FrameStack(VecPyTorch(envs), nstack=nstack)
return envs


class VecPyTorch(VecEnvWrapper):
def __init__(self, env):
super(VecPyTorch, self).__init__(env)
obs = self.observation_space.shape
self.observation_space = Box(0, 255, [obs[2], obs[0], obs[1]],
dtype=self.observation_space.dtype)
self.observation_space = Box(
0, 255, [obs[2], obs[0], obs[1]], dtype=self.observation_space.dtype
)

def reset(self):
return torch.from_numpy(self.venv.reset()).permute(0, 3, 1, 2)
Expand All @@ -60,6 +71,20 @@ def step_wait(self):
return obs, reward, done, info


class Grayscale(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
shp = env.observation_space.shape
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(*shp[:-1], 1), dtype=np.uint8
)

def observation(self, frame):
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = np.expand_dims(frame, -1)
return frame


class FrameStack(VecEnvWrapper):
def __init__(self, venv, nstack=4):
self.venv = venv
Expand Down
70 changes: 0 additions & 70 deletions common/make_obstacle_tower.py

This file was deleted.

9 changes: 6 additions & 3 deletions common/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ class ParamOptim:
lr: float = 1e-3
eps: float = 1e-8
clip_grad: float = None
anneal: bool = True
anneal: bool = False
weight_decay: float = 0

def __post_init__(self):
self.optim = torch.optim.Adam(self.params, lr=self.lr, eps=self.eps)
self.optim = torch.optim.Adam(
self.params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay
)

def set_lr(self, lr):
for pg in self.optim.param_groups:
pg['lr'] = lr
pg["lr"] = lr
return lr

def update(self, progress):
Expand Down
41 changes: 0 additions & 41 deletions common/tools.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,9 @@
from typing import List, Dict
import torch
import torch.nn as nn


def lerp_nn(source: nn.Module, target: nn.Module, tau: float):
for t, s in zip(target.parameters(), source.parameters()):
t.data.copy_(t.data * (1. - tau) + s.data * tau)


def flat_grads(params):
x = [p.grad.data.flatten() for p in params if p.grad is not None]
return torch.cat(x) if len(x) else None


def log_grads(model, outp: Dict[str, List[float]]):
for name, net in dict(model.named_children()).items():
fg = flat_grads(net.parameters())
if fg is not None:
outp[f'grad/{name}/max'].append(fg.max().item())
outp[f'grad/{name}/std'].append(fg.std().item())


def onehot(x, num):
r = [1] * (len(x.shape) - 1) + [num]
return torch.zeros_like(x).float().repeat(*r).scatter(-1, x, 1)


class Identity(torch.nn.Module):
def forward(self, x): return x


class Flatten(nn.Module):
def forward(self, x): return x.view(x.size(0), -1)


def init_ortho(module, gain=1):
if isinstance(gain, str):
gain = nn.init.calculate_gain(gain)
nn.init.orthogonal_(module.weight.data, gain=gain)
nn.init.constant_(module.bias.data, 0)
return module


def init_ortho_multi(module):
for name, param in module.named_parameters():
if 'bias' in name:
nn.init.constant_(param, 0)
elif 'weight' in name:
nn.init.orthogonal_(param)
32 changes: 0 additions & 32 deletions default.yaml

This file was deleted.

Loading

0 comments on commit 2f34bed

Please sign in to comment.