debug for unit test

This commit is contained in:
rtz19970824 2017-12-28 19:38:25 +08:00
parent 50e8ea36e8
commit 01f39f40d3
4 changed files with 10 additions and 17 deletions

View File

@ -29,17 +29,10 @@ class ZOTree:
length = len(seq) length = len(seq)
if length != self.depth: if length != self.depth:
raise ValueError("The game is not terminated!") raise ValueError("The game is not terminated!")
ones = 0 result = np.sum(seq)
zeros = 0 if result > self.size:
for i in range(len(seq)):
if seq[i] == 0:
zeros += 1
if seq[i] == 1:
ones += 1
result = ones - zeros
if result > 0:
winner = 1 winner = 1
elif result < 0: elif result < self.size:
winner = -1 winner = -1
else: else:
winner = 0 winner = 0
@ -98,7 +91,7 @@ class ZOTree:
if __name__ == "__main__": if __name__ == "__main__":
size = 2 size = 2
game = ZOTree(size) game = ZOTree(size)
seq = [1, -1, 1, 1] seq = [1, 0, 1, 1]
result = game.executor_do_move([seq, 1], 1) result = game.executor_do_move([seq, 1], 1)
print(result) print(result)
print(seq) print(seq)

View File

@ -17,7 +17,7 @@ class Agent:
def gen_move(self, seq): def gen_move(self, seq):
if len(seq) >= 2 * self.size: if len(seq) >= 2 * self.size:
raise ValueError("Game is terminated.") raise ValueError("Game is terminated.")
mcts = MCTS(self.simulator, self.evaluator, [seq, self.color], 2) mcts = MCTS(self.simulator, self.evaluator, [seq, self.color], 2, inverse=True)
mcts.search(max_step=50) mcts.search(max_step=50)
N = mcts.root.N N = mcts.root.N
N = np.power(N, 1.0 / temp) N = np.power(N, 1.0 / temp)

View File

@ -6,7 +6,7 @@ if __name__ == '__main__':
print("Our game has 2 players.") print("Our game has 2 players.")
print("Player 1 has color 1 and plays first. Player 2 has color -1 and plays following player 1.") print("Player 1 has color 1 and plays first. Player 2 has color -1 and plays following player 1.")
print("Both player choose 1 or 0 for an action.") print("Both player choose 1 or 0 for an action.")
size = 1 size = 2
print("This game has {} iterations".format(size)) print("This game has {} iterations".format(size))
print("If the final sequence has more 1 that 0, player 1 wins.") print("If the final sequence has more 1 that 0, player 1 wins.")
print("If the final sequence has less 1 that 0, player 2 wins.") print("If the final sequence has less 1 that 0, player 2 wins.")

View File

@ -187,7 +187,7 @@ class MCTS(object):
prior, value = self.evaluator(next_action.next_state) prior, value = self.evaluator(next_action.next_state)
next_action.expansion(prior, self.action_num) next_action.expansion(prior, self.action_num)
else: else:
value = 0 value = 0.
t2 = time.time() t2 = time.time()
if self.inverse: if self.inverse:
next_action.backpropagation(-value + 0.) next_action.backpropagation(-value + 0.)