Tianshou/tianshou/core/mcts/mcts_test.py

40 lines
1.4 KiB
Python
Raw Normal View History

2017-11-16 13:21:27 +08:00
import numpy as np
2017-11-16 17:05:54 +08:00
from mcts import MCTS
2017-11-21 22:19:52 +08:00
from evaluator import rollout_policy
2017-11-16 13:21:27 +08:00
class TestEnv:
def __init__(self, max_step=5):
self.max_step = max_step
2017-11-21 22:19:52 +08:00
self.reward = {i: np.random.uniform() for i in range(2 ** max_step)}
2017-11-21 22:52:17 +08:00
# self.reward = {0:1, 1:0}
2017-11-21 22:19:52 +08:00
self.best = max(self.reward.items(), key=lambda x: x[1])
2017-11-17 19:35:20 +08:00
print(self.reward)
2017-11-21 22:52:17 +08:00
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
2017-11-16 13:21:27 +08:00
2017-11-16 17:05:54 +08:00
def step_forward(self, state, action):
if action != 0 and action != 1:
2017-11-17 19:35:20 +08:00
raise ValueError("Action must be 0 or 1! Your action is {}".format(action))
2017-11-21 22:19:52 +08:00
if state[0] >= 2 ** state[1] or state[1] > self.max_step:
2017-11-17 19:35:20 +08:00
raise ValueError("Invalid State! Your state is {}".format(state))
2017-11-17 15:09:07 +08:00
# print("Operate action {} at state {}, timestep {}".format(action, state[0], state[1]))
2017-11-21 22:19:52 +08:00
if state[1] == self.max_step:
new_state = None
2017-11-16 13:21:27 +08:00
reward = 0
2017-11-21 22:19:52 +08:00
else:
num = state[0] + 2 ** state[1] * action
step = state[1] + 1
2017-11-26 13:36:52 +08:00
new_state = [num, step]
2017-11-21 22:19:52 +08:00
if step == self.max_step:
2017-11-21 22:52:17 +08:00
reward = int(np.random.uniform() < self.reward[num])
2017-11-21 22:19:52 +08:00
else:
2017-12-05 23:17:20 +08:00
reward = 0.
2017-11-21 22:19:52 +08:00
return new_state, reward
2017-11-16 13:21:27 +08:00
2017-11-21 22:19:52 +08:00
if __name__ == "__main__":
2017-11-21 22:52:17 +08:00
env = TestEnv(2)
2017-11-21 22:19:52 +08:00
rollout = rollout_policy(env, 2)
evaluator = lambda state: rollout(state)
2017-12-05 23:17:20 +08:00
mcts = MCTS(env, evaluator, [0, 0], 2, max_step=1e4)