diff --git a/examples/ppo_multivariate_normal.py b/examples/ppo_multivariate_normal.py new file mode 100644 index 0000000..7209253 --- /dev/null +++ b/examples/ppo_multivariate_normal.py @@ -0,0 +1,83 @@ +import tensorflow as tf +import gym +import numpy as np +import time + +import tensorflow_probability as tfp +tfd = tfp.distributions # TODO: use zhusuan.distributions + +import tianshou as ts + + +if __name__ == '__main__': + env = gym.make('BipedalWalker-v2') + observation_dim = env.observation_space.shape + action_dim = env.action_space.shape[0] + + clip_param = 0.2 + num_batches = 10 + batch_size = 512 + + seed = 0 + np.random.seed(seed) + tf.set_random_seed(seed) + + ### 1. build network with pure tf + observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) + + def my_policy(): + net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) + net = tf.layers.dense(net, 32, activation=tf.nn.tanh) + + action_logits = tf.layers.dense(net, action_dim, activation=None) + action_dist = tfd.MultivariateNormalDiag(loc=action_logits, scale_diag=[0.2] * action_dim) + + return action_dist, None + + ### 2. build policy, loss, optimizer + pi = ts.policy.Distributional(my_policy, observation_placeholder=observation_ph, has_old_net=True) + + ppo_loss_clip = ts.losses.ppo_clip(pi, clip_param) + + total_loss = ppo_loss_clip + optimizer = tf.train.AdamOptimizer(1e-4) + train_op = optimizer.minimize(total_loss, var_list=list(pi.trainable_variables)) + + ### 3. define data collection + data_buffer = ts.data.BatchSet() + + data_collector = ts.data.DataCollector( + env=env, + policy=pi, + data_buffer=data_buffer, + process_functions=[ts.data.advantage_estimation.full_return], + managed_networks=[pi], + ) + + ### 4. start training + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + sess.run(tf.global_variables_initializer()) + + # assign actor to pi_old + pi.sync_weights() + + start_time = time.time() + for i in range(1000): + # collect data + data_collector.collect(num_episodes=50) + + # print current return + print('Epoch {}:'.format(i)) + data_buffer.statistics() + + # update network + for _ in range(num_batches): + feed_dict = data_collector.next_batch(batch_size) + sess.run(train_op, feed_dict=feed_dict) + + # assigning pi_old to be current pi + pi.sync_weights() + + print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) diff --git a/setup.py b/setup.py index 2706e01..da46b80 100644 --- a/setup.py +++ b/setup.py @@ -81,7 +81,8 @@ setup( # your project is installed. For an analysis of "install_requires" vs pip's # requirements files see: # https://packaging.python.org/en/latest/requirements.html - install_requires=['numpy>=1.14.0'], + install_requires=['numpy>=1.14.0', + 'tensorflow-probability'], # List additional groups of dependencies here (e.g. development # dependencies). You can install these using the following syntax, diff --git a/tianshou/core/policy/distributional.py b/tianshou/core/policy/distributional.py index e04a8ed..8181c91 100644 --- a/tianshou/core/policy/distributional.py +++ b/tianshou/core/policy/distributional.py @@ -36,7 +36,7 @@ class Distributional(PolicyBase): self.action = action_dist.sample() weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) - self.network_weights = identify_dependent_variables(self.action_dist.log_prob(self.action), weights) + self.network_weights = identify_dependent_variables(self.action_dist.log_prob(tf.stop_gradient(self.action)), weights) self._trainable_variables = [var for var in self.network_weights if var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)] # deal with target network @@ -52,7 +52,7 @@ class Distributional(PolicyBase): # re-filter to rule out some edge cases old_weights = [var for var in old_weights if var.name[:len(net_old_scope)] == net_old_scope] - self.network_old_weights = identify_dependent_variables(self.action_dist_old.log_prob(self.action_old), old_weights) + self.network_old_weights = identify_dependent_variables(self.action_dist_old.log_prob(tf.stop_gradient(self.action_old)), old_weights) assert len(self.network_weights) == len(self.network_old_weights) self.sync_weights_ops = [tf.assign(variable_old, variable)