fixed the bugs on Jan 14, which gives inferior or even no improvement. mistook group_ndims. policy will soon need refactoring.

This commit is contained in:
haoshengzou 2018-01-17 11:55:51 +08:00
parent d599506dc9
commit ed25bf7586
15 changed files with 619 additions and 119 deletions

View File

@ -15,6 +15,30 @@ Tianshou(天授) is a reinforcement learning platform. The following image illus
    Specific network architectures in original paper of DQN, TRPO, A3C, etc. Policy-Value Network of AlphaGo Zero     Specific network architectures in original paper of DQN, TRPO, A3C, etc. Policy-Value Network of AlphaGo Zero
#### brief intro of current implementation:
how to write your own network:
- define the observation placeholder yourself, pass it to `observation_placeholder` when initializing a policy instance
- pass a callable when initializing a policy instance. The callable should satisfy only three conditions:
- it accepts no parameters
- it does not create any new placeholders
- it returns `action-related tensors, value_head`
Our lib will take care of your observation placeholder from now on, as well as
all the placeholders that will be created by our lib.
The other placeholders, such as `keep_prob` in dropout and `clip_param` in ppo loss
should be managed by your own (see examples/ppo_cartpole_alternative.py)
The `weight_update` parameter:
- 0 means manually update target network
- 1 means no target network (the target network is updated every 1 minibatch)
- (0, 1) is the target network as used in DDPG
- greater than 1 is the target network as used in DQN
Other comments are in the python files in example/ and in the lib codes.
Refactor is definitely needed so don't dwell too much on annoying details...
### Algorithm ### Algorithm
#### losses #### losses

View File

@ -0,0 +1,94 @@
#!/usr/bin/env python
from __future__ import absolute_import
import tensorflow as tf
import time
import numpy as np
# our lib imports here! It's ok to append path in examples
import sys
sys.path.append('..')
from tianshou.core import losses
from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
import tianshou.core.value_function.state_value as value_function
from rllab.envs.box2d.cartpole_env import CartpoleEnv
from rllab.envs.normalized_env import normalize
# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix
if __name__ == '__main__':
env = normalize(CartpoleEnv())
observation_dim = env.observation_space.shape
action_dim = env.action_space.flat_dim
clip_param = 0.2
num_batches = 10
batch_size = 128
seed = 10
np.random.seed(seed)
tf.set_random_seed(seed)
### 1. build network with pure tf
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
def my_network():
# placeholders defined in this function would be very difficult to manage
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
action_mean = tf.layers.dense(net, action_dim, activation=None)
action_logstd = tf.get_variable('action_logstd', shape=(action_dim, ))
value = tf.layers.dense(net, 1, activation=None)
return action_mean, action_logstd, value
# TODO: overriding seems not able to handle shared layers, unless a new class `SharedPolicyValue`
# maybe the most desired thing is to freely build policy and value function from any tensor?
# but for now, only the outputs of the network matters
### 2. build policy, critic, loss, optimizer
actor = policy.Normal(my_network, observation_placeholder=observation_ph, weight_update=1)
critic = value_function.StateValue(my_network, observation_placeholder=observation_ph)
actor_loss = losses.REINFORCE(actor)
critic_loss = losses.state_value_mse(critic)
total_loss = actor_loss + critic_loss
optimizer = tf.train.AdamOptimizer(1e-4)
# this hack would be unnecessary if we have a `SharedPolicyValue` class, or hack the trainable_variables management
var_list = list(set(actor.trainable_variables + critic.trainable_variables))
train_op = optimizer.minimize(total_loss, var_list=var_list)
### 3. define data collection
data_collector = Batch(env, actor,
[advantage_estimation.gae_lambda(1, critic), advantage_estimation.nstep_return(1, critic)],
[actor, critic])
# TODO: refactor this, data_collector should be just the top-level abstraction
### 4. start training
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
start_time = time.time()
for i in range(100):
# collect data
data_collector.collect(num_episodes=20)
# print current return
print('Epoch {}:'.format(i))
data_collector.statistics()
# update network
for _ in range(num_batches):
feed_dict = data_collector.next_batch(batch_size)
sess.run(train_op, feed_dict=feed_dict)
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))

