fix imports to support both python2 and python3. move contents from __init__.py to leave for work after major development.
This commit is contained in:
parent
fe54e4732d
commit
2addef41d2
@ -1,17 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import time
|
||||
import gym
|
||||
|
||||
# our lib imports here!
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
import tianshou.core.losses as losses
|
||||
from tianshou.core import losses
|
||||
from tianshou.data.batch import Batch
|
||||
import tianshou.data.advantage_estimation as advantage_estimation
|
||||
import tianshou.core.policy as policy
|
||||
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
||||
|
||||
|
||||
def policy_net(observation, action_dim, scope=None):
|
||||
|
@ -1,6 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .base import *
|
||||
from .stochastic import *
|
||||
from .dqn import *
|
@ -13,11 +13,23 @@ import tensorflow as tf
|
||||
__all__ = [
|
||||
'StochasticPolicy',
|
||||
'QValuePolicy',
|
||||
'PolicyBase'
|
||||
]
|
||||
|
||||
# TODO: a even more "base" class for policy
|
||||
|
||||
|
||||
class PolicyBase(object):
|
||||
"""
|
||||
base class for policy. only provides `act` method with exploration
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def act(self, observation, exploration):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class QValuePolicy(object):
|
||||
"""
|
||||
The policy as in DQN
|
||||
|
@ -1,16 +1,22 @@
|
||||
from tianshou.core.policy.base import QValuePolicy
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .base import PolicyBase
|
||||
import tensorflow as tf
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
import value_function.action_value as value_func
|
||||
from ..value_function.action_value import DQN
|
||||
|
||||
|
||||
class DQN_refactor(object):
|
||||
class DQNRefactor(PolicyBase):
|
||||
"""
|
||||
use DQN from value_function as a member
|
||||
"""
|
||||
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
|
||||
self._network = value_func.DQN(value_tensor, observation_placeholder, action_placeholder)
|
||||
self._network = DQN(value_tensor, observation_placeholder, action_placeholder)
|
||||
self._argmax_action = tf.argmax(value_tensor, axis=1)
|
||||
|
||||
def act(self, observation, exploration):
|
||||
sess = tf.get_default_session()
|
||||
if not exploration: # no exploration
|
||||
action = sess.run(self._argmax_action, feed_dict={})
|
||||
|
||||
|
||||
class DQN(QValuePolicy):
|
||||
|
@ -1,4 +1,6 @@
|
||||
from base import ValueFunctionBase
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .base import ValueFunctionBase
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@ -15,7 +17,6 @@ class ActionValue(ValueFunctionBase):
|
||||
|
||||
def get_value(self, observation, action):
|
||||
"""
|
||||
|
||||
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
|
||||
:param action: numpy array of actions, of shape (batchsize, action_dim)
|
||||
# TODO: Atari discrete action should have dim 1. Super Mario may should have, say, dim 5, where each can be 0/1
|
||||
@ -24,7 +25,7 @@ class ActionValue(ValueFunctionBase):
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
return sess.run(self.get_value_tensor(), feed_dict=
|
||||
{self._observation_placeholder: observation, self._action_placeholder:action})[:, 0]
|
||||
{self._observation_placeholder: observation, self._action_placeholder: action})
|
||||
|
||||
|
||||
class DQN(ActionValue):
|
||||
@ -39,13 +40,21 @@ class DQN(ActionValue):
|
||||
:param action_placeholder: of shape (batchsize, )
|
||||
"""
|
||||
self._value_tensor_all_actions = value_tensor
|
||||
canonical_value_tensor = value_tensor[action_placeholder] # maybe a tf.map_fn. for now it's wrong
|
||||
|
||||
batch_size = tf.shape(value_tensor)[0]
|
||||
batch_dim_index = tf.range(batch_size)
|
||||
indices = tf.stack([batch_dim_index, action_placeholder], axis=1)
|
||||
canonical_value_tensor = tf.gather_nd(value_tensor, indices)
|
||||
|
||||
super(DQN, self).__init__(value_tensor=canonical_value_tensor,
|
||||
observation_placeholder=observation_placeholder,
|
||||
action_placeholder=action_placeholder)
|
||||
|
||||
def get_value_all_actions(self, observation):
|
||||
"""
|
||||
:param observation:
|
||||
:return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions)
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation})
|
||||
|
||||
|
@ -1,3 +1,6 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# TODO: linear feature baseline also in tf?
|
||||
class ValueFunctionBase(object):
|
||||
@ -6,7 +9,7 @@ class ValueFunctionBase(object):
|
||||
"""
|
||||
def __init__(self, value_tensor, observation_placeholder):
|
||||
self._observation_placeholder = observation_placeholder
|
||||
self._value_tensor = value_tensor
|
||||
self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, )
|
||||
|
||||
def get_value(self, **kwargs):
|
||||
"""
|
||||
|
@ -1,4 +1,6 @@
|
||||
from base import ValueFunctionBase
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .base import ValueFunctionBase
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@ -17,7 +19,7 @@ class StateValue(ValueFunctionBase):
|
||||
|
||||
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
|
||||
:return: numpy array of state values, of shape (batchsize, )
|
||||
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a)
|
||||
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})[:, 0]
|
||||
return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})
|
Loading…
x
Reference in New Issue
Block a user