diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index eecfc4f..025abd5 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -15,7 +15,7 @@ __all__ = [ 'QValuePolicy', ] -# TODO: separate actor and critic, we should focus on it once we finish the basic module. +# TODO: a even more "base" class for policy class QValuePolicy(object): diff --git a/tianshou/core/policy/dqn.py b/tianshou/core/policy/dqn.py index 39f6a16..d03dbd4 100644 --- a/tianshou/core/policy/dqn.py +++ b/tianshou/core/policy/dqn.py @@ -1,5 +1,16 @@ from tianshou.core.policy.base import QValuePolicy import tensorflow as tf +import sys +sys.path.append('..') +import value_function.action_value as value_func + + +class DQN_refactor(object): + """ + use DQN from value_function as a member + """ + def __init__(self, value_tensor, observation_placeholder, action_placeholder): + self._network = value_func.DQN(value_tensor, observation_placeholder, action_placeholder) class DQN(QValuePolicy): diff --git a/tianshou/core/value_function/__init__.py b/tianshou/core/value_function/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tianshou/core/value_function/action_value.py b/tianshou/core/value_function/action_value.py new file mode 100644 index 0000000..cb8acc8 --- /dev/null +++ b/tianshou/core/value_function/action_value.py @@ -0,0 +1,53 @@ +from base import ValueFunctionBase +import tensorflow as tf + + +class ActionValue(ValueFunctionBase): + """ + class of action values Q(s, a). + """ + def __init__(self, value_tensor, observation_placeholder, action_placeholder): + self._action_placeholder = action_placeholder + super(ActionValue, self).__init__( + value_tensor=value_tensor, + observation_placeholder=observation_placeholder + ) + + def get_value(self, observation, action): + """ + + :param observation: numpy array of observations, of shape (batchsize, observation_dim). + :param action: numpy array of actions, of shape (batchsize, action_dim) + # TODO: Atari discrete action should have dim 1. Super Mario may should have, say, dim 5, where each can be 0/1 + :return: numpy array of state values, of shape (batchsize, ) + # TODO: dealing with the last dim of 1 in V(s) and Q(s, a) + """ + sess = tf.get_default_session() + return sess.run(self.get_value_tensor(), feed_dict= + {self._observation_placeholder: observation, self._action_placeholder:action})[:, 0] + + +class DQN(ActionValue): + """ + class of the very DQN architecture. Instead of feeding s and a to the network to get a value, DQN feed s to the + network and the last layer is Q(s, *) for all actions + """ + def __init__(self, value_tensor, observation_placeholder, action_placeholder): + """ + :param value_tensor: of shape (batchsize, num_actions) + :param observation_placeholder: of shape (batchsize, observation_dim) + :param action_placeholder: of shape (batchsize, ) + """ + self._value_tensor_all_actions = value_tensor + canonical_value_tensor = value_tensor[action_placeholder] # maybe a tf.map_fn. for now it's wrong + + super(DQN, self).__init__(value_tensor=canonical_value_tensor, + observation_placeholder=observation_placeholder, + action_placeholder=action_placeholder) + + def get_value_all_actions(self, observation): + sess = tf.get_default_session() + return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation}) + + def get_value_tensor_all_actions(self): + return self._value_tensor_all_actions \ No newline at end of file diff --git a/tianshou/core/value_function/base.py b/tianshou/core/value_function/base.py new file mode 100644 index 0000000..0b27759 --- /dev/null +++ b/tianshou/core/value_function/base.py @@ -0,0 +1,23 @@ + +# TODO: linear feature baseline also in tf? +class ValueFunctionBase(object): + """ + base class of value functions. Children include state values V(s) and action values Q(s, a) + """ + def __init__(self, value_tensor, observation_placeholder): + self._observation_placeholder = observation_placeholder + self._value_tensor = value_tensor + + def get_value(self, **kwargs): + """ + + :return: batch of corresponding values in numpy array + """ + raise NotImplementedError() + + def get_value_tensor(self): + """ + + :return: tensor of the corresponding values + """ + return self._value_tensor diff --git a/tianshou/core/value_function/state_value.py b/tianshou/core/value_function/state_value.py new file mode 100644 index 0000000..04fe442 --- /dev/null +++ b/tianshou/core/value_function/state_value.py @@ -0,0 +1,23 @@ +from base import ValueFunctionBase +import tensorflow as tf + + +class StateValue(ValueFunctionBase): + """ + class of state values V(s). + """ + def __init__(self, value_tensor, observation_placeholder): + super(StateValue, self).__init__( + value_tensor=value_tensor, + observation_placeholder=observation_placeholder + ) + + def get_value(self, observation): + """ + + :param observation: numpy array of observations, of shape (batchsize, observation_dim). + :return: numpy array of state values, of shape (batchsize, ) + # TODO: dealing with the last dim of 1 in V(s) and Q(s, a) + """ + sess = tf.get_default_session() + return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})[:, 0] \ No newline at end of file