View File

@ -0,0 +1,98 @@
#!/usr/bin/env python
from __future__ import absolute_import
import tensorflow as tf
import time
import numpy as np
# our lib imports here! It's ok to append path in examples
import sys
sys.path.append('..')
from tianshou.core import losses
from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
import tianshou.core.value_function.state_value as value_function
from rllab.envs.box2d.cartpole_env import CartpoleEnv
from rllab.envs.normalized_env import normalize
# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix
if __name__ == '__main__':
env = normalize(CartpoleEnv())
observation_dim = env.observation_space.shape
action_dim = env.action_space.flat_dim
clip_param = 0.2
num_batches = 10
batch_size = 128
seed = 10
np.random.seed(seed)
tf.set_random_seed(seed)
### 1. build network with pure tf
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
def my_actor():
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
action_mean = tf.layers.dense(net, action_dim, activation=None)
action_logstd = tf.get_variable('action_logstd', shape=(action_dim, ))
return action_mean, action_logstd, None
def my_critic():
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
value = tf.layers.dense(net, 1, activation=None)
return None, value
### 2. build policy, critic, loss, optimizer
actor = policy.Normal(my_actor, observation_placeholder=observation_ph, weight_update=1)
critic = value_function.StateValue(my_critic, observation_placeholder=observation_ph)
print('actor and critic will share variables in this case')
sys.exit()
actor_loss = losses.vanilla_policy_gradient(actor)
critic_loss = losses.state_value_mse(critic)
total_loss = actor_loss + critic_loss
optimizer = tf.train.AdamOptimizer(1e-4)
train_op = optimizer.minimize(total_loss, var_list=actor.trainable_variables)
### 3. define data collection
training_data = Batch(env, actor, advantage_estimation.full_return)
### 4. start training
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
# assign actor to pi_old
actor.sync_weights() # TODO: automate this for policies with target network
start_time = time.time()
for i in range(100):
# collect data
training_data.collect(num_episodes=20)
# print current return
print('Epoch {}:'.format(i))
training_data.statistics()
# update network
for _ in range(num_batches):
feed_dict = training_data.next_batch(batch_size)
sess.run(train_op, feed_dict=feed_dict)
# assigning actor to pi_old
actor.update_weights()
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))

View File

@ -0,0 +1,90 @@
#!/usr/bin/env python
from __future__ import absolute_import
import tensorflow as tf
import time
import numpy as np
# our lib imports here! It's ok to append path in examples
import sys
sys.path.append('..')
from tianshou.core import losses
from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
import tianshou.core.value_function.state_value as value_function
from rllab.envs.box2d.cartpole_env import CartpoleEnv
from rllab.envs.normalized_env import normalize
# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix
if __name__ == '__main__':
env = normalize(CartpoleEnv())
observation_dim = env.observation_space.shape
action_dim = env.action_space.flat_dim
clip_param = 0.2
num_batches = 10
batch_size = 128
seed = 10
np.random.seed(seed)
tf.set_random_seed(seed)
### 1. build network with pure tf
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
def my_network():
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
action_mean = tf.layers.dense(net, action_dim, activation=None)
action_logstd = tf.get_variable('action_logstd', shape=(action_dim, ))
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
value = tf.layers.dense(net, 1, activation=None)
return action_mean, action_logstd, value
### 2. build policy, critic, loss, optimizer
actor = policy.Normal(my_network, observation_placeholder=observation_ph, weight_update=1)
critic = value_function.StateValue(my_network, observation_placeholder=observation_ph)
actor_loss = losses.REINFORCE(actor)
critic_loss = losses.state_value_mse(critic)
actor_optimizer = tf.train.AdamOptimizer(1e-4)
actor_train_op = actor_optimizer.minimize(actor_loss, var_list=actor.trainable_variables)
critic_optimizer = tf.train.RMSPropOptimizer(1e-4)
critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables)
### 3. define data collection
data_collector = Batch(env, actor,
[advantage_estimation.gae_lambda(1, critic), advantage_estimation.nstep_return(1, critic)],
[actor, critic])
### 4. start training
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
start_time = time.time()
for i in range(100):
# collect data
data_collector.collect(num_episodes=20)
# print current return
print('Epoch {}:'.format(i))
data_collector.statistics()
# update network
for _ in range(num_batches):
feed_dict = data_collector.next_batch(batch_size)
sess.run([actor_train_op, critic_train_op], feed_dict=feed_dict)
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))

