fix episode_cutoff
This commit is contained in:
parent
ace59787ed
commit
739d360d9d
@ -61,6 +61,7 @@ class DataCollector(object):
|
||||
num_episodes_ = int(num_episodes)
|
||||
for _ in range(num_episodes_):
|
||||
observation = self.env.reset()
|
||||
self.policy.reset()
|
||||
done = False
|
||||
step_count = 0
|
||||
while not done:
|
||||
|
@ -48,12 +48,16 @@ def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_fa
|
||||
if num_timesteps > 0:
|
||||
current_discount = 1.
|
||||
observation = env_.reset()
|
||||
step_count_this_episode = 0
|
||||
for _ in range(num_timesteps):
|
||||
action = policy.act_test(observation)
|
||||
observation, reward, done, _ = env_.step(action)
|
||||
current_return += reward * current_discount
|
||||
current_undiscounted_return += reward
|
||||
current_discount *= discount_factor
|
||||
step_count_this_episode += 1
|
||||
if step_count_this_episode >= episode_cutoff:
|
||||
done = True
|
||||
|
||||
if done:
|
||||
returns.append(current_return)
|
||||
@ -62,6 +66,7 @@ def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_fa
|
||||
current_undiscounted_return = 0.
|
||||
current_discount = 1.
|
||||
observation = env_.reset()
|
||||
step_count_this_episode = 0
|
||||
|
||||
# log
|
||||
if returns: # has at least one finished episode
|
||||
|
Loading…
x
Reference in New Issue
Block a user