2017-11-16 13:21:27 +08:00
|
|
|
import numpy as np
|
2017-11-16 17:05:54 +08:00
|
|
|
from mcts import MCTS
|
2017-11-17 15:09:07 +08:00
|
|
|
import matplotlib.pyplot as plt
|
2017-11-16 13:21:27 +08:00
|
|
|
|
|
|
|
class TestEnv:
|
|
|
|
def __init__(self, max_step=5):
|
|
|
|
self.max_step = max_step
|
|
|
|
self.reward = {i:np.random.uniform() for i in range(2**max_step)}
|
2017-11-17 19:35:20 +08:00
|
|
|
# self.reward = {0:0.8, 1:0.2, 2:0.4, 3:0.6}
|
2017-11-16 13:21:27 +08:00
|
|
|
self.best = max(self.reward.items(), key=lambda x:x[1])
|
2017-11-17 19:35:20 +08:00
|
|
|
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
|
|
|
|
print(self.reward)
|
2017-11-16 13:21:27 +08:00
|
|
|
|
2017-11-16 17:05:54 +08:00
|
|
|
def step_forward(self, state, action):
|
|
|
|
if action != 0 and action != 1:
|
2017-11-17 19:35:20 +08:00
|
|
|
raise ValueError("Action must be 0 or 1! Your action is {}".format(action))
|
2017-11-16 17:05:54 +08:00
|
|
|
if state[0] >= 2**state[1] or state[1] >= self.max_step:
|
2017-11-17 19:35:20 +08:00
|
|
|
raise ValueError("Invalid State! Your state is {}".format(state))
|
2017-11-17 15:09:07 +08:00
|
|
|
# print("Operate action {} at state {}, timestep {}".format(action, state[0], state[1]))
|
2017-11-17 19:35:20 +08:00
|
|
|
new_state = [0,0]
|
|
|
|
new_state[0] = state[0] + 2**state[1]*action
|
|
|
|
new_state[1] = state[1] + 1
|
|
|
|
if new_state[1] == self.max_step:
|
|
|
|
reward = int(np.random.uniform() < self.reward[state[0]])
|
2017-11-17 15:09:07 +08:00
|
|
|
is_terminated = True
|
2017-11-16 13:21:27 +08:00
|
|
|
else:
|
|
|
|
reward = 0
|
2017-11-17 15:09:07 +08:00
|
|
|
is_terminated = False
|
2017-11-17 19:35:20 +08:00
|
|
|
return new_state, reward, is_terminated
|
2017-11-16 13:21:27 +08:00
|
|
|
|
|
|
|
if __name__=="__main__":
|
2017-11-17 19:35:20 +08:00
|
|
|
env = TestEnv(3)
|
|
|
|
evaluator = lambda state: env.step_forward(state, action)
|
|
|
|
mcts = MCTS(env, evaluator, [0,0], 2, np.array([0.5,0.5]), max_step=1e4)
|