View File

@ -0,0 +1,95 @@
#!/usr/bin/env python
import tensorflow as tf
import gym
# our lib imports here!
import sys
sys.path.append('..')
import tianshou.core.losses as losses
from tianshou.data.replay_buffer.utils import get_replay_buffer
import tianshou.core.policy.dqn as policy
# THIS EXAMPLE IS NOT FINISHED YET!!!
def policy_net(observation, action_dim):
"""
Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf
:param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels)
:param action_dim: int. The number of actions.
:param scope: str. Specifying the scope of the variables.
"""
net = tf.layers.conv2d(observation, 16, 8, 4, 'valid', activation=tf.nn.relu)
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
net = tf.layers.flatten(net)
net = tf.layers.dense(net, 256, activation=tf.nn.relu)
q_values = tf.layers.dense(net, action_dim)
return q_values
if __name__ == '__main__':
env = gym.make('PongNoFrameskip-v4')
observation_dim = env.observation_space.shape
action_dim = env.action_space.n
# 1. build network with pure tf
# TODO:
# pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer
# access this observation variable.
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions
with tf.variable_scope('q_net'):
q_values = policy_net(observation, action_dim)
with tf.variable_scope('target_net'):
q_values_target = policy_net(observation, action_dim)
# 2. build losses, optimizers
q_net = policy.DQNRefactor(q_values, observation_placeholder=observation, action_placeholder=action) # YongRen: policy.DQN
target_net = policy.DQNRefactor(q_values_target, observation_placeholder=observation, action_placeholder=action)
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN
dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen
global_step = tf.Variable(0, name='global_step', trainable=False)
train_var_list = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
total_loss = dqn_loss
optimizer = tf.train.AdamOptimizer(1e-3)
train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step())
# 3. define data collection
# configuration should be given as parameters, different replay buffer has different parameters.
replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net,
{'size': 1000, 'batch_size': 64, 'learn_start': 20})
# ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
# maybe a dict to manage the elements to be collected
# 4. start training
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
minibatch_count = 0
collection_count = 0
# need to first collect some then sample, collect_freq must be larger than batch_size
collect_freq = 100
while True: # until some stopping criterion met...
# collect data
for i in range(0, collect_freq):
replay_memory.collect() # ShihongSong
collection_count += 1
print('Collected {} times.'.format(collection_count))
# update network
data = replay_memory.next_batch(10) # YouQiaoben, ShihongSong
# TODO: auto managing of the placeholders? or add this to params of data.Batch
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], target: data['target']})
minibatch_count += 1
print('Trained {} minibatches.'.format(minibatch_count))
# TODO: assigning pi to pi_old is not implemented yet

View File

