2017-11-17 19:35:20 +08:00
|
|
|
import numpy as np
|
|
|
|
|
2017-11-21 22:19:52 +08:00
|
|
|
|
2017-11-17 19:35:20 +08:00
|
|
|
class evaluator(object):
|
|
|
|
def __init__(self, env, action_num):
|
|
|
|
self.env = env
|
|
|
|
self.action_num = action_num
|
|
|
|
|
|
|
|
def __call__(self, state):
|
|
|
|
raise NotImplementedError("Need to implement the evaluator")
|
|
|
|
|
2017-11-21 22:19:52 +08:00
|
|
|
|
2017-11-17 19:35:20 +08:00
|
|
|
class rollout_policy(evaluator):
|
|
|
|
def __init__(self, env, action_num):
|
|
|
|
super(rollout_policy, self).__init__(env, action_num)
|
|
|
|
self.is_terminated = False
|
|
|
|
|
|
|
|
def __call__(self, state):
|
|
|
|
# TODO: prior for rollout policy
|
2017-11-21 22:52:17 +08:00
|
|
|
total_reward = 0.
|
2017-11-21 22:19:52 +08:00
|
|
|
action = np.random.randint(0, self.action_num)
|
2017-12-19 16:51:50 +08:00
|
|
|
state, reward = self.env.simulate_step_forward(state, action)
|
2017-11-21 22:52:17 +08:00
|
|
|
total_reward += reward
|
2017-11-21 22:19:52 +08:00
|
|
|
while state is not None:
|
|
|
|
action = np.random.randint(0, self.action_num)
|
2017-12-19 16:51:50 +08:00
|
|
|
state, reward = self.env.simulate_step_forward(state, action)
|
2017-11-21 22:19:52 +08:00
|
|
|
total_reward += reward
|
2017-11-26 13:36:52 +08:00
|
|
|
return np.ones([self.action_num])/self.action_num, total_reward
|