towards policy/value refactor

This commit is contained in:
haoshengzou 2017-12-23 17:25:16 +08:00
parent 86bf94fde1
commit b33a141373
9 changed files with 41 additions and 42 deletions

View File

@ -1,8 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
import tensorflow as tf import tensorflow as tf
import numpy as np
import time
import gym import gym
# our lib imports here! # our lib imports here!
@ -10,7 +8,7 @@ import sys
sys.path.append('..') sys.path.append('..')
import tianshou.core.losses as losses import tianshou.core.losses as losses
from tianshou.data.replay_buffer.utils import get_replay_buffer from tianshou.data.replay_buffer.utils import get_replay_buffer
import tianshou.core.policy as policy import tianshou.core.policy.dqn as policy
def policy_net(observation, action_dim): def policy_net(observation, action_dim):
@ -41,6 +39,8 @@ if __name__ == '__main__':
# pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer # pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer
# access this observation variable. # access this observation variable.
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions
with tf.variable_scope('q_net'): with tf.variable_scope('q_net'):
q_values = policy_net(observation, action_dim) q_values = policy_net(observation, action_dim)
@ -48,10 +48,9 @@ if __name__ == '__main__':
q_values_target = policy_net(observation, action_dim) q_values_target = policy_net(observation, action_dim)
# 2. build losses, optimizers # 2. build losses, optimizers
q_net = policy.DQN(q_values, observation_placeholder=observation) # YongRen: policy.DQN q_net = policy.DQNRefactor(q_values, observation_placeholder=observation, action_placeholder=action) # YongRen: policy.DQN
target_net = policy.DQN(q_values_target, observation_placeholder=observation) target_net = policy.DQNRefactor(q_values_target, observation_placeholder=observation, action_placeholder=action)
action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN
dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen

View File

@ -22,3 +22,7 @@ referencing QValuePolicy in base.py, should have at least the listed methods.
TongzhengRen TongzhengRen
seems to be direct python functions. Though the management of placeholders may require some discussion. also may write it in a functional form. seems to be direct python functions. Though the management of placeholders may require some discussion. also may write it in a functional form.
# policy, value_function
naming should be reconsidered. Perhaps use plural forms for all nouns

View File

@ -35,17 +35,16 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
# TODO: Different baseline methods like REINFORCE, etc. # TODO: Different baseline methods like REINFORCE, etc.
return vanilla_policy_gradient_loss return vanilla_policy_gradient_loss
def dqn_loss(sampled_action, sampled_target, q_net): def dqn_loss(sampled_action, sampled_target, policy):
""" """
deep q-network deep q-network
:param sampled_action: placeholder of sampled actions during the interaction with the environment :param sampled_action: placeholder of sampled actions during the interaction with the environment
:param sampled_target: estimated Q(s,a) :param sampled_target: estimated Q(s,a)
:param q_net: current `policy` to be optimized :param policy: current `policy` to be optimized
:return: :return:
""" """
action_num = q_net.values_tensor().get_shape()[1] sampled_q = policy.q_net.value_tensor
sampled_q = tf.reduce_sum(q_net.values_tensor() * tf.one_hot(sampled_action, action_num), axis=1)
return tf.reduce_mean(tf.square(sampled_target - sampled_q)) return tf.reduce_mean(tf.square(sampled_target - sampled_q))
def deterministic_policy_gradient(sampled_state, critic): def deterministic_policy_gradient(sampled_state, critic):

View File

@ -3,19 +3,12 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
import warnings
import tensorflow as tf import tensorflow as tf
# from zhusuan.utils import add_name_scope # from zhusuan.utils import add_name_scope
__all__ = [
'StochasticPolicy',
'QValuePolicy',
'PolicyBase'
]
# TODO: a even more "base" class for policy # TODO: a even more "base" class for policy
@ -23,8 +16,8 @@ class PolicyBase(object):
""" """
base class for policy. only provides `act` method with exploration base class for policy. only provides `act` method with exploration
""" """
def __init__(self): def __init__(self, observation_placeholder):
pass self._observation_placeholder = observation_placeholder
def act(self, observation, exploration): def act(self, observation, exploration):
raise NotImplementedError() raise NotImplementedError()
@ -37,14 +30,14 @@ class QValuePolicy(object):
def __init__(self, observation_placeholder): def __init__(self, observation_placeholder):
self._observation_placeholder = observation_placeholder self._observation_placeholder = observation_placeholder
def act(self, observation, exploration=None): # first implement no exploration def act(self, observation, exploration=None): # first implement no exploration
""" """
return the action (int) to be executed. return the action (int) to be executed.
no exploration when exploration=None. no exploration when exploration=None.
""" """
self._act(observation, exploration) self._act(observation, exploration)
def _act(self, observation, exploration = None): def _act(self, observation, exploration=None):
raise NotImplementedError() raise NotImplementedError()
def values(self, observation): def values(self, observation):
@ -60,7 +53,6 @@ class QValuePolicy(object):
pass pass
class StochasticPolicy(object): class StochasticPolicy(object):
""" """
The :class:`Distribution` class is the base class for various probabilistic The :class:`Distribution` class is the base class for various probabilistic
@ -130,7 +122,7 @@ class StochasticPolicy(object):
param_dtype, param_dtype,
is_continuous, is_continuous,
observation_placeholder, observation_placeholder,
group_ndims=0, # maybe useful for repeat_action group_ndims=0, # maybe useful for repeat_action
**kwargs): **kwargs):
self._act_dtype = act_dtype self._act_dtype = act_dtype

