finish dqn example. advantage estimation module is not complete yet.

This commit is contained in:
haoshengzou 2018-01-18 12:19:48 +08:00
parent 9f96cc2461
commit 8fbde8283f
6 changed files with 187 additions and 106 deletions

View File

@ -13,6 +13,7 @@ from tianshou.core import losses
from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so that only need to import to policy
import tianshou.core.value_function.action_value as value_function
if __name__ == '__main__':
@ -31,31 +32,26 @@ if __name__ == '__main__':
### 1. build network with pure tf
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
def my_policy():
def my_network():
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
action_values = tf.layers.dense(net, action_dim, activation=None)
return action_values, None # None value head
# TODO: current implementation of passing function or overriding function has to return a value head
# to allow network sharing between policy and value networks. This makes 'policy' and 'value_function'
# imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact
# with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is
# not based on passing function or overriding function.
return None, action_values # no policy head
### 2. build policy, loss, optimizer
pi = policy.DQN(my_policy, observation_placeholder=observation_ph, weight_update=10)
dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, weight_update=100)
pi = policy.DQN(dqn)
dqn_loss = losses.qlearning(pi)
dqn_loss = losses.qlearning(dqn)
total_loss = dqn_loss
optimizer = tf.train.AdamOptimizer(1e-4)
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
train_op = optimizer.minimize(total_loss, var_list=dqn.trainable_variables)
### 3. define data collection
data_collector = Batch(env, pi, [advantage_estimation.nstep_q_return(1, pi.target_network)], [pi])
data_collector = Batch(env, pi, [advantage_estimation.nstep_q_return(1, dqn)], [dqn])
### 4. start training
config = tf.ConfigProto()

View File

@ -51,24 +51,25 @@ def state_value_mse(state_value_function):
:param state_value_function: instance of StateValue
:return: tensor of the mse loss
"""
state_value_ph = tf.placeholder(tf.float32, shape=(None,), name='state_value_mse/state_value_placeholder')
state_value_function.managed_placeholders['return'] = state_value_ph
target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='state_value_mse/state_value_placeholder')
state_value_function.managed_placeholders['return'] = target_value_ph
state_value = state_value_function.value_tensor
return tf.losses.mean_squared_error(state_value_ph, state_value)
return tf.losses.mean_squared_error(target_value_ph, state_value)
def dqn_loss(sampled_action, sampled_target, policy):
def qlearning(action_value_function):
"""
deep q-network
:param sampled_action: placeholder of sampled actions during the interaction with the environment
:param sampled_target: estimated Q(s,a)
:param policy: current `policy` to be optimized
:param action_value_function: current `action_value` to be optimized
:return:
"""
sampled_q = policy.q_net.value_tensor
return tf.reduce_mean(tf.square(sampled_target - sampled_q))
target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='qlearning/action_value_placeholder')
action_value_function.managed_placeholders['return'] = target_value_ph
q_value = action_value_function.value_tensor
return tf.losses.mean_squared_error(target_value_ph, q_value)
def deterministic_policy_gradient(sampled_state, critic):
"""

View File

