fix episode_cutoff

This commit is contained in:
haoshengzou 2018-03-31 19:26:48 +08:00
parent ace59787ed
commit 739d360d9d
2 changed files with 6 additions and 0 deletions

View File

@ -61,6 +61,7 @@ class DataCollector(object):
num_episodes_ = int(num_episodes)
for _ in range(num_episodes_):
observation = self.env.reset()
self.policy.reset()
done = False
step_count = 0
while not done:

View File

@ -48,12 +48,16 @@ def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_fa
if num_timesteps > 0:
current_discount = 1.
observation = env_.reset()
step_count_this_episode = 0
for _ in range(num_timesteps):
action = policy.act_test(observation)
observation, reward, done, _ = env_.step(action)
current_return += reward * current_discount
current_undiscounted_return += reward
current_discount *= discount_factor
step_count_this_episode += 1
if step_count_this_episode >= episode_cutoff:
done = True
if done:
returns.append(current_return)
@ -62,6 +66,7 @@ def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_fa
current_undiscounted_return = 0.
current_discount = 1.
observation = env_.reset()
step_count_this_episode = 0
# log
if returns: # has at least one finished episode