Merge remote-tracking branch 'origin/master'

# Conflicts:
#	README.md
This commit is contained in:
haoshengzou 2018-02-26 11:48:46 +08:00
commit 40190a282e
8 changed files with 199 additions and 85 deletions

View File

@ -40,6 +40,9 @@ class Game:
self.history_hashtable = set() self.history_hashtable = set()
self.game_engine = go.Go(size=self.size, komi=self.komi) self.game_engine = go.Go(size=self.size, komi=self.komi)
self.board = [utils.EMPTY] * (self.size ** 2) self.board = [utils.EMPTY] * (self.size ** 2)
self.group_ancestors = {} # key: idx, value: ancestor idx
self.liberty = {} # key: ancestor idx, value: set of liberty
self.stones = {} # key: ancestor idx, value: set of stones
elif self.name == "reversi": elif self.name == "reversi":
self.size = 8 self.size = 8
self.history_length = 1 self.history_length = 1
@ -62,6 +65,9 @@ class Game:
self.board = [utils.EMPTY] * (self.size ** 2) self.board = [utils.EMPTY] * (self.size ** 2)
del self.history[:] del self.history[:]
self.history_hashtable.clear() self.history_hashtable.clear()
self.group_ancestors.clear()
self.liberty.clear()
self.stones.clear()
if self.name == "reversi": if self.name == "reversi":
self.board = self.game_engine.get_board() self.board = self.game_engine.get_board()
for _ in range(self.history_length): for _ in range(self.history_length):
@ -109,7 +115,7 @@ class Game:
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
if self.name == "go": if self.name == "go":
res = self.game_engine.executor_do_move(self.history, self.history_hashtable, self.latest_boards, self.board, res = self.game_engine.executor_do_move(self.history, self.history_hashtable, self.latest_boards, self.board,
color, vertex) self.group_ancestors, self.liberty, self.stones, color, vertex)
return res return res
def think_play_move(self, color): def think_play_move(self, color):

View File

