fix reward_metric & n_episode bug in on policy algorithm (#306)
This commit is contained in:
parent
e605bdea94
commit
5c53f8c1f8
@ -103,7 +103,7 @@ def onpolicy_trainer(
|
|||||||
train_fn(epoch, env_step)
|
train_fn(epoch, env_step)
|
||||||
result = train_collector.collect(n_step=step_per_collect,
|
result = train_collector.collect(n_step=step_per_collect,
|
||||||
n_episode=episode_per_collect)
|
n_episode=episode_per_collect)
|
||||||
if reward_metric:
|
if result["n/ep"] > 0 and reward_metric:
|
||||||
result["rews"] = reward_metric(result["rews"])
|
result["rews"] = reward_metric(result["rews"])
|
||||||
env_step += int(result["n/st"])
|
env_step += int(result["n/st"])
|
||||||
t.update(result["n/st"])
|
t.update(result["n/st"])
|
||||||
@ -117,6 +117,7 @@ def onpolicy_trainer(
|
|||||||
"n/ep": str(int(result["n/ep"])),
|
"n/ep": str(int(result["n/ep"])),
|
||||||
"n/st": str(int(result["n/st"])),
|
"n/st": str(int(result["n/st"])),
|
||||||
}
|
}
|
||||||
|
if result["n/ep"] > 0:
|
||||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
policy, test_collector, test_fn,
|
policy, test_collector, test_fn,
|
||||||
@ -147,7 +148,7 @@ def onpolicy_trainer(
|
|||||||
t.update()
|
t.update()
|
||||||
# test
|
# test
|
||||||
test_result = test_episode(policy, test_collector, test_fn, epoch,
|
test_result = test_episode(policy, test_collector, test_fn, epoch,
|
||||||
episode_per_test, logger, env_step)
|
episode_per_test, logger, env_step, reward_metric)
|
||||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||||
if best_epoch == -1 or best_reward < rew:
|
if best_epoch == -1 or best_reward < rew:
|
||||||
best_reward, best_reward_std = rew, rew_std
|
best_reward, best_reward_std = rew, rew_std
|
||||||
|
Loading…
x
Reference in New Issue
Block a user