update save_fn in trainer (#459)
- collector.collect() now returns 4 extra keys: rew/rew_std/len/len_std (previously this work is done in logger) - save_fn() will be called at the beginning of trainer
This commit is contained in:
parent
e45e2096d8
commit
926ec0b9b1
2
setup.py
2
setup.py
@ -66,7 +66,7 @@ setup(
|
||||
"isort",
|
||||
"pytest",
|
||||
"pytest-cov",
|
||||
"ray>=1.0.0",
|
||||
"ray>=1.0.0,<1.7.0",
|
||||
"wandb>=0.12.0",
|
||||
"networkx",
|
||||
"mypy",
|
||||
|
@ -167,6 +167,10 @@ class Collector(object):
|
||||
* ``rews`` array of episode reward over collected episodes.
|
||||
* ``lens`` array of episode length over collected episodes.
|
||||
* ``idxs`` array of episode start index in buffer over collected episodes.
|
||||
* ``rew`` mean of episodic rewards.
|
||||
* ``len`` mean of episodic lengths.
|
||||
* ``rew_std`` standard error of episodic rewards.
|
||||
* ``len_std`` standard error of episodic lengths.
|
||||
"""
|
||||
assert not self.env.is_async, "Please use AsyncCollector if using async venv."
|
||||
if n_step is not None:
|
||||
@ -311,8 +315,11 @@ class Collector(object):
|
||||
[episode_rews, episode_lens, episode_start_indices]
|
||||
)
|
||||
)
|
||||
rew_mean, rew_std = rews.mean(), rews.std()
|
||||
len_mean, len_std = lens.mean(), lens.std()
|
||||
else:
|
||||
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
|
||||
rew_mean = rew_std = len_mean = len_std = 0
|
||||
|
||||
return {
|
||||
"n/ep": episode_count,
|
||||
@ -320,6 +327,10 @@ class Collector(object):
|
||||
"rews": rews,
|
||||
"lens": lens,
|
||||
"idxs": idxs,
|
||||
"rew": rew_mean,
|
||||
"len": len_mean,
|
||||
"rew_std": rew_std,
|
||||
"len_std": len_std,
|
||||
}
|
||||
|
||||
|
||||
@ -380,6 +391,10 @@ class AsyncCollector(Collector):
|
||||
* ``rews`` array of episode reward over collected episodes.
|
||||
* ``lens`` array of episode length over collected episodes.
|
||||
* ``idxs`` array of episode start index in buffer over collected episodes.
|
||||
* ``rew`` mean of episodic rewards.
|
||||
* ``len`` mean of episodic lengths.
|
||||
* ``rew_std`` standard error of episodic rewards.
|
||||
* ``len_std`` standard error of episodic lengths.
|
||||
"""
|
||||
# collect at least n_step or n_episode
|
||||
if n_step is not None:
|
||||
@ -530,8 +545,11 @@ class AsyncCollector(Collector):
|
||||
[episode_rews, episode_lens, episode_start_indices]
|
||||
)
|
||||
)
|
||||
rew_mean, rew_std = rews.mean(), rews.std()
|
||||
len_mean, len_std = lens.mean(), lens.std()
|
||||
else:
|
||||
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
|
||||
rew_mean = rew_std = len_mean = len_std = 0
|
||||
|
||||
return {
|
||||
"n/ep": episode_count,
|
||||
@ -539,4 +557,8 @@ class AsyncCollector(Collector):
|
||||
"rews": rews,
|
||||
"lens": lens,
|
||||
"idxs": idxs,
|
||||
"rew": rew_mean,
|
||||
"len": len_mean,
|
||||
"rew_std": rew_std,
|
||||
"len_std": len_std,
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ from torch import nn
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
|
||||
|
||||
class PPOPolicy(A2CPolicy):
|
||||
@ -83,6 +84,7 @@ class PPOPolicy(A2CPolicy):
|
||||
"value clip is available only when `reward_normalization` is True"
|
||||
self._norm_adv = advantage_normalization
|
||||
self._recompute_adv = recompute_advantage
|
||||
self._actor_critic: ActorCritic
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
||||
|
@ -81,6 +81,8 @@ def offline_trainer(
|
||||
)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
||||
policy.train()
|
||||
|
@ -98,6 +98,8 @@ def offpolicy_trainer(
|
||||
)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
||||
# train
|
||||
@ -110,7 +112,8 @@ def offpolicy_trainer(
|
||||
train_fn(epoch, env_step)
|
||||
result = train_collector.collect(n_step=step_per_collect)
|
||||
if result["n/ep"] > 0 and reward_metric:
|
||||
result["rews"] = reward_metric(result["rews"])
|
||||
rew = reward_metric(result["rews"])
|
||||
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
|
||||
env_step += int(result["n/st"])
|
||||
t.update(result["n/st"])
|
||||
logger.log_train_data(result, env_step)
|
||||
|
@ -104,6 +104,8 @@ def onpolicy_trainer(
|
||||
)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
||||
# train
|
||||
@ -118,7 +120,8 @@ def onpolicy_trainer(
|
||||
n_step=step_per_collect, n_episode=episode_per_collect
|
||||
)
|
||||
if result["n/ep"] > 0 and reward_metric:
|
||||
result["rews"] = reward_metric(result["rews"])
|
||||
rew = reward_metric(result["rews"])
|
||||
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
|
||||
env_step += int(result["n/st"])
|
||||
t.update(result["n/st"])
|
||||
logger.log_train_data(result, env_step)
|
||||
|
@ -26,7 +26,8 @@ def test_episode(
|
||||
test_fn(epoch, global_step)
|
||||
result = collector.collect(n_episode=n_episode)
|
||||
if reward_metric:
|
||||
result["rews"] = reward_metric(result["rews"])
|
||||
rew = reward_metric(result["rews"])
|
||||
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
|
||||
if logger and global_step is not None:
|
||||
logger.log_test_data(result, global_step)
|
||||
return result
|
||||
|
@ -47,14 +47,8 @@ class BaseLogger(ABC):
|
||||
:param collect_result: a dict containing information of data collected in
|
||||
training stage, i.e., returns of collector.collect().
|
||||
:param int step: stands for the timestep the collect_result being logged.
|
||||
|
||||
.. note::
|
||||
|
||||
``collect_result`` will be modified in-place with "rew" and "len" keys.
|
||||
"""
|
||||
if collect_result["n/ep"] > 0:
|
||||
collect_result["rew"] = collect_result["rews"].mean()
|
||||
collect_result["len"] = collect_result["lens"].mean()
|
||||
if step - self.last_log_train_step >= self.train_interval:
|
||||
log_data = {
|
||||
"train/episode": collect_result["n/ep"],
|
||||
@ -70,23 +64,15 @@ class BaseLogger(ABC):
|
||||
:param collect_result: a dict containing information of data collected in
|
||||
evaluating stage, i.e., returns of collector.collect().
|
||||
:param int step: stands for the timestep the collect_result being logged.
|
||||
|
||||
.. note::
|
||||
|
||||
``collect_result`` will be modified in-place with "rew", "rew_std", "len",
|
||||
and "len_std" keys.
|
||||
"""
|
||||
assert collect_result["n/ep"] > 0
|
||||
rews, lens = collect_result["rews"], collect_result["lens"]
|
||||
rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std()
|
||||
collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std)
|
||||
if step - self.last_log_test_step >= self.test_interval:
|
||||
log_data = {
|
||||
"test/env_step": step,
|
||||
"test/reward": rew,
|
||||
"test/length": len_,
|
||||
"test/reward_std": rew_std,
|
||||
"test/length_std": len_std,
|
||||
"test/reward": collect_result["rew"],
|
||||
"test/length": collect_result["len"],
|
||||
"test/reward_std": collect_result["rew_std"],
|
||||
"test/length_std": collect_result["len_std"],
|
||||
}
|
||||
self.write("test/env_step", step, log_data)
|
||||
self.last_log_test_step = step
|
||||
|
Loading…
x
Reference in New Issue
Block a user