@ -44,6 +44,23 @@ class Go:
nei.append((_x, _y)) nei.append((_x, _y))
return nei return nei
def _neighbor_color(self, current_board, vertex, color):
# return neighbors which are listed in different colors
color_neighbor = [] # 1)neighbors in the same color
reverse_color_neighbor = [] # 2)neighbors in the reverse color
empty_neighbor = [] # 2)empty neighbors
reverse_color = utils.BLACK if color == utils.WHITE else utils.WHITE
for n in self._neighbor(vertex):
if current_board[self._flatten(n)] == color:
color_neighbor.append(self._flatten(n))
elif current_board[self._flatten(n)] == utils.EMPTY:
empty_neighbor.append(self._flatten(n))
elif current_board[self._flatten(n)] == reverse_color:
reverse_color_neighbor.append(self._flatten(n))
else:
raise ValueError("board have other positions excluding BLACK, WHITE and EMPTY")
return color_neighbor, reverse_color_neighbor, empty_neighbor
def _corner(self, vertex): def _corner(self, vertex):
x, y = vertex x, y = vertex
corner = [] corner = []
@ -56,13 +73,11 @@ class Go:
def _find_group(self, current_board, vertex): def _find_group(self, current_board, vertex):
color = current_board[self._flatten(vertex)] color = current_board[self._flatten(vertex)]
# print ("color : ", color)
chain = set() chain = set()
frontier = [vertex] frontier = [vertex]
has_liberty = False has_liberty = False
while frontier: while frontier:
current = frontier.pop() current = frontier.pop()
# print ("current : ", current)
chain.add(current) chain.add(current)
for n in self._neighbor(current): for n in self._neighbor(current):
if current_board[self._flatten(n)] == color and not n in chain: if current_board[self._flatten(n)] == color and not n in chain:
@ -71,21 +86,26 @@ class Go:
has_liberty = True has_liberty = True
return has_liberty, chain return has_liberty, chain
def _is_suicide(self, current_board, color, vertex): def _find_ancestor(self, group_ancestors, idx):
current_board[self._flatten(vertex)] = color # assume that we already take this move r = idx
suicide = False while group_ancestors[r] != r:
r = group_ancestors[r]
group_ancestors[idx] = r
return r
has_liberty, group = self._find_group(current_board, vertex) def _is_suicide(self, current_board, group_ancestors, liberty, color, vertex):
if not has_liberty: color_neighbor, reverse_color_neighbor, empty_neighbor = self._neighbor_color(current_board, vertex, color)
suicide = True # no liberty, suicide if empty_neighbor:
for n in self._neighbor(vertex): return False # neighbors have empty spaces
if current_board[self._flatten(n)] == utils.another_color(color): elif color_neighbor: # neighbors have same color, they have liberties
opponent_liberty, group = self._find_group(current_board, n) for idx in color_neighbor:
if not opponent_liberty: if len(liberty[self._find_ancestor(group_ancestors, idx)]) > 1:
suicide = False # this move is able to take opponent's stone, not suicide return False
else: # neighbors have reverse color, they have only one liberty
current_board[self._flatten(vertex)] = utils.EMPTY # undo this move for idx in reverse_color_neighbor:
return suicide if len(liberty[self._find_ancestor(group_ancestors, idx)]) == 1:
return False
return True
def _process_board(self, current_board, color, vertex): def _process_board(self, current_board, color, vertex):
nei = self._neighbor(vertex) nei = self._neighbor(vertex)
@ -107,25 +127,15 @@ class Go:
return repeat return repeat
def _is_eye(self, current_board, color, vertex): def _is_eye(self, current_board, color, vertex):
nei = self._neighbor(vertex) # return is this position is an real eye of color
cor = self._corner(vertex) color_neighbor, reverse_color_neighbor, empty_neighbor = self._neighbor_color(current_board, vertex, color)
ncolor = {color == current_board[self._flatten(n)] for n in nei} if reverse_color_neighbor or empty_neighbor: # not an eye
if False in ncolor:
# print "not all neighbors are in same color with us"
return False return False
_, group = self._find_group(current_board, nei[0]) cor = self._corner(vertex)
if set(nei) < group:
# print "all neighbors are in same group and same color with us"
return True
else:
opponent_number = [current_board[self._flatten(c)] for c in cor].count(-color) opponent_number = [current_board[self._flatten(c)] for c in cor].count(-color)
opponent_propotion = float(opponent_number) / float(len(cor)) opponent_propotion = float(opponent_number) / float(len(cor))
if opponent_propotion < 0.5: # opponent_propotion<0.5 fake eye
# print "few opponents, real eye" return True if opponent_propotion < 0.5 else False
return True
else:
# print "many opponents, fake eye"
return False
def _knowledge_prunning(self, current_board, color, vertex): def _knowledge_prunning(self, current_board, color, vertex):
# forbid some stupid selfplay using human knowledge # forbid some stupid selfplay using human knowledge
@ -134,23 +144,6 @@ class Go:
# forbid position on its own eye. # forbid position on its own eye.
return True return True
def _is_game_finished(self, current_board, color):
'''
for each empty position, if it has both BLACK and WHITE neighbors, the game is still not finished
:return: return the game is finished
'''
board = copy.deepcopy(current_board)
empty_idx = [i for i, x in enumerate(board) if x == utils.EMPTY] # find all empty idx
for idx in empty_idx:
neighbor_idx = self._neighbor(self.deflatten(idx))
if len(neighbor_idx) > 1:
first_idx = neighbor_idx[0]
for other_idx in neighbor_idx[1:]:
if board[self.flatten(other_idx)] != board[self.flatten(first_idx)]:
return False
return True
def _action2vertex(self, action): def _action2vertex(self, action):
if action == self.size ** 2: if action == self.size ** 2:
vertex = (0, 0) vertex = (0, 0)
@ -158,7 +151,7 @@ class Go:
vertex = self._deflatten(action) vertex = self._deflatten(action)
return vertex return vertex
def _rule_check(self, history_hashtable, current_board, color, vertex, is_thinking=True): def _rule_check(self, history_hashtable, current_board, group_ancestors, liberty, color, vertex, is_thinking=True):
### in board ### in board
if not self._in_board(vertex): if not self._in_board(vertex):
if not is_thinking: if not is_thinking:
@ -174,7 +167,7 @@ class Go:
return False return False
### check if it is suicide ### check if it is suicide
if self._is_suicide(current_board, color, vertex): if self._is_suicide(current_board, group_ancestors, liberty, color, vertex):
if not is_thinking: if not is_thinking:
raise ValueError("Target point causes suicide, Current Board: {}, color: {}, vertex : {}".format(current_board, color, vertex)) raise ValueError("Target point causes suicide, Current Board: {}, color: {}, vertex : {}".format(current_board, color, vertex))
else: else:
@ -189,34 +182,59 @@ class Go:
return True return True
def _is_valid(self, state, action, history_hashtable): def _is_valid(self, state, action, history_hashtable, group_ancestors, liberty):
history_boards, color = state history_boards, color = state
vertex = self._action2vertex(action) vertex = self._action2vertex(action)
current_board = history_boards[-1] current_board = history_boards[-1]
if not self._rule_check(history_hashtable, current_board, color, vertex): if not self._rule_check(history_hashtable, current_board, group_ancestors, liberty, color, vertex):
return False return False
if not self._knowledge_prunning(current_board, color, vertex): if not self._knowledge_prunning(current_board, color, vertex):
return False return False
return True return True
def _get_groups(self, board):
group_ancestors = {} # key: idx, value: ancestor idx
liberty = {} # key: ancestor idx, value: set of liberty
for idx, color in enumerate(board):
if color and idx not in group_ancestors:
# build group
group_ancestors[idx] = idx
color_neighbor, _, empty_neighbor = \
self._neighbor_color(board, self._deflatten(idx), color)
liberty[idx] = set(empty_neighbor)
group_list = copy.deepcopy(color_neighbor)
while group_list:
add_idx = group_list.pop()
if add_idx not in group_ancestors:
group_ancestors[add_idx] = idx
color_neighbor_add, _, empty_neighbor_add = \
self._neighbor_color(board, self._deflatten(add_idx), color)
group_list += color_neighbor_add
liberty[idx] |= set(empty_neighbor_add)
return group_ancestors, liberty
def simulate_get_mask(self, state, action_set): def simulate_get_mask(self, state, action_set):
# find all the invalid actions # find all the invalid actions
invalid_action_mask = [] invalid_action_mask = []
history_boards, color = state history_boards, color = state
group_ancestors, liberty = self._get_groups(history_boards[-1])
history_hashtable = set() history_hashtable = set()
for board in history_boards: for board in history_boards:
history_hashtable.add(tuple(board)) history_hashtable.add(tuple(board))
for action_candidate in action_set[:-1]: for action_candidate in action_set[:-1]:
# go through all the actions excluding pass # go through all the actions excluding pass
if not self._is_valid(state, action_candidate, history_hashtable): if not self._is_valid(state, action_candidate, history_hashtable, group_ancestors, liberty):
invalid_action_mask.append(action_candidate) invalid_action_mask.append(action_candidate)
if len(invalid_action_mask) < len(action_set) - 1: if len(invalid_action_mask) < len(action_set) - 1:
invalid_action_mask.append(action_set[-1]) invalid_action_mask.append(action_set[-1])
# forbid pass, if we have other choices # forbid pass, if we have other choices
# TODO: In fact we should not do this. In some extreme cases, we should permit pass. # TODO: In fact we should not do this. In some extreme cases, we should permit pass.
del history_hashtable del history_hashtable
del group_ancestors
del liberty
# del stones
return invalid_action_mask return invalid_action_mask
def _do_move(self, board, color, vertex): def _do_move(self, board, color, vertex):
@ -243,12 +261,70 @@ class Go:
# since go is MDP, we only need the last board for hashing # since go is MDP, we only need the last board for hashing
return tuple(state[0][-1]) return tuple(state[0][-1])
def executor_do_move(self, history, history_hashtable, latest_boards, current_board, color, vertex): def _join_group(self, idx, idx_list, empty_neighbor, group_ancestors, liberty, stones):
if not self._rule_check(history_hashtable, current_board, color, vertex, is_thinking=False): # idx joins its neighbors id_list
# raise ValueError("!!! We have more than four ko at the same time !!!") # empty_neighbor: empty neighbors of idx
color_ancestor = set()
for color_idx in idx_list:
color_ancestor.add(self._find_ancestor(group_ancestors, color_idx))
joined_ancestor = color_ancestor.pop()
liberty[joined_ancestor] |= set(empty_neighbor)
stones[joined_ancestor].add(idx)
group_ancestors[idx] = joined_ancestor
# add other groups
for color_idx in color_ancestor:
liberty[joined_ancestor] |= liberty[color_idx]
stones[joined_ancestor] |= stones[color_idx]
del liberty[color_idx]
for stone in stones[color_idx]:
group_ancestors[stone] = joined_ancestor
del stones[color_idx]
liberty[joined_ancestor].remove(idx)
def _add_captured_liberty(self, board, liberty, group_ancestors, stones):
for captured_stone in stones:
color_neighbor, reverse_color_neighbor, empty_neighbor = \
self._neighbor_color(board, self._deflatten(captured_stone), board[captured_stone])
assert not empty_neighbor # make sure no empty spaces
for reverse_color_idx in reverse_color_neighbor:
reverse_color_idx_ancestor = self._find_ancestor(group_ancestors, reverse_color_idx)
liberty[reverse_color_idx_ancestor].add(captured_stone)
def _remove_liberty(self, idx, reverse_color_neighbor, current_board, group_ancestors, liberty, stones):
# reverse_color_neighbor: stones near idx in the reverse color
reverse_color_ancestor = set()
for reverse_idx in reverse_color_neighbor:
reverse_color_ancestor.add(self._find_ancestor(group_ancestors, reverse_idx))
for reverse_color_ancestor_idx in reverse_color_ancestor:
if len(liberty[reverse_color_ancestor_idx]) == 1:
# capture this group if no liberty left
self._add_captured_liberty(current_board, liberty, group_ancestors, stones[reverse_color_ancestor_idx])
for captured_stone in stones[reverse_color_ancestor_idx]:
current_board[captured_stone] = utils.EMPTY
del group_ancestors[captured_stone]
del liberty[reverse_color_ancestor_idx]
del stones[reverse_color_ancestor_idx]
else:
# remove this liberty
liberty[reverse_color_ancestor_idx].remove(idx)
def executor_do_move(self, history, history_hashtable, latest_boards, current_board, group_ancestors, liberty, stones, color, vertex):
#print("===")
#print(color, vertex)
#print(group_ancestors, liberty, stones)
if not self._rule_check(history_hashtable, current_board, group_ancestors, liberty, color, vertex):
return False return False
current_board[self._flatten(vertex)] = color idx = self._flatten(vertex)
self._process_board(current_board, color, vertex) current_board[idx] = color
color_neighbor, reverse_color_neighbor, empty_neighbor = self._neighbor_color(current_board, vertex, color)
if color_neighbor: # join nearby groups
self._join_group(idx, color_neighbor, empty_neighbor, group_ancestors, liberty, stones)
else: # build a new group
group_ancestors[idx] = idx
liberty[idx] = set(empty_neighbor)
stones[idx] = {idx}
if reverse_color_neighbor: # remove liberty for nearby reverse color
self._remove_liberty(idx, reverse_color_neighbor, current_board, group_ancestors, liberty, stones)
history.append(copy.deepcopy(current_board)) history.append(copy.deepcopy(current_board))
latest_boards.append(copy.deepcopy(current_board)) latest_boards.append(copy.deepcopy(current_board))
history_hashtable.add(copy.deepcopy(tuple(current_board))) history_hashtable.add(copy.deepcopy(tuple(current_board)))
@ -362,6 +438,8 @@ if __name__ == "__main__":
1, 0, 1, 1, 1, 1, 1, -1, 0, 1, 0, 1, 1, 1, 1, 1, -1, 0,
1, 1, 0, 1, -1, -1, -1, -1, -1 1, 1, 0, 1, -1, -1, -1, -1, -1
] ]
'''
time0 = time.time() time0 = time.time()
score = go.executor_get_score(endgame) score = go.executor_get_score(endgame)
time1 = time.time() time1 = time.time()
@ -370,6 +448,7 @@ if __name__ == "__main__":
time2 = time.time() time2 = time.time()
print(score, time2 - time1) print(score, time2 - time1)
''' '''
'''
### do unit test for Go class ### do unit test for Go class
pure_test = [ pure_test = [
0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0,

View File

@ -1,3 +1,4 @@
from __future__ import division
import argparse import argparse
import sys import sys
import re import re
@ -28,10 +29,13 @@ def play(engine, data_path):
size = {"go": 9, "reversi": 8} size = {"go": 9, "reversi": 8}
show = ['.', 'X', 'O'] show = ['.', 'X', 'O']
# evaluate_rounds = 100 evaluate_rounds = 0
game_num = 0 game_num = 0
total = 0
f=open('time.txt','w')
while True: while True:
# while game_num < evaluate_rounds: #while game_num < evaluate_rounds:
start = time.time()
engine._game.model.check_latest_model() engine._game.model.check_latest_model()
num = 0 num = 0
pass_flag = [False, False] pass_flag = [False, False]
@ -78,6 +82,14 @@ def play(engine, data_path):
data.reset() data.reset()
game_num += 1 game_num += 1
this_time = time.time() - start
total += this_time
f.write('time:'+ str(this_time)+'\n')
evaluate_rounds += 1
f.write('Avg time:' + str(total/evaluate_rounds))
f.close()
if __name__ == '__main__': if __name__ == '__main__':
""" """

View File

@ -83,12 +83,6 @@ Try to use full names. Don't use abbrevations for class/function/variable names
The """xxx""" comment should be written right after class/function. Also comment the part that's not intuitive during the code. We must comment, but for now we don't need to polish them. The """xxx""" comment should be written right after class/function. Also comment the part that's not intuitive during the code. We must comment, but for now we don't need to polish them.
# High Priority TODO
For Haosheng and Tongzheng: separate actor and critic, rewrite the interfaces for policy
Others can still focus on the task below.
## TODO ## TODO
Search based method parallel. Search based method parallel.
@ -106,7 +100,18 @@ Note: install openai/gym first to run the Atari environment; note that interface
Without preprocessing and other tricks, this example will not train to any meaningful results. Codes should past two tests: individual module test and run through this example code. Without preprocessing and other tricks, this example will not train to any meaningful results. Codes should past two tests: individual module test and run through this example code.
## Dependencies ## Some bug to fix
TensorFlow (Version >= 1.4)
gym For DQN and other deterministic policy: $\epsilon$-greedy or other exploration during collection?
In Batch.py, notice that we cannot stop by setting num_timestep
Magic numbers
## One idea
Like zhusuan, we can register losses background so that we need not claim it in the example.
## Dependency
Tensorflow (Version >= 1.4)
Gym

View File

@ -66,9 +66,11 @@ if __name__ == '__main__':
pi.sync_weights() # TODO: automate this for policies with target network pi.sync_weights() # TODO: automate this for policies with target network
start_time = time.time() start_time = time.time()
for i in range(100): #TODO : repeat_num shoulde be defined in some configuration files
repeat_num = 100
for i in range(repeat_num):
# collect data # collect data
data_collector.collect(num_episodes=50) data_collector.collect(num_episodes=50, epsilon_greedy= (repeat_num - i + 0.0) / repeat_num)
# print current return # print current return
print('Epoch {}:'.format(i)) print('Epoch {}:'.format(i))

View File

@ -5,6 +5,7 @@ import tensorflow as tf
import gym import gym
import numpy as np import numpy as np
import time import time
import argparse
# our lib imports here! It's ok to append path in examples # our lib imports here! It's ok to append path in examples
import sys import sys
@ -16,6 +17,9 @@ import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--render", action="store_true", default=False)
args = parser.parse_args()
env = gym.make('CartPole-v0') env = gym.make('CartPole-v0')
observation_dim = env.observation_space.shape observation_dim = env.observation_space.shape
action_dim = env.action_space.n action_dim = env.action_space.n
@ -55,7 +59,7 @@ if __name__ == '__main__':
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
### 3. define data collection ### 3. define data collection
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi]) training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render = args.render)
### 4. start training ### 4. start training
config = tf.ConfigProto() config = tf.ConfigProto()

