diff --git a/examples/ppo_example.py b/examples/ppo_example.py index 02ccb52..985c8f2 100755 --- a/examples/ppo_example.py +++ b/examples/ppo_example.py @@ -1,17 +1,16 @@ #!/usr/bin/env python +from __future__ import absolute_import import tensorflow as tf -import numpy as np -import time import gym # our lib imports here! import sys sys.path.append('..') -import tianshou.core.losses as losses +from tianshou.core import losses from tianshou.data.batch import Batch import tianshou.data.advantage_estimation as advantage_estimation -import tianshou.core.policy as policy +import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy def policy_net(observation, action_dim, scope=None): diff --git a/tianshou/core/policy/__init__.py b/tianshou/core/policy/__init__.py index ccde775..e69de29 100644 --- a/tianshou/core/policy/__init__.py +++ b/tianshou/core/policy/__init__.py @@ -1,6 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from .base import * -from .stochastic import * -from .dqn import * \ No newline at end of file diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index 025abd5..1adeaeb 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -13,11 +13,23 @@ import tensorflow as tf __all__ = [ 'StochasticPolicy', 'QValuePolicy', + 'PolicyBase' ] # TODO: a even more "base" class for policy +class PolicyBase(object): + """ + base class for policy. only provides `act` method with exploration + """ + def __init__(self): + pass + + def act(self, observation, exploration): + raise NotImplementedError() + + class QValuePolicy(object): """ The policy as in DQN diff --git a/tianshou/core/policy/dqn.py b/tianshou/core/policy/dqn.py index d03dbd4..716e4c4 100644 --- a/tianshou/core/policy/dqn.py +++ b/tianshou/core/policy/dqn.py @@ -1,16 +1,22 @@ -from tianshou.core.policy.base import QValuePolicy +from __future__ import absolute_import + +from .base import PolicyBase import tensorflow as tf -import sys -sys.path.append('..') -import value_function.action_value as value_func +from ..value_function.action_value import DQN -class DQN_refactor(object): +class DQNRefactor(PolicyBase): """ use DQN from value_function as a member """ def __init__(self, value_tensor, observation_placeholder, action_placeholder): - self._network = value_func.DQN(value_tensor, observation_placeholder, action_placeholder) + self._network = DQN(value_tensor, observation_placeholder, action_placeholder) + self._argmax_action = tf.argmax(value_tensor, axis=1) + + def act(self, observation, exploration): + sess = tf.get_default_session() + if not exploration: # no exploration + action = sess.run(self._argmax_action, feed_dict={}) class DQN(QValuePolicy): diff --git a/tianshou/core/value_function/action_value.py b/tianshou/core/value_function/action_value.py index cb8acc8..2bda4fa 100644 --- a/tianshou/core/value_function/action_value.py +++ b/tianshou/core/value_function/action_value.py @@ -1,4 +1,6 @@ -from base import ValueFunctionBase +from __future__ import absolute_import + +from .base import ValueFunctionBase import tensorflow as tf @@ -15,7 +17,6 @@ class ActionValue(ValueFunctionBase): def get_value(self, observation, action): """ - :param observation: numpy array of observations, of shape (batchsize, observation_dim). :param action: numpy array of actions, of shape (batchsize, action_dim) # TODO: Atari discrete action should have dim 1. Super Mario may should have, say, dim 5, where each can be 0/1 @@ -24,7 +25,7 @@ class ActionValue(ValueFunctionBase): """ sess = tf.get_default_session() return sess.run(self.get_value_tensor(), feed_dict= - {self._observation_placeholder: observation, self._action_placeholder:action})[:, 0] + {self._observation_placeholder: observation, self._action_placeholder: action}) class DQN(ActionValue): @@ -39,13 +40,21 @@ class DQN(ActionValue): :param action_placeholder: of shape (batchsize, ) """ self._value_tensor_all_actions = value_tensor - canonical_value_tensor = value_tensor[action_placeholder] # maybe a tf.map_fn. for now it's wrong + + batch_size = tf.shape(value_tensor)[0] + batch_dim_index = tf.range(batch_size) + indices = tf.stack([batch_dim_index, action_placeholder], axis=1) + canonical_value_tensor = tf.gather_nd(value_tensor, indices) super(DQN, self).__init__(value_tensor=canonical_value_tensor, observation_placeholder=observation_placeholder, action_placeholder=action_placeholder) def get_value_all_actions(self, observation): + """ + :param observation: + :return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions) + """ sess = tf.get_default_session() return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation}) diff --git a/tianshou/core/value_function/base.py b/tianshou/core/value_function/base.py index 0b27759..b15f1bf 100644 --- a/tianshou/core/value_function/base.py +++ b/tianshou/core/value_function/base.py @@ -1,3 +1,6 @@ +from __future__ import absolute_import + +import tensorflow as tf # TODO: linear feature baseline also in tf? class ValueFunctionBase(object): @@ -6,7 +9,7 @@ class ValueFunctionBase(object): """ def __init__(self, value_tensor, observation_placeholder): self._observation_placeholder = observation_placeholder - self._value_tensor = value_tensor + self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, ) def get_value(self, **kwargs): """ diff --git a/tianshou/core/value_function/state_value.py b/tianshou/core/value_function/state_value.py index 04fe442..b7de196 100644 --- a/tianshou/core/value_function/state_value.py +++ b/tianshou/core/value_function/state_value.py @@ -1,4 +1,6 @@ -from base import ValueFunctionBase +from __future__ import absolute_import + +from .base import ValueFunctionBase import tensorflow as tf @@ -17,7 +19,7 @@ class StateValue(ValueFunctionBase): :param observation: numpy array of observations, of shape (batchsize, observation_dim). :return: numpy array of state values, of shape (batchsize, ) - # TODO: dealing with the last dim of 1 in V(s) and Q(s, a) + # TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env """ sess = tf.get_default_session() - return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})[:, 0] \ No newline at end of file + return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation}) \ No newline at end of file