fix reward_metric & n_episode bug in on policy algorithm ()

This commit is contained in:
ChenDRAG 2021-03-08 14:35:30 +08:00 committed by GitHub
parent e605bdea94
commit 5c53f8c1f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -103,7 +103,7 @@ def onpolicy_trainer(
train_fn(epoch, env_step) train_fn(epoch, env_step)
result = train_collector.collect(n_step=step_per_collect, result = train_collector.collect(n_step=step_per_collect,
n_episode=episode_per_collect) n_episode=episode_per_collect)
if reward_metric: if result["n/ep"] > 0 and reward_metric:
result["rews"] = reward_metric(result["rews"]) result["rews"] = reward_metric(result["rews"])
env_step += int(result["n/st"]) env_step += int(result["n/st"])
t.update(result["n/st"]) t.update(result["n/st"])
@ -117,19 +117,20 @@ def onpolicy_trainer(
"n/ep": str(int(result["n/ep"])), "n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])), "n/st": str(int(result["n/st"])),
} }
if test_in_train and stop_fn and stop_fn(result["rew"]): if result["n/ep"] > 0:
test_result = test_episode( if test_in_train and stop_fn and stop_fn(result["rew"]):
policy, test_collector, test_fn, test_result = test_episode(
epoch, episode_per_test, logger, env_step) policy, test_collector, test_fn,
if stop_fn(test_result["rew"]): epoch, episode_per_test, logger, env_step)
if save_fn: if stop_fn(test_result["rew"]):
save_fn(policy) if save_fn:
t.set_postfix(**data) save_fn(policy)
return gather_info( t.set_postfix(**data)
start_time, train_collector, test_collector, return gather_info(
test_result["rew"], test_result["rew_std"]) start_time, train_collector, test_collector,
else: test_result["rew"], test_result["rew_std"])
policy.train() else:
policy.train()
losses = policy.update( losses = policy.update(
0, train_collector.buffer, 0, train_collector.buffer,
batch_size=batch_size, repeat=repeat_per_collect) batch_size=batch_size, repeat=repeat_per_collect)
@ -147,7 +148,7 @@ def onpolicy_trainer(
t.update() t.update()
# test # test
test_result = test_episode(policy, test_collector, test_fn, epoch, test_result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, logger, env_step) episode_per_test, logger, env_step, reward_metric)
rew, rew_std = test_result["rew"], test_result["rew_std"] rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch == -1 or best_reward < rew: if best_epoch == -1 or best_reward < rew:
best_reward, best_reward_std = rew, rew_std best_reward, best_reward_std = rew, rew_std