implement dqn loss and dpg loss, add TODO for separate actor and critic

This commit is contained in:
rtz19970824 2017-12-15 14:24:08 +08:00
parent 039c8140e2
commit e5bf7a9270
5 changed files with 35 additions and 12 deletions

View File

@ -53,7 +53,7 @@ if __name__ == '__main__':
action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN
dqn_loss = losses.dqn_loss(action, target, pi) # TongzhengRen dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen
total_loss = dqn_loss total_loss = dqn_loss
optimizer = tf.train.AdamOptimizer(1e-3) optimizer = tf.train.AdamOptimizer(1e-3)

View File

@ -1,3 +1,7 @@
#TODO:
Separate actor and critic. (Important, we need to focus on that recently)
# policy # policy
YongRen YongRen

View File

@ -26,7 +26,7 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
:param sampled_action: placeholder of sampled actions during interaction with the environment :param sampled_action: placeholder of sampled actions during interaction with the environment
:param reward: placeholder of reward the 'sampled_action' get :param reward: placeholder of reward the 'sampled_action' get
:param pi: current 'policy' to be optimized :param pi: current `policy` to be optimized
:param baseline: the baseline method used to reduce the variance, default is 'None' :param baseline: the baseline method used to reduce the variance, default is 'None'
:return: :return:
""" """
@ -35,8 +35,25 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
# TODO Different baseline methods like REINFORCE, etc. # TODO Different baseline methods like REINFORCE, etc.
return vanilla_policy_gradient_loss return vanilla_policy_gradient_loss
def temporal_difference_loss(): def dqn_loss(sampled_action, sampled_target, q_net):
pass """
deep q-network
def deterministic_policy_gradient(): :param sampled_action: placeholder of sampled actions during the interaction with the environment
pass :param sampled_target: estimated Q(s,a)
:param q_net: current `policy` to be optimized
:return:
"""
action_num = q_net.get_values().shape()[1]
sampled_q = tf.reduce_sum(q_net.get_values() * tf.one_hot(sampled_action, action_num), axis=1)
return tf.reduce_mean(tf.square(sampled_target - sampled_q))
def deterministic_policy_gradient(sampled_state, critic):
"""
deterministic policy gradient:
:param sampled_action: placeholder of sampled actions during the interaction with the environment
:param critic: current `value` function
:return:
"""
return tf.reduce_mean(critic.get_value(sampled_state))

View File

@ -14,12 +14,14 @@ __all__ = [
'StochasticPolicy', 'StochasticPolicy',
] ]
#TODO: separate actor and critic, we should focus on it once we finish the basic module.
class QValuePolicy(object): class QValuePolicy(object):
""" """
The policy as in DQN The policy as in DQN
""" """
def __init__(self, value_tensor): def __init__(self, observation_placeholder):
pass self.observation_placeholder = observation_placeholder
def act(self, observation, exploration=None): # first implement no exploration def act(self, observation, exploration=None): # first implement no exploration
""" """
@ -222,7 +224,3 @@ class StochasticPolicy(object):
Private method for subclasses to rewrite the :meth:`prob` method. Private method for subclasses to rewrite the :meth:`prob` method.
""" """
raise NotImplementedError() raise NotImplementedError()
class QValuePolicy(object):
pass

View File

@ -1,3 +1,7 @@
# TODO:
Notice that we will separate actor and critic, and batch will collect data for optimizing policy while replay will collect data for optimizing critic.
# Batch # Batch
YouQiaoben YouQiaoben