@ -1,95 +1,83 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import
import tensorflow as tf import tensorflow as tf
import gym import gym
import numpy as np
import time
# our lib imports here! # our lib imports here! It's ok to append path in examples
import sys import sys
sys.path.append('..') sys.path.append('..')
import tianshou.core.losses as losses from tianshou.core import losses
from tianshou.data.replay_buffer.utils import get_replay_buffer from tianshou.data.batch import Batch
import tianshou.core.policy.dqn as policy import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so that only need to import to policy
# THIS EXAMPLE IS NOT FINISHED YET!!!
def policy_net(observation, action_dim):
"""
Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf
:param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels)
:param action_dim: int. The number of actions.
:param scope: str. Specifying the scope of the variables.
"""
net = tf.layers.conv2d(observation, 16, 8, 4, 'valid', activation=tf.nn.relu)
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
net = tf.layers.flatten(net)
net = tf.layers.dense(net, 256, activation=tf.nn.relu)
q_values = tf.layers.dense(net, action_dim)
return q_values
if __name__ == '__main__': if __name__ == '__main__':
env = gym.make('PongNoFrameskip-v4') env = gym.make('CartPole-v0')
observation_dim = env.observation_space.shape observation_dim = env.observation_space.shape
action_dim = env.action_space.n action_dim = env.action_space.n
# 1. build network with pure tf clip_param = 0.2
# TODO: num_batches = 10
# pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer batch_size = 512
# access this observation variable.
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions
seed = 0
np.random.seed(seed)
tf.set_random_seed(seed)
with tf.variable_scope('q_net'): ### 1. build network with pure tf
q_values = policy_net(observation, action_dim) observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
with tf.variable_scope('target_net'):
q_values_target = policy_net(observation, action_dim)
# 2. build losses, optimizers def my_policy():
q_net = policy.DQNRefactor(q_values, observation_placeholder=observation, action_placeholder=action) # YongRen: policy.DQN net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
target_net = policy.DQNRefactor(q_values_target, observation_placeholder=observation, action_placeholder=action) net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN action_values = tf.layers.dense(net, action_dim, activation=None)
return action_values, None # None value head
# TODO: current implementation of passing function or overriding function has to return a value head
# to allow network sharing between policy and value networks. This makes 'policy' and 'value_function'
# imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact
# with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is
# not based on passing function or overriding function.
### 2. build policy, loss, optimizer
pi = policy.DQN(my_policy, observation_placeholder=observation_ph, weight_update=10)
dqn_loss = losses.qlearning(pi)
dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen
global_step = tf.Variable(0, name='global_step', trainable=False)
train_var_list = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
total_loss = dqn_loss total_loss = dqn_loss
optimizer = tf.train.AdamOptimizer(1e-3) optimizer = tf.train.AdamOptimizer(1e-4)
train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step()) train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
# 3. define data collection
# configuration should be given as parameters, different replay buffer has different parameters.
replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net,
{'size': 1000, 'batch_size': 64, 'learn_start': 20})
# ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
# maybe a dict to manage the elements to be collected
# 4. start training ### 3. define data collection
with tf.Session() as sess: data_collector = Batch(env, pi, [advantage_estimation.nstep_q_return(1, pi.target_network)], [pi])
### 4. start training
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
minibatch_count = 0 # assign actor to pi_old
collection_count = 0 pi.sync_weights() # TODO: automate this for policies with target network
# need to first collect some then sample, collect_freq must be larger than batch_size
collect_freq = 100 start_time = time.time()
while True: # until some stopping criterion met... for i in range(100):
# collect data # collect data
for i in range(0, collect_freq): data_collector.collect(num_episodes=50)
replay_memory.collect() # ShihongSong
collection_count += 1 # print current return
print('Collected {} times.'.format(collection_count)) print('Epoch {}:'.format(i))
data_collector.statistics()
# update network # update network
data = replay_memory.next_batch(10) # YouQiaoben, ShihongSong for _ in range(num_batches):
# TODO: auto managing of the placeholders? or add this to params of data.Batch feed_dict = data_collector.next_batch(batch_size)
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], target: data['target']}) sess.run(train_op, feed_dict=feed_dict)
minibatch_count += 1
print('Trained {} minibatches.'.format(minibatch_count))
# TODO: assigning pi to pi_old is not implemented yet print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))

View File

@ -61,7 +61,7 @@ if __name__ == '__main__':
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
### 3. define data collection ### 3. define data collection
training_data = Batch(env, pi, advantage_estimation.full_return) training_data = Batch(env, pi, [advantage_estimation.full_return], [pi])
### 4. start training ### 4. start training
config = tf.ConfigProto() config = tf.ConfigProto()
@ -69,7 +69,7 @@ if __name__ == '__main__':
with tf.Session(config=config) as sess: with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
# assign pi to pi_old # assign actor to pi_old
pi.sync_weights() # TODO: automate this for policies with target network pi.sync_weights() # TODO: automate this for policies with target network
start_time = time.time() start_time = time.time()
@ -86,7 +86,7 @@ if __name__ == '__main__':
feed_dict = training_data.next_batch(batch_size) feed_dict = training_data.next_batch(batch_size)
sess.run(train_op, feed_dict=feed_dict) sess.run(train_op, feed_dict=feed_dict)
# assigning pi to pi_old # assigning actor to pi_old
pi.update_weights() pi.update_weights()
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))

View File

