implement dqn loss and dpg loss, add TODO for separate actor and critic
This commit is contained in:
parent
f496725437
commit
0874d5342f
@ -53,7 +53,7 @@ if __name__ == '__main__':
|
|||||||
action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions
|
action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions
|
||||||
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN
|
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN
|
||||||
|
|
||||||
dqn_loss = losses.dqn_loss(action, target, pi) # TongzhengRen
|
dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen
|
||||||
|
|
||||||
total_loss = dqn_loss
|
total_loss = dqn_loss
|
||||||
optimizer = tf.train.AdamOptimizer(1e-3)
|
optimizer = tf.train.AdamOptimizer(1e-3)
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
#TODO:
|
||||||
|
|
||||||
|
Separate actor and critic. (Important, we need to focus on that recently)
|
||||||
|
|
||||||
# policy
|
# policy
|
||||||
|
|
||||||
YongRen
|
YongRen
|
||||||
|
@ -26,7 +26,7 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
|
|||||||
|
|
||||||
:param sampled_action: placeholder of sampled actions during interaction with the environment
|
:param sampled_action: placeholder of sampled actions during interaction with the environment
|
||||||
:param reward: placeholder of reward the 'sampled_action' get
|
:param reward: placeholder of reward the 'sampled_action' get
|
||||||
:param pi: current 'policy' to be optimized
|
:param pi: current `policy` to be optimized
|
||||||
:param baseline: the baseline method used to reduce the variance, default is 'None'
|
:param baseline: the baseline method used to reduce the variance, default is 'None'
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
@ -35,8 +35,25 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
|
|||||||
# TODO: Different baseline methods like REINFORCE, etc.
|
# TODO: Different baseline methods like REINFORCE, etc.
|
||||||
return vanilla_policy_gradient_loss
|
return vanilla_policy_gradient_loss
|
||||||
|
|
||||||
def temporal_difference_loss():
|
def dqn_loss(sampled_action, sampled_target, q_net):
|
||||||
pass
|
"""
|
||||||
|
deep q-network
|
||||||
|
|
||||||
def deterministic_policy_gradient():
|
:param sampled_action: placeholder of sampled actions during the interaction with the environment
|
||||||
pass
|
:param sampled_target: estimated Q(s,a)
|
||||||
|
:param q_net: current `policy` to be optimized
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
action_num = q_net.get_values().shape()[1]
|
||||||
|
sampled_q = tf.reduce_sum(q_net.get_values() * tf.one_hot(sampled_action, action_num), axis=1)
|
||||||
|
return tf.reduce_mean(tf.square(sampled_target - sampled_q))
|
||||||
|
|
||||||
|
def deterministic_policy_gradient(sampled_state, critic):
|
||||||
|
"""
|
||||||
|
deterministic policy gradient:
|
||||||
|
|
||||||
|
:param sampled_action: placeholder of sampled actions during the interaction with the environment
|
||||||
|
:param critic: current `value` function
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return tf.reduce_mean(critic.get_value(sampled_state))
|
@ -14,12 +14,14 @@ __all__ = [
|
|||||||
'StochasticPolicy',
|
'StochasticPolicy',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
#TODO: separate actor and critic, we should focus on it once we finish the basic module.
|
||||||
|
|
||||||
class QValuePolicy(object):
|
class QValuePolicy(object):
|
||||||
"""
|
"""
|
||||||
The policy as in DQN
|
The policy as in DQN
|
||||||
"""
|
"""
|
||||||
def __init__(self, value_tensor):
|
def __init__(self, observation_placeholder):
|
||||||
pass
|
self.observation_placeholder = observation_placeholder
|
||||||
|
|
||||||
def act(self, observation, exploration=None): # first implement no exploration
|
def act(self, observation, exploration=None): # first implement no exploration
|
||||||
"""
|
"""
|
||||||
@ -222,7 +224,3 @@ class StochasticPolicy(object):
|
|||||||
Private method for subclasses to rewrite the :meth:`prob` method.
|
Private method for subclasses to rewrite the :meth:`prob` method.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class QValuePolicy(object):
|
|
||||||
pass
|
|
@ -1,3 +1,7 @@
|
|||||||
|
# TODO:
|
||||||
|
|
||||||
|
Notice that we will separate actor and critic, and batch will collect data for optimizing policy while replay will collect data for optimizing critic.
|
||||||
|
|
||||||
# Batch
|
# Batch
|
||||||
|
|
||||||
YouQiaoben
|
YouQiaoben
|
||||||
|
Loading…
x
Reference in New Issue
Block a user