fixed the bugs on Jan 14, which gives inferior or even no improvement. mistook group_ndims. policy will soon need refactoring.
This commit is contained in:
parent
d599506dc9
commit
ed25bf7586
24
README.md
24
README.md
@ -15,6 +15,30 @@ Tianshou(天授) is a reinforcement learning platform. The following image illus
|
|||||||
|
|
||||||
Specific network architectures in original paper of DQN, TRPO, A3C, etc. Policy-Value Network of AlphaGo Zero
|
Specific network architectures in original paper of DQN, TRPO, A3C, etc. Policy-Value Network of AlphaGo Zero
|
||||||
|
|
||||||
|
#### brief intro of current implementation:
|
||||||
|
|
||||||
|
how to write your own network:
|
||||||
|
- define the observation placeholder yourself, pass it to `observation_placeholder` when initializing a policy instance
|
||||||
|
- pass a callable when initializing a policy instance. The callable should satisfy only three conditions:
|
||||||
|
- it accepts no parameters
|
||||||
|
- it does not create any new placeholders
|
||||||
|
- it returns `action-related tensors, value_head`
|
||||||
|
|
||||||
|
Our lib will take care of your observation placeholder from now on, as well as
|
||||||
|
all the placeholders that will be created by our lib.
|
||||||
|
|
||||||
|
The other placeholders, such as `keep_prob` in dropout and `clip_param` in ppo loss
|
||||||
|
should be managed by your own (see examples/ppo_cartpole_alternative.py)
|
||||||
|
|
||||||
|
The `weight_update` parameter:
|
||||||
|
- 0 means manually update target network
|
||||||
|
- 1 means no target network (the target network is updated every 1 minibatch)
|
||||||
|
- (0, 1) is the target network as used in DDPG
|
||||||
|
- greater than 1 is the target network as used in DQN
|
||||||
|
|
||||||
|
Other comments are in the python files in example/ and in the lib codes.
|
||||||
|
Refactor is definitely needed so don't dwell too much on annoying details...
|
||||||
|
|
||||||
### Algorithm
|
### Algorithm
|
||||||
|
|
||||||
#### losses
|
#### losses
|
||||||
|
|||||||
94
examples/actor_critic_cartpole.py
Executable file
94
examples/actor_critic_cartpole.py
Executable file
@ -0,0 +1,94 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# our lib imports here! It's ok to append path in examples
|
||||||
|
import sys
|
||||||
|
sys.path.append('..')
|
||||||
|
from tianshou.core import losses
|
||||||
|
from tianshou.data.batch import Batch
|
||||||
|
import tianshou.data.advantage_estimation as advantage_estimation
|
||||||
|
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
||||||
|
import tianshou.core.value_function.state_value as value_function
|
||||||
|
|
||||||
|
from rllab.envs.box2d.cartpole_env import CartpoleEnv
|
||||||
|
from rllab.envs.normalized_env import normalize
|
||||||
|
|
||||||
|
|
||||||
|
# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
env = normalize(CartpoleEnv())
|
||||||
|
observation_dim = env.observation_space.shape
|
||||||
|
action_dim = env.action_space.flat_dim
|
||||||
|
|
||||||
|
clip_param = 0.2
|
||||||
|
num_batches = 10
|
||||||
|
batch_size = 128
|
||||||
|
|
||||||
|
seed = 10
|
||||||
|
np.random.seed(seed)
|
||||||
|
tf.set_random_seed(seed)
|
||||||
|
|
||||||
|
### 1. build network with pure tf
|
||||||
|
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
|
||||||
|
|
||||||
|
def my_network():
|
||||||
|
# placeholders defined in this function would be very difficult to manage
|
||||||
|
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
|
||||||
|
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
|
||||||
|
|
||||||
|
action_mean = tf.layers.dense(net, action_dim, activation=None)
|
||||||
|
action_logstd = tf.get_variable('action_logstd', shape=(action_dim, ))
|
||||||
|
value = tf.layers.dense(net, 1, activation=None)
|
||||||
|
|
||||||
|
return action_mean, action_logstd, value
|
||||||
|
# TODO: overriding seems not able to handle shared layers, unless a new class `SharedPolicyValue`
|
||||||
|
# maybe the most desired thing is to freely build policy and value function from any tensor?
|
||||||
|
# but for now, only the outputs of the network matters
|
||||||
|
|
||||||
|
### 2. build policy, critic, loss, optimizer
|
||||||
|
actor = policy.Normal(my_network, observation_placeholder=observation_ph, weight_update=1)
|
||||||
|
critic = value_function.StateValue(my_network, observation_placeholder=observation_ph)
|
||||||
|
|
||||||
|
actor_loss = losses.REINFORCE(actor)
|
||||||
|
critic_loss = losses.state_value_mse(critic)
|
||||||
|
total_loss = actor_loss + critic_loss
|
||||||
|
|
||||||
|
optimizer = tf.train.AdamOptimizer(1e-4)
|
||||||
|
|
||||||
|
# this hack would be unnecessary if we have a `SharedPolicyValue` class, or hack the trainable_variables management
|
||||||
|
var_list = list(set(actor.trainable_variables + critic.trainable_variables))
|
||||||
|
|
||||||
|
train_op = optimizer.minimize(total_loss, var_list=var_list)
|
||||||
|
|
||||||
|
### 3. define data collection
|
||||||
|
data_collector = Batch(env, actor,
|
||||||
|
[advantage_estimation.gae_lambda(1, critic), advantage_estimation.nstep_return(1, critic)],
|
||||||
|
[actor, critic])
|
||||||
|
# TODO: refactor this, data_collector should be just the top-level abstraction
|
||||||
|
|
||||||
|
### 4. start training
|
||||||
|
config = tf.ConfigProto()
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
with tf.Session(config=config) as sess:
|
||||||
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
for i in range(100):
|
||||||
|
# collect data
|
||||||
|
data_collector.collect(num_episodes=20)
|
||||||
|
|
||||||
|
# print current return
|
||||||
|
print('Epoch {}:'.format(i))
|
||||||
|
data_collector.statistics()
|
||||||
|
|
||||||
|
# update network
|
||||||
|
for _ in range(num_batches):
|
||||||
|
feed_dict = data_collector.next_batch(batch_size)
|
||||||
|
sess.run(train_op, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
||||||
98
examples/actor_critic_fail_cartpole.py
Executable file
98
examples/actor_critic_fail_cartpole.py
Executable file
@ -0,0 +1,98 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# our lib imports here! It's ok to append path in examples
|
||||||
|
import sys
|
||||||
|
sys.path.append('..')
|
||||||
|
from tianshou.core import losses
|
||||||
|
from tianshou.data.batch import Batch
|
||||||
|
import tianshou.data.advantage_estimation as advantage_estimation
|
||||||
|
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
||||||
|
import tianshou.core.value_function.state_value as value_function
|
||||||
|
|
||||||
|
from rllab.envs.box2d.cartpole_env import CartpoleEnv
|
||||||
|
from rllab.envs.normalized_env import normalize
|
||||||
|
|
||||||
|
|
||||||
|
# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
env = normalize(CartpoleEnv())
|
||||||
|
observation_dim = env.observation_space.shape
|
||||||
|
action_dim = env.action_space.flat_dim
|
||||||
|
|
||||||
|
clip_param = 0.2
|
||||||
|
num_batches = 10
|
||||||
|
batch_size = 128
|
||||||
|
|
||||||
|
seed = 10
|
||||||
|
np.random.seed(seed)
|
||||||
|
tf.set_random_seed(seed)
|
||||||
|
|
||||||
|
### 1. build network with pure tf
|
||||||
|
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
|
||||||
|
|
||||||
|
def my_actor():
|
||||||
|
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
|
||||||
|
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
|
||||||
|
|
||||||
|
action_mean = tf.layers.dense(net, action_dim, activation=None)
|
||||||
|
action_logstd = tf.get_variable('action_logstd', shape=(action_dim, ))
|
||||||
|
|
||||||
|
return action_mean, action_logstd, None
|
||||||
|
|
||||||
|
def my_critic():
|
||||||
|
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
|
||||||
|
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
|
||||||
|
value = tf.layers.dense(net, 1, activation=None)
|
||||||
|
|
||||||
|
return None, value
|
||||||
|
|
||||||
|
### 2. build policy, critic, loss, optimizer
|
||||||
|
actor = policy.Normal(my_actor, observation_placeholder=observation_ph, weight_update=1)
|
||||||
|
critic = value_function.StateValue(my_critic, observation_placeholder=observation_ph)
|
||||||
|
|
||||||
|
print('actor and critic will share variables in this case')
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
actor_loss = losses.vanilla_policy_gradient(actor)
|
||||||
|
critic_loss = losses.state_value_mse(critic)
|
||||||
|
total_loss = actor_loss + critic_loss
|
||||||
|
|
||||||
|
optimizer = tf.train.AdamOptimizer(1e-4)
|
||||||
|
train_op = optimizer.minimize(total_loss, var_list=actor.trainable_variables)
|
||||||
|
|
||||||
|
### 3. define data collection
|
||||||
|
training_data = Batch(env, actor, advantage_estimation.full_return)
|
||||||
|
|
||||||
|
### 4. start training
|
||||||
|
config = tf.ConfigProto()
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
with tf.Session(config=config) as sess:
|
||||||
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
|
# assign actor to pi_old
|
||||||
|
actor.sync_weights() # TODO: automate this for policies with target network
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
for i in range(100):
|
||||||
|
# collect data
|
||||||
|
training_data.collect(num_episodes=20)
|
||||||
|
|
||||||
|
# print current return
|
||||||
|
print('Epoch {}:'.format(i))
|
||||||
|
training_data.statistics()
|
||||||
|
|
||||||
|
# update network
|
||||||
|
for _ in range(num_batches):
|
||||||
|
feed_dict = training_data.next_batch(batch_size)
|
||||||
|
sess.run(train_op, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
# assigning actor to pi_old
|
||||||
|
actor.update_weights()
|
||||||
|
|
||||||
|
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
||||||
90
examples/actor_critic_separate_cartpole.py
Executable file
90
examples/actor_critic_separate_cartpole.py
Executable file
@ -0,0 +1,90 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# our lib imports here! It's ok to append path in examples
|
||||||
|
import sys
|
||||||
|
sys.path.append('..')
|
||||||
|
from tianshou.core import losses
|
||||||
|
from tianshou.data.batch import Batch
|
||||||
|
import tianshou.data.advantage_estimation as advantage_estimation
|
||||||
|
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
||||||
|
import tianshou.core.value_function.state_value as value_function
|
||||||
|
|
||||||
|
from rllab.envs.box2d.cartpole_env import CartpoleEnv
|
||||||
|
from rllab.envs.normalized_env import normalize
|
||||||
|
|
||||||
|
|
||||||
|
# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
env = normalize(CartpoleEnv())
|
||||||
|
observation_dim = env.observation_space.shape
|
||||||
|
action_dim = env.action_space.flat_dim
|
||||||
|
|
||||||
|
clip_param = 0.2
|
||||||
|
num_batches = 10
|
||||||
|
batch_size = 128
|
||||||
|
|
||||||
|
seed = 10
|
||||||
|
np.random.seed(seed)
|
||||||
|
tf.set_random_seed(seed)
|
||||||
|
|
||||||
|
### 1. build network with pure tf
|
||||||
|
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
|
||||||
|
|
||||||
|
def my_network():
|
||||||
|
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
|
||||||
|
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
|
||||||
|
|
||||||
|
action_mean = tf.layers.dense(net, action_dim, activation=None)
|
||||||
|
action_logstd = tf.get_variable('action_logstd', shape=(action_dim, ))
|
||||||
|
|
||||||
|
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
|
||||||
|
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
|
||||||
|
value = tf.layers.dense(net, 1, activation=None)
|
||||||
|
|
||||||
|
return action_mean, action_logstd, value
|
||||||
|
|
||||||
|
### 2. build policy, critic, loss, optimizer
|
||||||
|
actor = policy.Normal(my_network, observation_placeholder=observation_ph, weight_update=1)
|
||||||
|
critic = value_function.StateValue(my_network, observation_placeholder=observation_ph)
|
||||||
|
|
||||||
|
actor_loss = losses.REINFORCE(actor)
|
||||||
|
critic_loss = losses.state_value_mse(critic)
|
||||||
|
|
||||||
|
actor_optimizer = tf.train.AdamOptimizer(1e-4)
|
||||||
|
actor_train_op = actor_optimizer.minimize(actor_loss, var_list=actor.trainable_variables)
|
||||||
|
|
||||||
|
critic_optimizer = tf.train.RMSPropOptimizer(1e-4)
|
||||||
|
critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables)
|
||||||
|
|
||||||
|
### 3. define data collection
|
||||||
|
data_collector = Batch(env, actor,
|
||||||
|
[advantage_estimation.gae_lambda(1, critic), advantage_estimation.nstep_return(1, critic)],
|
||||||
|
[actor, critic])
|
||||||
|
|
||||||
|
### 4. start training
|
||||||
|
config = tf.ConfigProto()
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
with tf.Session(config=config) as sess:
|
||||||
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
for i in range(100):
|
||||||
|
# collect data
|
||||||
|
data_collector.collect(num_episodes=20)
|
||||||
|
|
||||||
|
# print current return
|
||||||
|
print('Epoch {}:'.format(i))
|
||||||
|
data_collector.statistics()
|
||||||
|
|
||||||
|
# update network
|
||||||
|
for _ in range(num_batches):
|
||||||
|
feed_dict = data_collector.next_batch(batch_size)
|
||||||
|
sess.run([actor_train_op, critic_train_op], feed_dict=feed_dict)
|
||||||
|
|
||||||
|
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
||||||
95
examples/contrib_dqn_example.py
Normal file
95
examples/contrib_dqn_example.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import gym
|
||||||
|
|
||||||
|
# our lib imports here!
|
||||||
|
import sys
|
||||||
|
sys.path.append('..')
|
||||||
|
import tianshou.core.losses as losses
|
||||||
|
from tianshou.data.replay_buffer.utils import get_replay_buffer
|
||||||
|
import tianshou.core.policy.dqn as policy
|
||||||
|
|
||||||
|
|
||||||
|
# THIS EXAMPLE IS NOT FINISHED YET!!!
|
||||||
|
|
||||||
|
|
||||||
|
def policy_net(observation, action_dim):
|
||||||
|
"""
|
||||||
|
Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf
|
||||||
|
|
||||||
|
:param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels)
|
||||||
|
:param action_dim: int. The number of actions.
|
||||||
|
:param scope: str. Specifying the scope of the variables.
|
||||||
|
"""
|
||||||
|
net = tf.layers.conv2d(observation, 16, 8, 4, 'valid', activation=tf.nn.relu)
|
||||||
|
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
|
||||||
|
net = tf.layers.flatten(net)
|
||||||
|
net = tf.layers.dense(net, 256, activation=tf.nn.relu)
|
||||||
|
|
||||||
|
q_values = tf.layers.dense(net, action_dim)
|
||||||
|
|
||||||
|
return q_values
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
env = gym.make('PongNoFrameskip-v4')
|
||||||
|
observation_dim = env.observation_space.shape
|
||||||
|
action_dim = env.action_space.n
|
||||||
|
|
||||||
|
# 1. build network with pure tf
|
||||||
|
# TODO:
|
||||||
|
# pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer
|
||||||
|
# access this observation variable.
|
||||||
|
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
|
||||||
|
action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions
|
||||||
|
|
||||||
|
|
||||||
|
with tf.variable_scope('q_net'):
|
||||||
|
q_values = policy_net(observation, action_dim)
|
||||||
|
with tf.variable_scope('target_net'):
|
||||||
|
q_values_target = policy_net(observation, action_dim)
|
||||||
|
|
||||||
|
# 2. build losses, optimizers
|
||||||
|
q_net = policy.DQNRefactor(q_values, observation_placeholder=observation, action_placeholder=action) # YongRen: policy.DQN
|
||||||
|
target_net = policy.DQNRefactor(q_values_target, observation_placeholder=observation, action_placeholder=action)
|
||||||
|
|
||||||
|
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN
|
||||||
|
|
||||||
|
dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen
|
||||||
|
global_step = tf.Variable(0, name='global_step', trainable=False)
|
||||||
|
train_var_list = tf.get_collection(
|
||||||
|
tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
|
||||||
|
total_loss = dqn_loss
|
||||||
|
optimizer = tf.train.AdamOptimizer(1e-3)
|
||||||
|
train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step())
|
||||||
|
# 3. define data collection
|
||||||
|
# configuration should be given as parameters, different replay buffer has different parameters.
|
||||||
|
replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net,
|
||||||
|
{'size': 1000, 'batch_size': 64, 'learn_start': 20})
|
||||||
|
# ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
|
||||||
|
# maybe a dict to manage the elements to be collected
|
||||||
|
|
||||||
|
# 4. start training
|
||||||
|
with tf.Session() as sess:
|
||||||
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
|
minibatch_count = 0
|
||||||
|
collection_count = 0
|
||||||
|
# need to first collect some then sample, collect_freq must be larger than batch_size
|
||||||
|
collect_freq = 100
|
||||||
|
while True: # until some stopping criterion met...
|
||||||
|
# collect data
|
||||||
|
for i in range(0, collect_freq):
|
||||||
|
replay_memory.collect() # ShihongSong
|
||||||
|
collection_count += 1
|
||||||
|
print('Collected {} times.'.format(collection_count))
|
||||||
|
|
||||||
|
# update network
|
||||||
|
data = replay_memory.next_batch(10) # YouQiaoben, ShihongSong
|
||||||
|
# TODO: auto managing of the placeholders? or add this to params of data.Batch
|
||||||
|
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], target: data['target']})
|
||||||
|
minibatch_count += 1
|
||||||
|
print('Trained {} minibatches.'.format(minibatch_count))
|
||||||
|
|
||||||
|
# TODO: assigning pi to pi_old is not implemented yet
|
||||||
@ -1,95 +1,83 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import gym
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
# our lib imports here!
|
# our lib imports here! It's ok to append path in examples
|
||||||
import sys
|
import sys
|
||||||
sys.path.append('..')
|
sys.path.append('..')
|
||||||
import tianshou.core.losses as losses
|
from tianshou.core import losses
|
||||||
from tianshou.data.replay_buffer.utils import get_replay_buffer
|
from tianshou.data.batch import Batch
|
||||||
import tianshou.core.policy.dqn as policy
|
import tianshou.data.advantage_estimation as advantage_estimation
|
||||||
|
import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
||||||
|
|
||||||
# THIS EXAMPLE IS NOT FINISHED YET!!!
|
|
||||||
|
|
||||||
|
|
||||||
def policy_net(observation, action_dim):
|
|
||||||
"""
|
|
||||||
Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf
|
|
||||||
|
|
||||||
:param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels)
|
|
||||||
:param action_dim: int. The number of actions.
|
|
||||||
:param scope: str. Specifying the scope of the variables.
|
|
||||||
"""
|
|
||||||
net = tf.layers.conv2d(observation, 16, 8, 4, 'valid', activation=tf.nn.relu)
|
|
||||||
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
|
|
||||||
net = tf.layers.flatten(net)
|
|
||||||
net = tf.layers.dense(net, 256, activation=tf.nn.relu)
|
|
||||||
|
|
||||||
q_values = tf.layers.dense(net, action_dim)
|
|
||||||
|
|
||||||
return q_values
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
env = gym.make('PongNoFrameskip-v4')
|
env = gym.make('CartPole-v0')
|
||||||
observation_dim = env.observation_space.shape
|
observation_dim = env.observation_space.shape
|
||||||
action_dim = env.action_space.n
|
action_dim = env.action_space.n
|
||||||
|
|
||||||
# 1. build network with pure tf
|
clip_param = 0.2
|
||||||
# TODO:
|
num_batches = 10
|
||||||
# pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer
|
batch_size = 512
|
||||||
# access this observation variable.
|
|
||||||
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
|
|
||||||
action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions
|
|
||||||
|
|
||||||
|
seed = 0
|
||||||
|
np.random.seed(seed)
|
||||||
|
tf.set_random_seed(seed)
|
||||||
|
|
||||||
with tf.variable_scope('q_net'):
|
### 1. build network with pure tf
|
||||||
q_values = policy_net(observation, action_dim)
|
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
|
||||||
with tf.variable_scope('target_net'):
|
|
||||||
q_values_target = policy_net(observation, action_dim)
|
|
||||||
|
|
||||||
# 2. build losses, optimizers
|
def my_policy():
|
||||||
q_net = policy.DQNRefactor(q_values, observation_placeholder=observation, action_placeholder=action) # YongRen: policy.DQN
|
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
|
||||||
target_net = policy.DQNRefactor(q_values_target, observation_placeholder=observation, action_placeholder=action)
|
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
|
||||||
|
|
||||||
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN
|
action_values = tf.layers.dense(net, action_dim, activation=None)
|
||||||
|
|
||||||
|
return action_values, None # None value head
|
||||||
|
|
||||||
|
# TODO: current implementation of passing function or overriding function has to return a value head
|
||||||
|
# to allow network sharing between policy and value networks. This makes 'policy' and 'value_function'
|
||||||
|
# imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact
|
||||||
|
# with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is
|
||||||
|
# not based on passing function or overriding function.
|
||||||
|
|
||||||
|
### 2. build policy, loss, optimizer
|
||||||
|
pi = policy.DQN(my_policy, observation_placeholder=observation_ph, weight_update=10)
|
||||||
|
|
||||||
|
dqn_loss = losses.qlearning(pi)
|
||||||
|
|
||||||
dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen
|
|
||||||
global_step = tf.Variable(0, name='global_step', trainable=False)
|
|
||||||
train_var_list = tf.get_collection(
|
|
||||||
tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
|
|
||||||
total_loss = dqn_loss
|
total_loss = dqn_loss
|
||||||
optimizer = tf.train.AdamOptimizer(1e-3)
|
optimizer = tf.train.AdamOptimizer(1e-4)
|
||||||
train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step())
|
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
||||||
# 3. define data collection
|
|
||||||
# configuration should be given as parameters, different replay buffer has different parameters.
|
|
||||||
replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net,
|
|
||||||
{'size': 1000, 'batch_size': 64, 'learn_start': 20})
|
|
||||||
# ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
|
|
||||||
# maybe a dict to manage the elements to be collected
|
|
||||||
|
|
||||||
# 4. start training
|
### 3. define data collection
|
||||||
with tf.Session() as sess:
|
data_collector = Batch(env, pi, [advantage_estimation.nstep_q_return(1, pi.target_network)], [pi])
|
||||||
|
|
||||||
|
### 4. start training
|
||||||
|
config = tf.ConfigProto()
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
with tf.Session(config=config) as sess:
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
minibatch_count = 0
|
# assign actor to pi_old
|
||||||
collection_count = 0
|
pi.sync_weights() # TODO: automate this for policies with target network
|
||||||
# need to first collect some then sample, collect_freq must be larger than batch_size
|
|
||||||
collect_freq = 100
|
start_time = time.time()
|
||||||
while True: # until some stopping criterion met...
|
for i in range(100):
|
||||||
# collect data
|
# collect data
|
||||||
for i in range(0, collect_freq):
|
data_collector.collect(num_episodes=50)
|
||||||
replay_memory.collect() # ShihongSong
|
|
||||||
collection_count += 1
|
# print current return
|
||||||
print('Collected {} times.'.format(collection_count))
|
print('Epoch {}:'.format(i))
|
||||||
|
data_collector.statistics()
|
||||||
|
|
||||||
# update network
|
# update network
|
||||||
data = replay_memory.next_batch(10) # YouQiaoben, ShihongSong
|
for _ in range(num_batches):
|
||||||
# TODO: auto managing of the placeholders? or add this to params of data.Batch
|
feed_dict = data_collector.next_batch(batch_size)
|
||||||
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], target: data['target']})
|
sess.run(train_op, feed_dict=feed_dict)
|
||||||
minibatch_count += 1
|
|
||||||
print('Trained {} minibatches.'.format(minibatch_count))
|
|
||||||
|
|
||||||
# TODO: assigning pi to pi_old is not implemented yet
|
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
||||||
@ -61,7 +61,7 @@ if __name__ == '__main__':
|
|||||||
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
||||||
|
|
||||||
### 3. define data collection
|
### 3. define data collection
|
||||||
training_data = Batch(env, pi, advantage_estimation.full_return)
|
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi])
|
||||||
|
|
||||||
### 4. start training
|
### 4. start training
|
||||||
config = tf.ConfigProto()
|
config = tf.ConfigProto()
|
||||||
@ -69,7 +69,7 @@ if __name__ == '__main__':
|
|||||||
with tf.Session(config=config) as sess:
|
with tf.Session(config=config) as sess:
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# assign pi to pi_old
|
# assign actor to pi_old
|
||||||
pi.sync_weights() # TODO: automate this for policies with target network
|
pi.sync_weights() # TODO: automate this for policies with target network
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -86,7 +86,7 @@ if __name__ == '__main__':
|
|||||||
feed_dict = training_data.next_batch(batch_size)
|
feed_dict = training_data.next_batch(batch_size)
|
||||||
sess.run(train_op, feed_dict=feed_dict)
|
sess.run(train_op, feed_dict=feed_dict)
|
||||||
|
|
||||||
# assigning pi to pi_old
|
# assigning actor to pi_old
|
||||||
pi.update_weights()
|
pi.update_weights()
|
||||||
|
|
||||||
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
||||||
@ -47,7 +47,7 @@ if __name__ == '__main__':
|
|||||||
observation_dim = env.observation_space.shape
|
observation_dim = env.observation_space.shape
|
||||||
action_dim = env.action_space.flat_dim
|
action_dim = env.action_space.flat_dim
|
||||||
|
|
||||||
clip_param = 0.2
|
# clip_param = 0.2
|
||||||
num_batches = 10
|
num_batches = 10
|
||||||
batch_size = 128
|
batch_size = 128
|
||||||
|
|
||||||
@ -65,6 +65,7 @@ if __name__ == '__main__':
|
|||||||
### 2. build policy, loss, optimizer
|
### 2. build policy, loss, optimizer
|
||||||
pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0)
|
pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0)
|
||||||
|
|
||||||
|
clip_param = tf.placeholder(tf.float32, shape=(), name='ppo_loss_clip_param')
|
||||||
ppo_loss_clip = losses.ppo_clip(pi, clip_param)
|
ppo_loss_clip = losses.ppo_clip(pi, clip_param)
|
||||||
|
|
||||||
total_loss = ppo_loss_clip
|
total_loss = ppo_loss_clip
|
||||||
@ -72,7 +73,7 @@ if __name__ == '__main__':
|
|||||||
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
||||||
|
|
||||||
### 3. define data collection
|
### 3. define data collection
|
||||||
training_data = Batch(env, pi, advantage_estimation.full_return)
|
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi])
|
||||||
|
|
||||||
### 4. start training
|
### 4. start training
|
||||||
feed_dict_train = {is_training_ph: True, keep_prob_ph: 0.8}
|
feed_dict_train = {is_training_ph: True, keep_prob_ph: 0.8}
|
||||||
@ -83,7 +84,7 @@ if __name__ == '__main__':
|
|||||||
with tf.Session(config=config) as sess:
|
with tf.Session(config=config) as sess:
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# assign pi to pi_old
|
# assign actor to pi_old
|
||||||
pi.sync_weights() # TODO: automate this for policies with target network
|
pi.sync_weights() # TODO: automate this for policies with target network
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -95,13 +96,19 @@ if __name__ == '__main__':
|
|||||||
print('Epoch {}:'.format(i))
|
print('Epoch {}:'.format(i))
|
||||||
training_data.statistics()
|
training_data.statistics()
|
||||||
|
|
||||||
|
# manipulate decay_param
|
||||||
|
if i < 30:
|
||||||
|
feed_dict_train[clip_param] = 0.2
|
||||||
|
else:
|
||||||
|
feed_dict_train[clip_param] = 0.1
|
||||||
|
|
||||||
# update network
|
# update network
|
||||||
for _ in range(num_batches):
|
for _ in range(num_batches):
|
||||||
feed_dict = training_data.next_batch(batch_size)
|
feed_dict = training_data.next_batch(batch_size)
|
||||||
feed_dict.update(feed_dict_train)
|
feed_dict.update(feed_dict_train)
|
||||||
sess.run(train_op, feed_dict=feed_dict)
|
sess.run(train_op, feed_dict=feed_dict)
|
||||||
|
|
||||||
# assigning pi to pi_old
|
# assigning actor to pi_old
|
||||||
pi.update_weights()
|
pi.update_weights()
|
||||||
|
|
||||||
# approximate test mode
|
# approximate test mode
|
||||||
|
|||||||
@ -55,7 +55,7 @@ if __name__ == '__main__':
|
|||||||
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
||||||
|
|
||||||
### 3. define data collection
|
### 3. define data collection
|
||||||
training_data = Batch(env, pi, advantage_estimation.full_return)
|
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi])
|
||||||
|
|
||||||
### 4. start training
|
### 4. start training
|
||||||
config = tf.ConfigProto()
|
config = tf.ConfigProto()
|
||||||
@ -63,7 +63,7 @@ if __name__ == '__main__':
|
|||||||
with tf.Session(config=config) as sess:
|
with tf.Session(config=config) as sess:
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
# assign pi to pi_old
|
# assign actor to pi_old
|
||||||
pi.sync_weights() # TODO: automate this for policies with target network
|
pi.sync_weights() # TODO: automate this for policies with target network
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -80,7 +80,7 @@ if __name__ == '__main__':
|
|||||||
feed_dict = training_data.next_batch(batch_size)
|
feed_dict = training_data.next_batch(batch_size)
|
||||||
sess.run(train_op, feed_dict=feed_dict)
|
sess.run(train_op, feed_dict=feed_dict)
|
||||||
|
|
||||||
# assigning pi to pi_old
|
# assigning actor to pi_old
|
||||||
pi.update_weights()
|
pi.update_weights()
|
||||||
|
|
||||||
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
||||||
15
internal_keys.md
Normal file
15
internal_keys.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
network.managed_placeholders.keys()
|
||||||
|
|
||||||
|
data_collector.raw_data.keys()
|
||||||
|
|
||||||
|
data_collector.data.keys()
|
||||||
|
|
||||||
|
['observation']
|
||||||
|
|
||||||
|
['action']
|
||||||
|
|
||||||
|
['reward']
|
||||||
|
|
||||||
|
['start_flag']
|
||||||
|
|
||||||
|
['advantage'] > ['return'] # they may appear simultaneously
|
||||||
@ -14,7 +14,7 @@ def ppo_clip(policy, clip_param):
|
|||||||
action_ph = tf.placeholder(policy.act_dtype, shape=(None,) + policy.action_shape, name='ppo_clip_loss/action_placeholder')
|
action_ph = tf.placeholder(policy.act_dtype, shape=(None,) + policy.action_shape, name='ppo_clip_loss/action_placeholder')
|
||||||
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder')
|
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder')
|
||||||
policy.managed_placeholders['action'] = action_ph
|
policy.managed_placeholders['action'] = action_ph
|
||||||
policy.managed_placeholders['processed_reward'] = advantage_ph
|
policy.managed_placeholders['advantage'] = advantage_ph
|
||||||
|
|
||||||
log_pi_act = policy.log_prob(action_ph)
|
log_pi_act = policy.log_prob(action_ph)
|
||||||
log_pi_old_act = policy.log_prob_old(action_ph)
|
log_pi_old_act = policy.log_prob_old(action_ph)
|
||||||
@ -24,7 +24,7 @@ def ppo_clip(policy, clip_param):
|
|||||||
return ppo_clip_loss
|
return ppo_clip_loss
|
||||||
|
|
||||||
|
|
||||||
def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
|
def REINFORCE(policy):
|
||||||
"""
|
"""
|
||||||
vanilla policy gradient
|
vanilla policy gradient
|
||||||
|
|
||||||
@ -34,10 +34,29 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
|
|||||||
:param baseline: the baseline method used to reduce the variance, default is 'None'
|
:param baseline: the baseline method used to reduce the variance, default is 'None'
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
log_pi_act = pi.log_prob(sampled_action)
|
action_ph = tf.placeholder(policy.act_dtype, shape=(None,) + policy.action_shape,
|
||||||
vanilla_policy_gradient_loss = tf.reduce_mean(reward * log_pi_act)
|
name='REINFORCE/action_placeholder')
|
||||||
# TODO: Different baseline methods like REINFORCE, etc.
|
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='REINFORCE/advantage_placeholder')
|
||||||
return vanilla_policy_gradient_loss
|
policy.managed_placeholders['action'] = action_ph
|
||||||
|
policy.managed_placeholders['advantage'] = advantage_ph
|
||||||
|
|
||||||
|
log_pi_act = policy.log_prob(action_ph)
|
||||||
|
REINFORCE_loss = -tf.reduce_mean(advantage_ph * log_pi_act)
|
||||||
|
return REINFORCE_loss
|
||||||
|
|
||||||
|
|
||||||
|
def state_value_mse(state_value_function):
|
||||||
|
"""
|
||||||
|
L2 loss of state value
|
||||||
|
:param state_value_function: instance of StateValue
|
||||||
|
:return: tensor of the mse loss
|
||||||
|
"""
|
||||||
|
state_value_ph = tf.placeholder(tf.float32, shape=(None,), name='state_value_mse/state_value_placeholder')
|
||||||
|
state_value_function.managed_placeholders['return'] = state_value_ph
|
||||||
|
|
||||||
|
state_value = state_value_function.value_tensor
|
||||||
|
return tf.losses.mean_squared_error(state_value_ph, state_value)
|
||||||
|
|
||||||
|
|
||||||
def dqn_loss(sampled_action, sampled_target, policy):
|
def dqn_loss(sampled_action, sampled_target, policy):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -44,7 +44,8 @@ class OnehotCategorical(StochasticPolicy):
|
|||||||
self.weight_update = weight_update
|
self.weight_update = weight_update
|
||||||
self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1.
|
self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1.
|
||||||
|
|
||||||
with tf.variable_scope('network'):
|
# build network, action and value
|
||||||
|
with tf.variable_scope('network', reuse=tf.AUTO_REUSE):
|
||||||
logits, value_head = policy_callable()
|
logits, value_head = policy_callable()
|
||||||
self._logits = tf.convert_to_tensor(logits, dtype=tf.float32)
|
self._logits = tf.convert_to_tensor(logits, dtype=tf.float32)
|
||||||
self._action = tf.multinomial(self._logits, num_samples=1)
|
self._action = tf.multinomial(self._logits, num_samples=1)
|
||||||
@ -55,11 +56,12 @@ class OnehotCategorical(StochasticPolicy):
|
|||||||
|
|
||||||
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')
|
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')
|
||||||
|
|
||||||
|
# deal with target network
|
||||||
if self.weight_update == 1:
|
if self.weight_update == 1:
|
||||||
self.weight_update_ops = None
|
self.weight_update_ops = None
|
||||||
self.sync_weights_ops = None
|
self.sync_weights_ops = None
|
||||||
else: # then we need to build another tf graph as target network
|
else: # then we need to build another tf graph as target network
|
||||||
with tf.variable_scope('net_old'):
|
with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE):
|
||||||
logits, value_head = policy_callable()
|
logits, value_head = policy_callable()
|
||||||
self._logits_old = tf.convert_to_tensor(logits, dtype=tf.float32)
|
self._logits_old = tf.convert_to_tensor(logits, dtype=tf.float32)
|
||||||
|
|
||||||
@ -173,7 +175,8 @@ class Normal(StochasticPolicy):
|
|||||||
self.weight_update = weight_update
|
self.weight_update = weight_update
|
||||||
self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1.
|
self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1.
|
||||||
|
|
||||||
with tf.variable_scope('network'):
|
# build network, action and value
|
||||||
|
with tf.variable_scope('network', reuse=tf.AUTO_REUSE):
|
||||||
mean, logstd, value_head = policy_callable()
|
mean, logstd, value_head = policy_callable()
|
||||||
self._mean = tf.convert_to_tensor(mean, dtype = tf.float32)
|
self._mean = tf.convert_to_tensor(mean, dtype = tf.float32)
|
||||||
self._logstd = tf.convert_to_tensor(logstd, dtype = tf.float32)
|
self._logstd = tf.convert_to_tensor(logstd, dtype = tf.float32)
|
||||||
@ -188,11 +191,12 @@ class Normal(StochasticPolicy):
|
|||||||
|
|
||||||
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')
|
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')
|
||||||
|
|
||||||
|
# deal with target network
|
||||||
if self.weight_update == 1:
|
if self.weight_update == 1:
|
||||||
self.weight_update_ops = None
|
self.weight_update_ops = None
|
||||||
self.sync_weights_ops = None
|
self.sync_weights_ops = None
|
||||||
else: # then we need to build another tf graph as target network
|
else: # then we need to build another tf graph as target network
|
||||||
with tf.variable_scope('net_old'):
|
with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE):
|
||||||
mean, logstd, value_head = policy_callable()
|
mean, logstd, value_head = policy_callable()
|
||||||
self._mean_old = tf.convert_to_tensor(mean, dtype=tf.float32)
|
self._mean_old = tf.convert_to_tensor(mean, dtype=tf.float32)
|
||||||
self._logstd_old = tf.convert_to_tensor(logstd, dtype=tf.float32)
|
self._logstd_old = tf.convert_to_tensor(logstd, dtype=tf.float32)
|
||||||
|
|||||||
@ -8,7 +8,12 @@ class StateValue(ValueFunctionBase):
|
|||||||
"""
|
"""
|
||||||
class of state values V(s).
|
class of state values V(s).
|
||||||
"""
|
"""
|
||||||
def __init__(self, value_tensor, observation_placeholder):
|
def __init__(self, policy_callable, observation_placeholder):
|
||||||
|
self.managed_placeholders = {'observation': observation_placeholder}
|
||||||
|
with tf.variable_scope('network', reuse=tf.AUTO_REUSE):
|
||||||
|
value_tensor = policy_callable()[-1]
|
||||||
|
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
|
|
||||||
super(StateValue, self).__init__(
|
super(StateValue, self).__init__(
|
||||||
value_tensor=value_tensor,
|
value_tensor=value_tensor,
|
||||||
observation_placeholder=observation_placeholder
|
observation_placeholder=observation_placeholder
|
||||||
|
|||||||
@ -6,15 +6,13 @@ def full_return(raw_data):
|
|||||||
naively compute full return
|
naively compute full return
|
||||||
:param raw_data: dict of specified keys and values.
|
:param raw_data: dict of specified keys and values.
|
||||||
"""
|
"""
|
||||||
observations = raw_data['observations']
|
observations = raw_data['observation']
|
||||||
actions = raw_data['actions']
|
actions = raw_data['action']
|
||||||
rewards = raw_data['rewards']
|
rewards = raw_data['reward']
|
||||||
episode_start_flags = raw_data['episode_start_flags']
|
episode_start_flags = raw_data['end_flag']
|
||||||
num_timesteps = rewards.shape[0]
|
num_timesteps = rewards.shape[0]
|
||||||
|
|
||||||
data = {}
|
data = {}
|
||||||
data['observations'] = observations
|
|
||||||
data['actions'] = actions
|
|
||||||
|
|
||||||
returns = rewards.copy()
|
returns = rewards.copy()
|
||||||
episode_start_idx = 0
|
episode_start_idx = 0
|
||||||
@ -33,11 +31,39 @@ def full_return(raw_data):
|
|||||||
|
|
||||||
episode_start_idx = i
|
episode_start_idx = i
|
||||||
|
|
||||||
data['returns'] = returns
|
data['return'] = returns
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class gae_lambda:
|
||||||
|
"""
|
||||||
|
Generalized Advantage Estimation (Schulman, 15) to compute advantage
|
||||||
|
"""
|
||||||
|
def __init__(self, T, value_function):
|
||||||
|
self.T = T
|
||||||
|
self.value_function = value_function
|
||||||
|
|
||||||
|
def __call__(self, raw_data):
|
||||||
|
reward = raw_data['reward']
|
||||||
|
|
||||||
|
return {'advantage': reward}
|
||||||
|
|
||||||
|
|
||||||
|
class nstep_return:
|
||||||
|
"""
|
||||||
|
compute the n-step return from n-step rewards and bootstrapped value function
|
||||||
|
"""
|
||||||
|
def __init__(self, n, value_function):
|
||||||
|
self.n = n
|
||||||
|
self.value_function = value_function
|
||||||
|
|
||||||
|
def __call__(self, raw_data):
|
||||||
|
reward = raw_data['reward']
|
||||||
|
|
||||||
|
return {'return': reward}
|
||||||
|
|
||||||
|
|
||||||
class QLearningTarget:
|
class QLearningTarget:
|
||||||
def __init__(self, policy, gamma):
|
def __init__(self, policy, gamma):
|
||||||
self._policy = policy
|
self._policy = policy
|
||||||
@ -68,3 +94,4 @@ class QLearningTarget:
|
|||||||
data['rewards'] = np.array(rewards)
|
data['rewards'] = np.array(rewards)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import gc
|
import gc
|
||||||
|
import logging
|
||||||
|
from . import utils
|
||||||
|
|
||||||
# TODO: Refactor with tf.train.slice_input_producer, tf.train.Coordinator, tf.train.QueueRunner
|
# TODO: Refactor with tf.train.slice_input_producer, tf.train.Coordinator, tf.train.QueueRunner
|
||||||
class Batch(object):
|
class Batch(object):
|
||||||
@ -8,14 +9,31 @@ class Batch(object):
|
|||||||
class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy.
|
class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env, pi, advantage_estimation_function): # how to name the function?
|
def __init__(self, env, pi, reward_processors, networks): # how to name the function?
|
||||||
|
"""
|
||||||
|
constructor
|
||||||
|
:param env:
|
||||||
|
:param pi:
|
||||||
|
:param reward_processors: list of functions to process reward
|
||||||
|
:param networks: list of networks to be optimized, so as to match data in feed_dict
|
||||||
|
"""
|
||||||
self._env = env
|
self._env = env
|
||||||
self._pi = pi
|
self._pi = pi
|
||||||
self._advantage_estimation_function = advantage_estimation_function
|
self.raw_data = {}
|
||||||
|
self.data = {}
|
||||||
|
|
||||||
|
self.reward_processors = reward_processors
|
||||||
|
self.networks = networks
|
||||||
|
|
||||||
|
self.required_placeholders = {}
|
||||||
|
for net in self.networks:
|
||||||
|
self.required_placeholders.update(net.managed_placeholders)
|
||||||
|
self.require_advantage = 'advantage' in self.required_placeholders.keys()
|
||||||
|
|
||||||
self._is_first_collect = True
|
self._is_first_collect = True
|
||||||
|
|
||||||
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={},
|
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={},
|
||||||
apply_function=True): # specify how many data to collect here, or fix it in __init__()
|
process_reward=True): # specify how many data to collect here, or fix it in __init__()
|
||||||
assert sum(
|
assert sum(
|
||||||
[num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!"
|
[num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!"
|
||||||
|
|
||||||
@ -98,6 +116,7 @@ class Batch(object):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if done: # end of episode, discard s_T
|
if done: # end of episode, discard s_T
|
||||||
|
# TODO: for num_timesteps collection, has to store terminal flag instead of start flag!
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
observations.append(ob)
|
observations.append(ob)
|
||||||
@ -113,33 +132,48 @@ class Batch(object):
|
|||||||
del rewards
|
del rewards
|
||||||
del episode_start_flags
|
del episode_start_flags
|
||||||
|
|
||||||
self.raw_data = {'observations': self.observations, 'actions': self.actions, 'rewards': self.rewards,
|
self.raw_data = {'observation': self.observations, 'action': self.actions, 'reward': self.rewards,
|
||||||
'episode_start_flags': self.episode_start_flags}
|
'end_flag': self.episode_start_flags}
|
||||||
|
|
||||||
self._is_first_collect = False
|
self._is_first_collect = False
|
||||||
|
|
||||||
if apply_function:
|
if process_reward:
|
||||||
self.apply_advantage_estimation_function()
|
self.apply_advantage_estimation_function()
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def apply_advantage_estimation_function(self):
|
def apply_advantage_estimation_function(self):
|
||||||
self.data = self._advantage_estimation_function(self.raw_data)
|
for processor in self.reward_processors:
|
||||||
|
self.data.update(processor(self.raw_data))
|
||||||
|
|
||||||
def next_batch(self, batch_size, standardize_advantage=True): # YouQiaoben: referencing other iterate over batches
|
def next_batch(self, batch_size, standardize_advantage=True):
|
||||||
rand_idx = np.random.choice(self.data['observations'].shape[0], batch_size)
|
rand_idx = np.random.choice(self.raw_data['observation'].shape[0], batch_size)
|
||||||
current_batch = {key: value[rand_idx] for key, value in self.data.items()}
|
|
||||||
|
|
||||||
if standardize_advantage:
|
|
||||||
advantage_mean = np.mean(current_batch['returns'])
|
|
||||||
advantage_std = np.std(current_batch['returns'])
|
|
||||||
current_batch['returns'] = (current_batch['returns'] - advantage_mean) / advantage_std
|
|
||||||
|
|
||||||
feed_dict = {}
|
feed_dict = {}
|
||||||
feed_dict[self._pi.managed_placeholders['observation']] = current_batch['observations']
|
for key, placeholder in self.required_placeholders.items():
|
||||||
feed_dict[self._pi.managed_placeholders['action']] = current_batch['actions']
|
found, data_key = utils.internal_key_match(key, self.raw_data.keys())
|
||||||
feed_dict[self._pi.managed_placeholders['processed_reward']] = current_batch['returns']
|
if found:
|
||||||
# TODO: should use the keys in pi.managed_placeholders to find values in self.data and self.raw_data
|
feed_dict[placeholder] = self.raw_data[data_key][rand_idx]
|
||||||
|
else:
|
||||||
|
found, data_key = utils.internal_key_match(key, self.data.keys())
|
||||||
|
if found:
|
||||||
|
feed_dict[placeholder] = self.data[data_key][rand_idx]
|
||||||
|
|
||||||
|
if not found:
|
||||||
|
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
|
||||||
|
|
||||||
|
if standardize_advantage:
|
||||||
|
if self.require_advantage:
|
||||||
|
advantage_value = feed_dict[self.required_placeholders['advantage']]
|
||||||
|
advantage_mean = np.mean(advantage_value)
|
||||||
|
advantage_std = np.std(advantage_value)
|
||||||
|
if advantage_std < 1e-3:
|
||||||
|
logging.warning('advantage_std too small (< 1e-3) for advantage standardization. may cause numerical issues')
|
||||||
|
feed_dict[self.required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std
|
||||||
|
|
||||||
|
# TODO: maybe move all advantage estimation functions to tf, as in tensorforce (though haven't
|
||||||
|
# understood tensorforce after reading) maybe tf.stop_gradient for targets/advantages
|
||||||
|
# this will simplify data collector as it only needs to collect raw data, (s, a, r, done) only
|
||||||
|
|
||||||
return feed_dict
|
return feed_dict
|
||||||
|
|
||||||
@ -149,8 +183,8 @@ class Batch(object):
|
|||||||
compute the statistics of the current sampled paths
|
compute the statistics of the current sampled paths
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
rewards = self.raw_data['rewards']
|
rewards = self.raw_data['reward']
|
||||||
episode_start_flags = self.raw_data['episode_start_flags']
|
episode_start_flags = self.raw_data['end_flag']
|
||||||
num_timesteps = rewards.shape[0]
|
num_timesteps = rewards.shape[0]
|
||||||
|
|
||||||
returns = []
|
returns = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user