From 0874d5342f8bf2a4b32512f701a12affd6093869 Mon Sep 17 00:00:00 2001 From: rtz19970824 <1289226405@qq.com> Date: Fri, 15 Dec 2017 14:24:08 +0800 Subject: [PATCH] implement dqn loss and dpg loss, add TODO for separate actor and critic --- examples/dqn_example.py | 2 +- tianshou/core/README.md | 4 ++++ tianshou/core/losses.py | 27 ++++++++++++++++++++++----- tianshou/core/policy/base.py | 10 ++++------ tianshou/data/README.md | 4 ++++ 5 files changed, 35 insertions(+), 12 deletions(-) diff --git a/examples/dqn_example.py b/examples/dqn_example.py index 0a5c084..6a9e2a6 100644 --- a/examples/dqn_example.py +++ b/examples/dqn_example.py @@ -53,7 +53,7 @@ if __name__ == '__main__': action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN - dqn_loss = losses.dqn_loss(action, target, pi) # TongzhengRen + dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen total_loss = dqn_loss optimizer = tf.train.AdamOptimizer(1e-3) diff --git a/tianshou/core/README.md b/tianshou/core/README.md index 1e6d7c7..3617525 100644 --- a/tianshou/core/README.md +++ b/tianshou/core/README.md @@ -1,3 +1,7 @@ +#TODO: + +Separate actor and critic. (Important, we need to focus on that recently) + # policy YongRen diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index f7d798b..d281df9 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -26,7 +26,7 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"): :param sampled_action: placeholder of sampled actions during interaction with the environment :param reward: placeholder of reward the 'sampled_action' get - :param pi: current 'policy' to be optimized + :param pi: current `policy` to be optimized :param baseline: the baseline method used to reduce the variance, default is 'None' :return: """ @@ -35,8 +35,25 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"): # TODO: Different baseline methods like REINFORCE, etc. return vanilla_policy_gradient_loss -def temporal_difference_loss(): - pass +def dqn_loss(sampled_action, sampled_target, q_net): + """ + deep q-network -def deterministic_policy_gradient(): - pass \ No newline at end of file + :param sampled_action: placeholder of sampled actions during the interaction with the environment + :param sampled_target: estimated Q(s,a) + :param q_net: current `policy` to be optimized + :return: + """ + action_num = q_net.get_values().shape()[1] + sampled_q = tf.reduce_sum(q_net.get_values() * tf.one_hot(sampled_action, action_num), axis=1) + return tf.reduce_mean(tf.square(sampled_target - sampled_q)) + +def deterministic_policy_gradient(sampled_state, critic): + """ + deterministic policy gradient: + + :param sampled_action: placeholder of sampled actions during the interaction with the environment + :param critic: current `value` function + :return: + """ + return tf.reduce_mean(critic.get_value(sampled_state)) \ No newline at end of file diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index 0ae20a1..b6d8d48 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -14,12 +14,14 @@ __all__ = [ 'StochasticPolicy', ] +#TODO: separate actor and critic, we should focus on it once we finish the basic module. + class QValuePolicy(object): """ The policy as in DQN """ - def __init__(self, value_tensor): - pass + def __init__(self, observation_placeholder): + self.observation_placeholder = observation_placeholder def act(self, observation, exploration=None): # first implement no exploration """ @@ -222,7 +224,3 @@ class StochasticPolicy(object): Private method for subclasses to rewrite the :meth:`prob` method. """ raise NotImplementedError() - - -class QValuePolicy(object): - pass \ No newline at end of file diff --git a/tianshou/data/README.md b/tianshou/data/README.md index 241971a..e9e6374 100644 --- a/tianshou/data/README.md +++ b/tianshou/data/README.md @@ -1,3 +1,7 @@ +# TODO: + +Notice that we will separate actor and critic, and batch will collect data for optimizing policy while replay will collect data for optimizing critic. + # Batch YouQiaoben