Tianshou/tianshou/core/mcts/mcts_test.py

35 lines
1.3 KiB
Python
Raw Normal View History

2017-11-16 13:21:27 +08:00
import numpy as np
2017-11-16 17:05:54 +08:00
from mcts import MCTS
2017-11-17 15:09:07 +08:00
import matplotlib.pyplot as plt
2017-11-16 13:21:27 +08:00
class TestEnv:
def __init__(self, max_step=5):
self.max_step = max_step
self.reward = {i:np.random.uniform() for i in range(2**max_step)}
self.best = max(self.reward.items(), key=lambda x:x[1])
print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
2017-11-16 17:05:54 +08:00
def step_forward(self, state, action):
if action != 0 and action != 1:
raise ValueError("Action must be 0 or 1!")
if state[0] >= 2**state[1] or state[1] >= self.max_step:
raise ValueError("Invalid State!")
2017-11-17 15:09:07 +08:00
# print("Operate action {} at state {}, timestep {}".format(action, state[0], state[1]))
2017-11-16 17:05:54 +08:00
state[0] = state[0] + 2**state[1]*action
state[1] = state[1] + 1
2017-11-17 15:09:07 +08:00
return state
def evaluator(self, state):
2017-11-16 17:05:54 +08:00
if state[1] == self.max_step:
reward = int(np.random.uniform() > self.reward[state[0]])
2017-11-17 15:09:07 +08:00
is_terminated = True
2017-11-16 13:21:27 +08:00
else:
reward = 0
2017-11-17 15:09:07 +08:00
is_terminated = False
return reward, is_terminated
2017-11-16 13:21:27 +08:00
if __name__=="__main__":
env = TestEnv(1)
2017-11-17 15:09:07 +08:00
evaluator = lambda state: env.evaluator(state)
mcts = MCTS(env, evaluator, [0,0], 2, np.ones([2])/2, max_step=1e4)