debug for unit test

This commit is contained in:
rtz19970824 2017-12-28 19:38:25 +08:00
parent 50e8ea36e8
commit 01f39f40d3
4 changed files with 10 additions and 17 deletions

View File

@ -29,17 +29,10 @@ class ZOTree:
length = len(seq)
if length != self.depth:
raise ValueError("The game is not terminated!")
ones = 0
zeros = 0
for i in range(len(seq)):
if seq[i] == 0:
zeros += 1
if seq[i] == 1:
ones += 1
result = ones - zeros
if result > 0:
result = np.sum(seq)
if result > self.size:
winner = 1
elif result < 0:
elif result < self.size:
winner = -1
else:
winner = 0
@ -98,7 +91,7 @@ class ZOTree:
if __name__ == "__main__":
size = 2
game = ZOTree(size)
seq = [1, -1, 1, 1]
seq = [1, 0, 1, 1]
result = game.executor_do_move([seq, 1], 1)
print(result)
print(seq)
print(seq)

View File

@ -17,11 +17,11 @@ class Agent:
def gen_move(self, seq):
if len(seq) >= 2 * self.size:
raise ValueError("Game is terminated.")
mcts = MCTS(self.simulator, self.evaluator, [seq, self.color], 2)
mcts = MCTS(self.simulator, self.evaluator, [seq, self.color], 2, inverse=True)
mcts.search(max_step=50)
N = mcts.root.N
N = np.power(N, 1.0 / temp)
prob = N / np.sum(N)
print("prob: {}".format(prob))
action = int(np.random.binomial(1, prob[1]))
return action
return action

View File

@ -6,7 +6,7 @@ if __name__ == '__main__':
print("Our game has 2 players.")
print("Player 1 has color 1 and plays first. Player 2 has color -1 and plays following player 1.")
print("Both player choose 1 or 0 for an action.")
size = 1
size = 2
print("This game has {} iterations".format(size))
print("If the final sequence has more 1 that 0, player 1 wins.")
print("If the final sequence has less 1 that 0, player 2 wins.")
@ -34,4 +34,4 @@ if __name__ == '__main__':
break
print("The choice sequence is {}".format(seq))
print("The game result is {}".format(winner))
print("The game result is {}".format(winner))

View File

@ -187,7 +187,7 @@ class MCTS(object):
prior, value = self.evaluator(next_action.next_state)
next_action.expansion(prior, self.action_num)
else:
value = 0
value = 0.
t2 = time.time()
if self.inverse:
next_action.backpropagation(-value + 0.)