@ -47,7 +47,7 @@ if __name__ == '__main__':
observation_dim = env.observation_space.shape observation_dim = env.observation_space.shape
action_dim = env.action_space.flat_dim action_dim = env.action_space.flat_dim
clip_param = 0.2 # clip_param = 0.2
num_batches = 10 num_batches = 10
batch_size = 128 batch_size = 128
@ -65,6 +65,7 @@ if __name__ == '__main__':
### 2. build policy, loss, optimizer ### 2. build policy, loss, optimizer
pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0) pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0)
clip_param = tf.placeholder(tf.float32, shape=(), name='ppo_loss_clip_param')
ppo_loss_clip = losses.ppo_clip(pi, clip_param) ppo_loss_clip = losses.ppo_clip(pi, clip_param)
total_loss = ppo_loss_clip total_loss = ppo_loss_clip
@ -72,7 +73,7 @@ if __name__ == '__main__':
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
### 3. define data collection ### 3. define data collection
training_data = Batch(env, pi, advantage_estimation.full_return) training_data = Batch(env, pi, [advantage_estimation.full_return], [pi])
### 4. start training ### 4. start training
feed_dict_train = {is_training_ph: True, keep_prob_ph: 0.8} feed_dict_train = {is_training_ph: True, keep_prob_ph: 0.8}
@ -83,7 +84,7 @@ if __name__ == '__main__':
with tf.Session(config=config) as sess: with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
# assign pi to pi_old # assign actor to pi_old
pi.sync_weights() # TODO: automate this for policies with target network pi.sync_weights() # TODO: automate this for policies with target network
start_time = time.time() start_time = time.time()
@ -95,13 +96,19 @@ if __name__ == '__main__':
print('Epoch {}:'.format(i)) print('Epoch {}:'.format(i))
training_data.statistics() training_data.statistics()
# manipulate decay_param
if i < 30:
feed_dict_train[clip_param] = 0.2
else:
feed_dict_train[clip_param] = 0.1
# update network # update network
for _ in range(num_batches): for _ in range(num_batches):
feed_dict = training_data.next_batch(batch_size) feed_dict = training_data.next_batch(batch_size)
feed_dict.update(feed_dict_train) feed_dict.update(feed_dict_train)
sess.run(train_op, feed_dict=feed_dict) sess.run(train_op, feed_dict=feed_dict)
# assigning pi to pi_old # assigning actor to pi_old
pi.update_weights() pi.update_weights()
# approximate test mode # approximate test mode

View File

@ -55,7 +55,7 @@ if __name__ == '__main__':
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
### 3. define data collection ### 3. define data collection
training_data = Batch(env, pi, advantage_estimation.full_return) training_data = Batch(env, pi, [advantage_estimation.full_return], [pi])
### 4. start training ### 4. start training
config = tf.ConfigProto() config = tf.ConfigProto()
@ -63,7 +63,7 @@ if __name__ == '__main__':
with tf.Session(config=config) as sess: with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
# assign pi to pi_old # assign actor to pi_old
pi.sync_weights() # TODO: automate this for policies with target network pi.sync_weights() # TODO: automate this for policies with target network
start_time = time.time() start_time = time.time()
@ -80,7 +80,7 @@ if __name__ == '__main__':
feed_dict = training_data.next_batch(batch_size) feed_dict = training_data.next_batch(batch_size)
sess.run(train_op, feed_dict=feed_dict) sess.run(train_op, feed_dict=feed_dict)
# assigning pi to pi_old # assigning actor to pi_old
pi.update_weights() pi.update_weights()
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))

15
internal_keys.md Normal file
View File

@ -0,0 +1,15 @@
network.managed_placeholders.keys()
data_collector.raw_data.keys()
data_collector.data.keys()
['observation']
['action']
['reward']
['start_flag']
['advantage'] > ['return'] # they may appear simultaneously

View File

