stop gradient in policy/distributional
This commit is contained in:
parent
909dc786d1
commit
bdd85f8a27
83
examples/ppo_multivariate_normal.py
Normal file
83
examples/ppo_multivariate_normal.py
Normal file
@ -0,0 +1,83 @@
|
||||
import tensorflow as tf
|
||||
import gym
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import tensorflow_probability as tfp
|
||||
tfd = tfp.distributions # TODO: use zhusuan.distributions
|
||||
|
||||
import tianshou as ts
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
env = gym.make('BipedalWalker-v2')
|
||||
observation_dim = env.observation_space.shape
|
||||
action_dim = env.action_space.shape[0]
|
||||
|
||||
clip_param = 0.2
|
||||
num_batches = 10
|
||||
batch_size = 512
|
||||
|
||||
seed = 0
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
|
||||
### 1. build network with pure tf
|
||||
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
|
||||
|
||||
def my_policy():
|
||||
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
|
||||
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
|
||||
|
||||
action_logits = tf.layers.dense(net, action_dim, activation=None)
|
||||
action_dist = tfd.MultivariateNormalDiag(loc=action_logits, scale_diag=[0.2] * action_dim)
|
||||
|
||||
return action_dist, None
|
||||
|
||||
### 2. build policy, loss, optimizer
|
||||
pi = ts.policy.Distributional(my_policy, observation_placeholder=observation_ph, has_old_net=True)
|
||||
|
||||
ppo_loss_clip = ts.losses.ppo_clip(pi, clip_param)
|
||||
|
||||
total_loss = ppo_loss_clip
|
||||
optimizer = tf.train.AdamOptimizer(1e-4)
|
||||
train_op = optimizer.minimize(total_loss, var_list=list(pi.trainable_variables))
|
||||
|
||||
### 3. define data collection
|
||||
data_buffer = ts.data.BatchSet()
|
||||
|
||||
data_collector = ts.data.DataCollector(
|
||||
env=env,
|
||||
policy=pi,
|
||||
data_buffer=data_buffer,
|
||||
process_functions=[ts.data.advantage_estimation.full_return],
|
||||
managed_networks=[pi],
|
||||
)
|
||||
|
||||
### 4. start training
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
with tf.Session(config=config) as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# assign actor to pi_old
|
||||
pi.sync_weights()
|
||||
|
||||
start_time = time.time()
|
||||
for i in range(1000):
|
||||
# collect data
|
||||
data_collector.collect(num_episodes=50)
|
||||
|
||||
# print current return
|
||||
print('Epoch {}:'.format(i))
|
||||
data_buffer.statistics()
|
||||
|
||||
# update network
|
||||
for _ in range(num_batches):
|
||||
feed_dict = data_collector.next_batch(batch_size)
|
||||
sess.run(train_op, feed_dict=feed_dict)
|
||||
|
||||
# assigning pi_old to be current pi
|
||||
pi.sync_weights()
|
||||
|
||||
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
3
setup.py
3
setup.py
@ -81,7 +81,8 @@ setup(
|
||||
# your project is installed. For an analysis of "install_requires" vs pip's
|
||||
# requirements files see:
|
||||
# https://packaging.python.org/en/latest/requirements.html
|
||||
install_requires=['numpy>=1.14.0'],
|
||||
install_requires=['numpy>=1.14.0',
|
||||
'tensorflow-probability'],
|
||||
|
||||
# List additional groups of dependencies here (e.g. development
|
||||
# dependencies). You can install these using the following syntax,
|
||||
|
@ -36,7 +36,7 @@ class Distributional(PolicyBase):
|
||||
self.action = action_dist.sample()
|
||||
|
||||
weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
||||
self.network_weights = identify_dependent_variables(self.action_dist.log_prob(self.action), weights)
|
||||
self.network_weights = identify_dependent_variables(self.action_dist.log_prob(tf.stop_gradient(self.action)), weights)
|
||||
self._trainable_variables = [var for var in self.network_weights
|
||||
if var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)]
|
||||
# deal with target network
|
||||
@ -52,7 +52,7 @@ class Distributional(PolicyBase):
|
||||
# re-filter to rule out some edge cases
|
||||
old_weights = [var for var in old_weights if var.name[:len(net_old_scope)] == net_old_scope]
|
||||
|
||||
self.network_old_weights = identify_dependent_variables(self.action_dist_old.log_prob(self.action_old), old_weights)
|
||||
self.network_old_weights = identify_dependent_variables(self.action_dist_old.log_prob(tf.stop_gradient(self.action_old)), old_weights)
|
||||
assert len(self.network_weights) == len(self.network_old_weights)
|
||||
|
||||
self.sync_weights_ops = [tf.assign(variable_old, variable)
|
||||
|
Loading…
x
Reference in New Issue
Block a user