-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 33a1180
Showing
14 changed files
with
2,309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
*.pyc | ||
*.o | ||
|
||
# Windows image file caches | ||
Thumbs.db | ||
ehthumbs.db | ||
|
||
# Folder config file | ||
Desktop.ini | ||
|
||
# Recycle Bin used on file shares | ||
$RECYCLE.BIN/ | ||
|
||
# Windows Installer files | ||
*.cab | ||
*.msi | ||
*.msm | ||
*.msp | ||
|
||
# Windows shortcuts | ||
*.lnk | ||
|
||
# ========================= | ||
# Operating System Files | ||
# ========================= | ||
|
||
# OSX | ||
# ========================= | ||
|
||
.DS_Store | ||
.AppleDouble | ||
.LSOverride | ||
|
||
# Thumbnails | ||
._* | ||
|
||
# Files that might appear in the root of a volume | ||
.DocumentRevisions-V100 | ||
.fseventsd | ||
.Spotlight-V100 | ||
.TemporaryItems | ||
.Trashes | ||
.VolumeIcon.icns | ||
|
||
# Directories potentially created on remote AFP share | ||
.AppleDB | ||
.AppleDesktop | ||
Network Trash Folder | ||
Temporary Items | ||
.apdisk | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# QMDP-net | ||
|
||
Implementation of the NIPS 2017 paper: | ||
|
||
QMDP-Net: Deep Learning for Planning under Partial Observability | ||
Peter Karkus, David Hsu, Wee Sun Lee | ||
National University of Singapore | ||
https://arxiv.org/abs/1703.06692 | ||
|
||
The code implements the 2D grid navigation domain, and a QMDP-net with 2D state space in tensorflow. | ||
|
||
### Requirements | ||
|
||
Python 2.7 | ||
Tensorflow 1.3.0 | ||
Python packages: numpy, scipy, tables, pymdptoolbox, tensorpack | ||
|
||
To install these packages using pip: | ||
``` | ||
pip install tensorflow | ||
pip install numpy scipy tables pymdptoolbox tensorpack | ||
``` | ||
|
||
Optional: to speed up data generation download and install the latest version of pymdptoolbox | ||
``` | ||
git clone https://github.com/sawcordwell/pymdptoolbox.git pymdptoolbox | ||
cd ./pymdptoolbox | ||
python setup.py install | ||
``` | ||
|
||
|
||
### Train and evaluate a QMDP-net | ||
|
||
The folder ./data/grid10 contains training and test data for the deterministic 10x10 grid navigation domain | ||
(10,000 environments, 5 trajectories each for training, 500 environments, 1 trajectory each for testing). | ||
|
||
|
||
Train network using only the first 4 steps of each training trajectory: | ||
``` | ||
python train.py ./data/grid10/ --logpath ./data/grid10/output-lim4/ --lim_traj_len 4 | ||
``` | ||
The learned model will be saved to ./data/grid10/output-lim4/final.chk | ||
|
||
|
||
Load the previously saved model and train further using the full trajectories: | ||
``` | ||
python train.py ./data/grid10/ --logpath ./data/grid10/output-lim100/ --loadmodel ./data/grid10/output-lim4/final.chk --lim_traj_len 100 | ||
``` | ||
|
||
|
||
For help on arguments execute: | ||
``` | ||
python train.py --help | ||
``` | ||
|
||
|
||
### Generate data | ||
|
||
Generate data for the 18x18 deterministic grid navigation domain. | ||
10,000 environments for training, 500 for testing, 5 and 1 trajectories per environment | ||
|
||
``` | ||
python grid.py ./data/grid18/ 10000 500 --N 18 --train_trajs 5 --test_trajs 1 | ||
``` | ||
|
||
|
||
For help on arguments execute: | ||
``` | ||
python grid.py --help | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import pickle, argparse, os | ||
import numpy as np | ||
from utils.dotdict import dotdict | ||
|
||
|
||
def parse_args(arglist): | ||
|
||
parser = argparse.ArgumentParser(description='Run training on gridworld') | ||
|
||
parser.add_argument('path', | ||
help='Path to data folder containing train and test subfolders') | ||
parser.add_argument('--logpath', default='./log/', | ||
help='Path to save log and trained model') | ||
|
||
parser.add_argument('--loadmodel', nargs='*', | ||
help='Load model weights from checkpoint') | ||
|
||
parser.add_argument('--eval_samples', type=int, | ||
default=100, | ||
help='Number of samples to evaluate the learned policy on') | ||
parser.add_argument('--eval_repeats', type=int, | ||
default=1, | ||
help='Repeat simulating policy for a given number of times. Use 5 for stochastic domains') | ||
|
||
parser.add_argument('--batch_size', type=int, default=100, | ||
help='Size of minibatches for training') | ||
parser.add_argument('--training_envs', type=float, default=0.9, | ||
help='Proportion of training data used for trianing. Remainder will be used for validation') | ||
parser.add_argument('--step_size', type=int, default=4, | ||
help='Number of maximum steps for backpropagation through time') | ||
parser.add_argument('--lim_traj_len', type=int, default=100, | ||
help='Clip trajectories to a maximum length') | ||
parser.add_argument('--includefailed', action='store_true', | ||
help='Include unsuccessful demonstrations in the training and validation set.') | ||
|
||
parser.add_argument('--learning_rate', type=float, default=0.001, | ||
help='Initial learning rate') | ||
parser.add_argument('--patience_first', type=int, | ||
default=30, | ||
help='Start decaying learning rate if no improvement for a given number of steps') | ||
parser.add_argument('--patience_rest', type=int, | ||
default=5, | ||
help='Patience after decay started') | ||
parser.add_argument('--decaystep', type=int, | ||
default=15, | ||
help='Total number of learning rate decay steps') | ||
parser.add_argument('--epochs', type=int, | ||
default=1000, | ||
help='Maximum number of epochs') | ||
|
||
parser.add_argument('--cache', nargs='*', | ||
default=['steps', 'envs', 'bs'], | ||
help='Cache nodes from pytable dataset. Default: steps, envs, bs') | ||
|
||
parser.add_argument('-K', '--K', type=int, | ||
default=-1, | ||
help='Number of iterations of value iteration in QMDPNet. Compute from grid size if negative.') | ||
|
||
args = parser.parse_args(args=arglist) | ||
|
||
# load domain parameters | ||
params = dotdict(pickle.load(open(os.path.join(args.path, 'train/params.pickle'), 'rb'))) | ||
|
||
# set default K | ||
if args.K < 0: | ||
args.K = 3 * params.grid_n | ||
|
||
# combine all parameters to a single dotdict | ||
for key in vars(args): | ||
params[key] = getattr(args, key) | ||
|
||
return params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import numpy as np | ||
import tables | ||
import collections | ||
|
||
try: | ||
import ipdb as pdb | ||
except Exception: | ||
import pdb | ||
|
||
|
||
class Datatable(collections.Sequence): | ||
""" | ||
Wrapper around pytable node that supports non-unique indexing | ||
""" | ||
def __init__(self, node): | ||
self.node = node | ||
|
||
def __len__(self): | ||
return len(self.node) | ||
|
||
def __getitem__(self, sliced): | ||
return self.node.__getitem__(sliced) | ||
|
||
def __repr__(self): | ||
return str(self.node) | ||
|
||
def unique_index(self, idx): | ||
i_unique, remap = np.unique(idx, return_inverse=True) | ||
if self.node.ndim > 1: | ||
vals_unique = self.node[i_unique,:] | ||
else: | ||
vals_unique = self.node[i_unique] | ||
return vals_unique[remap] | ||
|
||
|
||
class Database(collections.defaultdict): | ||
""" | ||
Wrapper around pytable database with cache | ||
""" | ||
def __init__(self, filename=None, cache=None): | ||
super(Database, self).__init__() | ||
self.db = None | ||
self.filename = filename | ||
self.cache = ({} if cache is None else cache) | ||
|
||
def __getattr__(self, attr): | ||
return self.get(attr) | ||
|
||
def get(self, attr, **kwargs): | ||
try: | ||
return self.cache[str(attr)] | ||
except KeyError: | ||
return Datatable(self.db.get_node("/"+str(attr))) | ||
|
||
def get_all(self, attr): | ||
return self.db.get_node("/"+str(attr))[:] | ||
|
||
def __setitem__(self, key, value): | ||
raise NotImplementedError | ||
|
||
def __delitem__(self, key): | ||
raise NotImplementedError | ||
|
||
def __repr__(self): | ||
return str(self.db) | ||
|
||
def get_node(self, *args, **kwargs): | ||
return self.db.root.get_node(*args, **kwargs) | ||
|
||
def open(self, filename=None, mode='r'): | ||
self.close() | ||
if filename is not None: | ||
self.filename = filename | ||
self.db = tables.open_file(self.filename, mode=mode) | ||
|
||
def close(self): | ||
if self.db is not None: | ||
self.db.close() | ||
self.db = None | ||
|
||
def build_cache(self, cache_nodes): | ||
self.open() | ||
cache = {} | ||
for node in cache_nodes: | ||
node = str(node) | ||
try: | ||
cache[node] = self.get_all(node) | ||
print ("cached %s"%str(node)) | ||
except tables.NoSuchNodeError: | ||
print ("cannot cache %s"%str(node)) | ||
self.close() | ||
return cache |
Oops, something went wrong.