update save_fn in trainer (#459)
- collector.collect() now returns 4 extra keys: rew/rew_std/len/len_std (previously this work is done in logger) - save_fn() will be called at the beginning of trainer
This commit is contained in:
parent
e45e2096d8
commit
926ec0b9b1
2
setup.py
2
setup.py
@ -66,7 +66,7 @@ setup(
|
|||||||
"isort",
|
"isort",
|
||||||
"pytest",
|
"pytest",
|
||||||
"pytest-cov",
|
"pytest-cov",
|
||||||
"ray>=1.0.0",
|
"ray>=1.0.0,<1.7.0",
|
||||||
"wandb>=0.12.0",
|
"wandb>=0.12.0",
|
||||||
"networkx",
|
"networkx",
|
||||||
"mypy",
|
"mypy",
|
||||||
|
@ -167,6 +167,10 @@ class Collector(object):
|
|||||||
* ``rews`` array of episode reward over collected episodes.
|
* ``rews`` array of episode reward over collected episodes.
|
||||||
* ``lens`` array of episode length over collected episodes.
|
* ``lens`` array of episode length over collected episodes.
|
||||||
* ``idxs`` array of episode start index in buffer over collected episodes.
|
* ``idxs`` array of episode start index in buffer over collected episodes.
|
||||||
|
* ``rew`` mean of episodic rewards.
|
||||||
|
* ``len`` mean of episodic lengths.
|
||||||
|
* ``rew_std`` standard error of episodic rewards.
|
||||||
|
* ``len_std`` standard error of episodic lengths.
|
||||||
"""
|
"""
|
||||||
assert not self.env.is_async, "Please use AsyncCollector if using async venv."
|
assert not self.env.is_async, "Please use AsyncCollector if using async venv."
|
||||||
if n_step is not None:
|
if n_step is not None:
|
||||||
@ -311,8 +315,11 @@ class Collector(object):
|
|||||||
[episode_rews, episode_lens, episode_start_indices]
|
[episode_rews, episode_lens, episode_start_indices]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
rew_mean, rew_std = rews.mean(), rews.std()
|
||||||
|
len_mean, len_std = lens.mean(), lens.std()
|
||||||
else:
|
else:
|
||||||
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
|
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
|
||||||
|
rew_mean = rew_std = len_mean = len_std = 0
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"n/ep": episode_count,
|
"n/ep": episode_count,
|
||||||
@ -320,6 +327,10 @@ class Collector(object):
|
|||||||
"rews": rews,
|
"rews": rews,
|
||||||
"lens": lens,
|
"lens": lens,
|
||||||
"idxs": idxs,
|
"idxs": idxs,
|
||||||
|
"rew": rew_mean,
|
||||||
|
"len": len_mean,
|
||||||
|
"rew_std": rew_std,
|
||||||
|
"len_std": len_std,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -380,6 +391,10 @@ class AsyncCollector(Collector):
|
|||||||
* ``rews`` array of episode reward over collected episodes.
|
* ``rews`` array of episode reward over collected episodes.
|
||||||
* ``lens`` array of episode length over collected episodes.
|
* ``lens`` array of episode length over collected episodes.
|
||||||
* ``idxs`` array of episode start index in buffer over collected episodes.
|
* ``idxs`` array of episode start index in buffer over collected episodes.
|
||||||
|
* ``rew`` mean of episodic rewards.
|
||||||
|
* ``len`` mean of episodic lengths.
|
||||||
|
* ``rew_std`` standard error of episodic rewards.
|
||||||
|
* ``len_std`` standard error of episodic lengths.
|
||||||
"""
|
"""
|
||||||
# collect at least n_step or n_episode
|
# collect at least n_step or n_episode
|
||||||
if n_step is not None:
|
if n_step is not None:
|
||||||
@ -530,8 +545,11 @@ class AsyncCollector(Collector):
|
|||||||
[episode_rews, episode_lens, episode_start_indices]
|
[episode_rews, episode_lens, episode_start_indices]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
rew_mean, rew_std = rews.mean(), rews.std()
|
||||||
|
len_mean, len_std = lens.mean(), lens.std()
|
||||||
else:
|
else:
|
||||||
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
|
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
|
||||||
|
rew_mean = rew_std = len_mean = len_std = 0
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"n/ep": episode_count,
|
"n/ep": episode_count,
|
||||||
@ -539,4 +557,8 @@ class AsyncCollector(Collector):
|
|||||||
"rews": rews,
|
"rews": rews,
|
||||||
"lens": lens,
|
"lens": lens,
|
||||||
"idxs": idxs,
|
"idxs": idxs,
|
||||||
|
"rew": rew_mean,
|
||||||
|
"len": len_mean,
|
||||||
|
"rew_std": rew_std,
|
||||||
|
"len_std": len_std,
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ from torch import nn
|
|||||||
|
|
||||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||||
from tianshou.policy import A2CPolicy
|
from tianshou.policy import A2CPolicy
|
||||||
|
from tianshou.utils.net.common import ActorCritic
|
||||||
|
|
||||||
|
|
||||||
class PPOPolicy(A2CPolicy):
|
class PPOPolicy(A2CPolicy):
|
||||||
@ -83,6 +84,7 @@ class PPOPolicy(A2CPolicy):
|
|||||||
"value clip is available only when `reward_normalization` is True"
|
"value clip is available only when `reward_normalization` is True"
|
||||||
self._norm_adv = advantage_normalization
|
self._norm_adv = advantage_normalization
|
||||||
self._recompute_adv = recompute_advantage
|
self._recompute_adv = recompute_advantage
|
||||||
|
self._actor_critic: ActorCritic
|
||||||
|
|
||||||
def process_fn(
|
def process_fn(
|
||||||
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
||||||
|
@ -81,6 +81,8 @@ def offline_trainer(
|
|||||||
)
|
)
|
||||||
best_epoch = start_epoch
|
best_epoch = start_epoch
|
||||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||||
|
if save_fn:
|
||||||
|
save_fn(policy)
|
||||||
|
|
||||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
||||||
policy.train()
|
policy.train()
|
||||||
|
@ -98,6 +98,8 @@ def offpolicy_trainer(
|
|||||||
)
|
)
|
||||||
best_epoch = start_epoch
|
best_epoch = start_epoch
|
||||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||||
|
if save_fn:
|
||||||
|
save_fn(policy)
|
||||||
|
|
||||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
||||||
# train
|
# train
|
||||||
@ -110,7 +112,8 @@ def offpolicy_trainer(
|
|||||||
train_fn(epoch, env_step)
|
train_fn(epoch, env_step)
|
||||||
result = train_collector.collect(n_step=step_per_collect)
|
result = train_collector.collect(n_step=step_per_collect)
|
||||||
if result["n/ep"] > 0 and reward_metric:
|
if result["n/ep"] > 0 and reward_metric:
|
||||||
result["rews"] = reward_metric(result["rews"])
|
rew = reward_metric(result["rews"])
|
||||||
|
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
|
||||||
env_step += int(result["n/st"])
|
env_step += int(result["n/st"])
|
||||||
t.update(result["n/st"])
|
t.update(result["n/st"])
|
||||||
logger.log_train_data(result, env_step)
|
logger.log_train_data(result, env_step)
|
||||||
|
@ -104,6 +104,8 @@ def onpolicy_trainer(
|
|||||||
)
|
)
|
||||||
best_epoch = start_epoch
|
best_epoch = start_epoch
|
||||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||||
|
if save_fn:
|
||||||
|
save_fn(policy)
|
||||||
|
|
||||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
||||||
# train
|
# train
|
||||||
@ -118,7 +120,8 @@ def onpolicy_trainer(
|
|||||||
n_step=step_per_collect, n_episode=episode_per_collect
|
n_step=step_per_collect, n_episode=episode_per_collect
|
||||||
)
|
)
|
||||||
if result["n/ep"] > 0 and reward_metric:
|
if result["n/ep"] > 0 and reward_metric:
|
||||||
result["rews"] = reward_metric(result["rews"])
|
rew = reward_metric(result["rews"])
|
||||||
|
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
|
||||||
env_step += int(result["n/st"])
|
env_step += int(result["n/st"])
|
||||||
t.update(result["n/st"])
|
t.update(result["n/st"])
|
||||||
logger.log_train_data(result, env_step)
|
logger.log_train_data(result, env_step)
|
||||||
|
@ -26,7 +26,8 @@ def test_episode(
|
|||||||
test_fn(epoch, global_step)
|
test_fn(epoch, global_step)
|
||||||
result = collector.collect(n_episode=n_episode)
|
result = collector.collect(n_episode=n_episode)
|
||||||
if reward_metric:
|
if reward_metric:
|
||||||
result["rews"] = reward_metric(result["rews"])
|
rew = reward_metric(result["rews"])
|
||||||
|
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
|
||||||
if logger and global_step is not None:
|
if logger and global_step is not None:
|
||||||
logger.log_test_data(result, global_step)
|
logger.log_test_data(result, global_step)
|
||||||
return result
|
return result
|
||||||
|
@ -47,14 +47,8 @@ class BaseLogger(ABC):
|
|||||||
:param collect_result: a dict containing information of data collected in
|
:param collect_result: a dict containing information of data collected in
|
||||||
training stage, i.e., returns of collector.collect().
|
training stage, i.e., returns of collector.collect().
|
||||||
:param int step: stands for the timestep the collect_result being logged.
|
:param int step: stands for the timestep the collect_result being logged.
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
``collect_result`` will be modified in-place with "rew" and "len" keys.
|
|
||||||
"""
|
"""
|
||||||
if collect_result["n/ep"] > 0:
|
if collect_result["n/ep"] > 0:
|
||||||
collect_result["rew"] = collect_result["rews"].mean()
|
|
||||||
collect_result["len"] = collect_result["lens"].mean()
|
|
||||||
if step - self.last_log_train_step >= self.train_interval:
|
if step - self.last_log_train_step >= self.train_interval:
|
||||||
log_data = {
|
log_data = {
|
||||||
"train/episode": collect_result["n/ep"],
|
"train/episode": collect_result["n/ep"],
|
||||||
@ -70,23 +64,15 @@ class BaseLogger(ABC):
|
|||||||
:param collect_result: a dict containing information of data collected in
|
:param collect_result: a dict containing information of data collected in
|
||||||
evaluating stage, i.e., returns of collector.collect().
|
evaluating stage, i.e., returns of collector.collect().
|
||||||
:param int step: stands for the timestep the collect_result being logged.
|
:param int step: stands for the timestep the collect_result being logged.
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
``collect_result`` will be modified in-place with "rew", "rew_std", "len",
|
|
||||||
and "len_std" keys.
|
|
||||||
"""
|
"""
|
||||||
assert collect_result["n/ep"] > 0
|
assert collect_result["n/ep"] > 0
|
||||||
rews, lens = collect_result["rews"], collect_result["lens"]
|
|
||||||
rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std()
|
|
||||||
collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std)
|
|
||||||
if step - self.last_log_test_step >= self.test_interval:
|
if step - self.last_log_test_step >= self.test_interval:
|
||||||
log_data = {
|
log_data = {
|
||||||
"test/env_step": step,
|
"test/env_step": step,
|
||||||
"test/reward": rew,
|
"test/reward": collect_result["rew"],
|
||||||
"test/length": len_,
|
"test/length": collect_result["len"],
|
||||||
"test/reward_std": rew_std,
|
"test/reward_std": collect_result["rew_std"],
|
||||||
"test/length_std": len_std,
|
"test/length_std": collect_result["len_std"],
|
||||||
}
|
}
|
||||||
self.write("test/env_step", step, log_data)
|
self.write("test/env_step", step, log_data)
|
||||||
self.last_log_test_step = step
|
self.last_log_test_step = step
|
||||||
|
Loading…
x
Reference in New Issue
Block a user