From b21a55dc88fefe7773b842e87af2d6b3eaab821b Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Sat, 23 Dec 2017 17:25:16 +0800 Subject: [PATCH] towards policy/value refactor --- examples/dqn_example.py | 11 +++++------ tianshou/core/README.md | 6 +++++- tianshou/core/losses.py | 7 +++---- tianshou/core/policy/base.py | 18 +++++------------- tianshou/core/policy/dqn.py | 17 +++++++++++++---- tianshou/core/policy/stochastic.py | 6 ------ tianshou/core/value_function/action_value.py | 9 +++++---- tianshou/core/value_function/base.py | 5 +++-- tianshou/core/value_function/state_value.py | 4 ++-- 9 files changed, 41 insertions(+), 42 deletions(-) diff --git a/examples/dqn_example.py b/examples/dqn_example.py index b676475..cf20d66 100644 --- a/examples/dqn_example.py +++ b/examples/dqn_example.py @@ -1,8 +1,6 @@ #!/usr/bin/env python import tensorflow as tf -import numpy as np -import time import gym # our lib imports here! @@ -10,7 +8,7 @@ import sys sys.path.append('..') import tianshou.core.losses as losses from tianshou.data.replay_buffer.utils import get_replay_buffer -import tianshou.core.policy as policy +import tianshou.core.policy.dqn as policy def policy_net(observation, action_dim): @@ -41,6 +39,8 @@ if __name__ == '__main__': # pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer # access this observation variable. observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input + action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions + with tf.variable_scope('q_net'): q_values = policy_net(observation, action_dim) @@ -48,10 +48,9 @@ if __name__ == '__main__': q_values_target = policy_net(observation, action_dim) # 2. build losses, optimizers - q_net = policy.DQN(q_values, observation_placeholder=observation) # YongRen: policy.DQN - target_net = policy.DQN(q_values_target, observation_placeholder=observation) + q_net = policy.DQNRefactor(q_values, observation_placeholder=observation, action_placeholder=action) # YongRen: policy.DQN + target_net = policy.DQNRefactor(q_values_target, observation_placeholder=observation, action_placeholder=action) - action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen diff --git a/tianshou/core/README.md b/tianshou/core/README.md index 3617525..a9cda58 100644 --- a/tianshou/core/README.md +++ b/tianshou/core/README.md @@ -21,4 +21,8 @@ referencing QValuePolicy in base.py, should have at least the listed methods. TongzhengRen -seems to be direct python functions. Though the management of placeholders may require some discussion. also may write it in a functional form. \ No newline at end of file +seems to be direct python functions. Though the management of placeholders may require some discussion. also may write it in a functional form. + +# policy, value_function + +naming should be reconsidered. Perhaps use plural forms for all nouns \ No newline at end of file diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index 3461afb..5d5d2f3 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -35,17 +35,16 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"): # TODO: Different baseline methods like REINFORCE, etc. return vanilla_policy_gradient_loss -def dqn_loss(sampled_action, sampled_target, q_net): +def dqn_loss(sampled_action, sampled_target, policy): """ deep q-network :param sampled_action: placeholder of sampled actions during the interaction with the environment :param sampled_target: estimated Q(s,a) - :param q_net: current `policy` to be optimized + :param policy: current `policy` to be optimized :return: """ - action_num = q_net.values_tensor().get_shape()[1] - sampled_q = tf.reduce_sum(q_net.values_tensor() * tf.one_hot(sampled_action, action_num), axis=1) + sampled_q = policy.q_net.value_tensor return tf.reduce_mean(tf.square(sampled_target - sampled_q)) def deterministic_policy_gradient(sampled_state, critic): diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index 1adeaeb..1c1e1c5 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -3,19 +3,12 @@ from __future__ import absolute_import from __future__ import division -import warnings import tensorflow as tf # from zhusuan.utils import add_name_scope -__all__ = [ - 'StochasticPolicy', - 'QValuePolicy', - 'PolicyBase' -] - # TODO: a even more "base" class for policy @@ -23,8 +16,8 @@ class PolicyBase(object): """ base class for policy. only provides `act` method with exploration """ - def __init__(self): - pass + def __init__(self, observation_placeholder): + self._observation_placeholder = observation_placeholder def act(self, observation, exploration): raise NotImplementedError() @@ -37,14 +30,14 @@ class QValuePolicy(object): def __init__(self, observation_placeholder): self._observation_placeholder = observation_placeholder - def act(self, observation, exploration=None): # first implement no exploration + def act(self, observation, exploration=None): # first implement no exploration """ return the action (int) to be executed. no exploration when exploration=None. """ self._act(observation, exploration) - def _act(self, observation, exploration = None): + def _act(self, observation, exploration=None): raise NotImplementedError() def values(self, observation): @@ -60,7 +53,6 @@ class QValuePolicy(object): pass - class StochasticPolicy(object): """ The :class:`Distribution` class is the base class for various probabilistic @@ -130,7 +122,7 @@ class StochasticPolicy(object): param_dtype, is_continuous, observation_placeholder, - group_ndims=0, # maybe useful for repeat_action + group_ndims=0, # maybe useful for repeat_action **kwargs): self._act_dtype = act_dtype diff --git a/tianshou/core/policy/dqn.py b/tianshou/core/policy/dqn.py index 716e4c4..8533549 100644 --- a/tianshou/core/policy/dqn.py +++ b/tianshou/core/policy/dqn.py @@ -10,16 +10,25 @@ class DQNRefactor(PolicyBase): use DQN from value_function as a member """ def __init__(self, value_tensor, observation_placeholder, action_placeholder): - self._network = DQN(value_tensor, observation_placeholder, action_placeholder) + self._q_net = DQN(value_tensor, observation_placeholder, action_placeholder) self._argmax_action = tf.argmax(value_tensor, axis=1) - def act(self, observation, exploration): + super(DQNRefactor, self).__init__(observation_placeholder=observation_placeholder) + + def act(self, observation, exploration=None): sess = tf.get_default_session() if not exploration: # no exploration - action = sess.run(self._argmax_action, feed_dict={}) + action = sess.run(self._argmax_action, feed_dict={self._observation_placeholder: observation}) -class DQN(QValuePolicy): + return action + + @property + def q_net(self): + return self._q_net + + +class DQNOld(QValuePolicy): """ The policy as in DQN """ diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index 3ef463e..d7a75d7 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -10,12 +10,6 @@ import tensorflow as tf from .base import StochasticPolicy -__all__ = [ - 'OnehotCategorical', - 'OnehotDiscrete', -] - - class OnehotCategorical(StochasticPolicy): """ The class of one-hot Categorical distribution. diff --git a/tianshou/core/value_function/action_value.py b/tianshou/core/value_function/action_value.py index 2bda4fa..c62dae6 100644 --- a/tianshou/core/value_function/action_value.py +++ b/tianshou/core/value_function/action_value.py @@ -15,7 +15,7 @@ class ActionValue(ValueFunctionBase): observation_placeholder=observation_placeholder ) - def get_value(self, observation, action): + def eval_value(self, observation, action): """ :param observation: numpy array of observations, of shape (batchsize, observation_dim). :param action: numpy array of actions, of shape (batchsize, action_dim) @@ -24,7 +24,7 @@ class ActionValue(ValueFunctionBase): # TODO: dealing with the last dim of 1 in V(s) and Q(s, a) """ sess = tf.get_default_session() - return sess.run(self.get_value_tensor(), feed_dict= + return sess.run(self.value_tensor, feed_dict= {self._observation_placeholder: observation, self._action_placeholder: action}) @@ -50,7 +50,7 @@ class DQN(ActionValue): observation_placeholder=observation_placeholder, action_placeholder=action_placeholder) - def get_value_all_actions(self, observation): + def eval_value_all_actions(self, observation): """ :param observation: :return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions) @@ -58,5 +58,6 @@ class DQN(ActionValue): sess = tf.get_default_session() return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation}) - def get_value_tensor_all_actions(self): + @property + def value_tensor_all_actions(self): return self._value_tensor_all_actions \ No newline at end of file diff --git a/tianshou/core/value_function/base.py b/tianshou/core/value_function/base.py index b15f1bf..8ca9dd0 100644 --- a/tianshou/core/value_function/base.py +++ b/tianshou/core/value_function/base.py @@ -11,14 +11,15 @@ class ValueFunctionBase(object): self._observation_placeholder = observation_placeholder self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, ) - def get_value(self, **kwargs): + def eval_value(self, **kwargs): """ :return: batch of corresponding values in numpy array """ raise NotImplementedError() - def get_value_tensor(self): + @property + def value_tensor(self): """ :return: tensor of the corresponding values diff --git a/tianshou/core/value_function/state_value.py b/tianshou/core/value_function/state_value.py index b7de196..02c12fe 100644 --- a/tianshou/core/value_function/state_value.py +++ b/tianshou/core/value_function/state_value.py @@ -14,7 +14,7 @@ class StateValue(ValueFunctionBase): observation_placeholder=observation_placeholder ) - def get_value(self, observation): + def eval_value(self, observation): """ :param observation: numpy array of observations, of shape (batchsize, observation_dim). @@ -22,4 +22,4 @@ class StateValue(ValueFunctionBase): # TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env """ sess = tf.get_default_session() - return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation}) \ No newline at end of file + return sess.run(self.value_tensor, feed_dict={self._observation_placeholder: observation}) \ No newline at end of file