Skip to content

Commit

Permalink
adding files
Browse files Browse the repository at this point in the history
  • Loading branch information
karkuspeter committed Oct 10, 2017
0 parents commit 33a1180
Show file tree
Hide file tree
Showing 14 changed files with 2,309 additions and 0 deletions.
51 changes: 51 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
*.pyc
*.o

# Windows image file caches
Thumbs.db
ehthumbs.db

# Folder config file
Desktop.ini

# Recycle Bin used on file shares
$RECYCLE.BIN/

# Windows Installer files
*.cab
*.msi
*.msm
*.msp

# Windows shortcuts
*.lnk

# =========================
# Operating System Files
# =========================

# OSX
# =========================

.DS_Store
.AppleDouble
.LSOverride

# Thumbnails
._*

# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns

# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk

71 changes: 71 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# QMDP-net

Implementation of the NIPS 2017 paper:

QMDP-Net: Deep Learning for Planning under Partial Observability
Peter Karkus, David Hsu, Wee Sun Lee
National University of Singapore
https://arxiv.org/abs/1703.06692

The code implements the 2D grid navigation domain, and a QMDP-net with 2D state space in tensorflow.

### Requirements

Python 2.7
Tensorflow 1.3.0
Python packages: numpy, scipy, tables, pymdptoolbox, tensorpack

To install these packages using pip:
```
pip install tensorflow
pip install numpy scipy tables pymdptoolbox tensorpack
```

Optional: to speed up data generation download and install the latest version of pymdptoolbox
```
git clone https://github.com/sawcordwell/pymdptoolbox.git pymdptoolbox
cd ./pymdptoolbox
python setup.py install
```


### Train and evaluate a QMDP-net

The folder ./data/grid10 contains training and test data for the deterministic 10x10 grid navigation domain
(10,000 environments, 5 trajectories each for training, 500 environments, 1 trajectory each for testing).


Train network using only the first 4 steps of each training trajectory:
```
python train.py ./data/grid10/ --logpath ./data/grid10/output-lim4/ --lim_traj_len 4
```
The learned model will be saved to ./data/grid10/output-lim4/final.chk


Load the previously saved model and train further using the full trajectories:
```
python train.py ./data/grid10/ --logpath ./data/grid10/output-lim100/ --loadmodel ./data/grid10/output-lim4/final.chk --lim_traj_len 100
```


For help on arguments execute:
```
python train.py --help
```


### Generate data

Generate data for the 18x18 deterministic grid navigation domain.
10,000 environments for training, 500 for testing, 5 and 1 trajectories per environment

```
python grid.py ./data/grid18/ 10000 500 --N 18 --train_trajs 5 --test_trajs 1
```


For help on arguments execute:
```
python grid.py --help
```

2 changes: 2 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@


72 changes: 72 additions & 0 deletions arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pickle, argparse, os
import numpy as np
from utils.dotdict import dotdict


def parse_args(arglist):

parser = argparse.ArgumentParser(description='Run training on gridworld')

parser.add_argument('path',
help='Path to data folder containing train and test subfolders')
parser.add_argument('--logpath', default='./log/',
help='Path to save log and trained model')

parser.add_argument('--loadmodel', nargs='*',
help='Load model weights from checkpoint')

parser.add_argument('--eval_samples', type=int,
default=100,
help='Number of samples to evaluate the learned policy on')
parser.add_argument('--eval_repeats', type=int,
default=1,
help='Repeat simulating policy for a given number of times. Use 5 for stochastic domains')

parser.add_argument('--batch_size', type=int, default=100,
help='Size of minibatches for training')
parser.add_argument('--training_envs', type=float, default=0.9,
help='Proportion of training data used for trianing. Remainder will be used for validation')
parser.add_argument('--step_size', type=int, default=4,
help='Number of maximum steps for backpropagation through time')
parser.add_argument('--lim_traj_len', type=int, default=100,
help='Clip trajectories to a maximum length')
parser.add_argument('--includefailed', action='store_true',
help='Include unsuccessful demonstrations in the training and validation set.')

parser.add_argument('--learning_rate', type=float, default=0.001,
help='Initial learning rate')
parser.add_argument('--patience_first', type=int,
default=30,
help='Start decaying learning rate if no improvement for a given number of steps')
parser.add_argument('--patience_rest', type=int,
default=5,
help='Patience after decay started')
parser.add_argument('--decaystep', type=int,
default=15,
help='Total number of learning rate decay steps')
parser.add_argument('--epochs', type=int,
default=1000,
help='Maximum number of epochs')

parser.add_argument('--cache', nargs='*',
default=['steps', 'envs', 'bs'],
help='Cache nodes from pytable dataset. Default: steps, envs, bs')

parser.add_argument('-K', '--K', type=int,
default=-1,
help='Number of iterations of value iteration in QMDPNet. Compute from grid size if negative.')

args = parser.parse_args(args=arglist)

# load domain parameters
params = dotdict(pickle.load(open(os.path.join(args.path, 'train/params.pickle'), 'rb')))

# set default K
if args.K < 0:
args.K = 3 * params.grid_n

# combine all parameters to a single dotdict
for key in vars(args):
params[key] = getattr(args, key)

return params
92 changes: 92 additions & 0 deletions database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import numpy as np
import tables
import collections

try:
import ipdb as pdb
except Exception:
import pdb


class Datatable(collections.Sequence):
"""
Wrapper around pytable node that supports non-unique indexing
"""
def __init__(self, node):
self.node = node

def __len__(self):
return len(self.node)

def __getitem__(self, sliced):
return self.node.__getitem__(sliced)

def __repr__(self):
return str(self.node)

def unique_index(self, idx):
i_unique, remap = np.unique(idx, return_inverse=True)
if self.node.ndim > 1:
vals_unique = self.node[i_unique,:]
else:
vals_unique = self.node[i_unique]
return vals_unique[remap]


class Database(collections.defaultdict):
"""
Wrapper around pytable database with cache
"""
def __init__(self, filename=None, cache=None):
super(Database, self).__init__()
self.db = None
self.filename = filename
self.cache = ({} if cache is None else cache)

def __getattr__(self, attr):
return self.get(attr)

def get(self, attr, **kwargs):
try:
return self.cache[str(attr)]
except KeyError:
return Datatable(self.db.get_node("/"+str(attr)))

def get_all(self, attr):
return self.db.get_node("/"+str(attr))[:]

def __setitem__(self, key, value):
raise NotImplementedError

def __delitem__(self, key):
raise NotImplementedError

def __repr__(self):
return str(self.db)

def get_node(self, *args, **kwargs):
return self.db.root.get_node(*args, **kwargs)

def open(self, filename=None, mode='r'):
self.close()
if filename is not None:
self.filename = filename
self.db = tables.open_file(self.filename, mode=mode)

def close(self):
if self.db is not None:
self.db.close()
self.db = None

def build_cache(self, cache_nodes):
self.open()
cache = {}
for node in cache_nodes:
node = str(node)
try:
cache[node] = self.get_all(node)
print ("cached %s"%str(node))
except tables.NoSuchNodeError:
print ("cannot cache %s"%str(node))
self.close()
return cache
Loading

0 comments on commit 33a1180

Please sign in to comment.