towards policy/value refactor
This commit is contained in:
parent
8c13d8ebe6
commit
b21a55dc88
@ -1,8 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import time
|
||||
import gym
|
||||
|
||||
# our lib imports here!
|
||||
@ -10,7 +8,7 @@ import sys
|
||||
sys.path.append('..')
|
||||
import tianshou.core.losses as losses
|
||||
from tianshou.data.replay_buffer.utils import get_replay_buffer
|
||||
import tianshou.core.policy as policy
|
||||
import tianshou.core.policy.dqn as policy
|
||||
|
||||
|
||||
def policy_net(observation, action_dim):
|
||||
@ -41,6 +39,8 @@ if __name__ == '__main__':
|
||||
# pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer
|
||||
# access this observation variable.
|
||||
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
|
||||
action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions
|
||||
|
||||
|
||||
with tf.variable_scope('q_net'):
|
||||
q_values = policy_net(observation, action_dim)
|
||||
@ -48,10 +48,9 @@ if __name__ == '__main__':
|
||||
q_values_target = policy_net(observation, action_dim)
|
||||
|
||||
# 2. build losses, optimizers
|
||||
q_net = policy.DQN(q_values, observation_placeholder=observation) # YongRen: policy.DQN
|
||||
target_net = policy.DQN(q_values_target, observation_placeholder=observation)
|
||||
q_net = policy.DQNRefactor(q_values, observation_placeholder=observation, action_placeholder=action) # YongRen: policy.DQN
|
||||
target_net = policy.DQNRefactor(q_values_target, observation_placeholder=observation, action_placeholder=action)
|
||||
|
||||
action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions
|
||||
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN
|
||||
|
||||
dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen
|
||||
|
@ -22,3 +22,7 @@ referencing QValuePolicy in base.py, should have at least the listed methods.
|
||||
TongzhengRen
|
||||
|
||||
seems to be direct python functions. Though the management of placeholders may require some discussion. also may write it in a functional form.
|
||||
|
||||
# policy, value_function
|
||||
|
||||
naming should be reconsidered. Perhaps use plural forms for all nouns
|
@ -35,17 +35,16 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
|
||||
# TODO: Different baseline methods like REINFORCE, etc.
|
||||
return vanilla_policy_gradient_loss
|
||||
|
||||
def dqn_loss(sampled_action, sampled_target, q_net):
|
||||
def dqn_loss(sampled_action, sampled_target, policy):
|
||||
"""
|
||||
deep q-network
|
||||
|
||||
:param sampled_action: placeholder of sampled actions during the interaction with the environment
|
||||
:param sampled_target: estimated Q(s,a)
|
||||
:param q_net: current `policy` to be optimized
|
||||
:param policy: current `policy` to be optimized
|
||||
:return:
|
||||
"""
|
||||
action_num = q_net.values_tensor().get_shape()[1]
|
||||
sampled_q = tf.reduce_sum(q_net.values_tensor() * tf.one_hot(sampled_action, action_num), axis=1)
|
||||
sampled_q = policy.q_net.value_tensor
|
||||
return tf.reduce_mean(tf.square(sampled_target - sampled_q))
|
||||
|
||||
def deterministic_policy_gradient(sampled_state, critic):
|
||||
|
@ -3,19 +3,12 @@
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
import warnings
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# from zhusuan.utils import add_name_scope
|
||||
|
||||
|
||||
__all__ = [
|
||||
'StochasticPolicy',
|
||||
'QValuePolicy',
|
||||
'PolicyBase'
|
||||
]
|
||||
|
||||
# TODO: a even more "base" class for policy
|
||||
|
||||
|
||||
@ -23,8 +16,8 @@ class PolicyBase(object):
|
||||
"""
|
||||
base class for policy. only provides `act` method with exploration
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, observation_placeholder):
|
||||
self._observation_placeholder = observation_placeholder
|
||||
|
||||
def act(self, observation, exploration):
|
||||
raise NotImplementedError()
|
||||
@ -37,14 +30,14 @@ class QValuePolicy(object):
|
||||
def __init__(self, observation_placeholder):
|
||||
self._observation_placeholder = observation_placeholder
|
||||
|
||||
def act(self, observation, exploration=None): # first implement no exploration
|
||||
def act(self, observation, exploration=None): # first implement no exploration
|
||||
"""
|
||||
return the action (int) to be executed.
|
||||
no exploration when exploration=None.
|
||||
"""
|
||||
self._act(observation, exploration)
|
||||
|
||||
def _act(self, observation, exploration = None):
|
||||
def _act(self, observation, exploration=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def values(self, observation):
|
||||
@ -60,7 +53,6 @@ class QValuePolicy(object):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class StochasticPolicy(object):
|
||||
"""
|
||||
The :class:`Distribution` class is the base class for various probabilistic
|
||||
@ -130,7 +122,7 @@ class StochasticPolicy(object):
|
||||
param_dtype,
|
||||
is_continuous,
|
||||
observation_placeholder,
|
||||
group_ndims=0, # maybe useful for repeat_action
|
||||
group_ndims=0, # maybe useful for repeat_action
|
||||
**kwargs):
|
||||
|
||||
self._act_dtype = act_dtype
|
||||
|
@ -10,16 +10,25 @@ class DQNRefactor(PolicyBase):
|
||||
use DQN from value_function as a member
|
||||
"""
|
||||
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
|
||||
self._network = DQN(value_tensor, observation_placeholder, action_placeholder)
|
||||
self._q_net = DQN(value_tensor, observation_placeholder, action_placeholder)
|
||||
self._argmax_action = tf.argmax(value_tensor, axis=1)
|
||||
|
||||
def act(self, observation, exploration):
|
||||
super(DQNRefactor, self).__init__(observation_placeholder=observation_placeholder)
|
||||
|
||||
def act(self, observation, exploration=None):
|
||||
sess = tf.get_default_session()
|
||||
if not exploration: # no exploration
|
||||
action = sess.run(self._argmax_action, feed_dict={})
|
||||
action = sess.run(self._argmax_action, feed_dict={self._observation_placeholder: observation})
|
||||
|
||||
|
||||
class DQN(QValuePolicy):
|
||||
return action
|
||||
|
||||
@property
|
||||
def q_net(self):
|
||||
return self._q_net
|
||||
|
||||
|
||||
class DQNOld(QValuePolicy):
|
||||
"""
|
||||
The policy as in DQN
|
||||
"""
|
||||
|
@ -10,12 +10,6 @@ import tensorflow as tf
|
||||
from .base import StochasticPolicy
|
||||
|
||||
|
||||
__all__ = [
|
||||
'OnehotCategorical',
|
||||
'OnehotDiscrete',
|
||||
]
|
||||
|
||||
|
||||
class OnehotCategorical(StochasticPolicy):
|
||||
"""
|
||||
The class of one-hot Categorical distribution.
|
||||
|
@ -15,7 +15,7 @@ class ActionValue(ValueFunctionBase):
|
||||
observation_placeholder=observation_placeholder
|
||||
)
|
||||
|
||||
def get_value(self, observation, action):
|
||||
def eval_value(self, observation, action):
|
||||
"""
|
||||
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
|
||||
:param action: numpy array of actions, of shape (batchsize, action_dim)
|
||||
@ -24,7 +24,7 @@ class ActionValue(ValueFunctionBase):
|
||||
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a)
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
return sess.run(self.get_value_tensor(), feed_dict=
|
||||
return sess.run(self.value_tensor, feed_dict=
|
||||
{self._observation_placeholder: observation, self._action_placeholder: action})
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ class DQN(ActionValue):
|
||||
observation_placeholder=observation_placeholder,
|
||||
action_placeholder=action_placeholder)
|
||||
|
||||
def get_value_all_actions(self, observation):
|
||||
def eval_value_all_actions(self, observation):
|
||||
"""
|
||||
:param observation:
|
||||
:return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions)
|
||||
@ -58,5 +58,6 @@ class DQN(ActionValue):
|
||||
sess = tf.get_default_session()
|
||||
return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation})
|
||||
|
||||
def get_value_tensor_all_actions(self):
|
||||
@property
|
||||
def value_tensor_all_actions(self):
|
||||
return self._value_tensor_all_actions
|
@ -11,14 +11,15 @@ class ValueFunctionBase(object):
|
||||
self._observation_placeholder = observation_placeholder
|
||||
self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, )
|
||||
|
||||
def get_value(self, **kwargs):
|
||||
def eval_value(self, **kwargs):
|
||||
"""
|
||||
|
||||
:return: batch of corresponding values in numpy array
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_value_tensor(self):
|
||||
@property
|
||||
def value_tensor(self):
|
||||
"""
|
||||
|
||||
:return: tensor of the corresponding values
|
||||
|
@ -14,7 +14,7 @@ class StateValue(ValueFunctionBase):
|
||||
observation_placeholder=observation_placeholder
|
||||
)
|
||||
|
||||
def get_value(self, observation):
|
||||
def eval_value(self, observation):
|
||||
"""
|
||||
|
||||
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
|
||||
@ -22,4 +22,4 @@ class StateValue(ValueFunctionBase):
|
||||
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})
|
||||
return sess.run(self.value_tensor, feed_dict={self._observation_placeholder: observation})
|
Loading…
x
Reference in New Issue
Block a user