Skip to content

Commit

Permalink
init from ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
crush committed Sep 4, 2019
0 parents commit cfd82f4
Show file tree
Hide file tree
Showing 20 changed files with 685 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
**/__pycache__
runs
*.pt
.vscode
.env
.DS_Store
UnitySDK.log
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2019 Alexander Ermolov

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## Proximal Policy Optimization
- Python 3.7, PyTorch 1.2
- Neat, simple and efficient code
- `atari pacman` score ≈4200 after 24h training on T4 GPU

## Start
```
pip install -r requirements.txt
tensorboard --logdir runs
python -m train cartpole
```

## Dependencies
```
git clone https://github.com/openai/baselines.git
pip install -e baselines
```
87 changes: 87 additions & 0 deletions agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from collections import defaultdict
from dataclasses import dataclass
import random
import torch

from common.optim import ParamOptim
from common.tools import log_grads
from model import ActorCritic


@dataclass
class Agent:
model: ActorCritic
optim: ParamOptim
pi_clip: float
epochs: int
batch_size: int
val_loss_k: float
ent_k: float
gamma: float
gae_lambda: float

def _gae(self, rollout, next_val):
m = rollout['masks'] * self.gamma
r, v = rollout['rewards'], rollout['vals']
adv, returns = torch.empty_like(v), torch.empty_like(v)
gae = 0
for i in reversed(range(adv.shape[0])):
if i == adv.shape[0] - 1:
next_return = next_val
else:
next_val = v[i + 1]
next_return = returns[i + 1]

delta = r[i] - v[i] + next_val * m[i]
adv[i] = gae = delta + self.gae_lambda * m[i] * gae
returns[i] = r[i] + next_return * m[i]

adv = (adv - adv.mean()) / (adv.std() + 1e-8)
return adv, returns

def update(self, rollout):
num_step, num_env = rollout['log_probs'].shape[:2]
with torch.no_grad():
next_val = self.model(rollout['obs'][-1])[1]
adv, returns = self._gae(rollout, next_val)

