Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Commit

Permalink
Enable multi-process training with distributed Coach.
Browse files Browse the repository at this point in the history
  • Loading branch information
balajismaniam committed Mar 8, 2019
1 parent 9a895a1 commit 785b8eb
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def create_worker_server_and_device(cluster_spec: tf.train.ClusterSpec, task_ind
return server.target, device


def create_monitored_session(target: tf.train.Server, task_index: int,
checkpoint_dir: str, checkpoint_save_secs: int, config: tf.ConfigProto=None) -> tf.Session:
def create_monitored_session(target: tf.train.Server, task_index: int, checkpoint_dir: str, checkpoint_save_secs: int,
scaffold: tf.train.Scaffold, config: tf.ConfigProto=None) -> tf.Session:
"""
Create a monitored session for the worker
:param target: the target string for the tf.Session
Expand Down
2 changes: 1 addition & 1 deletion rl_coach/architectures/tensorflow_components/savers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, name):
# if graph is finalized, savers must have already already been added. This happens
# in the case of a MonitoredSession
self._variables = tf.global_variables()

# target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
# the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
self._variables = [v for v in self._variables if '/target' not in v.name]
Expand Down
4 changes: 2 additions & 2 deletions rl_coach/base_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,8 @@ def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_on


class DistributedTaskParameters(TaskParameters):
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
task_index: int, evaluate_only: int=None, num_tasks: int=None,
def __init__(self, framework_type: Frameworks=None, parameters_server_hosts: str=None, worker_hosts: str=None,
job_type: str=None, task_index: int=None, evaluate_only: int=None, num_tasks: int=None,
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
Expand Down
135 changes: 53 additions & 82 deletions rl_coach/coach.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from multiprocessing.managers import BaseManager
import subprocess
from rl_coach.graph_managers.graph_manager import HumanPlayScheduleParameters, GraphManager
from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, get_base_dir
from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, \
get_base_dir, start_multi_threaded_learning
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
Expand Down Expand Up @@ -87,6 +88,17 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'

def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
ckpt_inside_container = "/checkpoint"
non_dist_task_parameters = TaskParameters(
framework_type=args.framework,
evaluate_only=args.evaluate,
experiment_path=args.experiment_path,
seed=args.seed,
use_cpu=args.use_cpu,
checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph,
apply_stop_condition=args.apply_stop_condition
)

memory_backend_params = None
if args.memory_backend_params:
Expand All @@ -102,15 +114,18 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
graph_manager.data_store_params = data_store_params

if args.distributed_coach_run_type == RunType.TRAINER:
if not args.distributed_training:
task_parameters = non_dist_task_parameters
task_parameters.checkpoint_save_dir = ckpt_inside_container
training_worker(
graph_manager=graph_manager,
task_parameters=task_parameters,
args=args,
is_multi_node_test=args.is_multi_node_test
)

if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
task_parameters.checkpoint_restore_dir = ckpt_inside_container
non_dist_task_parameters.checkpoint_restore_dir = ckpt_inside_container

data_store = None
if args.data_store_params:
Expand All @@ -120,7 +135,7 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
graph_manager=graph_manager,
data_store=data_store,
num_workers=args.num_workers,
task_parameters=task_parameters
task_parameters=non_dist_task_parameters
)


Expand Down Expand Up @@ -552,6 +567,11 @@ def get_argument_parser(self) -> argparse.ArgumentParser:
parser.add_argument('-dc', '--distributed_coach',
help="(flag) Use distributed Coach.",
action='store_true')
parser.add_argument('-dt', '--distributed_training',
help="(flag) Use distributed training with Coach."
"Used only with --distributed_coach flag."
"Ignored if --distributed_coach flag is not used.",
action='store_true')
parser.add_argument('-dcp', '--distributed_coach_config_path',
help="(string) Path to config file when using distributed rollout workers."
"Only distributed Coach parameters should be provided through this config file."
Expand Down Expand Up @@ -607,18 +627,31 @@ def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namesp
atexit.register(logger.summarize_experiment)
screen.change_terminal_title(args.experiment_name)

task_parameters = TaskParameters(
framework_type=args.framework,
evaluate_only=args.evaluate,
experiment_path=args.experiment_path,
seed=args.seed,
use_cpu=args.use_cpu,
checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_restore_dir=args.checkpoint_restore_dir,
checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph,
apply_stop_condition=args.apply_stop_condition
)
if args.num_workers == 1:
task_parameters = TaskParameters(
framework_type=args.framework,
evaluate_only=args.evaluate,
experiment_path=args.experiment_path,
seed=args.seed,
use_cpu=args.use_cpu,
checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_restore_dir=args.checkpoint_restore_dir,
checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph,
apply_stop_condition=args.apply_stop_condition
)
else:
task_parameters = DistributedTaskParameters(
framework_type=args.framework,
use_cpu=args.use_cpu,
num_training_tasks=args.num_workers,
experiment_path=args.experiment_path,
checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_restore_dir=args.checkpoint_restore_dir,
checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph,
apply_stop_condition=args.apply_stop_condition
)

# open dashboard
if args.open_dashboard:
Expand All @@ -633,78 +666,16 @@ def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namesp

# Single-threaded runs
if args.num_workers == 1:
self.start_single_threaded(task_parameters, graph_manager, args)
self.start_single_threaded_learning(task_parameters, graph_manager, args)
else:
self.start_multi_threaded(graph_manager, args)
global start_graph
start_multi_threaded_learning(start_graph, (graph_manager, task_parameters),
task_parameters, graph_manager, args)

def start_single_threaded(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
def start_single_threaded_learning(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
# Start the training or evaluation
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)

def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
total_tasks = args.num_workers
if args.evaluation_worker:
total_tasks += 1

ps_hosts = "localhost:{}".format(get_open_port())
worker_hosts = ",".join(["localhost:{}".format(get_open_port()) for i in range(total_tasks)])

# Shared memory
class CommManager(BaseManager):
pass
CommManager.register('SharedMemoryScratchPad', SharedMemoryScratchPad, exposed=['add', 'get', 'internal_call'])
comm_manager = CommManager()
comm_manager.start()
shared_memory_scratchpad = comm_manager.SharedMemoryScratchPad()

def start_distributed_task(job_type, task_index, evaluation_worker=False,
shared_memory_scratchpad=shared_memory_scratchpad):
task_parameters = DistributedTaskParameters(
framework_type=args.framework,
parameters_server_hosts=ps_hosts,
worker_hosts=worker_hosts,
job_type=job_type,
task_index=task_index,
evaluate_only=0 if evaluation_worker else None, # 0 value for evaluation worker as it should run infinitely
use_cpu=args.use_cpu,
num_tasks=total_tasks, # training tasks + 1 evaluation task
num_training_tasks=args.num_workers,
experiment_path=args.experiment_path,
shared_memory_scratchpad=shared_memory_scratchpad,
seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed
checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_restore_dir=args.checkpoint_restore_dir,
checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph,
apply_stop_condition=args.apply_stop_condition
)
# we assume that only the evaluation workers are rendering
graph_manager.visualization_parameters.render = args.render and evaluation_worker
p = Process(target=start_graph, args=(graph_manager, task_parameters))
# p.daemon = True
p.start()
return p

# parameter server
parameter_server = start_distributed_task("ps", 0)

# training workers
# wait a bit before spawning the non chief workers in order to make sure the session is already created
workers = []
workers.append(start_distributed_task("worker", 0))
time.sleep(2)
for task_index in range(1, args.num_workers):
workers.append(start_distributed_task("worker", task_index))

# evaluation worker
if args.evaluation_worker or args.render:
evaluation_worker = start_distributed_task("worker", args.num_workers, evaluation_worker=True)

# wait for all workers
[w.join() for w in workers]
if args.evaluation_worker:
evaluation_worker.terminate()


def main():
launcher = CoachLauncher()
Expand Down
20 changes: 14 additions & 6 deletions rl_coach/data_stores/s3_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,20 @@ def save_to_store(self):
# Acquire lock
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)

ckpt_state_filename = CheckpointStateFile.checkpoint_state_filename
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
if state_file.exists():
ckpt_state = state_file.read()
ckpt_name_prefix = ckpt_state.name

if ckpt_state_filename is not None and ckpt_name_prefix is not None:
checkpoint_file = None
for root, dirs, files in os.walk(self.params.checkpoint_dir):
for filename in files:
if filename == CheckpointStateFile.checkpoint_state_filename:
if filename == ckpt_state_filename:
checkpoint_file = (root, filename)
continue
if filename.startswith(ckpt_state.name):
if filename.startswith(ckpt_name_prefix):
abs_name = os.path.abspath(os.path.join(root, filename))
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
Expand Down Expand Up @@ -131,6 +135,8 @@ def load_from_store(self):
"""
try:
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
ckpt_state_filename = state_file.filename
ckpt_state_file_path = state_file.path

# wait until lock is removed
while True:
Expand All @@ -139,7 +145,7 @@ def load_from_store(self):
if next(objects, None) is None:
try:
# fetch checkpoint state file from S3
self.mc.fget_object(self.params.bucket_name, state_file.filename, state_file.path)
self.mc.fget_object(self.params.bucket_name, ckpt_state_filename, ckpt_state_file_path)
except Exception as e:
continue
break
Expand All @@ -156,10 +162,12 @@ def load_from_store(self):
)
except Exception as e:
pass
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
ckpt_state = state_file.read()
ckpt_name_prefix = ckpt_state.name

