add value_function (critic). value_function and policy not finished yet.
This commit is contained in:
parent
fae273f219
commit
6611d948dd
@ -15,7 +15,7 @@ __all__ = [
|
|||||||
'QValuePolicy',
|
'QValuePolicy',
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: separate actor and critic, we should focus on it once we finish the basic module.
|
# TODO: a even more "base" class for policy
|
||||||
|
|
||||||
|
|
||||||
class QValuePolicy(object):
|
class QValuePolicy(object):
|
||||||
|
@ -1,5 +1,16 @@
|
|||||||
from tianshou.core.policy.base import QValuePolicy
|
from tianshou.core.policy.base import QValuePolicy
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
import sys
|
||||||
|
sys.path.append('..')
|
||||||
|
import value_function.action_value as value_func
|
||||||
|
|
||||||
|
|
||||||
|
class DQN_refactor(object):
|
||||||
|
"""
|
||||||
|
use DQN from value_function as a member
|
||||||
|
"""
|
||||||
|
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
|
||||||
|
self._network = value_func.DQN(value_tensor, observation_placeholder, action_placeholder)
|
||||||
|
|
||||||
|
|
||||||
class DQN(QValuePolicy):
|
class DQN(QValuePolicy):
|
||||||
|
0
tianshou/core/value_function/__init__.py
Normal file
0
tianshou/core/value_function/__init__.py
Normal file
53
tianshou/core/value_function/action_value.py
Normal file
53
tianshou/core/value_function/action_value.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from base import ValueFunctionBase
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class ActionValue(ValueFunctionBase):
|
||||||
|
"""
|
||||||
|
class of action values Q(s, a).
|
||||||
|
"""
|
||||||
|
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
|
||||||
|
self._action_placeholder = action_placeholder
|
||||||
|
super(ActionValue, self).__init__(
|
||||||
|
value_tensor=value_tensor,
|
||||||
|
observation_placeholder=observation_placeholder
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_value(self, observation, action):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
|
||||||
|
:param action: numpy array of actions, of shape (batchsize, action_dim)
|
||||||
|
# TODO: Atari discrete action should have dim 1. Super Mario may should have, say, dim 5, where each can be 0/1
|
||||||
|
:return: numpy array of state values, of shape (batchsize, )
|
||||||
|
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a)
|
||||||
|
"""
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
return sess.run(self.get_value_tensor(), feed_dict=
|
||||||
|
{self._observation_placeholder: observation, self._action_placeholder:action})[:, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class DQN(ActionValue):
|
||||||
|
"""
|
||||||
|
class of the very DQN architecture. Instead of feeding s and a to the network to get a value, DQN feed s to the
|
||||||
|
network and the last layer is Q(s, *) for all actions
|
||||||
|
"""
|
||||||
|
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
|
||||||
|
"""
|
||||||
|
:param value_tensor: of shape (batchsize, num_actions)
|
||||||
|
:param observation_placeholder: of shape (batchsize, observation_dim)
|
||||||
|
:param action_placeholder: of shape (batchsize, )
|
||||||
|
"""
|
||||||
|
self._value_tensor_all_actions = value_tensor
|
||||||
|
canonical_value_tensor = value_tensor[action_placeholder] # maybe a tf.map_fn. for now it's wrong
|
||||||
|
|
||||||
|
super(DQN, self).__init__(value_tensor=canonical_value_tensor,
|
||||||
|
observation_placeholder=observation_placeholder,
|
||||||
|
action_placeholder=action_placeholder)
|
||||||
|
|
||||||
|
def get_value_all_actions(self, observation):
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation})
|
||||||
|
|
||||||
|
def get_value_tensor_all_actions(self):
|
||||||
|
return self._value_tensor_all_actions
|
23
tianshou/core/value_function/base.py
Normal file
23
tianshou/core/value_function/base.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
|
||||||
|
# TODO: linear feature baseline also in tf?
|
||||||
|
class ValueFunctionBase(object):
|
||||||
|
"""
|
||||||
|
base class of value functions. Children include state values V(s) and action values Q(s, a)
|
||||||
|
"""
|
||||||
|
def __init__(self, value_tensor, observation_placeholder):
|
||||||
|
self._observation_placeholder = observation_placeholder
|
||||||
|
self._value_tensor = value_tensor
|
||||||
|
|
||||||
|
def get_value(self, **kwargs):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:return: batch of corresponding values in numpy array
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_value_tensor(self):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:return: tensor of the corresponding values
|
||||||
|
"""
|
||||||
|
return self._value_tensor
|
23
tianshou/core/value_function/state_value.py
Normal file
23
tianshou/core/value_function/state_value.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from base import ValueFunctionBase
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class StateValue(ValueFunctionBase):
|
||||||
|
"""
|
||||||
|
class of state values V(s).
|
||||||
|
"""
|
||||||
|
def __init__(self, value_tensor, observation_placeholder):
|
||||||
|
super(StateValue, self).__init__(
|
||||||
|
value_tensor=value_tensor,
|
||||||
|
observation_placeholder=observation_placeholder
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_value(self, observation):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
|
||||||
|
:return: numpy array of state values, of shape (batchsize, )
|
||||||
|
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a)
|
||||||
|
"""
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})[:, 0]
|
Loading…
x
Reference in New Issue
Block a user