logs, grads = defaultdict(list), defaultdict(list)
for _ in range(self.epochs * num_step * num_env // self.batch_size):
idx1d = random.sample(range(num_step * num_env), self.batch_size)
idx = tuple(zip(*[(i % num_step, i // num_step) for i in idx1d]))

dist, vals = self.model(rollout['obs'][idx])
act = rollout['actions'][idx].squeeze(-1)
log_probs = dist.log_prob(act).unsqueeze(-1)
ent = dist.entropy().mean()

old_lp = rollout['log_probs'][idx]
ratio = torch.exp(log_probs - old_lp)
surr1 = adv[idx] * ratio
surr2 = adv[idx] * \
torch.clamp(ratio, 1 - self.pi_clip, 1 + self.pi_clip)
act_loss = -torch.min(surr1, surr2).mean()
val_loss = .5 * (vals - returns[idx]).pow(2).mean()

self.optim.step(-self.ent_k * ent + act_loss +
self.val_loss_k * val_loss)

log_grads(self.model, grads)
logs['ent'].append(ent)
logs['clipfrac'].append(
(torch.abs(ratio - 1) > self.pi_clip).float().mean())
logs['loss/actor'].append(act_loss)
logs['loss/critic'].append(val_loss)

for name, val in grads.items():
if '/max' in name:
grads[name] = max(val)
elif '/std' in name:
grads[name] = sum(val) / (len(val) ** .5)
return {
'ent': torch.stack(logs['ent']).mean(),
'clip/frac': torch.stack(logs['clipfrac']).mean(),
'loss/actor': torch.stack(logs['loss/actor']).mean(),
'loss/critic': torch.stack(logs['loss/critic']).mean(),
**grads,
}
Empty file added common/__init__.py
Empty file.
30 changes: 30 additions & 0 deletions common/cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pathlib import Path
import re
import yaml


def find_checkpoint(cfg):
cp_iter = cfg['train']['checkpoint_every']
steps = cfg['train']['steps']
n_cp, fname_cp = 0, None
for n_iter in range(cp_iter, steps + cp_iter, cp_iter):
fname = cfg['train']['checkpoint_name'].format(n_iter=n_iter//cp_iter)
if Path(fname).exists():
n_cp, fname_cp = n_iter, fname
return n_cp, fname_cp


def replace_e_float(d):
p = re.compile(r'^-?\d+(\.\d+)?e-?\d+$')
for name, val in d.items():
if type(val) == dict:
replace_e_float(val)
elif type(val) == str and p.match(val):
d[name] = float(val)


def load_cfg(name, prefix='.'):
with open(f'{prefix}/config/{name}.yaml') as f:
cfg = yaml.load(f, Loader=yaml.SafeLoader)
replace_e_float(cfg)
return cfg
29 changes: 29 additions & 0 deletions common/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import List, Tuple
import torch
import torch.nn as nn


class Conv(nn.Module):
def __init__(
self,
input_size: Tuple[int],
channels: List[int],
kernel_size: List[int],
stride: List[int],
):
super(Conv, self).__init__()
assert len(channels) == len(kernel_size) == len(stride)
input_size = input_size[2], input_size[0], input_size[1]
self.conv = nn.Sequential(*[
nn.Sequential(nn.Conv2d(c_in, c_out, ker, st), nn.ReLU())
for c_in, c_out, ker, st in zip([input_size[0]] + channels[:-1],
channels, kernel_size, stride)])
with torch.no_grad():
tmp = torch.zeros((1,) + input_size)
self.output_size = len(self.conv(tmp).view(-1))

def forward(self, x):
x = x.permute(0, 3, 1, 2).float() / 255
x = self.conv(x)
x = x.view(x.shape[0], -1)
return x
30 changes: 30 additions & 0 deletions common/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
from torch.utils.tensorboard import SummaryWriter
try:
import nvidia_smi
except ModuleNotFoundError:
nvidia_smi = None


class Logger:
def __init__(self, device='cpu'):
self.log = SummaryWriter()
if nvidia_smi and device != 'cpu':
nvidia_smi.nvmlInit()
self.handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
else:
self.handle = None

def output(self, data_dict, n_iter):
if self.handle is not None:
res = nvidia_smi.nvmlDeviceGetUtilizationRates(self.handle)
self.log.add_scalar('nvidia/load', res.gpu, n_iter)
res = nvidia_smi.nvmlDeviceGetMemoryInfo(self.handle)
self.log.add_scalar(
'nvidia/mem_gb', res.used / (1024 ** 3), n_iter)

for key, val in data_dict.items():
if hasattr(val, 'shape') and np.prod(val.shape) > 1:
self.log.add_histogram(key, val, n_iter)
else:
self.log.add_scalar(key, val, n_iter)
45 changes: 45 additions & 0 deletions common/make_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import gym
from baselines import bench
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.shmem_vec_env import ShmemVecEnv
from baselines.common.vec_env import VecEnvWrapper


def make_vec_envs(name, num, seed=0):
def make_env(rank):
def _thunk():
env = gym.make(name)
is_atari = hasattr(gym.envs, 'atari') and isinstance(
env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
if is_atari:
env = make_atari(name)

env.seed(seed + rank)
env = bench.Monitor(env, None)
if is_atari:
env = wrap_deepmind(env, frame_stack=True)
return env
return _thunk

envs = [make_env(i) for i in range(num)]
envs = DummyVecEnv(envs) if num == 1 else ShmemVecEnv(envs, context='fork')
envs = VecPyTorch(envs)
return envs


class VecPyTorch(VecEnvWrapper):
def reset(self):
return torch.from_numpy(self.venv.reset())

def step_async(self, actions):
assert len(actions.shape) == 2
self.venv.step_async(actions.squeeze(1).cpu().numpy())

def step_wait(self):
obs, reward, done, info = self.venv.step_wait()
obs = torch.from_numpy(obs)
reward = torch.from_numpy(reward).unsqueeze(dim=1)
done = torch.tensor(done.tolist()).unsqueeze(dim=1)
return obs, reward, done, info
31 changes: 31 additions & 0 deletions common/optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from typing import List
import torch
from torch.optim import AdamW, Optimizer
from torchcontrib.optim import SWA


@dataclass
class ParamOptim:
params: List[torch.Tensor]
lr: float = 1e-3
eps: float = 1e-8
clip_grad: float = None
optimizer: Optimizer = AdamW

def __post_init__(self):
base_opt = self.optimizer(self.params, lr=self.lr, eps=self.eps)
self.optim = SWA(base_opt)

def set_lr(self, lr):
for pg in self.optim.param_groups:
pg['lr'] = lr
return lr

def step(self, loss):
self.optim.zero_grad()
loss.backward()
if self.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(self.params, self.clip_grad)
self.optim.step()
return loss
30 changes: 30 additions & 0 deletions common/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import List, Dict
import torch
import torch.nn as nn


def lerp_nn(source: nn.Module, target: nn.Module, tau: float):
for t, s in zip(target.parameters(), source.parameters()):
t.data.copy_(t.data * (1. - tau) + s.data * tau)


def flat_grads(params):
x = [p.grad.data.flatten() for p in params if p.grad is not None]
return torch.cat(x) if len(x) else None


def log_grads(model, outp: Dict[str, List[float]]):
for name, net in dict(model.named_children()).items():
fg = flat_grads(net.parameters())
if fg is not None:
outp[f'grad/{name}/max'].append(fg.max().item())
outp[f'grad/{name}/std'].append(fg.std().item())


def onehot(x, num):
r = [1] * (len(x.shape) - 1) + [num]
return torch.zeros_like(x).float().repeat(*r).scatter(-1, x, 1)


class Identity(torch.nn.Module):
def forward(self, x): return x
27 changes: 27 additions & 0 deletions config/cartpole.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
optimizer:
lr: 1e-3
eps: 1e-5
clip_grad: 1

model:
hidden_sizes: [32, 32]

agent:
pi_clip: .2
gamma: .99
epochs: 4
batch_size: 64
ent_k: 0
val_loss_k: .001
gae_lambda: .95

env:
name: CartPole-v0
num: 4

train:
steps: 100
rollout_size: 128
log_every: 1
checkpoint_every: 100000
checkpoint_name: models/cartpole_{n_iter}.pt
32 changes: 32 additions & 0 deletions config/pacman.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
optimizer:
lr: 5e-5
eps: 1e-5
clip_grad: 1

conv:
channels: [32, 64, 64]
kernel_size: [8, 4, 3]
stride: [4, 2, 1]

model:
hidden_sizes: [512]

agent:
pi_clip: .1
gamma: .99
epochs: 10
batch_size: 512
ent_k: .01
val_loss_k: .1
gae_lambda: .95

env:
name: MsPacmanNoFrameskip-v4
num: 100

train:
steps: 20000
rollout_size: 64
log_every: 1
checkpoint_every: 500
checkpoint_name: models/pacman_{n_iter}.pt
Loading

0 comments on commit cfd82f4

Please sign in to comment.