checkpoint_state = state_file.read()
if checkpoint_state is not None:
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
if ckpt_name_prefix is not None:
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=ckpt_name_prefix, recursive=True)
for obj in objects:
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name))
if not os.path.exists(filename):
Expand Down
18 changes: 13 additions & 5 deletions rl_coach/graph_managers/graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,15 @@ def _create_session_tf(self, task_parameters: TaskParameters):
else:
checkpoint_dir = task_parameters.checkpoint_save_dir

self.checkpoint_saver = tf.train.Saver()
scaffold = tf.train.Scaffold(saver=self.checkpoint_saver)

self.sess = create_monitored_session(target=task_parameters.worker_target,
task_index=task_parameters.task_index,
checkpoint_dir=checkpoint_dir,
checkpoint_save_secs=task_parameters.checkpoint_save_secs,
config=config)
config=config,
scaffold=scaffold)
# set the session for all the modules
self.set_session(self.sess)
else:
Expand Down Expand Up @@ -258,9 +262,11 @@ def create_session(self, task_parameters: TaskParameters):
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))

# Create parameter saver
self.checkpoint_saver = SaverCollection()
for level in self.level_managers:
self.checkpoint_saver.update(level.collect_savers())
if not isinstance(task_parameters, DistributedTaskParameters):
self.checkpoint_saver = SaverCollection()
for level in self.level_managers:
self.checkpoint_saver.update(level.collect_savers())

# restore from checkpoint if given
self.restore_checkpoint()

Expand Down Expand Up @@ -599,7 +605,9 @@ def save_checkpoint(self):
if not isinstance(self.task_parameters, DistributedTaskParameters):
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
else:
saved_checkpoint_path = checkpoint_path
# FIXME: Explicitly managing Saver inside monitored training session is not recommended.
# https://github.com/tensorflow/tensorflow/issues/8425#issuecomment-286927528.
saved_checkpoint_path = self.checkpoint_saver.save(self.sess._tf_sess(), checkpoint_path)

# this is required in order for agents to save additional information like a DND for example
[manager.save_checkpoint(checkpoint_name) for manager in self.level_managers]
Expand Down
6 changes: 5 additions & 1 deletion rl_coach/tests/test_dist_coach.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def get_tests():
"""
tests = [
'rl_coach/coach.py -p CartPole_ClippedPPO -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1',
'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1'
'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1',
'rl_coach/coach.py -p CartPole_ClippedPPO -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1 -n 2',
'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1 -n 2',
'rl_coach/coach.py -p CartPole_ClippedPPO -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1 -n 2 -dt',
'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1 -n 2 -dt'
]
return tests

Expand Down
Loading

0 comments on commit 785b8eb

Please sign in to comment.