diff --git a/.gitignore b/.gitignore
index d697b92..8ee6691 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,8 +4,8 @@ leela-zero
parameters
*.swp
*.sublime*
-checkpoints
-checkpoints_origin
+checkpoint
*.json
.DS_Store
data
+.log
diff --git a/AlphaGo/game.py b/AlphaGo/game.py
index ff1faf5..90d0bf0 100644
--- a/AlphaGo/game.py
+++ b/AlphaGo/game.py
@@ -27,29 +27,30 @@ class Game:
'''
def __init__(self, name="go", checkpoint_path=None):
self.name = name
- if "go" == name:
+ if self.name == "go":
self.size = 9
self.komi = 3.75
self.board = [utils.EMPTY] * (self.size ** 2)
self.history = []
+ self.history_length = 8
self.latest_boards = deque(maxlen=8)
for _ in range(8):
self.latest_boards.append(self.board)
-
- self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8)
self.game_engine = go.Go(size=self.size, komi=self.komi)
- elif "reversi" == name:
+ elif self.name == "reversi":
self.size = 8
- self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=1)
+ self.history_length = 1
self.game_engine = reversi.Reversi()
self.board = self.game_engine.get_board()
else:
- print(name + " is an unknown game...")
+ raise ValueError(name + " is an unknown game...")
+
+ self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length)
def clear(self):
self.board = [utils.EMPTY] * (self.size ** 2)
self.history = []
- for _ in range(8):
+ for _ in range(self.history_length):
self.latest_boards.append(self.board)
def set_size(self, n):
@@ -76,9 +77,9 @@ class Game:
if vertex == utils.PASS:
return True
# TODO this implementation is not very elegant
- if "go" == self.name:
+ if self.name == "go":
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
- elif "revsersi" == self.name:
+ elif self.name == "reversi":
res = self.game_engine.executor_do_move(self.board, color, vertex)
return res
diff --git a/AlphaGo/player.py b/AlphaGo/player.py
index 0e3daff..e848d2b 100644
--- a/AlphaGo/player.py
+++ b/AlphaGo/player.py
@@ -34,7 +34,7 @@ if __name__ == '__main__':
daemon = Pyro4.Daemon() # make a Pyro daemon
ns = Pyro4.locateNS() # find the name server
- player = Player(role = args.role, engine = engine)
+ player = Player(role=args.role, engine=engine)
print "Init " + args.role + " player finished"
uri = daemon.register(player) # register the greeting maker as a Pyro object
print "Start on name " + args.role
diff --git a/README.md b/README.md
index 9c3af16..fc7d494 100644
--- a/README.md
+++ b/README.md
@@ -41,6 +41,11 @@ Tianshou(天授) is a reinforcement learning platform. The following image illus
+## examples
+
+During development, run examples under `./examples/` directory with, e.g. `python ppo_example.py`.
+Running them under this directory with `python examples/ppo_example.py` will not work.
+
## About coding style
diff --git a/examples/ppo_example.py b/examples/ppo_example.py
index 02ccb52..985c8f2 100755
--- a/examples/ppo_example.py
+++ b/examples/ppo_example.py
@@ -1,17 +1,16 @@
#!/usr/bin/env python
+from __future__ import absolute_import
import tensorflow as tf
-import numpy as np
-import time
import gym
# our lib imports here!
import sys
sys.path.append('..')
-import tianshou.core.losses as losses
+from tianshou.core import losses
from tianshou.data.batch import Batch
import tianshou.data.advantage_estimation as advantage_estimation
-import tianshou.core.policy as policy
+import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
def policy_net(observation, action_dim, scope=None):
diff --git a/tianshou/core/policy/__init__.py b/tianshou/core/policy/__init__.py
index ccde775..e69de29 100644
--- a/tianshou/core/policy/__init__.py
+++ b/tianshou/core/policy/__init__.py
@@ -1,6 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from .base import *
-from .stochastic import *
-from .dqn import *
\ No newline at end of file
diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py
index 025abd5..1adeaeb 100644
--- a/tianshou/core/policy/base.py
+++ b/tianshou/core/policy/base.py
@@ -13,11 +13,23 @@ import tensorflow as tf
__all__ = [
'StochasticPolicy',
'QValuePolicy',
+ 'PolicyBase'
]
# TODO: a even more "base" class for policy
+class PolicyBase(object):
+ """
+ base class for policy. only provides `act` method with exploration
+ """
+ def __init__(self):
+ pass
+
+ def act(self, observation, exploration):
+ raise NotImplementedError()
+
+
class QValuePolicy(object):
"""
The policy as in DQN
diff --git a/tianshou/core/policy/dqn.py b/tianshou/core/policy/dqn.py
index d03dbd4..716e4c4 100644
--- a/tianshou/core/policy/dqn.py
+++ b/tianshou/core/policy/dqn.py
@@ -1,16 +1,22 @@
-from tianshou.core.policy.base import QValuePolicy
+from __future__ import absolute_import
+
+from .base import PolicyBase
import tensorflow as tf
-import sys
-sys.path.append('..')
-import value_function.action_value as value_func
+from ..value_function.action_value import DQN
-class DQN_refactor(object):
+class DQNRefactor(PolicyBase):
"""
use DQN from value_function as a member
"""
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
- self._network = value_func.DQN(value_tensor, observation_placeholder, action_placeholder)
+ self._network = DQN(value_tensor, observation_placeholder, action_placeholder)
+ self._argmax_action = tf.argmax(value_tensor, axis=1)
+
+ def act(self, observation, exploration):
+ sess = tf.get_default_session()
+ if not exploration: # no exploration
+ action = sess.run(self._argmax_action, feed_dict={})
class DQN(QValuePolicy):
diff --git a/tianshou/core/value_function/action_value.py b/tianshou/core/value_function/action_value.py
index cb8acc8..2bda4fa 100644
--- a/tianshou/core/value_function/action_value.py
+++ b/tianshou/core/value_function/action_value.py
@@ -1,4 +1,6 @@
-from base import ValueFunctionBase
+from __future__ import absolute_import
+
+from .base import ValueFunctionBase
import tensorflow as tf
@@ -15,7 +17,6 @@ class ActionValue(ValueFunctionBase):
def get_value(self, observation, action):
"""
-
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
:param action: numpy array of actions, of shape (batchsize, action_dim)
# TODO: Atari discrete action should have dim 1. Super Mario may should have, say, dim 5, where each can be 0/1
@@ -24,7 +25,7 @@ class ActionValue(ValueFunctionBase):
"""
sess = tf.get_default_session()
return sess.run(self.get_value_tensor(), feed_dict=
- {self._observation_placeholder: observation, self._action_placeholder:action})[:, 0]
+ {self._observation_placeholder: observation, self._action_placeholder: action})
class DQN(ActionValue):
@@ -39,13 +40,21 @@ class DQN(ActionValue):
:param action_placeholder: of shape (batchsize, )
"""
self._value_tensor_all_actions = value_tensor
- canonical_value_tensor = value_tensor[action_placeholder] # maybe a tf.map_fn. for now it's wrong
+
+ batch_size = tf.shape(value_tensor)[0]
+ batch_dim_index = tf.range(batch_size)
+ indices = tf.stack([batch_dim_index, action_placeholder], axis=1)
+ canonical_value_tensor = tf.gather_nd(value_tensor, indices)
super(DQN, self).__init__(value_tensor=canonical_value_tensor,
observation_placeholder=observation_placeholder,
action_placeholder=action_placeholder)
def get_value_all_actions(self, observation):
+ """
+ :param observation:
+ :return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions)
+ """
sess = tf.get_default_session()
return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation})
diff --git a/tianshou/core/value_function/base.py b/tianshou/core/value_function/base.py
index 0b27759..b15f1bf 100644
--- a/tianshou/core/value_function/base.py
+++ b/tianshou/core/value_function/base.py
@@ -1,3 +1,6 @@
+from __future__ import absolute_import
+
+import tensorflow as tf
# TODO: linear feature baseline also in tf?
class ValueFunctionBase(object):
@@ -6,7 +9,7 @@ class ValueFunctionBase(object):
"""
def __init__(self, value_tensor, observation_placeholder):
self._observation_placeholder = observation_placeholder
- self._value_tensor = value_tensor
+ self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, )
def get_value(self, **kwargs):
"""
diff --git a/tianshou/core/value_function/state_value.py b/tianshou/core/value_function/state_value.py
index 04fe442..b7de196 100644
--- a/tianshou/core/value_function/state_value.py
+++ b/tianshou/core/value_function/state_value.py
@@ -1,4 +1,6 @@
-from base import ValueFunctionBase
+from __future__ import absolute_import
+
+from .base import ValueFunctionBase
import tensorflow as tf
@@ -17,7 +19,7 @@ class StateValue(ValueFunctionBase):
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
:return: numpy array of state values, of shape (batchsize, )
- # TODO: dealing with the last dim of 1 in V(s) and Q(s, a)
+ # TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env
"""
sess = tf.get_default_session()
- return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})[:, 0]
\ No newline at end of file
+ return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})
\ No newline at end of file