finish ddpg example. all examples under examples/ (except those containing 'contrib' and 'fail') can run! advantage estimation module is not complete yet.
This commit is contained in:
parent
8fbde8283f
commit
f32e1d9c12
@ -49,10 +49,10 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
### 2. build policy, critic, loss, optimizer
|
### 2. build policy, critic, loss, optimizer
|
||||||
actor = policy.OnehotCategorical(my_network, observation_placeholder=observation_ph, weight_update=1)
|
actor = policy.OnehotCategorical(my_network, observation_placeholder=observation_ph, weight_update=1)
|
||||||
critic = value_function.StateValue(my_network, observation_placeholder=observation_ph)
|
critic = value_function.StateValue(my_network, observation_placeholder=observation_ph) # no target network
|
||||||
|
|
||||||
actor_loss = losses.REINFORCE(actor)
|
actor_loss = losses.REINFORCE(actor)
|
||||||
critic_loss = losses.state_value_mse(critic)
|
critic_loss = losses.value_mse(critic)
|
||||||
total_loss = actor_loss + critic_loss
|
total_loss = actor_loss + critic_loss
|
||||||
|
|
||||||
optimizer = tf.train.AdamOptimizer(1e-4)
|
optimizer = tf.train.AdamOptimizer(1e-4)
|
||||||
|
@ -57,7 +57,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
|
|
||||||
actor_loss = losses.vanilla_policy_gradient(actor)
|
actor_loss = losses.vanilla_policy_gradient(actor)
|
||||||
critic_loss = losses.state_value_mse(critic)
|
critic_loss = losses.value_mse(critic)
|
||||||
total_loss = actor_loss + critic_loss
|
total_loss = actor_loss + critic_loss
|
||||||
|
|
||||||
optimizer = tf.train.AdamOptimizer(1e-4)
|
optimizer = tf.train.AdamOptimizer(1e-4)
|
||||||
|
@ -51,7 +51,7 @@ if __name__ == '__main__':
|
|||||||
critic = value_function.StateValue(my_network, observation_placeholder=observation_ph)
|
critic = value_function.StateValue(my_network, observation_placeholder=observation_ph)
|
||||||
|
|
||||||
actor_loss = losses.REINFORCE(actor)
|
actor_loss = losses.REINFORCE(actor)
|
||||||
critic_loss = losses.state_value_mse(critic)
|
critic_loss = losses.value_mse(critic)
|
||||||
|
|
||||||
actor_optimizer = tf.train.AdamOptimizer(1e-4)
|
actor_optimizer = tf.train.AdamOptimizer(1e-4)
|
||||||
actor_train_op = actor_optimizer.minimize(actor_loss, var_list=actor.trainable_variables)
|
actor_train_op = actor_optimizer.minimize(actor_loss, var_list=actor.trainable_variables)
|
||||||
|
89
examples/ddpg_example.py
Normal file
89
examples/ddpg_example.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
# our lib imports here! It's ok to append path in examples
|
||||||
|
import sys
|
||||||
|
sys.path.append('..')
|
||||||
|
from tianshou.core import losses
|
||||||
|
from tianshou.data.batch import Batch
|
||||||
|
import tianshou.data.advantage_estimation as advantage_estimation
|
||||||
|
import tianshou.core.policy as policy
|
||||||
|
import tianshou.core.value_function.action_value as value_function
|
||||||
|
import tianshou.core.opt as opt
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
env = gym.make('Pendulum-v0')
|
||||||
|
observation_dim = env.observation_space.shape
|
||||||
|
action_dim = env.action_space.shape
|
||||||
|
|
||||||
|
clip_param = 0.2
|
||||||
|
num_batches = 10
|
||||||
|
batch_size = 512
|
||||||
|
|
||||||
|
seed = 0
|
||||||
|
np.random.seed(seed)
|
||||||
|
tf.set_random_seed(seed)
|
||||||
|
|
||||||
|
### 1. build network with pure tf
|
||||||
|
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
|
||||||
|
action_ph = tf.placeholder(tf.float32, shape=(None,) + action_dim)
|
||||||
|
|
||||||
|
def my_network():
|
||||||
|
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.relu)
|
||||||
|
net = tf.layers.dense(net, 32, activation=tf.nn.relu)
|
||||||
|
action = tf.layers.dense(net, action_dim[0], activation=None)
|
||||||
|
|
||||||
|
action_value_input = tf.concat([observation_ph, action_ph], axis=1)
|
||||||
|
net = tf.layers.dense(action_value_input, 32, activation=tf.nn.relu)
|
||||||
|
net = tf.layers.dense(net, 32, activation=tf.nn.relu)
|
||||||
|
action_value = tf.layers.dense(net, 1, activation=None)
|
||||||
|
|
||||||
|
return action, action_value
|
||||||
|
|
||||||
|
### 2. build policy, loss, optimizer
|
||||||
|
actor = policy.Deterministic(my_network, observation_placeholder=observation_ph, weight_update=1e-3)
|
||||||
|
critic = value_function.ActionValue(my_network, observation_placeholder=observation_ph,
|
||||||
|
action_placeholder=action_ph, weight_update=1e-3)
|
||||||
|
|
||||||
|
critic_loss = losses.value_mse(critic)
|
||||||
|
critic_optimizer = tf.train.AdamOptimizer(1e-3)
|
||||||
|
critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables)
|
||||||
|
|
||||||
|
dpg_grads = opt.DPG(actor, critic) # not sure if it's correct
|
||||||
|
actor_optimizer = tf.train.AdamOptimizer(1e-4)
|
||||||
|
actor_train_op = actor_optimizer.apply_gradients(dpg_grads)
|
||||||
|
|
||||||
|
### 3. define data collection
|
||||||
|
data_collector = Batch(env, actor, [advantage_estimation.ddpg_return(actor, critic)], [actor, critic])
|
||||||
|
|
||||||
|
### 4. start training
|
||||||
|
config = tf.ConfigProto()
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
with tf.Session(config=config) as sess:
|
||||||
|
sess.run(tf.global_variables_initializer())
|
||||||
|
|
||||||
|
# assign actor to pi_old
|
||||||
|
actor.sync_weights() # TODO: automate this for policies with target network
|
||||||
|
critic.sync_weights()
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
for i in range(100):
|
||||||
|
# collect data
|
||||||
|
data_collector.collect(num_episodes=50)
|
||||||
|
|
||||||
|
# print current return
|
||||||
|
print('Epoch {}:'.format(i))
|
||||||
|
data_collector.statistics()
|
||||||
|
|
||||||
|
# update network
|
||||||
|
for _ in range(num_batches):
|
||||||
|
feed_dict = data_collector.next_batch(batch_size)
|
||||||
|
sess.run([actor_train_op, critic_train_op], feed_dict=feed_dict)
|
||||||
|
|
||||||
|
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
@ -16,6 +16,9 @@ import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so tha
|
|||||||
import tianshou.core.value_function.action_value as value_function
|
import tianshou.core.value_function.action_value as value_function
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: why this solves cartpole even without training?
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
env = gym.make('CartPole-v0')
|
env = gym.make('CartPole-v0')
|
||||||
observation_dim = env.observation_space.shape
|
observation_dim = env.observation_space.shape
|
||||||
|
@ -45,13 +45,13 @@ def REINFORCE(policy):
|
|||||||
return REINFORCE_loss
|
return REINFORCE_loss
|
||||||
|
|
||||||
|
|
||||||
def state_value_mse(state_value_function):
|
def value_mse(state_value_function):
|
||||||
"""
|
"""
|
||||||
L2 loss of state value
|
L2 loss of state value
|
||||||
:param state_value_function: instance of StateValue
|
:param state_value_function: instance of StateValue
|
||||||
:return: tensor of the mse loss
|
:return: tensor of the mse loss
|
||||||
"""
|
"""
|
||||||
target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='state_value_mse/state_value_placeholder')
|
target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='value_mse/return_placeholder')
|
||||||
state_value_function.managed_placeholders['return'] = target_value_ph
|
state_value_function.managed_placeholders['return'] = target_value_ph
|
||||||
|
|
||||||
state_value = state_value_function.value_tensor
|
state_value = state_value_function.value_tensor
|
||||||
|
21
tianshou/core/opt.py
Normal file
21
tianshou/core/opt.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def DPG(policy, action_value):
|
||||||
|
"""
|
||||||
|
construct the gradient tensor of deterministic policy gradient
|
||||||
|
:param policy:
|
||||||
|
:param action_value:
|
||||||
|
:return: list of (gradient, variable) pairs
|
||||||
|
"""
|
||||||
|
trainable_variables = policy.trainable_variables
|
||||||
|
critic_action_input = action_value._action_placeholder
|
||||||
|
critic_value_loss = -tf.reduce_mean(action_value.value_tensor)
|
||||||
|
policy_action_output = policy.action
|
||||||
|
|
||||||
|
grad_ys = tf.gradients(critic_value_loss, critic_action_input)
|
||||||
|
grad_policy_vars = tf.gradients(policy_action_output, trainable_variables, grad_ys=grad_ys)
|
||||||
|
|
||||||
|
grads_and_vars = zip(grad_policy_vars, trainable_variables)
|
||||||
|
|
||||||
|
return grads_and_vars
|
@ -0,0 +1,3 @@
|
|||||||
|
from .deterministic import *
|
||||||
|
from .dqn import *
|
||||||
|
from .stochastic import *
|
111
tianshou/core/policy/deterministic.py
Normal file
111
tianshou/core/policy/deterministic.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
from .base import PolicyBase
|
||||||
|
|
||||||
|
class Deterministic(PolicyBase):
|
||||||
|
"""
|
||||||
|
deterministic policy as used in deterministic policy gradient methods
|
||||||
|
"""
|
||||||
|
def __init__(self, policy_callable, observation_placeholder, weight_update=1):
|
||||||
|
self._observation_placeholder = observation_placeholder
|
||||||
|
self.managed_placeholders = {'observation': observation_placeholder}
|
||||||
|
self.weight_update = weight_update
|
||||||
|
self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1.
|
||||||
|
|
||||||
|
# build network, action and value
|
||||||
|
with tf.variable_scope('network', reuse=tf.AUTO_REUSE):
|
||||||
|
action, _ = policy_callable()
|
||||||
|
self.action = action
|
||||||
|
# TODO: self._action should be exactly the action tensor to run that directly gives action_dim
|
||||||
|
|
||||||
|
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')
|
||||||
|
|
||||||
|
# deal with target network
|
||||||
|
if self.weight_update == 1:
|
||||||
|
self.weight_update_ops = None
|
||||||
|
self.sync_weights_ops = None
|
||||||
|
else: # then we need to build another tf graph as target network
|
||||||
|
with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE):
|
||||||
|
action, _ = policy_callable()
|
||||||
|
self.action_old = action
|
||||||
|
|
||||||
|
network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network')
|
||||||
|
network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old')
|
||||||
|
# TODO: use a scope that the user will almost surely not use. so get_collection will return
|
||||||
|
# the correct weights and old_weights, since it filters by regular expression
|
||||||
|
# or we write a util to parse the variable names and use only the topmost scope
|
||||||
|
|
||||||
|
assert len(network_weights) == len(network_old_weights)
|
||||||
|
self.sync_weights_ops = [tf.assign(variable_old, variable)
|
||||||
|
for (variable_old, variable) in zip(network_old_weights, network_weights)]
|
||||||
|
|
||||||
|
if weight_update == 0:
|
||||||
|
self.weight_update_ops = self.sync_weights_ops
|
||||||
|
elif 0 < weight_update < 1: # as in DDPG
|
||||||
|
self.weight_update_ops = [tf.assign(variable_old,
|
||||||
|
weight_update * variable + (1 - weight_update) * variable_old)
|
||||||
|
for (variable_old, variable) in zip(network_old_weights, network_weights)]
|
||||||
|
else:
|
||||||
|
self.interaction_count = 0 # as in DQN
|
||||||
|
import math
|
||||||
|
self.weight_update = math.ceil(weight_update)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_shape(self):
|
||||||
|
return self.action.shape.as_list()[1:]
|
||||||
|
|
||||||
|
def act(self, observation, my_feed_dict={}):
|
||||||
|
# TODO: this may be ugly. also maybe huge problem when parallel
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
# observation[None] adds one dimension at the beginning
|
||||||
|
|
||||||
|
feed_dict = {self._observation_placeholder: observation[None]}
|
||||||
|
feed_dict.update(my_feed_dict)
|
||||||
|
sampled_action = sess.run(self.action, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
sampled_action = sampled_action[0]
|
||||||
|
|
||||||
|
return sampled_action
|
||||||
|
|
||||||
|
def update_weights(self):
|
||||||
|
"""
|
||||||
|
updates the weights of policy_old.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.weight_update_ops is not None:
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
sess.run(self.weight_update_ops)
|
||||||
|
|
||||||
|
def sync_weights(self):
|
||||||
|
"""
|
||||||
|
sync the weights of network_old. Direct copy the weights of network.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.sync_weights_ops is not None:
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
sess.run(self.sync_weights_ops)
|
||||||
|
|
||||||
|
def eval_action(self, observation):
|
||||||
|
"""
|
||||||
|
evaluate action in minibatch
|
||||||
|
:param observation:
|
||||||
|
:return: 2-D numpy array
|
||||||
|
"""
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
|
||||||
|
feed_dict = {self._observation_placeholder: observation}
|
||||||
|
action = sess.run(self.action, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
def eval_action_old(self, observation):
|
||||||
|
"""
|
||||||
|
evaluate action in minibatch
|
||||||
|
:param observation:
|
||||||
|
:return: 2-D numpy array
|
||||||
|
"""
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
|
||||||
|
feed_dict = {self._observation_placeholder: observation}
|
||||||
|
action = sess.run(self.action_old, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
return action
|
@ -8,12 +8,47 @@ class ActionValue(ValueFunctionBase):
|
|||||||
"""
|
"""
|
||||||
class of action values Q(s, a).
|
class of action values Q(s, a).
|
||||||
"""
|
"""
|
||||||
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
|
def __init__(self, network_callable, observation_placeholder, action_placeholder, weight_update=1):
|
||||||
|
self._observation_placeholder = observation_placeholder
|
||||||
self._action_placeholder = action_placeholder
|
self._action_placeholder = action_placeholder
|
||||||
super(ActionValue, self).__init__(
|
self.managed_placeholders = {'observation': observation_placeholder, 'action': action_placeholder}
|
||||||
value_tensor=value_tensor,
|
self.weight_update = weight_update
|
||||||
observation_placeholder=observation_placeholder
|
self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1.
|
||||||
)
|
|
||||||
|
with tf.variable_scope('network', reuse=tf.AUTO_REUSE):
|
||||||
|
value_tensor = network_callable()[-1]
|
||||||
|
|
||||||
|
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')
|
||||||
|
|
||||||
|
super(ActionValue, self).__init__(value_tensor, observation_placeholder=observation_placeholder)
|
||||||
|
|
||||||
|
# deal with target network
|
||||||
|
if self.weight_update == 1:
|
||||||
|
self.weight_update_ops = None
|
||||||
|
self.sync_weights_ops = None
|
||||||
|
else: # then we need to build another tf graph as target network
|
||||||
|
with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE):
|
||||||
|
value_tensor = network_callable()[-1]
|
||||||
|
self.value_tensor_old = tf.squeeze(value_tensor)
|
||||||
|
|
||||||
|
network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network')
|
||||||
|
network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old')
|
||||||
|
|
||||||
|
assert len(network_weights) == len(network_old_weights)
|
||||||
|
self.sync_weights_ops = [tf.assign(variable_old, variable)
|
||||||
|
for (variable_old, variable) in zip(network_old_weights, network_weights)]
|
||||||
|
|
||||||
|
if weight_update == 0:
|
||||||
|
self.weight_update_ops = self.sync_weights_ops
|
||||||
|
elif 0 < weight_update < 1: # useful in DDPG
|
||||||
|
self.weight_update_ops = [tf.assign(variable_old,
|
||||||
|
weight_update * variable + (1 - weight_update) * variable_old)
|
||||||
|
for (variable_old, variable) in zip(network_old_weights, network_weights)]
|
||||||
|
else:
|
||||||
|
self.interaction_count = 0
|
||||||
|
import math
|
||||||
|
self.weight_update = math.ceil(weight_update)
|
||||||
|
self.weight_update_ops = self.sync_weights_ops
|
||||||
|
|
||||||
def eval_value(self, observation, action):
|
def eval_value(self, observation, action):
|
||||||
"""
|
"""
|
||||||
@ -27,6 +62,35 @@ class ActionValue(ValueFunctionBase):
|
|||||||
return sess.run(self.value_tensor, feed_dict=
|
return sess.run(self.value_tensor, feed_dict=
|
||||||
{self._observation_placeholder: observation, self._action_placeholder: action})
|
{self._observation_placeholder: observation, self._action_placeholder: action})
|
||||||
|
|
||||||
|
def eval_value_old(self, observation, action):
|
||||||
|
"""
|
||||||
|
eval value using target network
|
||||||
|
:param observation: numpy array of obs
|
||||||
|
:param action: numpy array of action
|
||||||
|
:return: numpy array of action value
|
||||||
|
"""
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
feed_dict = {self._observation_placeholder: observation, self._action_placeholder: action}
|
||||||
|
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
def sync_weights(self):
|
||||||
|
"""
|
||||||
|
sync the weights of network_old. Direct copy the weights of network.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.sync_weights_ops is not None:
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
sess.run(self.sync_weights_ops)
|
||||||
|
|
||||||
|
def update_weights(self):
|
||||||
|
"""
|
||||||
|
updates the weights of policy_old.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.weight_update_ops is not None:
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
sess.run(self.weight_update_ops)
|
||||||
|
|
||||||
|
|
||||||
class DQN(ValueFunctionBase):
|
class DQN(ValueFunctionBase):
|
||||||
"""
|
"""
|
||||||
|
@ -76,6 +76,31 @@ class nstep_return:
|
|||||||
return {'return': return_}
|
return {'return': return_}
|
||||||
|
|
||||||
|
|
||||||
|
class ddpg_return:
|
||||||
|
"""
|
||||||
|
compute the return as in DDPG. this seems to have to be special
|
||||||
|
"""
|
||||||
|
def __init__(self, actor, critic, use_target_network=True):
|
||||||
|
self.actor = actor
|
||||||
|
self.critic = critic
|
||||||
|
self.use_target_network = use_target_network
|
||||||
|
|
||||||
|
def __call__(self, raw_data):
|
||||||
|
observation = raw_data['observation']
|
||||||
|
reward = raw_data['reward']
|
||||||
|
|
||||||
|
if self.use_target_network:
|
||||||
|
action_target = self.actor.eval_action_old(observation)
|
||||||
|
value_target = self.critic.eval_value_old(observation, action_target)
|
||||||
|
else:
|
||||||
|
action_target = self.actor.eval_action(observation)
|
||||||
|
value_target = self.critic.eval_value(observation, action_target)
|
||||||
|
|
||||||
|
return_ = reward + value_target
|
||||||
|
|
||||||
|
return {'return': return_}
|
||||||
|
|
||||||
|
|
||||||
class nstep_q_return:
|
class nstep_q_return:
|
||||||
"""
|
"""
|
||||||
compute the n-step return for Q-learning targets
|
compute the n-step return for Q-learning targets
|
||||||
|
Loading…
x
Reference in New Issue
Block a user