You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'd like to see the SAC example train_eval.py converge on a simple environment other than mujoco. Can anyone share an example?
Here's my environment, which simply returns the most recent actions as the observation. The action is a size-2 vector with minimum the_minimum and maximum the_maximum. The loss is then the distance to the 2D point the_target.
Furthermore I've removed all of dense layers and activations on the actor network. I think this means that the actor network is connecting the size-2 observation to a MultivariateNormalDiag (see tf_agents.networks.normal_projection_network.py). The critic network is taking a size-4 (observation concatenated with action) into a 2-neuron dense layer, going into a custom activation layer for squaring. This is meant to enable the dense layer to capture the meaning of the loss function. The observation is connected to those two neurons, although the network must learn to "ignore" it in order to converge correctly.
If you feel that actor_fc_layers=None is bad, setting activation_fn=lambda x: x**2 on the critic_net is bad, or critic_joint_fc_layers is set badly, please change it. My main goal is just having this network converge correctly for the environment, and then I'd like to expand the second dimension of the_maximum to a much higher number. For example:
The code will print out actions taken during evaluation, action: -0.1234, .5678 etc but I want it to converge to the right answer action: -.2, .2.
This is how I run tensorboard too on Windows. launch_tensorboard.bat: echo 'Launching Tensorboard...' && python -m tensorboard.main --logdir=./runs --port=6006
# coding=utf-8# Copyright 2020 The TF-Agents Authors.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## https://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# Lint as: python2, python3r"""Train and Eval SAC.All hyperparameters come from the SAC paperhttps://arxiv.org/pdf/1812.05905.pdfTo run:#```bashtensorboard --logdir $HOME/tmp/sac/gym/HalfCheetah-v2/ --port 2223 &python tf_agents/agents/sac/examples/v2/train_eval.py \ --root_dir=$HOME/tmp/sac/gym/HalfCheetah-v2/ \ --alsologtostderr#```"""from __future__ importabsolute_importfrom __future__ importdivisionfrom __future__ importprint_functionimportosimporttimefromabslimportappfromabslimportflagsfromabslimportloggingimportginfromsix.movesimportrangeimporttensorflowastf# pylint: disable=g-explicit-tensorflow-version-importfromtf_agents.agents.ddpgimportcritic_networkfromtf_agents.agents.sacimportsac_agentfromtf_agents.agents.sac.tanh_normal_projection_networkimportTanhNormalProjectionNetworkfromtf_agents.driversimportdynamic_step_driverfromtf_agents.environmentsimportsuite_mujocofromtf_agents.environmentsimporttf_py_environmentfromtf_agents.evalimportmetric_utilsfromtf_agents.metricsimporttf_metricsfromtf_agents.networksimportactor_distribution_networkfromtf_agents.policiesimportgreedy_policyfromtf_agents.policiesimportrandom_tf_policyfromtf_agents.replay_buffersimporttf_uniform_replay_bufferfromtf_agents.utilsimportcommonimporttensorflow_probabilityastfpfromtf_agents.networksimportnormal_projection_networkfromtf_agents.environmentsimportpy_environmentfromtf_agents.environmentsimportparallel_py_environmentfromtf_agents.environments.py_environmentimportPyEnvironmentfromtf_agents.typingimporttypesfromtf_agents.trajectoriesimporttime_stepastsfromtf_agents.specs.tensor_specimportTensorSpec, BoundedTensorSpecfromtf_agents.environmentsimportTimeLimitfromfunctoolsimportpartialfromdatetimeimportdatetimeimportnumpyasnpimporttypingimporttqdmfromtf_agents.system.system_multiprocessingimportenable_interactive_mode# tf.config.run_functions_eagerly(True) # if you want to see numerical Tensor values.gpu_devices=tf.config.experimental.list_physical_devices('GPU')
fordeviceingpu_devices:
tf.config.experimental.set_memory_growth(device, True)
# necessary for having parallel environments.enable_interactive_mode()
flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_multi_string('gin_file', None, 'Path to the trainer config files.')
flags.DEFINE_multi_string('gin_param', None, 'Gin binding to pass through.')
FLAGS=flags.FLAGS# the_minimum = np.array([0.00001, 30.], dtype=np.float32)# the_maximum = np.array([2., 20000.], dtype=np.float32)# the_target = np.array([0.1, 17000.], dtype=np.float32)the_minimum=np.array([-1.0, -1.0], dtype=np.float32)
the_maximum=np.array([1., 1.], dtype=np.float32)
the_target=np.array([-.2, .2], dtype=np.float32)
classMyEnvironment(PyEnvironment):
def__init__(self, env_name):
super(MyEnvironment, self).__init__()
self._env_name=env_nameself._action_spec= {
'CONTINUOUS': BoundedTensorSpec(
shape=(2,),
dtype=np.float32,
minimum=the_minimum,
maximum=the_maximum,
name='CONTINUOUS'
)
}
action_obs_shape=self._action_spec['CONTINUOUS'].shape.dimsself._observation_spec=TensorSpec(shape=action_obs_shape, dtype=np.float32, name='continuous_pars')
self._reward_spec=TensorSpec((), np.dtype('float32'), 'reward')
self._discount_spec=BoundedTensorSpec(shape=(1,), dtype=np.float32, minimum=0., maximum=1.,
name='discount')
defreward_spec(self):
returnself._reward_specdefdiscount_spec(self) ->types.NestedArraySpec:
returnself._discount_specdeftime_step_spec(self) ->ts.TimeStep:
returnts.time_step_spec(self.observation_spec(), self.reward_spec())
defaction_spec(self):
returnself._action_specdefobservation_spec(self):
returnself._observation_spec@propertydefbatch_size(self) ->typing.Optional[int]:
return1@propertydefbatched(self) ->bool:
returnFalsedef_step(self, action):
state=action['CONTINUOUS']
logging.debug('action: '+', '.join(['{:2.2f}'.format(v) forvinstate]))
# normalize according to how big the action space is in each dimension.normalized= (state-the_target)/(the_maximum-the_minimum)
loss=np.linalg.norm(normalized)
# loss = np.power(loss, 0.5)# loss = np.log(loss+1e-9)logging.debug('loss: {:2.2f}'.format(loss))
reward=-lossifloss<.00001:
reward+=1.returnts.termination(state, reward=reward)
else:
returnts.transition(state, reward=reward, discount=1.0)
def_reset(self):
# random starting statestate=the_minimum+(the_maximum-the_minimum)*np.random.random_sample(the_minimum.shape)
state=state.astype(np.float32)
returnts.restart(state)
@gin.configurabledefload_environment(env_name, env_wrappers: types.Sequence[types.PyEnvWrapper] = ()
) ->PyEnvironment:
env=MyEnvironment(env_name)
forwrapperinenv_wrappers:
env=wrapper(env)
returnenvclassMyTimeLimit(TimeLimit):
def__init__(self, env: py_environment.PyEnvironment, duration: types.Int=50):
super(MyTimeLimit, self).__init__(env, duration)
defreward_spec(self) ->types.NestedArraySpec:
returnTensorSpec((), np.dtype('float32'), 'reward')
defdiscount_spec(self) ->types.NestedArraySpec:
returnBoundedTensorSpec(
shape=(), dtype=np.float32, minimum=0., maximum=1., name='discount')
def_normal_projection_net(action_spec,
init_action_stddev=.35,
init_means_output_factor=0.1,
seed_stream_class=tfp.util.SeedStream,
seed=None):
std_bias_initializer_value=np.log(np.exp(init_action_stddev) -1)
returnnormal_projection_network.NormalProjectionNetwork(
action_spec,
init_means_output_factor=init_means_output_factor,
std_bias_initializer_value=std_bias_initializer_value,
# scale_distribution=True, # default is False# state_dependent_std=True, # default is False# std_transform=(lambda x: x * (the_maximum-the_minimum)/100.), # default is tf.nn.softplusseed_stream_class=seed_stream_class,
# mean_transform=None, # default is tanh_squash_to_specseed=seed)
@gin.configurabledeftrain_eval(
env_name='HalfCheetah-v2',
eval_env_name=None,
env_load_fn=load_environment,
# The SAC paper reported:# Hopper and Cartpole results up to 1000000 iters,# Humanoid results up to 10000000 iters,# Other mujoco tasks up to 3000000 iters.num_iterations=3000000,
actor_fc_layers=None,
critic_obs_fc_layers=None,
critic_action_fc_layers=None,
critic_joint_fc_layers=(2, ),
continuous_projection_net=TanhNormalProjectionNetwork,
# continuous_projection_net=_normal_projection_net,# Params for collect# Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py# HalfCheetah and Ant take 10000 initial collection steps.# Other mujoco tasks take 1000.# Different choices roughly keep the initial episodes about the same.initial_collect_steps=10_000,
collect_steps_per_iteration=1,
replay_buffer_capacity=10_000,
# Params for target updatetarget_update_tau=0.005,
target_update_period=1,
# Params for traintrain_steps_per_iteration=1,
batch_size=16,
learning_rate=3e-4,
td_errors_loss_fn=tf.math.squared_difference,
gamma=0.99,
reward_scale_factor=.1, # important parameter according to papergradient_clipping=None,
use_tf_functions=True,
train_time_limit=10,
eval_time_limit=5,
# Params for evalnum_eval_episodes=1,
eval_interval=100,
# Params for summaries and loggingtrain_checkpoint_interval=50000,
policy_checkpoint_interval=50000,
rb_checkpoint_interval=50000,
log_interval=100,
summary_interval=100,
summaries_flush_secs=10,
debug_summaries=False,
summarize_grads_and_vars=False,
eval_metrics_callback=None):
"""A simple train and eval for SAC."""actor_learning_rate=critic_learning_rate=alpha_learning_rate=learning_rateroot_dir=datetime.now().strftime('%Y-%m-%d_%H%M%S')
root_dir=os.path.join(os.getcwd(), 'runs', root_dir)
root_dir+=f',lr={learning_rate},reward-scale={reward_scale_factor},proj-net={continuous_projection_net.__name__}'root_dir=os.path.expanduser(root_dir)
train_dir=os.path.join(root_dir, 'train')
eval_dir=os.path.join(root_dir, 'eval')
train_summary_writer=tf.compat.v2.summary.create_file_writer(
train_dir, flush_millis=summaries_flush_secs*1000)
train_summary_writer.set_as_default()
eval_summary_writer=tf.compat.v2.summary.create_file_writer(
eval_dir, flush_millis=summaries_flush_secs*1000)
eval_metrics= [
tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
]
global_step=tf.compat.v1.train.get_or_create_global_step()
withtf.compat.v2.summary.record_if(
lambda: tf.math.equal(global_step%summary_interval, 0)):
train_wrappers= [partial(MyTimeLimit, duration=train_time_limit)]
eval_wrappers= [partial(MyTimeLimit, duration=eval_time_limit)]
py_env=env_load_fn(env_name, env_wrappers=train_wrappers)
# num_parallel_environments = 4# py_env = parallel_py_environment.ParallelPyEnvironment(# [lambda: env_load_fn(env_name, env_wrappers=wrappers)] * num_parallel_environments)tf_env=tf_py_environment.TFPyEnvironment(py_env)
eval_env_name=eval_env_nameorenv_nameeval_tf_env=tf_py_environment.TFPyEnvironment(env_load_fn(eval_env_name, env_wrappers=eval_wrappers))
time_step_spec=tf_env.time_step_spec()
observation_spec=time_step_spec.observationaction_spec=tf_env.action_spec()
actor_net=actor_distribution_network.ActorDistributionNetwork(
observation_spec,
action_spec,
fc_layer_params=actor_fc_layers,
continuous_projection_net=continuous_projection_net
)
critic_net=critic_network.CriticNetwork(
(observation_spec, action_spec),
observation_fc_layer_params=critic_obs_fc_layers,
action_fc_layer_params=critic_action_fc_layers,
joint_fc_layer_params=critic_joint_fc_layers,
activation_fn=lambdax: x**2, # special change here!!kernel_initializer='glorot_uniform',
last_kernel_initializer='glorot_uniform')
tf_agent=sac_agent.SacAgent(
time_step_spec,
action_spec,
actor_network=actor_net,
critic_network=critic_net,
actor_optimizer=tf.compat.v1.train.AdamOptimizer(
learning_rate=actor_learning_rate),
critic_optimizer=tf.compat.v1.train.AdamOptimizer(
learning_rate=critic_learning_rate),
alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
learning_rate=alpha_learning_rate),
target_update_tau=target_update_tau,
target_update_period=target_update_period,
td_errors_loss_fn=td_errors_loss_fn,
gamma=gamma,
reward_scale_factor=reward_scale_factor,
gradient_clipping=gradient_clipping,
debug_summaries=debug_summaries,
summarize_grads_and_vars=summarize_grads_and_vars,
train_step_counter=global_step)
tf_agent.initialize()
# Make the replay buffer.replay_buffer=tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=tf_agent.collect_data_spec,
batch_size=1,
max_length=replay_buffer_capacity)
replay_observer= [replay_buffer.add_batch]
train_metrics= [
tf_metrics.NumberOfEpisodes(),
tf_metrics.EnvironmentSteps(),
tf_metrics.AverageReturnMetric(
buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
tf_metrics.AverageEpisodeLengthMetric(
buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
]
eval_policy=greedy_policy.GreedyPolicy(tf_agent.policy)
initial_collect_policy=random_tf_policy.RandomTFPolicy(
tf_env.time_step_spec(), tf_env.action_spec())
collect_policy=tf_agent.collect_policytrain_checkpointer=common.Checkpointer(
ckpt_dir=train_dir,
agent=tf_agent,
global_step=global_step,
metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
policy_checkpointer=common.Checkpointer(
ckpt_dir=os.path.join(train_dir, 'policy'),
policy=eval_policy,
global_step=global_step)
rb_checkpointer=common.Checkpointer(
ckpt_dir=os.path.join(os.getcwd(), 'runs', 'replay_buffer'),
max_to_keep=1,
replay_buffer=replay_buffer)
train_checkpointer.initialize_or_restore()
rb_checkpointer.initialize_or_restore()
initial_collect_driver=dynamic_step_driver.DynamicStepDriver(
tf_env,
initial_collect_policy,
observers=replay_observer+train_metrics,
num_steps=initial_collect_steps)
collect_driver=dynamic_step_driver.DynamicStepDriver(
tf_env,
collect_policy,
observers=replay_observer+train_metrics,
num_steps=collect_steps_per_iteration)
ifuse_tf_functions:
initial_collect_driver.run=common.function(initial_collect_driver.run)
collect_driver.run=common.function(collect_driver.run)
tf_agent.train=common.function(tf_agent.train)
global_step_val=global_step.numpy()
ifreplay_buffer.num_frames() ==0:
# Collect initial replay data.logging.info(
'Initializing replay buffer by collecting experience for %d steps ''with a random policy.', initial_collect_steps)
initial_collect_driver.run()
rb_checkpointer.save(global_step=global_step_val) # save the initial buffer and never save later.results=metric_utils.eager_compute(
eval_metrics,
eval_tf_env,
eval_policy,
num_episodes=num_eval_episodes,
train_step=global_step,
summary_writer=eval_summary_writer,
summary_prefix='Metrics',
)
ifeval_metrics_callbackisnotNone:
eval_metrics_callback(results, global_step.numpy())
metric_utils.log_metrics(eval_metrics)
time_step=Nonepolicy_state=collect_policy.get_initial_state(tf_env.batch_size)
timed_at_step=global_step.numpy()
time_acc=0# Prepare replay buffer as dataset with invalid transitions filtered.def_filter_invalid_transition(trajectories, unused_arg1):
return~trajectories.is_boundary()[0]
dataset=replay_buffer.as_dataset(
sample_batch_size=batch_size,
num_steps=2).unbatch().filter(
_filter_invalid_transition).batch(batch_size).prefetch(5)
# Dataset generates trajectories with shape [Bx2x...]iterator=iter(dataset)
deftrain_step():
experience, _=next(iterator)
returntf_agent.train(experience)
ifuse_tf_functions:
train_step=common.function(train_step)
logging.info('GLOBAL TRAINING')
pbar=tqdm.tqdm(total=num_iterations, desc='global steps')
whileglobal_step_val<num_iterations:
start_time=time.time()
time_step, policy_state=collect_driver.run(
time_step=time_step,
policy_state=policy_state,
)
for_inrange(train_steps_per_iteration):
train_loss=train_step()
time_acc+=time.time() -start_timeglobal_step_val=global_step.numpy()
pbar.update(global_step_val)
ifglobal_step_val%log_interval==0:
logging.info('step = %d, loss = %f', global_step_val,
train_loss.loss)
steps_per_sec= (global_step_val-timed_at_step) /time_acclogging.info('%.3f steps/sec', steps_per_sec)
tf.compat.v2.summary.scalar(
name='global_steps_per_sec', data=steps_per_sec, step=global_step)
timed_at_step=global_step_valtime_acc=0fortrain_metricintrain_metrics:
train_metric.tf_summaries(
train_step=global_step, step_metrics=train_metrics[:2])
ifglobal_step_val%eval_interval==0:
logging.set_verbosity(logging.DEBUG)
results=metric_utils.eager_compute(
eval_metrics,
eval_tf_env,
eval_policy,
num_episodes=num_eval_episodes,
train_step=global_step,
summary_writer=eval_summary_writer,
summary_prefix='Metrics',
)
ifeval_metrics_callbackisnotNone:
eval_metrics_callback(results, global_step_val)
metric_utils.log_metrics(eval_metrics)
logging.set_verbosity(logging.INFO)
ifglobal_step_val%train_checkpoint_interval==0:
train_checkpointer.save(global_step=global_step_val)
ifglobal_step_val%policy_checkpoint_interval==0:
policy_checkpointer.save(global_step=global_step_val)
## disable saving for this toy example. We want to keep reloading the initial buffer## for different trials.# if global_step_val % rb_checkpoint_interval == 0:# rb_checkpointer.save(global_step=global_step_val)pbar.close()
returntrain_lossdefmain(_):
tf.compat.v1.enable_v2_behavior()
logging.set_verbosity(logging.INFO)
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
train_eval()
if__name__=='__main__':
# flags.mark_flag_as_required('root_dir')app.run(main)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I'd like to see the SAC example
train_eval.py
converge on a simple environment other than mujoco. Can anyone share an example?Here's my environment, which simply returns the most recent actions as the observation. The action is a size-2 vector with minimum
the_minimum
and maximumthe_maximum
. The loss is then the distance to the 2D pointthe_target
.Furthermore I've removed all of dense layers and activations on the actor network. I think this means that the actor network is connecting the size-2 observation to a
MultivariateNormalDiag
(seetf_agents.networks.normal_projection_network.py
). The critic network is taking a size-4 (observation concatenated with action) into a 2-neuron dense layer, going into a custom activation layer for squaring. This is meant to enable the dense layer to capture the meaning of the loss function. The observation is connected to those two neurons, although the network must learn to "ignore" it in order to converge correctly.If you feel that
actor_fc_layers=None
is bad, settingactivation_fn=lambda x: x**2
on the critic_net is bad, orcritic_joint_fc_layers
is set badly, please change it. My main goal is just having this network converge correctly for the environment, and then I'd like to expand the second dimension ofthe_maximum
to a much higher number. For example:I'm also curious about the potential to use
NormalProjectionNetwork
instead ofTanhNormalProjectionNetwork
.For the code below, I encourage you to use an online git diff to compare to https://github.com/tensorflow/agents/blob/master/tf_agents/agents/sac/examples/v2/train_eval.py
The code will print out actions taken during evaluation,
action: -0.1234, .5678
etc but I want it to converge to the right answeraction: -.2, .2
.This is how I run tensorboard too on Windows.
launch_tensorboard.bat
:echo 'Launching Tensorboard...' && python -m tensorboard.main --logdir=./runs --port=6006
Beta Was this translation helpful? Give feedback.
All reactions