fix reward_metric & n_episode bug in on policy algorithm (#306)

This commit is contained in:
ChenDRAG 2021-03-08 14:35:30 +08:00 committed by GitHub
parent e605bdea94
commit 5c53f8c1f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -103,7 +103,7 @@ def onpolicy_trainer(
train_fn(epoch, env_step)
result = train_collector.collect(n_step=step_per_collect,
n_episode=episode_per_collect)
if reward_metric:
if result["n/ep"] > 0 and reward_metric:
result["rews"] = reward_metric(result["rews"])
env_step += int(result["n/st"])
t.update(result["n/st"])
@ -117,19 +117,20 @@ def onpolicy_trainer(
"n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])),
}
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test, logger, env_step)
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
test_result["rew"], test_result["rew_std"])
else:
policy.train()
if result["n/ep"] > 0:
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test, logger, env_step)
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
test_result["rew"], test_result["rew_std"])
else:
policy.train()
losses = policy.update(
0, train_collector.buffer,
batch_size=batch_size, repeat=repeat_per_collect)
@ -147,7 +148,7 @@ def onpolicy_trainer(
t.update()
# test
test_result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, logger, env_step)
episode_per_test, logger, env_step, reward_metric)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch == -1 or best_reward < rew:
best_reward, best_reward_std = rew, rew_std