From 739d360d9d43f4ce2355afed9a468773fc1a65a3 Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Sat, 31 Mar 2018 19:26:48 +0800 Subject: [PATCH] fix episode_cutoff --- tianshou/data/data_collector.py | 1 + tianshou/data/tester.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/tianshou/data/data_collector.py b/tianshou/data/data_collector.py index 42efd8e..5f241b1 100644 --- a/tianshou/data/data_collector.py +++ b/tianshou/data/data_collector.py @@ -61,6 +61,7 @@ class DataCollector(object): num_episodes_ = int(num_episodes) for _ in range(num_episodes_): observation = self.env.reset() + self.policy.reset() done = False step_count = 0 while not done: diff --git a/tianshou/data/tester.py b/tianshou/data/tester.py index 8983b7c..2c4990a 100644 --- a/tianshou/data/tester.py +++ b/tianshou/data/tester.py @@ -48,12 +48,16 @@ def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_fa if num_timesteps > 0: current_discount = 1. observation = env_.reset() + step_count_this_episode = 0 for _ in range(num_timesteps): action = policy.act_test(observation) observation, reward, done, _ = env_.step(action) current_return += reward * current_discount current_undiscounted_return += reward current_discount *= discount_factor + step_count_this_episode += 1 + if step_count_this_episode >= episode_cutoff: + done = True if done: returns.append(current_return) @@ -62,6 +66,7 @@ def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_fa current_undiscounted_return = 0. current_discount = 1. observation = env_.reset() + step_count_this_episode = 0 # log if returns: # has at least one finished episode