Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
430b78abf5
4
.gitignore
vendored
4
.gitignore
vendored
@ -4,8 +4,8 @@ leela-zero
|
|||||||
parameters
|
parameters
|
||||||
*.swp
|
*.swp
|
||||||
*.sublime*
|
*.sublime*
|
||||||
checkpoints
|
checkpoint
|
||||||
checkpoints_origin
|
|
||||||
*.json
|
*.json
|
||||||
.DS_Store
|
.DS_Store
|
||||||
data
|
data
|
||||||
|
.log
|
||||||
|
@ -27,29 +27,30 @@ class Game:
|
|||||||
'''
|
'''
|
||||||
def __init__(self, name="go", checkpoint_path=None):
|
def __init__(self, name="go", checkpoint_path=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
if "go" == name:
|
if self.name == "go":
|
||||||
self.size = 9
|
self.size = 9
|
||||||
self.komi = 3.75
|
self.komi = 3.75
|
||||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||||
self.history = []
|
self.history = []
|
||||||
|
self.history_length = 8
|
||||||
self.latest_boards = deque(maxlen=8)
|
self.latest_boards = deque(maxlen=8)
|
||||||
for _ in range(8):
|
for _ in range(8):
|
||||||
self.latest_boards.append(self.board)
|
self.latest_boards.append(self.board)
|
||||||
|
|
||||||
self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8)
|
|
||||||
self.game_engine = go.Go(size=self.size, komi=self.komi)
|
self.game_engine = go.Go(size=self.size, komi=self.komi)
|
||||||
elif "reversi" == name:
|
elif self.name == "reversi":
|
||||||
self.size = 8
|
self.size = 8
|
||||||
self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=1)
|
self.history_length = 1
|
||||||
self.game_engine = reversi.Reversi()
|
self.game_engine = reversi.Reversi()
|
||||||
self.board = self.game_engine.get_board()
|
self.board = self.game_engine.get_board()
|
||||||
else:
|
else:
|
||||||
print(name + " is an unknown game...")
|
raise ValueError(name + " is an unknown game...")
|
||||||
|
|
||||||
|
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||||
self.history = []
|
self.history = []
|
||||||
for _ in range(8):
|
for _ in range(self.history_length):
|
||||||
self.latest_boards.append(self.board)
|
self.latest_boards.append(self.board)
|
||||||
|
|
||||||
def set_size(self, n):
|
def set_size(self, n):
|
||||||
@ -76,9 +77,9 @@ class Game:
|
|||||||
if vertex == utils.PASS:
|
if vertex == utils.PASS:
|
||||||
return True
|
return True
|
||||||
# TODO this implementation is not very elegant
|
# TODO this implementation is not very elegant
|
||||||
if "go" == self.name:
|
if self.name == "go":
|
||||||
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
||||||
elif "revsersi" == self.name:
|
elif self.name == "reversi":
|
||||||
res = self.game_engine.executor_do_move(self.board, color, vertex)
|
res = self.game_engine.executor_do_move(self.board, color, vertex)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
daemon = Pyro4.Daemon() # make a Pyro daemon
|
daemon = Pyro4.Daemon() # make a Pyro daemon
|
||||||
ns = Pyro4.locateNS() # find the name server
|
ns = Pyro4.locateNS() # find the name server
|
||||||
player = Player(role = args.role, engine = engine)
|
player = Player(role=args.role, engine=engine)
|
||||||
print "Init " + args.role + " player finished"
|
print "Init " + args.role + " player finished"
|
||||||
uri = daemon.register(player) # register the greeting maker as a Pyro object
|
uri = daemon.register(player) # register the greeting maker as a Pyro object
|
||||||
print "Start on name " + args.role
|
print "Start on name " + args.role
|
||||||
|
@ -41,6 +41,11 @@ Tianshou(天授) is a reinforcement learning platform. The following image illus
|
|||||||
|
|
||||||
<img src="https://github.com/sproblvem/tianshou/blob/master/docs/figures/go.png" height="150"/> <img src="https://github.com/sproblvem/tianshou/blob/master/docs/figures/reversi.jpg" height="150"/> <img src="https://github.com/sproblvem/tianshou/blob/master/docs/figures/warzone.jpg" height="150"/>
|
<img src="https://github.com/sproblvem/tianshou/blob/master/docs/figures/go.png" height="150"/> <img src="https://github.com/sproblvem/tianshou/blob/master/docs/figures/reversi.jpg" height="150"/> <img src="https://github.com/sproblvem/tianshou/blob/master/docs/figures/warzone.jpg" height="150"/>
|
||||||
|
|
||||||
|
## examples
|
||||||
|
|
||||||
|
During development, run examples under `./examples/` directory with, e.g. `python ppo_example.py`.
|
||||||
|
Running them under this directory with `python examples/ppo_example.py` will not work.
|
||||||
|
|
||||||
|
|
||||||
## About coding style
|
## About coding style
|
||||||
|
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
|
||||||
import time
|
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
# our lib imports here!
|
# our lib imports here!
|
||||||
import sys
|
import sys
|
||||||
sys.path.append('..')
|
sys.path.append('..')
|
||||||
import tianshou.core.losses as losses
|
from tianshou.core import losses
|
||||||
from tianshou.data.batch import Batch
|
from tianshou.data.batch import Batch
|
||||||
import tianshou.data.advantage_estimation as advantage_estimation
|
import tianshou.data.advantage_estimation as advantage_estimation
|
||||||
import tianshou.core.policy as policy
|
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
||||||
|
|
||||||
|
|
||||||
def policy_net(observation, action_dim, scope=None):
|
def policy_net(observation, action_dim, scope=None):
|
||||||
|
@ -1,6 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
from .base import *
|
|
||||||
from .stochastic import *
|
|
||||||
from .dqn import *
|
|
@ -13,11 +13,23 @@ import tensorflow as tf
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'StochasticPolicy',
|
'StochasticPolicy',
|
||||||
'QValuePolicy',
|
'QValuePolicy',
|
||||||
|
'PolicyBase'
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: a even more "base" class for policy
|
# TODO: a even more "base" class for policy
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyBase(object):
|
||||||
|
"""
|
||||||
|
base class for policy. only provides `act` method with exploration
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def act(self, observation, exploration):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class QValuePolicy(object):
|
class QValuePolicy(object):
|
||||||
"""
|
"""
|
||||||
The policy as in DQN
|
The policy as in DQN
|
||||||
|
@ -1,16 +1,22 @@
|
|||||||
from tianshou.core.policy.base import QValuePolicy
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
from .base import PolicyBase
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import sys
|
from ..value_function.action_value import DQN
|
||||||
sys.path.append('..')
|
|
||||||
import value_function.action_value as value_func
|
|
||||||
|
|
||||||
|
|
||||||
class DQN_refactor(object):
|
class DQNRefactor(PolicyBase):
|
||||||
"""
|
"""
|
||||||
use DQN from value_function as a member
|
use DQN from value_function as a member
|
||||||
"""
|
"""
|
||||||
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
|
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
|
||||||
self._network = value_func.DQN(value_tensor, observation_placeholder, action_placeholder)
|
self._network = DQN(value_tensor, observation_placeholder, action_placeholder)
|
||||||
|
self._argmax_action = tf.argmax(value_tensor, axis=1)
|
||||||
|
|
||||||
|
def act(self, observation, exploration):
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
if not exploration: # no exploration
|
||||||
|
action = sess.run(self._argmax_action, feed_dict={})
|
||||||
|
|
||||||
|
|
||||||
class DQN(QValuePolicy):
|
class DQN(QValuePolicy):
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from base import ValueFunctionBase
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
from .base import ValueFunctionBase
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
@ -15,7 +17,6 @@ class ActionValue(ValueFunctionBase):
|
|||||||
|
|
||||||
def get_value(self, observation, action):
|
def get_value(self, observation, action):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
|
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
|
||||||
:param action: numpy array of actions, of shape (batchsize, action_dim)
|
:param action: numpy array of actions, of shape (batchsize, action_dim)
|
||||||
# TODO: Atari discrete action should have dim 1. Super Mario may should have, say, dim 5, where each can be 0/1
|
# TODO: Atari discrete action should have dim 1. Super Mario may should have, say, dim 5, where each can be 0/1
|
||||||
@ -24,7 +25,7 @@ class ActionValue(ValueFunctionBase):
|
|||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
return sess.run(self.get_value_tensor(), feed_dict=
|
return sess.run(self.get_value_tensor(), feed_dict=
|
||||||
{self._observation_placeholder: observation, self._action_placeholder:action})[:, 0]
|
{self._observation_placeholder: observation, self._action_placeholder: action})
|
||||||
|
|
||||||
|
|
||||||
class DQN(ActionValue):
|
class DQN(ActionValue):
|
||||||
@ -39,13 +40,21 @@ class DQN(ActionValue):
|
|||||||
:param action_placeholder: of shape (batchsize, )
|
:param action_placeholder: of shape (batchsize, )
|
||||||
"""
|
"""
|
||||||
self._value_tensor_all_actions = value_tensor
|
self._value_tensor_all_actions = value_tensor
|
||||||
canonical_value_tensor = value_tensor[action_placeholder] # maybe a tf.map_fn. for now it's wrong
|
|
||||||
|
batch_size = tf.shape(value_tensor)[0]
|
||||||
|
batch_dim_index = tf.range(batch_size)
|
||||||
|
indices = tf.stack([batch_dim_index, action_placeholder], axis=1)
|
||||||
|
canonical_value_tensor = tf.gather_nd(value_tensor, indices)
|
||||||
|
|
||||||
super(DQN, self).__init__(value_tensor=canonical_value_tensor,
|
super(DQN, self).__init__(value_tensor=canonical_value_tensor,
|
||||||
observation_placeholder=observation_placeholder,
|
observation_placeholder=observation_placeholder,
|
||||||
action_placeholder=action_placeholder)
|
action_placeholder=action_placeholder)
|
||||||
|
|
||||||
def get_value_all_actions(self, observation):
|
def get_value_all_actions(self, observation):
|
||||||
|
"""
|
||||||
|
:param observation:
|
||||||
|
:return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions)
|
||||||
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation})
|
return sess.run(self._value_tensor_all_actions, feed_dict={self._observation_placeholder: observation})
|
||||||
|
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
# TODO: linear feature baseline also in tf?
|
# TODO: linear feature baseline also in tf?
|
||||||
class ValueFunctionBase(object):
|
class ValueFunctionBase(object):
|
||||||
@ -6,7 +9,7 @@ class ValueFunctionBase(object):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, value_tensor, observation_placeholder):
|
def __init__(self, value_tensor, observation_placeholder):
|
||||||
self._observation_placeholder = observation_placeholder
|
self._observation_placeholder = observation_placeholder
|
||||||
self._value_tensor = value_tensor
|
self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, )
|
||||||
|
|
||||||
def get_value(self, **kwargs):
|
def get_value(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from base import ValueFunctionBase
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
from .base import ValueFunctionBase
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
@ -17,7 +19,7 @@ class StateValue(ValueFunctionBase):
|
|||||||
|
|
||||||
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
|
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
|
||||||
:return: numpy array of state values, of shape (batchsize, )
|
:return: numpy array of state values, of shape (batchsize, )
|
||||||
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a)
|
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})[:, 0]
|
return sess.run(self.get_value_tensor(), feed_dict={self._observation_placeholder: observation})
|
Loading…
x
Reference in New Issue
Block a user