debug for unit test
This commit is contained in:
parent
50e8ea36e8
commit
01f39f40d3
@ -29,17 +29,10 @@ class ZOTree:
|
||||
length = len(seq)
|
||||
if length != self.depth:
|
||||
raise ValueError("The game is not terminated!")
|
||||
ones = 0
|
||||
zeros = 0
|
||||
for i in range(len(seq)):
|
||||
if seq[i] == 0:
|
||||
zeros += 1
|
||||
if seq[i] == 1:
|
||||
ones += 1
|
||||
result = ones - zeros
|
||||
if result > 0:
|
||||
result = np.sum(seq)
|
||||
if result > self.size:
|
||||
winner = 1
|
||||
elif result < 0:
|
||||
elif result < self.size:
|
||||
winner = -1
|
||||
else:
|
||||
winner = 0
|
||||
@ -98,7 +91,7 @@ class ZOTree:
|
||||
if __name__ == "__main__":
|
||||
size = 2
|
||||
game = ZOTree(size)
|
||||
seq = [1, -1, 1, 1]
|
||||
seq = [1, 0, 1, 1]
|
||||
result = game.executor_do_move([seq, 1], 1)
|
||||
print(result)
|
||||
print(seq)
|
||||
print(seq)
|
||||
|
@ -17,11 +17,11 @@ class Agent:
|
||||
def gen_move(self, seq):
|
||||
if len(seq) >= 2 * self.size:
|
||||
raise ValueError("Game is terminated.")
|
||||
mcts = MCTS(self.simulator, self.evaluator, [seq, self.color], 2)
|
||||
mcts = MCTS(self.simulator, self.evaluator, [seq, self.color], 2, inverse=True)
|
||||
mcts.search(max_step=50)
|
||||
N = mcts.root.N
|
||||
N = np.power(N, 1.0 / temp)
|
||||
prob = N / np.sum(N)
|
||||
print("prob: {}".format(prob))
|
||||
action = int(np.random.binomial(1, prob[1]))
|
||||
return action
|
||||
return action
|
||||
|
@ -6,7 +6,7 @@ if __name__ == '__main__':
|
||||
print("Our game has 2 players.")
|
||||
print("Player 1 has color 1 and plays first. Player 2 has color -1 and plays following player 1.")
|
||||
print("Both player choose 1 or 0 for an action.")
|
||||
size = 1
|
||||
size = 2
|
||||
print("This game has {} iterations".format(size))
|
||||
print("If the final sequence has more 1 that 0, player 1 wins.")
|
||||
print("If the final sequence has less 1 that 0, player 2 wins.")
|
||||
@ -34,4 +34,4 @@ if __name__ == '__main__':
|
||||
break
|
||||
|
||||
print("The choice sequence is {}".format(seq))
|
||||
print("The game result is {}".format(winner))
|
||||
print("The game result is {}".format(winner))
|
||||
|
@ -187,7 +187,7 @@ class MCTS(object):
|
||||
prior, value = self.evaluator(next_action.next_state)
|
||||
next_action.expansion(prior, self.action_num)
|
||||
else:
|
||||
value = 0
|
||||
value = 0.
|
||||
t2 = time.time()
|
||||
if self.inverse:
|
||||
next_action.backpropagation(-value + 0.)
|
||||
|
Loading…
x
Reference in New Issue
Block a user