fix episode_cutoff
This commit is contained in:
parent
ace59787ed
commit
739d360d9d
@ -61,6 +61,7 @@ class DataCollector(object):
|
|||||||
num_episodes_ = int(num_episodes)
|
num_episodes_ = int(num_episodes)
|
||||||
for _ in range(num_episodes_):
|
for _ in range(num_episodes_):
|
||||||
observation = self.env.reset()
|
observation = self.env.reset()
|
||||||
|
self.policy.reset()
|
||||||
done = False
|
done = False
|
||||||
step_count = 0
|
step_count = 0
|
||||||
while not done:
|
while not done:
|
||||||
|
@ -48,12 +48,16 @@ def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_fa
|
|||||||
if num_timesteps > 0:
|
if num_timesteps > 0:
|
||||||
current_discount = 1.
|
current_discount = 1.
|
||||||
observation = env_.reset()
|
observation = env_.reset()
|
||||||
|
step_count_this_episode = 0
|
||||||
for _ in range(num_timesteps):
|
for _ in range(num_timesteps):
|
||||||
action = policy.act_test(observation)
|
action = policy.act_test(observation)
|
||||||
observation, reward, done, _ = env_.step(action)
|
observation, reward, done, _ = env_.step(action)
|
||||||
current_return += reward * current_discount
|
current_return += reward * current_discount
|
||||||
current_undiscounted_return += reward
|
current_undiscounted_return += reward
|
||||||
current_discount *= discount_factor
|
current_discount *= discount_factor
|
||||||
|
step_count_this_episode += 1
|
||||||
|
if step_count_this_episode >= episode_cutoff:
|
||||||
|
done = True
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
returns.append(current_return)
|
returns.append(current_return)
|
||||||
@ -62,6 +66,7 @@ def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_fa
|
|||||||
current_undiscounted_return = 0.
|
current_undiscounted_return = 0.
|
||||||
current_discount = 1.
|
current_discount = 1.
|
||||||
observation = env_.reset()
|
observation = env_.reset()
|
||||||
|
step_count_this_episode = 0
|
||||||
|
|
||||||
# log
|
# log
|
||||||
if returns: # has at least one finished episode
|
if returns: # has at least one finished episode
|
||||||
|
Loading…
x
Reference in New Issue
Block a user