@ -2,88 +2,53 @@ from __future__ import absolute_import
from .base import PolicyBase
import tensorflow as tf
from ..value_function.action_value import DQN
import numpy as np
class DQNRefactor(PolicyBase):
class DQN(PolicyBase):
"""
use DQN from value_function as a member
"""
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
self._q_net = DQN(value_tensor, observation_placeholder, action_placeholder)
self._argmax_action = tf.argmax(value_tensor, axis=1)
super(DQNRefactor, self).__init__(observation_placeholder=observation_placeholder)
def __init__(self, dqn):
self.action_value = dqn
self._argmax_action = tf.argmax(dqn.value_tensor_all_actions, axis=1)
self.weight_update = dqn.weight_update
if self.weight_update > 1:
self.interaction_count = 0
else:
self.interaction_count = -1
def act(self, observation, exploration=None):
sess = tf.get_default_session()
if not exploration: # no exploration
action = sess.run(self._argmax_action, feed_dict={self._observation_placeholder: observation})
if self.weight_update > 1:
if self.interaction_count % self.weight_update == 0:
self.update_weights()
feed_dict = {self.action_value._observation_placeholder: observation[None]}
action = sess.run(self._argmax_action, feed_dict=feed_dict)
return action
if self.weight_update > 0:
self.interaction_count += 1
if not exploration:
return np.squeeze(action)
@property
def q_net(self):
return self._q_net
return self.action_value
def sync_weights(self):
"""
sync the weights of network_old. Direct copy the weights of network.
:return:
"""
if self.action_value.sync_weights_ops is not None:
self.action_value.sync_weights()
class DQNOld(QValuePolicy):
"""
The policy as in DQN
"""
def __init__(self, logits, observation_placeholder, dtype=None, **kwargs):
# TODO: this version only support non-continuous action space, extend it to support continuous action space
self._logits = tf.convert_to_tensor(logits)
if dtype is None:
dtype = tf.int32
self._n_categories = self._logits.get_shape()[-1].value
super(DQN, self).__init__(observation_placeholder)
# TODO: put the net definition outside of the class
net = tf.layers.conv2d(self._observation_placeholder, 16, 8, 4, 'valid', activation=tf.nn.relu)
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
net = tf.layers.flatten(net)
net = tf.layers.dense(net, 256, activation=tf.nn.relu, use_bias=True)
self._value = tf.layers.dense(net, self._n_categories)
def _act(self, observation, exploration=None): # first implement no exploration
def update_weights(self):
"""
return the action (int) to be executed.
no exploration when exploration=None.
updates the weights of policy_old.
:return:
"""
# TODO: ensure thread safety, tf.multinomial to init
sess = tf.get_default_session()
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1),
feed_dict={self._observation_placeholder: observation[None]})
return sampled_action
@property
def logits(self):
"""
:return: action values
"""
return self._logits
@property
def n_categories(self):
"""
:return: dimension of action space if not continuous
"""
return self._n_categories
def values(self, observation):
"""
returns the Q(s, a) values (float) for all actions a at observation s
"""
sess = tf.get_default_session()
value = sess.run(self._value, feed_dict={self._observation_placeholder: observation[None]})
return value
def values_tensor(self):
"""
returns the tensor of the values for all actions a at observation s
"""
return self._value
if self.action_value.weight_update_ops is not None:
self.action_value.update_weights()

View File

@ -65,7 +65,7 @@ class OnehotCategorical(StochasticPolicy):
logits, value_head = policy_callable()
self._logits_old = tf.convert_to_tensor(logits, dtype=tf.float32)
if value_head is not None: # useful in DDPG
if value_head is not None:
pass
network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network')
@ -201,7 +201,7 @@ class Normal(StochasticPolicy):
self._mean_old = tf.convert_to_tensor(mean, dtype=tf.float32)
self._logstd_old = tf.convert_to_tensor(logstd, dtype=tf.float32)
if value_head is not None: # useful in DDPG
if value_head is not None:
pass
network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network')
@ -215,7 +215,7 @@ class Normal(StochasticPolicy):
if weight_update == 0:
self.weight_update_ops = self.sync_weights_ops
elif 0 < weight_update < 1:
elif 0 < weight_update < 1: # useful in DDPG
pass
else:
self.interaction_count = 0

View File

