From f32e1d9c12968a5d31c38404d323ffaaf7508611 Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Thu, 18 Jan 2018 17:38:52 +0800 Subject: [PATCH] finish ddpg example. all examples under examples/ (except those containing 'contrib' and 'fail') can run! advantage estimation module is not complete yet. --- examples/actor_critic_cartpole.py | 4 +- examples/actor_critic_fail_cartpole.py | 2 +- examples/actor_critic_separate_cartpole.py | 2 +- ...b_dqn_example.py => contrib_dqn_replay.py} | 0 examples/ddpg_example.py | 89 ++++++++++++++ examples/dqn_example.py | 3 + tianshou/core/losses.py | 4 +- tianshou/core/opt.py | 21 ++++ tianshou/core/policy/__init__.py | 3 + tianshou/core/policy/deterministic.py | 111 ++++++++++++++++++ tianshou/core/value_function/action_value.py | 74 +++++++++++- tianshou/data/advantage_estimation.py | 25 ++++ 12 files changed, 327 insertions(+), 11 deletions(-) rename examples/{contrib_dqn_example.py => contrib_dqn_replay.py} (100%) create mode 100644 examples/ddpg_example.py create mode 100644 tianshou/core/opt.py create mode 100644 tianshou/core/policy/deterministic.py diff --git a/examples/actor_critic_cartpole.py b/examples/actor_critic_cartpole.py index a5ddabf..4aa1020 100755 --- a/examples/actor_critic_cartpole.py +++ b/examples/actor_critic_cartpole.py @@ -49,10 +49,10 @@ if __name__ == '__main__': ### 2. build policy, critic, loss, optimizer actor = policy.OnehotCategorical(my_network, observation_placeholder=observation_ph, weight_update=1) - critic = value_function.StateValue(my_network, observation_placeholder=observation_ph) + critic = value_function.StateValue(my_network, observation_placeholder=observation_ph) # no target network actor_loss = losses.REINFORCE(actor) - critic_loss = losses.state_value_mse(critic) + critic_loss = losses.value_mse(critic) total_loss = actor_loss + critic_loss optimizer = tf.train.AdamOptimizer(1e-4) diff --git a/examples/actor_critic_fail_cartpole.py b/examples/actor_critic_fail_cartpole.py index 5cf422c..f708a85 100755 --- a/examples/actor_critic_fail_cartpole.py +++ b/examples/actor_critic_fail_cartpole.py @@ -57,7 +57,7 @@ if __name__ == '__main__': actor_loss = losses.vanilla_policy_gradient(actor) - critic_loss = losses.state_value_mse(critic) + critic_loss = losses.value_mse(critic) total_loss = actor_loss + critic_loss optimizer = tf.train.AdamOptimizer(1e-4) diff --git a/examples/actor_critic_separate_cartpole.py b/examples/actor_critic_separate_cartpole.py index 08cf914..e4795b0 100755 --- a/examples/actor_critic_separate_cartpole.py +++ b/examples/actor_critic_separate_cartpole.py @@ -51,7 +51,7 @@ if __name__ == '__main__': critic = value_function.StateValue(my_network, observation_placeholder=observation_ph) actor_loss = losses.REINFORCE(actor) - critic_loss = losses.state_value_mse(critic) + critic_loss = losses.value_mse(critic) actor_optimizer = tf.train.AdamOptimizer(1e-4) actor_train_op = actor_optimizer.minimize(actor_loss, var_list=actor.trainable_variables) diff --git a/examples/contrib_dqn_example.py b/examples/contrib_dqn_replay.py similarity index 100% rename from examples/contrib_dqn_example.py rename to examples/contrib_dqn_replay.py diff --git a/examples/ddpg_example.py b/examples/ddpg_example.py new file mode 100644 index 0000000..297c9a1 --- /dev/null +++ b/examples/ddpg_example.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +from __future__ import absolute_import + +import tensorflow as tf +import gym +import numpy as np +import time + +# our lib imports here! It's ok to append path in examples +import sys +sys.path.append('..') +from tianshou.core import losses +from tianshou.data.batch import Batch +import tianshou.data.advantage_estimation as advantage_estimation +import tianshou.core.policy as policy +import tianshou.core.value_function.action_value as value_function +import tianshou.core.opt as opt + + +if __name__ == '__main__': + env = gym.make('Pendulum-v0') + observation_dim = env.observation_space.shape + action_dim = env.action_space.shape + + clip_param = 0.2 + num_batches = 10 + batch_size = 512 + + seed = 0 + np.random.seed(seed) + tf.set_random_seed(seed) + + ### 1. build network with pure tf + observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) + action_ph = tf.placeholder(tf.float32, shape=(None,) + action_dim) + + def my_network(): + net = tf.layers.dense(observation_ph, 32, activation=tf.nn.relu) + net = tf.layers.dense(net, 32, activation=tf.nn.relu) + action = tf.layers.dense(net, action_dim[0], activation=None) + + action_value_input = tf.concat([observation_ph, action_ph], axis=1) + net = tf.layers.dense(action_value_input, 32, activation=tf.nn.relu) + net = tf.layers.dense(net, 32, activation=tf.nn.relu) + action_value = tf.layers.dense(net, 1, activation=None) + + return action, action_value + + ### 2. build policy, loss, optimizer + actor = policy.Deterministic(my_network, observation_placeholder=observation_ph, weight_update=1e-3) + critic = value_function.ActionValue(my_network, observation_placeholder=observation_ph, + action_placeholder=action_ph, weight_update=1e-3) + + critic_loss = losses.value_mse(critic) + critic_optimizer = tf.train.AdamOptimizer(1e-3) + critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables) + + dpg_grads = opt.DPG(actor, critic) # not sure if it's correct + actor_optimizer = tf.train.AdamOptimizer(1e-4) + actor_train_op = actor_optimizer.apply_gradients(dpg_grads) + + ### 3. define data collection + data_collector = Batch(env, actor, [advantage_estimation.ddpg_return(actor, critic)], [actor, critic]) + + ### 4. start training + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + sess.run(tf.global_variables_initializer()) + + # assign actor to pi_old + actor.sync_weights() # TODO: automate this for policies with target network + critic.sync_weights() + + start_time = time.time() + for i in range(100): + # collect data + data_collector.collect(num_episodes=50) + + # print current return + print('Epoch {}:'.format(i)) + data_collector.statistics() + + # update network + for _ in range(num_batches): + feed_dict = data_collector.next_batch(batch_size) + sess.run([actor_train_op, critic_train_op], feed_dict=feed_dict) + + print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/examples/dqn_example.py b/examples/dqn_example.py index 56e84bf..6998373 100644 --- a/examples/dqn_example.py +++ b/examples/dqn_example.py @@ -16,6 +16,9 @@ import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so tha import tianshou.core.value_function.action_value as value_function +# TODO: why this solves cartpole even without training? + + if __name__ == '__main__': env = gym.make('CartPole-v0') observation_dim = env.observation_space.shape diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index f1e77c7..396054a 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -45,13 +45,13 @@ def REINFORCE(policy): return REINFORCE_loss -def state_value_mse(state_value_function): +def value_mse(state_value_function): """ L2 loss of state value :param state_value_function: instance of StateValue :return: tensor of the mse loss """ - target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='state_value_mse/state_value_placeholder') + target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='value_mse/return_placeholder') state_value_function.managed_placeholders['return'] = target_value_ph state_value = state_value_function.value_tensor diff --git a/tianshou/core/opt.py b/tianshou/core/opt.py new file mode 100644 index 0000000..59e4608 --- /dev/null +++ b/tianshou/core/opt.py @@ -0,0 +1,21 @@ +import tensorflow as tf + + +def DPG(policy, action_value): + """ + construct the gradient tensor of deterministic policy gradient + :param policy: + :param action_value: + :return: list of (gradient, variable) pairs + """ + trainable_variables = policy.trainable_variables + critic_action_input = action_value._action_placeholder + critic_value_loss = -tf.reduce_mean(action_value.value_tensor) + policy_action_output = policy.action + + grad_ys = tf.gradients(critic_value_loss, critic_action_input) + grad_policy_vars = tf.gradients(policy_action_output, trainable_variables, grad_ys=grad_ys) + + grads_and_vars = zip(grad_policy_vars, trainable_variables) + + return grads_and_vars \ No newline at end of file diff --git a/tianshou/core/policy/__init__.py b/tianshou/core/policy/__init__.py index e69de29..0b45efe 100644 --- a/tianshou/core/policy/__init__.py +++ b/tianshou/core/policy/__init__.py @@ -0,0 +1,3 @@ +from .deterministic import * +from .dqn import * +from .stochastic import * \ No newline at end of file diff --git a/tianshou/core/policy/deterministic.py b/tianshou/core/policy/deterministic.py new file mode 100644 index 0000000..7cd7a1c --- /dev/null +++ b/tianshou/core/policy/deterministic.py @@ -0,0 +1,111 @@ +import tensorflow as tf +from .base import PolicyBase + +class Deterministic(PolicyBase): + """ + deterministic policy as used in deterministic policy gradient methods + """ + def __init__(self, policy_callable, observation_placeholder, weight_update=1): + self._observation_placeholder = observation_placeholder + self.managed_placeholders = {'observation': observation_placeholder} + self.weight_update = weight_update + self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1. + + # build network, action and value + with tf.variable_scope('network', reuse=tf.AUTO_REUSE): + action, _ = policy_callable() + self.action = action + # TODO: self._action should be exactly the action tensor to run that directly gives action_dim + + self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network') + + # deal with target network + if self.weight_update == 1: + self.weight_update_ops = None + self.sync_weights_ops = None + else: # then we need to build another tf graph as target network + with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE): + action, _ = policy_callable() + self.action_old = action + + network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network') + network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old') + # TODO: use a scope that the user will almost surely not use. so get_collection will return + # the correct weights and old_weights, since it filters by regular expression + # or we write a util to parse the variable names and use only the topmost scope + + assert len(network_weights) == len(network_old_weights) + self.sync_weights_ops = [tf.assign(variable_old, variable) + for (variable_old, variable) in zip(network_old_weights, network_weights)] + + if weight_update == 0: + self.weight_update_ops = self.sync_weights_ops + elif 0 < weight_update < 1: # as in DDPG + self.weight_update_ops = [tf.assign(variable_old, + weight_update * variable + (1 - weight_update) * variable_old) + for (variable_old, variable) in zip(network_old_weights, network_weights)] + else: + self.interaction_count = 0 # as in DQN + import math + self.weight_update = math.ceil(weight_update) + + @property + def action_shape(self): + return self.action.shape.as_list()[1:] + + def act(self, observation, my_feed_dict={}): + # TODO: this may be ugly. also maybe huge problem when parallel + sess = tf.get_default_session() + # observation[None] adds one dimension at the beginning + + feed_dict = {self._observation_placeholder: observation[None]} + feed_dict.update(my_feed_dict) + sampled_action = sess.run(self.action, feed_dict=feed_dict) + + sampled_action = sampled_action[0] + + return sampled_action + + def update_weights(self): + """ + updates the weights of policy_old. + :return: + """ + if self.weight_update_ops is not None: + sess = tf.get_default_session() + sess.run(self.weight_update_ops) + + def sync_weights(self): + """ + sync the weights of network_old. Direct copy the weights of network. + :return: + """ + if self.sync_weights_ops is not None: + sess = tf.get_default_session() + sess.run(self.sync_weights_ops) + + def eval_action(self, observation): + """ + evaluate action in minibatch + :param observation: + :return: 2-D numpy array + """ + sess = tf.get_default_session() + + feed_dict = {self._observation_placeholder: observation} + action = sess.run(self.action, feed_dict=feed_dict) + + return action + + def eval_action_old(self, observation): + """ + evaluate action in minibatch + :param observation: + :return: 2-D numpy array + """ + sess = tf.get_default_session() + + feed_dict = {self._observation_placeholder: observation} + action = sess.run(self.action_old, feed_dict=feed_dict) + + return action \ No newline at end of file diff --git a/tianshou/core/value_function/action_value.py b/tianshou/core/value_function/action_value.py index e6145ec..2773687 100644 --- a/tianshou/core/value_function/action_value.py +++ b/tianshou/core/value_function/action_value.py @@ -8,12 +8,47 @@ class ActionValue(ValueFunctionBase): """ class of action values Q(s, a). """ - def __init__(self, value_tensor, observation_placeholder, action_placeholder): + def __init__(self, network_callable, observation_placeholder, action_placeholder, weight_update=1): + self._observation_placeholder = observation_placeholder self._action_placeholder = action_placeholder - super(ActionValue, self).__init__( - value_tensor=value_tensor, - observation_placeholder=observation_placeholder - ) + self.managed_placeholders = {'observation': observation_placeholder, 'action': action_placeholder} + self.weight_update = weight_update + self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1. + + with tf.variable_scope('network', reuse=tf.AUTO_REUSE): + value_tensor = network_callable()[-1] + + self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network') + + super(ActionValue, self).__init__(value_tensor, observation_placeholder=observation_placeholder) + + # deal with target network + if self.weight_update == 1: + self.weight_update_ops = None + self.sync_weights_ops = None + else: # then we need to build another tf graph as target network + with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE): + value_tensor = network_callable()[-1] + self.value_tensor_old = tf.squeeze(value_tensor) + + network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network') + network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old') + + assert len(network_weights) == len(network_old_weights) + self.sync_weights_ops = [tf.assign(variable_old, variable) + for (variable_old, variable) in zip(network_old_weights, network_weights)] + + if weight_update == 0: + self.weight_update_ops = self.sync_weights_ops + elif 0 < weight_update < 1: # useful in DDPG + self.weight_update_ops = [tf.assign(variable_old, + weight_update * variable + (1 - weight_update) * variable_old) + for (variable_old, variable) in zip(network_old_weights, network_weights)] + else: + self.interaction_count = 0 + import math + self.weight_update = math.ceil(weight_update) + self.weight_update_ops = self.sync_weights_ops def eval_value(self, observation, action): """ @@ -27,6 +62,35 @@ class ActionValue(ValueFunctionBase): return sess.run(self.value_tensor, feed_dict= {self._observation_placeholder: observation, self._action_placeholder: action}) + def eval_value_old(self, observation, action): + """ + eval value using target network + :param observation: numpy array of obs + :param action: numpy array of action + :return: numpy array of action value + """ + sess = tf.get_default_session() + feed_dict = {self._observation_placeholder: observation, self._action_placeholder: action} + return sess.run(self.value_tensor_old, feed_dict=feed_dict) + + def sync_weights(self): + """ + sync the weights of network_old. Direct copy the weights of network. + :return: + """ + if self.sync_weights_ops is not None: + sess = tf.get_default_session() + sess.run(self.sync_weights_ops) + + def update_weights(self): + """ + updates the weights of policy_old. + :return: + """ + if self.weight_update_ops is not None: + sess = tf.get_default_session() + sess.run(self.weight_update_ops) + class DQN(ValueFunctionBase): """ diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index 78467b9..b9bf0e3 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -76,6 +76,31 @@ class nstep_return: return {'return': return_} +class ddpg_return: + """ + compute the return as in DDPG. this seems to have to be special + """ + def __init__(self, actor, critic, use_target_network=True): + self.actor = actor + self.critic = critic + self.use_target_network = use_target_network + + def __call__(self, raw_data): + observation = raw_data['observation'] + reward = raw_data['reward'] + + if self.use_target_network: + action_target = self.actor.eval_action_old(observation) + value_target = self.critic.eval_value_old(observation, action_target) + else: + action_target = self.actor.eval_action(observation) + value_target = self.critic.eval_value(observation, action_target) + + return_ = reward + value_target + + return {'return': return_} + + class nstep_q_return: """ compute the n-step return for Q-learning targets