fix reward_metric & n_episode bug in on policy algorithm (#306)
This commit is contained in:
parent
e605bdea94
commit
5c53f8c1f8
@ -103,7 +103,7 @@ def onpolicy_trainer(
|
||||
train_fn(epoch, env_step)
|
||||
result = train_collector.collect(n_step=step_per_collect,
|
||||
n_episode=episode_per_collect)
|
||||
if reward_metric:
|
||||
if result["n/ep"] > 0 and reward_metric:
|
||||
result["rews"] = reward_metric(result["rews"])
|
||||
env_step += int(result["n/st"])
|
||||
t.update(result["n/st"])
|
||||
@ -117,19 +117,20 @@ def onpolicy_trainer(
|
||||
"n/ep": str(int(result["n/ep"])),
|
||||
"n/st": str(int(result["n/st"])),
|
||||
}
|
||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn,
|
||||
epoch, episode_per_test, logger, env_step)
|
||||
if stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
t.set_postfix(**data)
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector,
|
||||
test_result["rew"], test_result["rew_std"])
|
||||
else:
|
||||
policy.train()
|
||||
if result["n/ep"] > 0:
|
||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn,
|
||||
epoch, episode_per_test, logger, env_step)
|
||||
if stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
t.set_postfix(**data)
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector,
|
||||
test_result["rew"], test_result["rew_std"])
|
||||
else:
|
||||
policy.train()
|
||||
losses = policy.update(
|
||||
0, train_collector.buffer,
|
||||
batch_size=batch_size, repeat=repeat_per_collect)
|
||||
@ -147,7 +148,7 @@ def onpolicy_trainer(
|
||||
t.update()
|
||||
# test
|
||||
test_result = test_episode(policy, test_collector, test_fn, epoch,
|
||||
episode_per_test, logger, env_step)
|
||||
episode_per_test, logger, env_step, reward_metric)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if best_epoch == -1 or best_reward < rew:
|
||||
best_reward, best_reward_std = rew, rew_std
|
||||
|
Loading…
x
Reference in New Issue
Block a user