@ -28,27 +28,70 @@ class ActionValue(ValueFunctionBase):
{self._observation_placeholder: observation, self._action_placeholder: action})
class DQN(ActionValue):
class DQN(ValueFunctionBase):
"""
class of the very DQN architecture. Instead of feeding s and a to the network to get a value, DQN feed s to the
network and the last layer is Q(s, *) for all actions
"""
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
def __init__(self, network_callable, observation_placeholder, weight_update=1):
"""
:param value_tensor: of shape (batchsize, num_actions)
:param observation_placeholder: of shape (batchsize, observation_dim)
:param action_placeholder: of shape (batchsize, )
"""
self._value_tensor_all_actions = value_tensor
self._observation_placeholder = observation_placeholder
self.action_placeholder = action_placeholder = tf.placeholder(tf.int32, shape=(None,), name='action_value.DQN/action_placeholder')
self.managed_placeholders = {'observation': observation_placeholder, 'action': action_placeholder}
self.weight_update = weight_update
self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1.
batch_size = tf.shape(value_tensor)[0]
batch_dim_index = tf.range(batch_size)
indices = tf.stack([batch_dim_index, action_placeholder], axis=1)
canonical_value_tensor = tf.gather_nd(value_tensor, indices)
with tf.variable_scope('network', reuse=tf.AUTO_REUSE):
value_tensor = network_callable()[-1]
self._value_tensor_all_actions = value_tensor
batch_size = tf.shape(value_tensor)[0]
batch_dim_index = tf.range(batch_size)
indices = tf.stack([batch_dim_index, action_placeholder], axis=1)
canonical_value_tensor = tf.gather_nd(value_tensor, indices)
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')
super(DQN, self).__init__(canonical_value_tensor, observation_placeholder=observation_placeholder)
# deal with target network
if self.weight_update == 1:
self.weight_update_ops = None
self.sync_weights_ops = None
else: # then we need to build another tf graph as target network
with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE):
value_tensor = network_callable()[-1]
self.value_tensor_all_actions_old = value_tensor
batch_size = tf.shape(value_tensor)[0]
batch_dim_index = tf.range(batch_size)
indices = tf.stack([batch_dim_index, action_placeholder], axis=1)
canonical_value_tensor = tf.gather_nd(value_tensor, indices)
self.value_tensor_old = canonical_value_tensor
network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network')
network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old')
assert len(network_weights) == len(network_old_weights)
self.sync_weights_ops = [tf.assign(variable_old, variable)
for (variable_old, variable) in zip(network_old_weights, network_weights)]
if weight_update == 0:
self.weight_update_ops = self.sync_weights_ops
elif 0 < weight_update < 1: # useful in DDPG
pass
else:
self.interaction_count = 0
import math
self.weight_update = math.ceil(weight_update)
self.weight_update_ops = self.sync_weights_ops
super(DQN, self).__init__(value_tensor=canonical_value_tensor,
observation_placeholder=observation_placeholder,
action_placeholder=action_placeholder)
def eval_value_all_actions(self, observation):
"""
@ -60,4 +103,41 @@ class DQN(ActionValue):
@property
def value_tensor_all_actions(self):
return self._value_tensor_all_actions
return self._value_tensor_all_actions
def eval_value_old(self, observation, action):
"""
eval value using target network
:param observation: numpy array of obs
:param action: numpy array of action
:return: numpy array of action value
"""
sess = tf.get_default_session()
feed_dict = {self._observation_placeholder: observation, self.action_placeholder: action}
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
def eval_value_all_actions_old(self, observation):
"""
:param observation:
:return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions)
"""
sess = tf.get_default_session()
return sess.run(self.value_tensor_all_actions_old, feed_dict={self._observation_placeholder: observation})
def sync_weights(self):
"""
sync the weights of network_old. Direct copy the weights of network.
:return:
"""
if self.sync_weights_ops is not None:
sess = tf.get_default_session()
sess.run(self.sync_weights_ops)
def update_weights(self):
"""
updates the weights of policy_old.
:return:
"""
if self.weight_update_ops is not None:
sess = tf.get_default_session()
sess.run(self.weight_update_ops)

View File

@ -46,8 +46,14 @@ class gae_lambda:
def __call__(self, raw_data):
reward = raw_data['reward']
observation = raw_data['observation']
return {'advantage': reward}
state_value = self.value_function.eval_value(observation)
# wrong version of advantage just to run
advantage = reward + state_value
return {'advantage': advantage}
class nstep_return:
@ -60,8 +66,41 @@ class nstep_return:
def __call__(self, raw_data):
reward = raw_data['reward']
observation = raw_data['observation']
return {'return': reward}
state_value = self.value_function.eval_value(observation)
# wrong version of return just to run
return_ = reward + state_value
return {'return': return_}
class nstep_q_return:
"""
compute the n-step return for Q-learning targets
"""
def __init__(self, n, action_value, use_target_network=True):
self.n = n
self.action_value = action_value
self.use_target_network = use_target_network
def __call__(self, raw_data):
# raw_data should contain 'next_observation' from replay memory...?
# maybe the main difference between Batch and Replay is the stored data format?
reward = raw_data['reward']
observation = raw_data['observation']
if self.use_target_network:
action_value_all_actions = self.action_value.eval_value_all_actions_old(observation)
else:
action_value_all_actions = self.action_value.eval_value_all_actions(observation)
action_value_max = np.max(action_value_all_actions, axis=1)
return_ = reward + action_value_max
return {'return': return_}
class QLearningTarget: