towards policy/value refactor

This commit is contained in:
haoshengzou 2017-12-23 17:25:16 +08:00
parent 86bf94fde1
commit b33a141373
9 changed files with 41 additions and 42 deletions

View File

@ -1,8 +1,6 @@
#!/usr/bin/env python
import tensorflow as tf
import numpy as np
import time
import gym
# our lib imports here!
@ -10,7 +8,7 @@ import sys
sys.path.append('..')
import tianshou.core.losses as losses
from tianshou.data.replay_buffer.utils import get_replay_buffer
import tianshou.core.policy as policy
import tianshou.core.policy.dqn as policy
def policy_net(observation, action_dim):
@ -41,6 +39,8 @@ if __name__ == '__main__':
# pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer
# access this observation variable.
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions
with tf.variable_scope('q_net'):
q_values = policy_net(observation, action_dim)
@ -48,10 +48,9 @@ if __name__ == '__main__':
q_values_target = policy_net(observation, action_dim)
# 2. build losses, optimizers
q_net = policy.DQN(q_values, observation_placeholder=observation) # YongRen: policy.DQN
target_net = policy.DQN(q_values_target, observation_placeholder=observation)
q_net = policy.DQNRefactor(q_values, observation_placeholder=observation, action_placeholder=action) # YongRen: policy.DQN
target_net = policy.DQNRefactor(q_values_target, observation_placeholder=observation, action_placeholder=action)
action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN
dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen

View File

@ -22,3 +22,7 @@ referencing QValuePolicy in base.py, should have at least the listed methods.
TongzhengRen
seems to be direct python functions. Though the management of placeholders may require some discussion. also may write it in a functional form.
# policy, value_function
naming should be reconsidered. Perhaps use plural forms for all nouns

View File

@ -35,17 +35,16 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
# TODO: Different baseline methods like REINFORCE, etc.
return vanilla_policy_gradient_loss
def dqn_loss(sampled_action, sampled_target, q_net):
def dqn_loss(sampled_action, sampled_target, policy):
"""
deep q-network
:param sampled_action: placeholder of sampled actions during the interaction with the environment
:param sampled_target: estimated Q(s,a)
:param q_net: current `policy` to be optimized
:param policy: current `policy` to be optimized
:return:
"""
action_num = q_net.values_tensor().get_shape()[1]
sampled_q = tf.reduce_sum(q_net.values_tensor() * tf.one_hot(sampled_action, action_num), axis=1)
sampled_q = policy.q_net.value_tensor
return tf.reduce_mean(tf.square(sampled_target - sampled_q))
def deterministic_policy_gradient(sampled_state, critic):

View File

@ -3,19 +3,12 @@
from __future__ import absolute_import
from __future__ import division
import warnings
import tensorflow as tf
# from zhusuan.utils import add_name_scope
__all__ = [
'StochasticPolicy',
'QValuePolicy',
'PolicyBase'
]
# TODO: a even more "base" class for policy
@ -23,8 +16,8 @@ class PolicyBase(object):
"""
base class for policy. only provides `act` method with exploration
"""
def __init__(self):
pass
def __init__(self, observation_placeholder):
self._observation_placeholder = observation_placeholder
def act(self, observation, exploration):
raise NotImplementedError()
@ -60,7 +53,6 @@ class QValuePolicy(object):
pass
class StochasticPolicy(object):
"""
The :class:`Distribution` class is the base class for various probabilistic

View File

@ -10,16 +10,25 @@ class DQNRefactor(PolicyBase):
use DQN from value_function as a member
"""
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
self._network = DQN(value_tensor, observation_placeholder, action_placeholder)
self._q_net = DQN(value_tensor, observation_placeholder, action_placeholder)
self._argmax_action = tf.argmax(value_tensor, axis=1)
def act(self, observation, exploration):
super(DQNRefactor, self).__init__(observation_placeholder=observation_placeholder)
def act(self, observation, exploration=None):
sess = tf.get_default_session()
if not exploration: # no exploration
action = sess.run(self._argmax_action, feed_dict={})
action = sess.run(self._argmax_action, feed_dict={self._observation_placeholder: observation})
class DQN(QValuePolicy):
return action
@property
def q_net(self):
return self._q_net
class DQNOld(QValuePolicy):
"""
The policy as in DQN
"""

View File

@ -10,12 +10,6 @@ import tensorflow as tf
from .base import StochasticPolicy
__all__ = [
'OnehotCategorical',
'OnehotDiscrete',
]
class OnehotCategorical(StochasticPolicy):
"""
The class of one-hot Categorical distribution.

View File

@ -15,7 +15,7 @@ class ActionValue(ValueFunctionBase):
observation_placeholder=observation_placeholder
)
def get_value(self, observation, action):
def eval_value(self, observation, action):
"""
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
:param action: numpy array of actions, of shape (batchsize, action_dim)
@ -24,7 +24,7 @@ class ActionValue(ValueFunctionBase):
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a)
"""
sess = tf.get_default_session()
return sess.run(self.get_value_tensor(), feed_dict=
return sess.run(self.value_tensor, feed_dict=
{self._observation_placeholder: observation, self._action_placeholder: action})
@ -50,7 +50,7 @@ class DQN(ActionValue):
observation_placeholder=observation_placeholder,
action_placeholder=action_placeholder)
def get_value_all_actions(self, observation):
def eval_value_all_actions(self, observation):
"""
:param observation:
:return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions)
@ -58,5 +58,6 @@ class DQN(ActionValue):
sess = tf.get_default_session()
return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation})
def get_value_tensor_all_actions(self):
@property
def value_tensor_all_actions(self):
return self._value_tensor_all_actions

View File

@ -11,14 +11,15 @@ class ValueFunctionBase(object):
self._observation_placeholder = observation_placeholder
self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, )
def get_value(self, **kwargs):
def eval_value(self, **kwargs):
"""
:return: batch of corresponding values in numpy array
"""
raise NotImplementedError()
def get_value_tensor(self):
@property
def value_tensor(self):
"""
:return: tensor of the corresponding values

View File

@ -14,7 +14,7 @@ class StateValue(ValueFunctionBase):
observation_placeholder=observation_placeholder
)
def get_value(self, observation):
def eval_value(self, observation):
"""
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
@ -22,4 +22,4 @@ class StateValue(ValueFunctionBase):
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env
"""
sess = tf.get_default_session()
return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})
return sess.run(self.value_tensor, feed_dict={self._observation_placeholder: observation})