@ -14,7 +14,7 @@ def ppo_clip(policy, clip_param):
action_ph = tf.placeholder(policy.act_dtype, shape=(None,) + policy.action_shape, name='ppo_clip_loss/action_placeholder') action_ph = tf.placeholder(policy.act_dtype, shape=(None,) + policy.action_shape, name='ppo_clip_loss/action_placeholder')
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder') advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder')
policy.managed_placeholders['action'] = action_ph policy.managed_placeholders['action'] = action_ph
policy.managed_placeholders['processed_reward'] = advantage_ph policy.managed_placeholders['advantage'] = advantage_ph
log_pi_act = policy.log_prob(action_ph) log_pi_act = policy.log_prob(action_ph)
log_pi_old_act = policy.log_prob_old(action_ph) log_pi_old_act = policy.log_prob_old(action_ph)
@ -24,7 +24,7 @@ def ppo_clip(policy, clip_param):
return ppo_clip_loss return ppo_clip_loss
def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"): def REINFORCE(policy):
""" """
vanilla policy gradient vanilla policy gradient
@ -34,10 +34,29 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
:param baseline: the baseline method used to reduce the variance, default is 'None' :param baseline: the baseline method used to reduce the variance, default is 'None'
:return: :return:
""" """
log_pi_act = pi.log_prob(sampled_action) action_ph = tf.placeholder(policy.act_dtype, shape=(None,) + policy.action_shape,
vanilla_policy_gradient_loss = tf.reduce_mean(reward * log_pi_act) name='REINFORCE/action_placeholder')
# TODO: Different baseline methods like REINFORCE, etc. advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='REINFORCE/advantage_placeholder')
return vanilla_policy_gradient_loss policy.managed_placeholders['action'] = action_ph
policy.managed_placeholders['advantage'] = advantage_ph
log_pi_act = policy.log_prob(action_ph)
REINFORCE_loss = -tf.reduce_mean(advantage_ph * log_pi_act)
return REINFORCE_loss
def state_value_mse(state_value_function):
"""
L2 loss of state value
:param state_value_function: instance of StateValue
:return: tensor of the mse loss
"""
state_value_ph = tf.placeholder(tf.float32, shape=(None,), name='state_value_mse/state_value_placeholder')
state_value_function.managed_placeholders['return'] = state_value_ph
state_value = state_value_function.value_tensor
return tf.losses.mean_squared_error(state_value_ph, state_value)
def dqn_loss(sampled_action, sampled_target, policy): def dqn_loss(sampled_action, sampled_target, policy):
""" """

View File

@ -44,7 +44,8 @@ class OnehotCategorical(StochasticPolicy):
self.weight_update = weight_update self.weight_update = weight_update
self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1. self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1.
with tf.variable_scope('network'): # build network, action and value
with tf.variable_scope('network', reuse=tf.AUTO_REUSE):
logits, value_head = policy_callable() logits, value_head = policy_callable()
self._logits = tf.convert_to_tensor(logits, dtype=tf.float32) self._logits = tf.convert_to_tensor(logits, dtype=tf.float32)
self._action = tf.multinomial(self._logits, num_samples=1) self._action = tf.multinomial(self._logits, num_samples=1)
@ -55,11 +56,12 @@ class OnehotCategorical(StochasticPolicy):
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network') self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')
# deal with target network
if self.weight_update == 1: if self.weight_update == 1:
self.weight_update_ops = None self.weight_update_ops = None
self.sync_weights_ops = None self.sync_weights_ops = None
else: # then we need to build another tf graph as target network else: # then we need to build another tf graph as target network
with tf.variable_scope('net_old'): with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE):
logits, value_head = policy_callable() logits, value_head = policy_callable()
self._logits_old = tf.convert_to_tensor(logits, dtype=tf.float32) self._logits_old = tf.convert_to_tensor(logits, dtype=tf.float32)
@ -173,7 +175,8 @@ class Normal(StochasticPolicy):
self.weight_update = weight_update self.weight_update = weight_update
self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1. self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1.
with tf.variable_scope('network'): # build network, action and value
with tf.variable_scope('network', reuse=tf.AUTO_REUSE):
mean, logstd, value_head = policy_callable() mean, logstd, value_head = policy_callable()
self._mean = tf.convert_to_tensor(mean, dtype = tf.float32) self._mean = tf.convert_to_tensor(mean, dtype = tf.float32)
self._logstd = tf.convert_to_tensor(logstd, dtype = tf.float32) self._logstd = tf.convert_to_tensor(logstd, dtype = tf.float32)
@ -188,11 +191,12 @@ class Normal(StochasticPolicy):
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network') self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')
# deal with target network
if self.weight_update == 1: if self.weight_update == 1:
self.weight_update_ops = None self.weight_update_ops = None
self.sync_weights_ops = None self.sync_weights_ops = None
else: # then we need to build another tf graph as target network else: # then we need to build another tf graph as target network
with tf.variable_scope('net_old'): with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE):
mean, logstd, value_head = policy_callable() mean, logstd, value_head = policy_callable()
self._mean_old = tf.convert_to_tensor(mean, dtype=tf.float32) self._mean_old = tf.convert_to_tensor(mean, dtype=tf.float32)
self._logstd_old = tf.convert_to_tensor(logstd, dtype=tf.float32) self._logstd_old = tf.convert_to_tensor(logstd, dtype=tf.float32)

