fix imports to support both python2 and python3. move contents from __init__.py to leave for work after major development.

This commit is contained in:
haoshengzou 2017-12-23 15:36:10 +08:00
parent fe54e4732d
commit 2addef41d2
7 changed files with 49 additions and 24 deletions

View File

@ -1,17 +1,16 @@
#!/usr/bin/env python
from __future__ import absolute_import
import tensorflow as tf
import numpy as np
import time
import gym
# our lib imports here!
import sys
sys.path.append('..')
import tianshou.core.losses as losses
from tianshou.core import losses
from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy as policy
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
def policy_net(observation, action_dim, scope=None):

View File

@ -1,6 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from .base import *
from .stochastic import *
from .dqn import *

View File

@ -13,11 +13,23 @@ import tensorflow as tf
__all__ = [
'StochasticPolicy',
'QValuePolicy',
'PolicyBase'
]
# TODO: a even more "base" class for policy
class PolicyBase(object):
"""
base class for policy. only provides `act` method with exploration
"""
def __init__(self):
pass
def act(self, observation, exploration):
raise NotImplementedError()
class QValuePolicy(object):
"""
The policy as in DQN

View File

@ -1,16 +1,22 @@
from tianshou.core.policy.base import QValuePolicy
from __future__ import absolute_import
from .base import PolicyBase
import tensorflow as tf
import sys
sys.path.append('..')
import value_function.action_value as value_func
from ..value_function.action_value import DQN
class DQN_refactor(object):
class DQNRefactor(PolicyBase):
"""
use DQN from value_function as a member
"""
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
self._network = value_func.DQN(value_tensor, observation_placeholder, action_placeholder)
self._network = DQN(value_tensor, observation_placeholder, action_placeholder)
self._argmax_action = tf.argmax(value_tensor, axis=1)
def act(self, observation, exploration):
sess = tf.get_default_session()
if not exploration: # no exploration
action = sess.run(self._argmax_action, feed_dict={})
class DQN(QValuePolicy):

View File

@ -1,4 +1,6 @@
from base import ValueFunctionBase
from __future__ import absolute_import
from .base import ValueFunctionBase
import tensorflow as tf
@ -15,7 +17,6 @@ class ActionValue(ValueFunctionBase):
def get_value(self, observation, action):
"""
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
:param action: numpy array of actions, of shape (batchsize, action_dim)
# TODO: Atari discrete action should have dim 1. Super Mario may should have, say, dim 5, where each can be 0/1
@ -24,7 +25,7 @@ class ActionValue(ValueFunctionBase):
"""
sess = tf.get_default_session()
return sess.run(self.get_value_tensor(), feed_dict=
{self._observation_placeholder: observation, self._action_placeholder:action})[:, 0]
{self._observation_placeholder: observation, self._action_placeholder: action})
class DQN(ActionValue):
@ -39,13 +40,21 @@ class DQN(ActionValue):
:param action_placeholder: of shape (batchsize, )
"""
self._value_tensor_all_actions = value_tensor
canonical_value_tensor = value_tensor[action_placeholder] # maybe a tf.map_fn. for now it's wrong
batch_size = tf.shape(value_tensor)[0]
batch_dim_index = tf.range(batch_size)
indices = tf.stack([batch_dim_index, action_placeholder], axis=1)
canonical_value_tensor = tf.gather_nd(value_tensor, indices)
super(DQN, self).__init__(value_tensor=canonical_value_tensor,
observation_placeholder=observation_placeholder,
action_placeholder=action_placeholder)
def get_value_all_actions(self, observation):
"""
:param observation:
:return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions)
"""
sess = tf.get_default_session()
return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation})

View File

@ -1,3 +1,6 @@
from __future__ import absolute_import
import tensorflow as tf
# TODO: linear feature baseline also in tf?
class ValueFunctionBase(object):
@ -6,7 +9,7 @@ class ValueFunctionBase(object):
"""
def __init__(self, value_tensor, observation_placeholder):
self._observation_placeholder = observation_placeholder
self._value_tensor = value_tensor
self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, )
def get_value(self, **kwargs):
"""

View File

@ -1,4 +1,6 @@
from base import ValueFunctionBase
from __future__ import absolute_import
from .base import ValueFunctionBase
import tensorflow as tf
@ -17,7 +19,7 @@ class StateValue(ValueFunctionBase):
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
:return: numpy array of state values, of shape (batchsize, )
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a)
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env
"""
sess = tf.get_default_session()
return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})[:, 0]
return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})