View File

@ -10,16 +10,25 @@ class DQNRefactor(PolicyBase):
use DQN from value_function as a member use DQN from value_function as a member
""" """
def __init__(self, value_tensor, observation_placeholder, action_placeholder): def __init__(self, value_tensor, observation_placeholder, action_placeholder):
self._network = DQN(value_tensor, observation_placeholder, action_placeholder) self._q_net = DQN(value_tensor, observation_placeholder, action_placeholder)
self._argmax_action = tf.argmax(value_tensor, axis=1) self._argmax_action = tf.argmax(value_tensor, axis=1)
def act(self, observation, exploration): super(DQNRefactor, self).__init__(observation_placeholder=observation_placeholder)
def act(self, observation, exploration=None):
sess = tf.get_default_session() sess = tf.get_default_session()
if not exploration: # no exploration if not exploration: # no exploration
action = sess.run(self._argmax_action, feed_dict={}) action = sess.run(self._argmax_action, feed_dict={self._observation_placeholder: observation})
class DQN(QValuePolicy): return action
@property
def q_net(self):
return self._q_net
class DQNOld(QValuePolicy):
""" """
The policy as in DQN The policy as in DQN
""" """

View File

@ -10,12 +10,6 @@ import tensorflow as tf
from .base import StochasticPolicy from .base import StochasticPolicy
__all__ = [
'OnehotCategorical',
'OnehotDiscrete',
]
class OnehotCategorical(StochasticPolicy): class OnehotCategorical(StochasticPolicy):
""" """
The class of one-hot Categorical distribution. The class of one-hot Categorical distribution.

View File

@ -15,7 +15,7 @@ class ActionValue(ValueFunctionBase):
observation_placeholder=observation_placeholder observation_placeholder=observation_placeholder
) )
def get_value(self, observation, action): def eval_value(self, observation, action):
""" """
:param observation: numpy array of observations, of shape (batchsize, observation_dim). :param observation: numpy array of observations, of shape (batchsize, observation_dim).
:param action: numpy array of actions, of shape (batchsize, action_dim) :param action: numpy array of actions, of shape (batchsize, action_dim)
@ -24,7 +24,7 @@ class ActionValue(ValueFunctionBase):
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a) # TODO: dealing with the last dim of 1 in V(s) and Q(s, a)
""" """
sess = tf.get_default_session() sess = tf.get_default_session()
return sess.run(self.get_value_tensor(), feed_dict= return sess.run(self.value_tensor, feed_dict=
{self._observation_placeholder: observation, self._action_placeholder: action}) {self._observation_placeholder: observation, self._action_placeholder: action})
@ -50,7 +50,7 @@ class DQN(ActionValue):
observation_placeholder=observation_placeholder, observation_placeholder=observation_placeholder,
action_placeholder=action_placeholder) action_placeholder=action_placeholder)
def get_value_all_actions(self, observation): def eval_value_all_actions(self, observation):
""" """
:param observation: :param observation:
:return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions) :return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions)
@ -58,5 +58,6 @@ class DQN(ActionValue):
sess = tf.get_default_session() sess = tf.get_default_session()
return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation}) return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation})
def get_value_tensor_all_actions(self): @property
def value_tensor_all_actions(self):
return self._value_tensor_all_actions return self._value_tensor_all_actions

View File

@ -11,14 +11,15 @@ class ValueFunctionBase(object):
self._observation_placeholder = observation_placeholder self._observation_placeholder = observation_placeholder
self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, ) self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, )
def get_value(self, **kwargs): def eval_value(self, **kwargs):
""" """
:return: batch of corresponding values in numpy array :return: batch of corresponding values in numpy array
""" """
raise NotImplementedError() raise NotImplementedError()
def get_value_tensor(self): @property
def value_tensor(self):
""" """
:return: tensor of the corresponding values :return: tensor of the corresponding values

View File

@ -14,7 +14,7 @@ class StateValue(ValueFunctionBase):
observation_placeholder=observation_placeholder observation_placeholder=observation_placeholder
) )
def get_value(self, observation): def eval_value(self, observation):
""" """
:param observation: numpy array of observations, of shape (batchsize, observation_dim). :param observation: numpy array of observations, of shape (batchsize, observation_dim).
@ -22,4 +22,4 @@ class StateValue(ValueFunctionBase):
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env # TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env
""" """
sess = tf.get_default_session() sess = tf.get_default_session()
return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation}) return sess.run(self.value_tensor, feed_dict={self._observation_placeholder: observation})