preliminary design of dqn_example, dqn interface. identify the assign of networks

This commit is contained in:
Haosheng Zou 2017-12-13 20:47:45 +08:00
parent d280260a46
commit 72ae304ab3
5 changed files with 124 additions and 5 deletions

86
examples/dqn_example.py Normal file
View File

@ -0,0 +1,86 @@
#!/usr/bin/env python
import tensorflow as tf
import numpy as np
import time
import gym
# our lib imports here!
import sys
sys.path.append('..')
import tianshou.core.losses as losses
from tianshou.data.replay import Replay
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy as policy
def policy_net(observation, action_dim):
"""
Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf
:param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels)
:param action_dim: int. The number of actions.
:param scope: str. Specifying the scope of the variables.
"""
net = tf.layers.conv2d(observation, 16, 8, 4, 'valid', activation=tf.nn.relu)
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
net = tf.layers.flatten(net)
net = tf.layers.dense(net, 256, activation=tf.nn.relu)
q_values = tf.layers.dense(net, action_dim)
return q_values
if __name__ == '__main__':
env = gym.make('PongNoFrameskip-v4')
observation_dim = env.observation_space.shape
action_dim = env.action_space.n
# 1. build network with pure tf
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input
with tf.variable_scope('q_net'):
q_values = policy_net(observation, action_dim)
train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
with tf.variable_scope('target_net'):
q_values_target = policy_net(observation, action_dim)
# 2. build losses, optimizers
q_net = policy.DQN(q_values, observation_placeholder=observation) # YongRen: policy.DQN
target_net = policy.DQN(q_values_target, observation_placeholder=observation)
action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN
dqn_loss = losses.dqn_loss(action, target, pi) # TongzhengRen
total_loss = dqn_loss
optimizer = tf.train.AdamOptimizer(1e-3)
train_op = optimizer.minimize(total_loss, var_list=train_var_list)
# 3. define data collection
training_data = Replay(env, q_net, advantage_estimation.qlearning_target(target_net)) #
# ShihongSong: Replay(env, pi, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
# maybe a dict to manage the elements to be collected
# 4. start training
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
minibatch_count = 0
collection_count = 0
while True: # until some stopping criterion met...
# collect data
training_data.collect() # ShihongSong
collection_count += 1
print('Collected {} times.'.format(collection_count))
# update network
data = training_data.next_batch(64) # YouQiaoben, ShihongSong
# TODO: auto managing of the placeholders? or add this to params of data.Batch
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], target: data['target']})
minibatch_count += 1
print('Trained {} minibatches.'.format(minibatch_count))
# TODO: assigning pi to pi_old is not implemented yet

View File

@ -66,7 +66,7 @@ if __name__ == '__main__': # a clean version with only policy net, no value net
# 3. define data collection # 3. define data collection
training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper
# ShihongSong: Replay(env, pi, advantage_estimation.target_network), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN # ShihongSong: Replay(), see dqn_example.py
# maybe a dict to manage the elements to be collected # maybe a dict to manage the elements to be collected
# 4. start training # 4. start training
@ -88,3 +88,5 @@ if __name__ == '__main__': # a clean version with only policy net, no value net
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']}) sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']})
minibatch_count += 1 minibatch_count += 1
print('Trained {} minibatches.'.format(minibatch_count)) print('Trained {} minibatches.'.format(minibatch_count))
# TODO: assigning pi to pi_old is not implemented yet

View File

@ -10,8 +10,7 @@ follow OnehotCategorical to write Gaussian, can be in the same file as stochasti
not sure how to write, but should at least have act() method to interact with environment not sure how to write, but should at least have act() method to interact with environment
DQN should have an effective argmax_{actions}() method to use as a value network referencing QValuePolicy in base.py, should have at least the listed methods.
# losses # losses

View File

@ -14,6 +14,33 @@ __all__ = [
'StochasticPolicy', 'StochasticPolicy',
] ]
class QValuePolicy(object):
"""
The policy as in DQN
"""
def __init__(self, value_tensor):
pass
def act(self, observation, exploration=None): # first implement no exploration
"""
return the action (int) to be executed.
no exploration when exploration=None.
"""
pass
def values(self, observation):
"""
returns the Q(s, a) values (float) for all actions a at observation s
"""
pass
def values_tensor(self, observation):
"""
returns the tensor of the values for all actions a at observation s
"""
pass
class StochasticPolicy(object): class StochasticPolicy(object):
""" """
@ -195,3 +222,7 @@ class StochasticPolicy(object):
Private method for subclasses to rewrite the :meth:`prob` method. Private method for subclasses to rewrite the :meth:`prob` method.
""" """
raise NotImplementedError() raise NotImplementedError()
class QValuePolicy(object):
pass

View File

@ -70,6 +70,7 @@ class OnehotCategorical(StochasticPolicy):
def _act(self, observation): def _act(self, observation):
sess = tf.get_default_session() # TODO: this may be ugly. also maybe huge problem when parallel sess = tf.get_default_session() # TODO: this may be ugly. also maybe huge problem when parallel
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), feed_dict={self._observation_placeholder: observation[None]}) sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), feed_dict={self._observation_placeholder: observation[None]})
# observation[None] adds one dimension at the beginning
sampled_action = sampled_action[0, 0] sampled_action = sampled_action[0, 0]