View File

@ -8,7 +8,12 @@ class StateValue(ValueFunctionBase):
""" """
class of state values V(s). class of state values V(s).
""" """
def __init__(self, value_tensor, observation_placeholder): def __init__(self, policy_callable, observation_placeholder):
self.managed_placeholders = {'observation': observation_placeholder}
with tf.variable_scope('network', reuse=tf.AUTO_REUSE):
value_tensor = policy_callable()[-1]
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
super(StateValue, self).__init__( super(StateValue, self).__init__(
value_tensor=value_tensor, value_tensor=value_tensor,
observation_placeholder=observation_placeholder observation_placeholder=observation_placeholder

View File

@ -6,15 +6,13 @@ def full_return(raw_data):
naively compute full return naively compute full return
:param raw_data: dict of specified keys and values. :param raw_data: dict of specified keys and values.
""" """
observations = raw_data['observations'] observations = raw_data['observation']
actions = raw_data['actions'] actions = raw_data['action']
rewards = raw_data['rewards'] rewards = raw_data['reward']
episode_start_flags = raw_data['episode_start_flags'] episode_start_flags = raw_data['end_flag']
num_timesteps = rewards.shape[0] num_timesteps = rewards.shape[0]
data = {} data = {}
data['observations'] = observations
data['actions'] = actions
returns = rewards.copy() returns = rewards.copy()
episode_start_idx = 0 episode_start_idx = 0
@ -33,11 +31,39 @@ def full_return(raw_data):
episode_start_idx = i episode_start_idx = i
data['returns'] = returns data['return'] = returns
return data return data
class gae_lambda:
"""
Generalized Advantage Estimation (Schulman, 15) to compute advantage
"""
def __init__(self, T, value_function):
self.T = T
self.value_function = value_function
def __call__(self, raw_data):
reward = raw_data['reward']
return {'advantage': reward}
class nstep_return:
"""
compute the n-step return from n-step rewards and bootstrapped value function
"""
def __init__(self, n, value_function):
self.n = n
self.value_function = value_function
def __call__(self, raw_data):
reward = raw_data['reward']
return {'return': reward}
class QLearningTarget: class QLearningTarget:
def __init__(self, policy, gamma): def __init__(self, policy, gamma):
self._policy = policy self._policy = policy
@ -68,3 +94,4 @@ class QLearningTarget:
data['rewards'] = np.array(rewards) data['rewards'] = np.array(rewards)
return data return data

View File

@ -1,6 +1,7 @@
import numpy as np import numpy as np
import gc import gc
import logging
from . import utils
# TODO: Refactor with tf.train.slice_input_producer, tf.train.Coordinator, tf.train.QueueRunner # TODO: Refactor with tf.train.slice_input_producer, tf.train.Coordinator, tf.train.QueueRunner
class Batch(object): class Batch(object):
@ -8,14 +9,31 @@ class Batch(object):
class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy. class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy.
""" """
def __init__(self, env, pi, advantage_estimation_function): # how to name the function? def __init__(self, env, pi, reward_processors, networks): # how to name the function?
"""
constructor
:param env:
:param pi:
:param reward_processors: list of functions to process reward
:param networks: list of networks to be optimized, so as to match data in feed_dict
"""
self._env = env self._env = env
self._pi = pi self._pi = pi
self._advantage_estimation_function = advantage_estimation_function self.raw_data = {}
self.data = {}
self.reward_processors = reward_processors
self.networks = networks
self.required_placeholders = {}
for net in self.networks:
self.required_placeholders.update(net.managed_placeholders)
self.require_advantage = 'advantage' in self.required_placeholders.keys()
self._is_first_collect = True self._is_first_collect = True
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={},
apply_function=True): # specify how many data to collect here, or fix it in __init__() process_reward=True): # specify how many data to collect here, or fix it in __init__()
assert sum( assert sum(
[num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!" [num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!"
@ -98,6 +116,7 @@ class Batch(object):
break break
if done: # end of episode, discard s_T if done: # end of episode, discard s_T
# TODO: for num_timesteps collection, has to store terminal flag instead of start flag!
break break
else: else:
observations.append(ob) observations.append(ob)
@ -113,33 +132,48 @@ class Batch(object):
del rewards del rewards
del episode_start_flags del episode_start_flags
self.raw_data = {'observations': self.observations, 'actions': self.actions, 'rewards': self.rewards, self.raw_data = {'observation': self.observations, 'action': self.actions, 'reward': self.rewards,
'episode_start_flags': self.episode_start_flags} 'end_flag': self.episode_start_flags}
self._is_first_collect = False self._is_first_collect = False
if apply_function: if process_reward:
self.apply_advantage_estimation_function() self.apply_advantage_estimation_function()
gc.collect() gc.collect()
def apply_advantage_estimation_function(self): def apply_advantage_estimation_function(self):
self.data = self._advantage_estimation_function(self.raw_data) for processor in self.reward_processors:
self.data.update(processor(self.raw_data))
def next_batch(self, batch_size, standardize_advantage=True): # YouQiaoben: referencing other iterate over batches def next_batch(self, batch_size, standardize_advantage=True):
rand_idx = np.random.choice(self.data['observations'].shape[0], batch_size) rand_idx = np.random.choice(self.raw_data['observation'].shape[0], batch_size)
current_batch = {key: value[rand_idx] for key, value in self.data.items()}
if standardize_advantage:
advantage_mean = np.mean(current_batch['returns'])
advantage_std = np.std(current_batch['returns'])
current_batch['returns'] = (current_batch['returns'] - advantage_mean) / advantage_std
feed_dict = {} feed_dict = {}
feed_dict[self._pi.managed_placeholders['observation']] = current_batch['observations'] for key, placeholder in self.required_placeholders.items():
feed_dict[self._pi.managed_placeholders['action']] = current_batch['actions'] found, data_key = utils.internal_key_match(key, self.raw_data.keys())
feed_dict[self._pi.managed_placeholders['processed_reward']] = current_batch['returns'] if found:
# TODO: should use the keys in pi.managed_placeholders to find values in self.data and self.raw_data feed_dict[placeholder] = self.raw_data[data_key][rand_idx]
else:
found, data_key = utils.internal_key_match(key, self.data.keys())
if found:
feed_dict[placeholder] = self.data[data_key][rand_idx]
if not found:
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
if standardize_advantage:
if self.require_advantage:
advantage_value = feed_dict[self.required_placeholders['advantage']]
advantage_mean = np.mean(advantage_value)
advantage_std = np.std(advantage_value)
if advantage_std < 1e-3:
logging.warning('advantage_std too small (< 1e-3) for advantage standardization. may cause numerical issues')
feed_dict[self.required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std
# TODO: maybe move all advantage estimation functions to tf, as in tensorforce (though haven't
# understood tensorforce after reading) maybe tf.stop_gradient for targets/advantages
# this will simplify data collector as it only needs to collect raw data, (s, a, r, done) only
return feed_dict return feed_dict
@ -149,8 +183,8 @@ class Batch(object):
compute the statistics of the current sampled paths compute the statistics of the current sampled paths
:return: :return:
""" """
rewards = self.raw_data['rewards'] rewards = self.raw_data['reward']
episode_start_flags = self.raw_data['episode_start_flags'] episode_start_flags = self.raw_data['end_flag']
num_timesteps = rewards.shape[0] num_timesteps = rewards.shape[0]
returns = [] returns = []