update save_fn in trainer (#459)

- collector.collect() now returns 4 extra keys: rew/rew_std/len/len_std (previously this work is done in logger)
- save_fn() will be called at the beginning of trainer
This commit is contained in:
Jiayi Weng 2021-10-13 09:25:24 -04:00 committed by GitHub
parent e45e2096d8
commit 926ec0b9b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 41 additions and 22 deletions

View File

@ -66,7 +66,7 @@ setup(
"isort", "isort",
"pytest", "pytest",
"pytest-cov", "pytest-cov",
"ray>=1.0.0", "ray>=1.0.0,<1.7.0",
"wandb>=0.12.0", "wandb>=0.12.0",
"networkx", "networkx",
"mypy", "mypy",

View File

@ -167,6 +167,10 @@ class Collector(object):
* ``rews`` array of episode reward over collected episodes. * ``rews`` array of episode reward over collected episodes.
* ``lens`` array of episode length over collected episodes. * ``lens`` array of episode length over collected episodes.
* ``idxs`` array of episode start index in buffer over collected episodes. * ``idxs`` array of episode start index in buffer over collected episodes.
* ``rew`` mean of episodic rewards.
* ``len`` mean of episodic lengths.
* ``rew_std`` standard error of episodic rewards.
* ``len_std`` standard error of episodic lengths.
""" """
assert not self.env.is_async, "Please use AsyncCollector if using async venv." assert not self.env.is_async, "Please use AsyncCollector if using async venv."
if n_step is not None: if n_step is not None:
@ -311,8 +315,11 @@ class Collector(object):
[episode_rews, episode_lens, episode_start_indices] [episode_rews, episode_lens, episode_start_indices]
) )
) )
rew_mean, rew_std = rews.mean(), rews.std()
len_mean, len_std = lens.mean(), lens.std()
else: else:
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int) rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
rew_mean = rew_std = len_mean = len_std = 0
return { return {
"n/ep": episode_count, "n/ep": episode_count,
@ -320,6 +327,10 @@ class Collector(object):
"rews": rews, "rews": rews,
"lens": lens, "lens": lens,
"idxs": idxs, "idxs": idxs,
"rew": rew_mean,
"len": len_mean,
"rew_std": rew_std,
"len_std": len_std,
} }
@ -380,6 +391,10 @@ class AsyncCollector(Collector):
* ``rews`` array of episode reward over collected episodes. * ``rews`` array of episode reward over collected episodes.
* ``lens`` array of episode length over collected episodes. * ``lens`` array of episode length over collected episodes.
* ``idxs`` array of episode start index in buffer over collected episodes. * ``idxs`` array of episode start index in buffer over collected episodes.
* ``rew`` mean of episodic rewards.
* ``len`` mean of episodic lengths.
* ``rew_std`` standard error of episodic rewards.
* ``len_std`` standard error of episodic lengths.
""" """
# collect at least n_step or n_episode # collect at least n_step or n_episode
if n_step is not None: if n_step is not None:
@ -530,8 +545,11 @@ class AsyncCollector(Collector):
[episode_rews, episode_lens, episode_start_indices] [episode_rews, episode_lens, episode_start_indices]
) )
) )
rew_mean, rew_std = rews.mean(), rews.std()
len_mean, len_std = lens.mean(), lens.std()
else: else:
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int) rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
rew_mean = rew_std = len_mean = len_std = 0
return { return {
"n/ep": episode_count, "n/ep": episode_count,
@ -539,4 +557,8 @@ class AsyncCollector(Collector):
"rews": rews, "rews": rews,
"lens": lens, "lens": lens,
"idxs": idxs, "idxs": idxs,
"rew": rew_mean,
"len": len_mean,
"rew_std": rew_std,
"len_std": len_std,
} }

View File

@ -6,6 +6,7 @@ from torch import nn
from tianshou.data import Batch, ReplayBuffer, to_torch_as from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.policy import A2CPolicy from tianshou.policy import A2CPolicy
from tianshou.utils.net.common import ActorCritic
class PPOPolicy(A2CPolicy): class PPOPolicy(A2CPolicy):
@ -83,6 +84,7 @@ class PPOPolicy(A2CPolicy):
"value clip is available only when `reward_normalization` is True" "value clip is available only when `reward_normalization` is True"
self._norm_adv = advantage_normalization self._norm_adv = advantage_normalization
self._recompute_adv = recompute_advantage self._recompute_adv = recompute_advantage
self._actor_critic: ActorCritic
def process_fn( def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray

View File

@ -81,6 +81,8 @@ def offline_trainer(
) )
best_epoch = start_epoch best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if save_fn:
save_fn(policy)
for epoch in range(1 + start_epoch, 1 + max_epoch): for epoch in range(1 + start_epoch, 1 + max_epoch):
policy.train() policy.train()

