From b96fa9448bde1c42cd5a696568a30bda7bddf195 Mon Sep 17 00:00:00 2001 From: rtz19970824 <1289226405@qq.com> Date: Sat, 23 Dec 2017 14:45:07 +0800 Subject: [PATCH 1/3] minor fixed --- .gitignore | 4 ++-- AlphaGo/game.py | 19 ++++++++++--------- AlphaGo/player.py | 2 +- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index d697b92..8ee6691 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,8 @@ leela-zero parameters *.swp *.sublime* -checkpoints -checkpoints_origin +checkpoint *.json .DS_Store data +.log diff --git a/AlphaGo/game.py b/AlphaGo/game.py index ff1faf5..90d0bf0 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -27,29 +27,30 @@ class Game: ''' def __init__(self, name="go", checkpoint_path=None): self.name = name - if "go" == name: + if self.name == "go": self.size = 9 self.komi = 3.75 self.board = [utils.EMPTY] * (self.size ** 2) self.history = [] + self.history_length = 8 self.latest_boards = deque(maxlen=8) for _ in range(8): self.latest_boards.append(self.board) - - self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8) self.game_engine = go.Go(size=self.size, komi=self.komi) - elif "reversi" == name: + elif self.name == "reversi": self.size = 8 - self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=1) + self.history_length = 1 self.game_engine = reversi.Reversi() self.board = self.game_engine.get_board() else: - print(name + " is an unknown game...") + raise ValueError(name + " is an unknown game...") + + self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length) def clear(self): self.board = [utils.EMPTY] * (self.size ** 2) self.history = [] - for _ in range(8): + for _ in range(self.history_length): self.latest_boards.append(self.board) def set_size(self, n): @@ -76,9 +77,9 @@ class Game: if vertex == utils.PASS: return True # TODO this implementation is not very elegant - if "go" == self.name: + if self.name == "go": res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) - elif "revsersi" == self.name: + elif self.name == "reversi": res = self.game_engine.executor_do_move(self.board, color, vertex) return res diff --git a/AlphaGo/player.py b/AlphaGo/player.py index 0e3daff..e848d2b 100644 --- a/AlphaGo/player.py +++ b/AlphaGo/player.py @@ -34,7 +34,7 @@ if __name__ == '__main__': daemon = Pyro4.Daemon() # make a Pyro daemon ns = Pyro4.locateNS() # find the name server - player = Player(role = args.role, engine = engine) + player = Player(role=args.role, engine=engine) print "Init " + args.role + " player finished" uri = daemon.register(player) # register the greeting maker as a Pyro object print "Start on name " + args.role From 951eed60edeabbcd90ac465fc2df2050584a0238 Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Sat, 23 Dec 2017 15:34:44 +0800 Subject: [PATCH 2/3] fix imports to support both python2 and python3. move contents from __init__.py to leave for work after major development. --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 9c3af16..fc7d494 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,11 @@ Tianshou(天授) is a reinforcement learning platform. The following image illus +## examples + +During development, run examples under `./examples/` directory with, e.g. `python ppo_example.py`. +Running them under this directory with `python examples/ppo_example.py` will not work. + ## About coding style From 04048b78738d1092768c669f37fa63a9e1922d1a Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Sat, 23 Dec 2017 15:36:10 +0800 Subject: [PATCH 3/3] fix imports to support both python2 and python3. move contents from __init__.py to leave for work after major development. --- examples/ppo_example.py | 7 +++---- tianshou/core/policy/__init__.py | 6 ------ tianshou/core/policy/base.py | 12 ++++++++++++ tianshou/core/policy/dqn.py | 18 ++++++++++++------ tianshou/core/value_function/action_value.py | 17 +++++++++++++---- tianshou/core/value_function/base.py | 5 ++++- tianshou/core/value_function/state_value.py | 8 +++++--- 7 files changed, 49 insertions(+), 24 deletions(-) diff --git a/examples/ppo_example.py b/examples/ppo_example.py index 02ccb52..985c8f2 100755 --- a/examples/ppo_example.py +++ b/examples/ppo_example.py @@ -1,17 +1,16 @@ #!/usr/bin/env python +from __future__ import absolute_import import tensorflow as tf -import numpy as np -import time import gym # our lib imports here! import sys sys.path.append('..') -import tianshou.core.losses as losses +from tianshou.core import losses from tianshou.data.batch import Batch import tianshou.data.advantage_estimation as advantage_estimation -import tianshou.core.policy as policy +import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy def policy_net(observation, action_dim, scope=None): diff --git a/tianshou/core/policy/__init__.py b/tianshou/core/policy/__init__.py index ccde775..e69de29 100644 --- a/tianshou/core/policy/__init__.py +++ b/tianshou/core/policy/__init__.py @@ -1,6 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from .base import * -from .stochastic import * -from .dqn import * \ No newline at end of file diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index 025abd5..1adeaeb 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -13,11 +13,23 @@ import tensorflow as tf __all__ = [ 'StochasticPolicy', 'QValuePolicy', + 'PolicyBase' ] # TODO: a even more "base" class for policy +class PolicyBase(object): + """ + base class for policy. only provides `act` method with exploration + """ + def __init__(self): + pass + + def act(self, observation, exploration): + raise NotImplementedError() + + class QValuePolicy(object): """ The policy as in DQN diff --git a/tianshou/core/policy/dqn.py b/tianshou/core/policy/dqn.py index d03dbd4..716e4c4 100644 --- a/tianshou/core/policy/dqn.py +++ b/tianshou/core/policy/dqn.py @@ -1,16 +1,22 @@ -from tianshou.core.policy.base import QValuePolicy +from __future__ import absolute_import + +from .base import PolicyBase import tensorflow as tf -import sys -sys.path.append('..') -import value_function.action_value as value_func +from ..value_function.action_value import DQN -class DQN_refactor(object): +class DQNRefactor(PolicyBase): """ use DQN from value_function as a member """ def __init__(self, value_tensor, observation_placeholder, action_placeholder): - self._network = value_func.DQN(value_tensor, observation_placeholder, action_placeholder) + self._network = DQN(value_tensor, observation_placeholder, action_placeholder) + self._argmax_action = tf.argmax(value_tensor, axis=1) + + def act(self, observation, exploration): + sess = tf.get_default_session() + if not exploration: # no exploration + action = sess.run(self._argmax_action, feed_dict={}) class DQN(QValuePolicy): diff --git a/tianshou/core/value_function/action_value.py b/tianshou/core/value_function/action_value.py index cb8acc8..2bda4fa 100644 --- a/tianshou/core/value_function/action_value.py +++ b/tianshou/core/value_function/action_value.py @@ -1,4 +1,6 @@ -from base import ValueFunctionBase +from __future__ import absolute_import + +from .base import ValueFunctionBase import tensorflow as tf @@ -15,7 +17,6 @@ class ActionValue(ValueFunctionBase): def get_value(self, observation, action): """ - :param observation: numpy array of observations, of shape (batchsize, observation_dim). :param action: numpy array of actions, of shape (batchsize, action_dim) # TODO: Atari discrete action should have dim 1. Super Mario may should have, say, dim 5, where each can be 0/1 @@ -24,7 +25,7 @@ class ActionValue(ValueFunctionBase): """ sess = tf.get_default_session() return sess.run(self.get_value_tensor(), feed_dict= - {self._observation_placeholder: observation, self._action_placeholder:action})[:, 0] + {self._observation_placeholder: observation, self._action_placeholder: action}) class DQN(ActionValue): @@ -39,13 +40,21 @@ class DQN(ActionValue): :param action_placeholder: of shape (batchsize, ) """ self._value_tensor_all_actions = value_tensor - canonical_value_tensor = value_tensor[action_placeholder] # maybe a tf.map_fn. for now it's wrong + + batch_size = tf.shape(value_tensor)[0] + batch_dim_index = tf.range(batch_size) + indices = tf.stack([batch_dim_index, action_placeholder], axis=1) + canonical_value_tensor = tf.gather_nd(value_tensor, indices) super(DQN, self).__init__(value_tensor=canonical_value_tensor, observation_placeholder=observation_placeholder, action_placeholder=action_placeholder) def get_value_all_actions(self, observation): + """ + :param observation: + :return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions) + """ sess = tf.get_default_session() return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation}) diff --git a/tianshou/core/value_function/base.py b/tianshou/core/value_function/base.py index 0b27759..b15f1bf 100644 --- a/tianshou/core/value_function/base.py +++ b/tianshou/core/value_function/base.py @@ -1,3 +1,6 @@ +from __future__ import absolute_import + +import tensorflow as tf # TODO: linear feature baseline also in tf? class ValueFunctionBase(object): @@ -6,7 +9,7 @@ class ValueFunctionBase(object): """ def __init__(self, value_tensor, observation_placeholder): self._observation_placeholder = observation_placeholder - self._value_tensor = value_tensor + self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, ) def get_value(self, **kwargs): """ diff --git a/tianshou/core/value_function/state_value.py b/tianshou/core/value_function/state_value.py index 04fe442..b7de196 100644 --- a/tianshou/core/value_function/state_value.py +++ b/tianshou/core/value_function/state_value.py @@ -1,4 +1,6 @@ -from base import ValueFunctionBase +from __future__ import absolute_import + +from .base import ValueFunctionBase import tensorflow as tf @@ -17,7 +19,7 @@ class StateValue(ValueFunctionBase): :param observation: numpy array of observations, of shape (batchsize, observation_dim). :return: numpy array of state values, of shape (batchsize, ) - # TODO: dealing with the last dim of 1 in V(s) and Q(s, a) + # TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env """ sess = tf.get_default_session() - return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})[:, 0] \ No newline at end of file + return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation}) \ No newline at end of file