From 72ae304ab3477242dfad48aac22f4b54a208b4c0 Mon Sep 17 00:00:00 2001 From: Haosheng Zou Date: Wed, 13 Dec 2017 20:47:45 +0800 Subject: [PATCH] preliminary design of dqn_example, dqn interface. identify the assign of networks --- examples/dqn_example.py | 86 ++++++++++++++++++++++++++++++ examples/ppo_example.py | 6 ++- tianshou/core/README.md | 3 +- tianshou/core/policy/base.py | 33 +++++++++++- tianshou/core/policy/stochastic.py | 1 + 5 files changed, 124 insertions(+), 5 deletions(-) create mode 100644 examples/dqn_example.py diff --git a/examples/dqn_example.py b/examples/dqn_example.py new file mode 100644 index 0000000..0a5c084 --- /dev/null +++ b/examples/dqn_example.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python + +import tensorflow as tf +import numpy as np +import time +import gym + +# our lib imports here! +import sys +sys.path.append('..') +import tianshou.core.losses as losses +from tianshou.data.replay import Replay +import tianshou.data.advantage_estimation as advantage_estimation +import tianshou.core.policy as policy + + +def policy_net(observation, action_dim): + """ + Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf + + :param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels) + :param action_dim: int. The number of actions. + :param scope: str. Specifying the scope of the variables. + """ + net = tf.layers.conv2d(observation, 16, 8, 4, 'valid', activation=tf.nn.relu) + net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu) + net = tf.layers.flatten(net) + net = tf.layers.dense(net, 256, activation=tf.nn.relu) + + q_values = tf.layers.dense(net, action_dim) + + return q_values + + +if __name__ == '__main__': + env = gym.make('PongNoFrameskip-v4') + observation_dim = env.observation_space.shape + action_dim = env.action_space.n + + # 1. build network with pure tf + observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input + + with tf.variable_scope('q_net'): + q_values = policy_net(observation, action_dim) + train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES + with tf.variable_scope('target_net'): + q_values_target = policy_net(observation, action_dim) + + # 2. build losses, optimizers + q_net = policy.DQN(q_values, observation_placeholder=observation) # YongRen: policy.DQN + target_net = policy.DQN(q_values_target, observation_placeholder=observation) + + action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions + target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN + + dqn_loss = losses.dqn_loss(action, target, pi) # TongzhengRen + + total_loss = dqn_loss + optimizer = tf.train.AdamOptimizer(1e-3) + train_op = optimizer.minimize(total_loss, var_list=train_var_list) + + # 3. define data collection + training_data = Replay(env, q_net, advantage_estimation.qlearning_target(target_net)) # + # ShihongSong: Replay(env, pi, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN + # maybe a dict to manage the elements to be collected + + # 4. start training + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + minibatch_count = 0 + collection_count = 0 + while True: # until some stopping criterion met... + # collect data + training_data.collect() # ShihongSong + collection_count += 1 + print('Collected {} times.'.format(collection_count)) + + # update network + data = training_data.next_batch(64) # YouQiaoben, ShihongSong + # TODO: auto managing of the placeholders? or add this to params of data.Batch + sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], target: data['target']}) + minibatch_count += 1 + print('Trained {} minibatches.'.format(minibatch_count)) + + # TODO: assigning pi to pi_old is not implemented yet \ No newline at end of file diff --git a/examples/ppo_example.py b/examples/ppo_example.py index d085273..02ccb52 100755 --- a/examples/ppo_example.py +++ b/examples/ppo_example.py @@ -66,7 +66,7 @@ if __name__ == '__main__': # a clean version with only policy net, no value net # 3. define data collection training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper - # ShihongSong: Replay(env, pi, advantage_estimation.target_network), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN + # ShihongSong: Replay(), see dqn_example.py # maybe a dict to manage the elements to be collected # 4. start training @@ -87,4 +87,6 @@ if __name__ == '__main__': # a clean version with only policy net, no value net # TODO: auto managing of the placeholders? or add this to params of data.Batch sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']}) minibatch_count += 1 - print('Trained {} minibatches.'.format(minibatch_count)) \ No newline at end of file + print('Trained {} minibatches.'.format(minibatch_count)) + + # TODO: assigning pi to pi_old is not implemented yet \ No newline at end of file diff --git a/tianshou/core/README.md b/tianshou/core/README.md index 16d915e..1e6d7c7 100644 --- a/tianshou/core/README.md +++ b/tianshou/core/README.md @@ -10,8 +10,7 @@ follow OnehotCategorical to write Gaussian, can be in the same file as stochasti not sure how to write, but should at least have act() method to interact with environment -DQN should have an effective argmax_{actions}() method to use as a value network - +referencing QValuePolicy in base.py, should have at least the listed methods. # losses diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index b0bf28a..0ae20a1 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -14,6 +14,33 @@ __all__ = [ 'StochasticPolicy', ] +class QValuePolicy(object): + """ + The policy as in DQN + """ + def __init__(self, value_tensor): + pass + + def act(self, observation, exploration=None): # first implement no exploration + """ + return the action (int) to be executed. + no exploration when exploration=None. + """ + pass + + def values(self, observation): + """ + returns the Q(s, a) values (float) for all actions a at observation s + """ + pass + + def values_tensor(self, observation): + """ + returns the tensor of the values for all actions a at observation s + """ + pass + + class StochasticPolicy(object): """ @@ -194,4 +221,8 @@ class StochasticPolicy(object): """ Private method for subclasses to rewrite the :meth:`prob` method. """ - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() + + +class QValuePolicy(object): + pass \ No newline at end of file diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index 37eb1be..3ef463e 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -70,6 +70,7 @@ class OnehotCategorical(StochasticPolicy): def _act(self, observation): sess = tf.get_default_session() # TODO: this may be ugly. also maybe huge problem when parallel sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), feed_dict={self._observation_placeholder: observation[None]}) + # observation[None] adds one dimension at the beginning sampled_action = sampled_action[0, 0]