View File

@ -98,6 +98,8 @@ def offpolicy_trainer(
) )
best_epoch = start_epoch best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if save_fn:
save_fn(policy)
for epoch in range(1 + start_epoch, 1 + max_epoch): for epoch in range(1 + start_epoch, 1 + max_epoch):
# train # train
@ -110,7 +112,8 @@ def offpolicy_trainer(
train_fn(epoch, env_step) train_fn(epoch, env_step)
result = train_collector.collect(n_step=step_per_collect) result = train_collector.collect(n_step=step_per_collect)
if result["n/ep"] > 0 and reward_metric: if result["n/ep"] > 0 and reward_metric:
result["rews"] = reward_metric(result["rews"]) rew = reward_metric(result["rews"])
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
env_step += int(result["n/st"]) env_step += int(result["n/st"])
t.update(result["n/st"]) t.update(result["n/st"])
logger.log_train_data(result, env_step) logger.log_train_data(result, env_step)

View File

@ -104,6 +104,8 @@ def onpolicy_trainer(
) )
best_epoch = start_epoch best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if save_fn:
save_fn(policy)
for epoch in range(1 + start_epoch, 1 + max_epoch): for epoch in range(1 + start_epoch, 1 + max_epoch):
# train # train
@ -118,7 +120,8 @@ def onpolicy_trainer(
n_step=step_per_collect, n_episode=episode_per_collect n_step=step_per_collect, n_episode=episode_per_collect
) )
if result["n/ep"] > 0 and reward_metric: if result["n/ep"] > 0 and reward_metric:
result["rews"] = reward_metric(result["rews"]) rew = reward_metric(result["rews"])
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
env_step += int(result["n/st"]) env_step += int(result["n/st"])
t.update(result["n/st"]) t.update(result["n/st"])
logger.log_train_data(result, env_step) logger.log_train_data(result, env_step)

View File

@ -26,7 +26,8 @@ def test_episode(
test_fn(epoch, global_step) test_fn(epoch, global_step)
result = collector.collect(n_episode=n_episode) result = collector.collect(n_episode=n_episode)
if reward_metric: if reward_metric:
result["rews"] = reward_metric(result["rews"]) rew = reward_metric(result["rews"])
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
if logger and global_step is not None: if logger and global_step is not None:
logger.log_test_data(result, global_step) logger.log_test_data(result, global_step)
return result return result

View File

@ -47,14 +47,8 @@ class BaseLogger(ABC):
:param collect_result: a dict containing information of data collected in :param collect_result: a dict containing information of data collected in
training stage, i.e., returns of collector.collect(). training stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged. :param int step: stands for the timestep the collect_result being logged.
.. note::
``collect_result`` will be modified in-place with "rew" and "len" keys.
""" """
if collect_result["n/ep"] > 0: if collect_result["n/ep"] > 0:
collect_result["rew"] = collect_result["rews"].mean()
collect_result["len"] = collect_result["lens"].mean()
if step - self.last_log_train_step >= self.train_interval: if step - self.last_log_train_step >= self.train_interval:
log_data = { log_data = {
"train/episode": collect_result["n/ep"], "train/episode": collect_result["n/ep"],
@ -70,23 +64,15 @@ class BaseLogger(ABC):
:param collect_result: a dict containing information of data collected in :param collect_result: a dict containing information of data collected in
evaluating stage, i.e., returns of collector.collect(). evaluating stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged. :param int step: stands for the timestep the collect_result being logged.
.. note::
``collect_result`` will be modified in-place with "rew", "rew_std", "len",
and "len_std" keys.
""" """
assert collect_result["n/ep"] > 0 assert collect_result["n/ep"] > 0
rews, lens = collect_result["rews"], collect_result["lens"]
rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std()
collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std)
if step - self.last_log_test_step >= self.test_interval: if step - self.last_log_test_step >= self.test_interval:
log_data = { log_data = {
"test/env_step": step, "test/env_step": step,
"test/reward": rew, "test/reward": collect_result["rew"],
"test/length": len_, "test/length": collect_result["len"],
"test/reward_std": rew_std, "test/reward_std": collect_result["rew_std"],
"test/length_std": len_std, "test/length_std": collect_result["len_std"],
} }
self.write("test/env_step", step, log_data) self.write("test/env_step", step, log_data)
self.last_log_test_step = step self.last_log_test_step = step