View File

@ -18,7 +18,7 @@ class DQN(PolicyBase):
else: else:
self.interaction_count = -1 self.interaction_count = -1
def act(self, observation, exploration=None): def act(self, observation, my_feed_dict):
sess = tf.get_default_session() sess = tf.get_default_session()
if self.weight_update > 1: if self.weight_update > 1:
if self.interaction_count % self.weight_update == 0: if self.interaction_count % self.weight_update == 0:
@ -30,7 +30,6 @@ class DQN(PolicyBase):
if self.weight_update > 0: if self.weight_update > 0:
self.interaction_count += 1 self.interaction_count += 1
if not exploration:
return np.squeeze(action) return np.squeeze(action)
@property @property

View File

@ -9,7 +9,7 @@ class Batch(object):
class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy. class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy.
""" """
def __init__(self, env, pi, reward_processors, networks): # how to name the function? def __init__(self, env, pi, reward_processors, networks, render=False): # how to name the function?
""" """
constructor constructor
:param env: :param env:
@ -24,6 +24,7 @@ class Batch(object):
self.reward_processors = reward_processors self.reward_processors = reward_processors
self.networks = networks self.networks = networks
self.render = render
self.required_placeholders = {} self.required_placeholders = {}
for net in self.networks: for net in self.networks:
@ -33,7 +34,7 @@ class Batch(object):
self._is_first_collect = True self._is_first_collect = True
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={},
process_reward=True): # specify how many data to collect here, or fix it in __init__() process_reward=True, epsilon_greedy=0): # specify how many data to collect here, or fix it in __init__()
assert sum( assert sum(
[num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!" [num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!"
@ -105,15 +106,21 @@ class Batch(object):
episode_start_flags.append(True) episode_start_flags.append(True)
while True: while True:
# a simple implementation of epsilon greedy
if epsilon_greedy > 0 and np.random.random() < epsilon_greedy:
ac = np.random.randint(low = 0, high = self._env.action_space.n)
else:
ac = self._pi.act(ob, my_feed_dict) ac = self._pi.act(ob, my_feed_dict)
actions.append(ac) actions.append(ac)
if self.render:
self._env.render()
ob, reward, done, _ = self._env.step(ac) ob, reward, done, _ = self._env.step(ac)
rewards.append(reward) rewards.append(reward)
t_count += 1 #t_count += 1
if t_count >= 100: # force episode stop, just to test if memory still grows #if t_count >= 100: # force episode stop, just to test if memory still grows
break # break
if done: # end of episode, discard s_T if done: # end of episode, discard s_T
# TODO: for num_timesteps collection, has to store terminal flag instead of start flag! # TODO: for num_timesteps collection, has to store terminal flag instead of start flag!