Add Trainers as generators (#559)
The new proposed feature is to have trainers as generators. The usage pattern is: ```python trainer = OnPolicyTrainer(...) for epoch, epoch_stat, info in trainer: print(f"Epoch: {epoch}") print(epoch_stat) print(info) do_something_with_policy() query_something_about_policy() make_a_plot_with(epoch_stat) display(info) ``` - epoch int: the epoch number - epoch_stat dict: a large collection of metrics of the current epoch, including stat - info dict: the usual dict out of the non-generator version of the trainer You can even iterate on several different trainers at the same time: ```python trainer1 = OnPolicyTrainer(...) trainer2 = OnPolicyTrainer(...) for result1, result2, ... in zip(trainer1, trainer2, ...): compare_results(result1, result2, ...) ``` Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
parent
2336a7db1b
commit
10d919052b
4
.github/ISSUE_TEMPLATE.md
vendored
4
.github/ISSUE_TEMPLATE.md
vendored
@ -7,6 +7,6 @@
|
|||||||
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
||||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||||
```python
|
```python
|
||||||
import tianshou, torch, numpy, sys
|
import tianshou, gym, torch, numpy, sys
|
||||||
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
|
print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
|
||||||
```
|
```
|
||||||
|
4
Makefile
4
Makefile
@ -22,10 +22,8 @@ lint:
|
|||||||
flake8 ${LINT_PATHS} --count --show-source --statistics
|
flake8 ${LINT_PATHS} --count --show-source --statistics
|
||||||
|
|
||||||
format:
|
format:
|
||||||
# sort imports
|
|
||||||
$(call check_install, isort)
|
$(call check_install, isort)
|
||||||
isort ${LINT_PATHS}
|
isort ${LINT_PATHS}
|
||||||
# reformat using yapf
|
|
||||||
$(call check_install, yapf)
|
$(call check_install, yapf)
|
||||||
yapf -ir ${LINT_PATHS}
|
yapf -ir ${LINT_PATHS}
|
||||||
|
|
||||||
@ -57,6 +55,6 @@ doc-clean:
|
|||||||
|
|
||||||
clean: doc-clean
|
clean: doc-clean
|
||||||
|
|
||||||
commit-checks: format lint mypy check-docstyle spelling
|
commit-checks: lint check-codestyle mypy check-docstyle spelling
|
||||||
|
|
||||||
.PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks
|
.PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks
|
||||||
|
@ -1,7 +1,49 @@
|
|||||||
tianshou.trainer
|
tianshou.trainer
|
||||||
================
|
================
|
||||||
|
|
||||||
.. automodule:: tianshou.trainer
|
|
||||||
|
On-policy
|
||||||
|
---------
|
||||||
|
|
||||||
|
.. autoclass:: tianshou.trainer.OnpolicyTrainer
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autofunction:: tianshou.trainer.onpolicy_trainer
|
||||||
|
|
||||||
|
.. autoclass:: tianshou.trainer.onpolicy_trainer_iter
|
||||||
|
|
||||||
|
|
||||||
|
Off-policy
|
||||||
|
----------
|
||||||
|
|
||||||
|
.. autoclass:: tianshou.trainer.OffpolicyTrainer
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autofunction:: tianshou.trainer.offpolicy_trainer
|
||||||
|
|
||||||
|
.. autoclass:: tianshou.trainer.offpolicy_trainer_iter
|
||||||
|
|
||||||
|
|
||||||
|
Offline
|
||||||
|
-------
|
||||||
|
|
||||||
|
.. autoclass:: tianshou.trainer.OfflineTrainer
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autofunction:: tianshou.trainer.offline_trainer
|
||||||
|
|
||||||
|
.. autoclass:: tianshou.trainer.offline_trainer_iter
|
||||||
|
|
||||||
|
|
||||||
|
utils
|
||||||
|
-----
|
||||||
|
|
||||||
|
.. autofunction:: tianshou.trainer.test_episode
|
||||||
|
|
||||||
|
.. autofunction:: tianshou.trainer.gather_info
|
||||||
|
@ -24,12 +24,15 @@ fqf
|
|||||||
iqn
|
iqn
|
||||||
qrdqn
|
qrdqn
|
||||||
rl
|
rl
|
||||||
|
offpolicy
|
||||||
|
onpolicy
|
||||||
quantile
|
quantile
|
||||||
quantiles
|
quantiles
|
||||||
dqn
|
dqn
|
||||||
param
|
param
|
||||||
async
|
async
|
||||||
subprocess
|
subprocess
|
||||||
|
deque
|
||||||
nn
|
nn
|
||||||
equ
|
equ
|
||||||
cql
|
cql
|
||||||
|
@ -380,6 +380,26 @@ Once you have a collector and a policy, you can start writing the training metho
|
|||||||
|
|
||||||
Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.
|
Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.
|
||||||
|
|
||||||
|
We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic:
|
||||||
|
::
|
||||||
|
|
||||||
|
trainer = OnpolicyTrainer(...)
|
||||||
|
for epoch, epoch_stat, info in trainer:
|
||||||
|
print(f"Epoch: {epoch}")
|
||||||
|
print(epoch_stat)
|
||||||
|
print(info)
|
||||||
|
do_something_with_policy()
|
||||||
|
query_something_about_policy()
|
||||||
|
make_a_plot_with(epoch_stat)
|
||||||
|
display(info)
|
||||||
|
|
||||||
|
# or even iterate on several trainers at the same time
|
||||||
|
|
||||||
|
trainer1 = OnpolicyTrainer(...)
|
||||||
|
trainer2 = OnpolicyTrainer(...)
|
||||||
|
for result1, result2, ... in zip(trainer1, trainer2, ...):
|
||||||
|
compare_results(result1, result2, ...)
|
||||||
|
|
||||||
|
|
||||||
.. _pseudocode:
|
.. _pseudocode:
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import OnpolicyTrainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.utils.net.common import ActorCritic, Net
|
from tianshou.utils.net.common import ActorCritic, Net
|
||||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
@ -157,7 +157,7 @@ def test_ppo(args=get_args()):
|
|||||||
print("Fail to restore policy and optim.")
|
print("Fail to restore policy and optim.")
|
||||||
|
|
||||||
# trainer
|
# trainer
|
||||||
result = onpolicy_trainer(
|
trainer = OnpolicyTrainer(
|
||||||
policy,
|
policy,
|
||||||
train_collector,
|
train_collector,
|
||||||
test_collector,
|
test_collector,
|
||||||
@ -173,10 +173,16 @@ def test_ppo(args=get_args()):
|
|||||||
resume_from_log=args.resume,
|
resume_from_log=args.resume,
|
||||||
save_checkpoint_fn=save_checkpoint_fn
|
save_checkpoint_fn=save_checkpoint_fn
|
||||||
)
|
)
|
||||||
assert stop_fn(result['best_reward'])
|
|
||||||
|
for epoch, epoch_stat, info in trainer:
|
||||||
|
print(f"Epoch: {epoch}")
|
||||||
|
print(epoch_stat)
|
||||||
|
print(info)
|
||||||
|
|
||||||
|
assert stop_fn(info["best_reward"])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(info)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
@ -24,7 +24,7 @@ def get_args():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
parser.add_argument('--reward-threshold', type=float, default=None)
|
parser.add_argument('--reward-threshold', type=float, default=None)
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=1)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||||
|
@ -11,7 +11,7 @@ from tianshou.data import Collector, VectorReplayBuffer
|
|||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.exploration import GaussianNoise
|
from tianshou.exploration import GaussianNoise
|
||||||
from tianshou.policy import TD3Policy
|
from tianshou.policy import TD3Policy
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import OffpolicyTrainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.utils.net.continuous import Actor, Critic
|
from tianshou.utils.net.continuous import Actor, Critic
|
||||||
@ -135,8 +135,8 @@ def test_td3(args=get_args()):
|
|||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
return mean_rewards >= args.reward_threshold
|
return mean_rewards >= args.reward_threshold
|
||||||
|
|
||||||
# trainer
|
# Iterator trainer
|
||||||
result = offpolicy_trainer(
|
trainer = OffpolicyTrainer(
|
||||||
policy,
|
policy,
|
||||||
train_collector,
|
train_collector,
|
||||||
test_collector,
|
test_collector,
|
||||||
@ -148,12 +148,17 @@ def test_td3(args=get_args()):
|
|||||||
update_per_step=args.update_per_step,
|
update_per_step=args.update_per_step,
|
||||||
stop_fn=stop_fn,
|
stop_fn=stop_fn,
|
||||||
save_fn=save_fn,
|
save_fn=save_fn,
|
||||||
logger=logger
|
logger=logger,
|
||||||
)
|
)
|
||||||
assert stop_fn(result['best_reward'])
|
for epoch, epoch_stat, info in trainer:
|
||||||
|
print(f"Epoch: {epoch}")
|
||||||
|
print(epoch_stat)
|
||||||
|
print(info)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
assert stop_fn(info["best_reward"])
|
||||||
pprint.pprint(result)
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pprint.pprint(info)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
@ -12,7 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.policy import CQLPolicy
|
from tianshou.policy import CQLPolicy
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import OfflineTrainer
|
||||||
from tianshou.utils import TensorboardLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
@ -195,7 +195,7 @@ def test_cql(args=get_args()):
|
|||||||
collector.collect(n_episode=1, render=1 / 35)
|
collector.collect(n_episode=1, render=1 / 35)
|
||||||
|
|
||||||
# trainer
|
# trainer
|
||||||
result = offline_trainer(
|
trainer = OfflineTrainer(
|
||||||
policy,
|
policy,
|
||||||
buffer,
|
buffer,
|
||||||
test_collector,
|
test_collector,
|
||||||
@ -207,11 +207,17 @@ def test_cql(args=get_args()):
|
|||||||
stop_fn=stop_fn,
|
stop_fn=stop_fn,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
assert stop_fn(result['best_reward'])
|
|
||||||
|
for epoch, epoch_stat, info in trainer:
|
||||||
|
print(f"Epoch: {epoch}")
|
||||||
|
print(epoch_stat)
|
||||||
|
print(info)
|
||||||
|
|
||||||
|
assert stop_fn(info["best_reward"])
|
||||||
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
pprint.pprint(result)
|
pprint.pprint(info)
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
|
@ -1,16 +1,34 @@
|
|||||||
"""Trainer package."""
|
"""Trainer package."""
|
||||||
|
|
||||||
# isort:skip_file
|
from tianshou.trainer.base import BaseTrainer
|
||||||
|
from tianshou.trainer.offline import (
|
||||||
from tianshou.trainer.utils import test_episode, gather_info
|
OfflineTrainer,
|
||||||
from tianshou.trainer.onpolicy import onpolicy_trainer
|
offline_trainer,
|
||||||
from tianshou.trainer.offpolicy import offpolicy_trainer
|
offline_trainer_iter,
|
||||||
from tianshou.trainer.offline import offline_trainer
|
)
|
||||||
|
from tianshou.trainer.offpolicy import (
|
||||||
|
OffpolicyTrainer,
|
||||||
|
offpolicy_trainer,
|
||||||
|
offpolicy_trainer_iter,
|
||||||
|
)
|
||||||
|
from tianshou.trainer.onpolicy import (
|
||||||
|
OnpolicyTrainer,
|
||||||
|
onpolicy_trainer,
|
||||||
|
onpolicy_trainer_iter,
|
||||||
|
)
|
||||||
|
from tianshou.trainer.utils import gather_info, test_episode
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"BaseTrainer",
|
||||||
"offpolicy_trainer",
|
"offpolicy_trainer",
|
||||||
|
"offpolicy_trainer_iter",
|
||||||
|
"OffpolicyTrainer",
|
||||||
"onpolicy_trainer",
|
"onpolicy_trainer",
|
||||||
|
"onpolicy_trainer_iter",
|
||||||
|
"OnpolicyTrainer",
|
||||||
"offline_trainer",
|
"offline_trainer",
|
||||||
|
"offline_trainer_iter",
|
||||||
|
"OfflineTrainer",
|
||||||
"test_episode",
|
"test_episode",
|
||||||
"gather_info",
|
"gather_info",
|
||||||
]
|
]
|
||||||
|
419
tianshou/trainer/base.py
Normal file
419
tianshou/trainer/base.py
Normal file
@ -0,0 +1,419 @@
|
|||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
from tianshou.policy import BasePolicy
|
||||||
|
from tianshou.trainer.utils import gather_info, test_episode
|
||||||
|
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTrainer(ABC):
|
||||||
|
"""An iterator base class for trainers procedure.
|
||||||
|
|
||||||
|
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results
|
||||||
|
on every epoch.
|
||||||
|
|
||||||
|
:param learning_type str: type of learning iterator, available choices are
|
||||||
|
"offpolicy", "onpolicy" and "offline".
|
||||||
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||||
|
:param Collector train_collector: the collector used for training.
|
||||||
|
:param Collector test_collector: the collector used for testing. If it's None,
|
||||||
|
then no testing will be performed.
|
||||||
|
:param int max_epoch: the maximum number of epochs for training. The training
|
||||||
|
process might be finished before reaching ``max_epoch`` if ``stop_fn``
|
||||||
|
is set.
|
||||||
|
:param int step_per_epoch: the number of transitions collected per epoch.
|
||||||
|
:param int repeat_per_collect: the number of repeat time for policy learning,
|
||||||
|
for example, set it to 2 means the policy needs to learn each given batch
|
||||||
|
data twice.
|
||||||
|
:param int episode_per_test: the number of episodes for one policy evaluation.
|
||||||
|
:param int batch_size: the batch size of sample data, which is going to feed in
|
||||||
|
the policy network.
|
||||||
|
:param int step_per_collect: the number of transitions the collector would
|
||||||
|
collect before the network update, i.e., trainer will collect
|
||||||
|
"step_per_collect" transitions and do some policy network update repeatedly
|
||||||
|
in each epoch.
|
||||||
|
:param int episode_per_collect: the number of episodes the collector would
|
||||||
|
collect before the network update, i.e., trainer will collect
|
||||||
|
"episode_per_collect" episodes and do some policy network update repeatedly
|
||||||
|
in each epoch.
|
||||||
|
:param function train_fn: a hook called at the beginning of training in each
|
||||||
|
epoch. It can be used to perform custom additional operations, with the
|
||||||
|
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||||
|
:param function test_fn: a hook called at the beginning of testing in each
|
||||||
|
epoch. It can be used to perform custom additional operations, with the
|
||||||
|
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||||
|
:param function save_fn: a hook called when the undiscounted average mean
|
||||||
|
reward in evaluation phase gets better, with the signature
|
||||||
|
``f(policy: BasePolicy) -> None``.
|
||||||
|
:param function save_checkpoint_fn: a function to save training process, with
|
||||||
|
the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
|
||||||
|
you can save whatever you want.
|
||||||
|
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
||||||
|
from existing tensorboard log. Default to False.
|
||||||
|
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||||
|
bool``, receives the average undiscounted returns of the testing result,
|
||||||
|
returns a boolean which indicates whether reaching the goal.
|
||||||
|
:param function reward_metric: a function with signature
|
||||||
|
``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray
|
||||||
|
with shape (num_episode,)``, used in multi-agent RL. We need to return a
|
||||||
|
single scalar for each episode's result to monitor training in the
|
||||||
|
multi-agent RL setting. This function specifies what is the desired metric,
|
||||||
|
e.g., the reward of agent 1 or the average reward over all agents.
|
||||||
|
:param BaseLogger logger: A logger that logs statistics during
|
||||||
|
training/testing/updating. Default to a logger that doesn't log anything.
|
||||||
|
:param bool verbose: whether to print the information. Default to True.
|
||||||
|
:param bool test_in_train: whether to test in the training phase.
|
||||||
|
Default to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gen_doc(learning_type: str) -> str:
|
||||||
|
"""Document string for subclass trainer."""
|
||||||
|
step_means = f'The "step" in {learning_type} trainer means '
|
||||||
|
if learning_type != "offline":
|
||||||
|
step_means += "an environment step (a.k.a. transition)."
|
||||||
|
else: # offline
|
||||||
|
step_means += "a gradient step."
|
||||||
|
|
||||||
|
trainer_name = learning_type.capitalize() + "Trainer"
|
||||||
|
|
||||||
|
return f"""An iterator class for {learning_type} trainer procedure.
|
||||||
|
|
||||||
|
Returns an iterator that yields a 3-tuple (epoch, stats, info) of
|
||||||
|
train results on every epoch.
|
||||||
|
|
||||||
|
{step_means}
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
trainer = {trainer_name}(...)
|
||||||
|
for epoch, epoch_stat, info in trainer:
|
||||||
|
print("Epoch:", epoch)
|
||||||
|
print(epoch_stat)
|
||||||
|
print(info)
|
||||||
|
do_something_with_policy()
|
||||||
|
query_something_about_policy()
|
||||||
|
make_a_plot_with(epoch_stat)
|
||||||
|
display(info)
|
||||||
|
|
||||||
|
- epoch int: the epoch number
|
||||||
|
- epoch_stat dict: a large collection of metrics of the current epoch
|
||||||
|
- info dict: result returned from :func:`~tianshou.trainer.gather_info`
|
||||||
|
|
||||||
|
You can even iterate on several trainers at the same time:
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
trainer1 = {trainer_name}(...)
|
||||||
|
trainer2 = {trainer_name}(...)
|
||||||
|
for result1, result2, ... in zip(trainer1, trainer2, ...):
|
||||||
|
compare_results(result1, result2, ...)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
learning_type: str,
|
||||||
|
policy: BasePolicy,
|
||||||
|
max_epoch: int,
|
||||||
|
batch_size: int,
|
||||||
|
train_collector: Optional[Collector] = None,
|
||||||
|
test_collector: Optional[Collector] = None,
|
||||||
|
buffer: Optional[ReplayBuffer] = None,
|
||||||
|
step_per_epoch: Optional[int] = None,
|
||||||
|
repeat_per_collect: Optional[int] = None,
|
||||||
|
episode_per_test: Optional[int] = None,
|
||||||
|
update_per_step: Union[int, float] = 1,
|
||||||
|
update_per_epoch: Optional[int] = None,
|
||||||
|
step_per_collect: Optional[int] = None,
|
||||||
|
episode_per_collect: Optional[int] = None,
|
||||||
|
train_fn: Optional[Callable[[int, int], None]] = None,
|
||||||
|
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||||
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||||
|
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||||
|
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||||
|
resume_from_log: bool = False,
|
||||||
|
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||||
|
logger: BaseLogger = LazyLogger(),
|
||||||
|
verbose: bool = True,
|
||||||
|
test_in_train: bool = True,
|
||||||
|
):
|
||||||
|
self.policy = policy
|
||||||
|
self.buffer = buffer
|
||||||
|
|
||||||
|
self.train_collector = train_collector
|
||||||
|
self.test_collector = test_collector
|
||||||
|
|
||||||
|
self.logger = logger
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.stat: DefaultDict[str, MovAvg] = defaultdict(MovAvg)
|
||||||
|
self.best_reward = 0.0
|
||||||
|
self.best_reward_std = 0.0
|
||||||
|
self.start_epoch = 0
|
||||||
|
self.gradient_step = 0
|
||||||
|
self.env_step = 0
|
||||||
|
self.max_epoch = max_epoch
|
||||||
|
self.step_per_epoch = step_per_epoch
|
||||||
|
|
||||||
|
# either on of these two
|
||||||
|
self.step_per_collect = step_per_collect
|
||||||
|
self.episode_per_collect = episode_per_collect
|
||||||
|
|
||||||
|
self.update_per_step = update_per_step
|
||||||
|
self.repeat_per_collect = repeat_per_collect
|
||||||
|
|
||||||
|
self.episode_per_test = episode_per_test
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
self.train_fn = train_fn
|
||||||
|
self.test_fn = test_fn
|
||||||
|
self.stop_fn = stop_fn
|
||||||
|
self.save_fn = save_fn
|
||||||
|
self.save_checkpoint_fn = save_checkpoint_fn
|
||||||
|
|
||||||
|
self.reward_metric = reward_metric
|
||||||
|
self.verbose = verbose
|
||||||
|
self.test_in_train = test_in_train
|
||||||
|
self.resume_from_log = resume_from_log
|
||||||
|
|
||||||
|
self.is_run = False
|
||||||
|
self.last_rew, self.last_len = 0.0, 0
|
||||||
|
|
||||||
|
self.epoch = self.start_epoch
|
||||||
|
self.best_epoch = self.start_epoch
|
||||||
|
self.stop_fn_flag = False
|
||||||
|
self.iter_num = 0
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Initialize or reset the instance to yield a new iterator from zero."""
|
||||||
|
self.is_run = False
|
||||||
|
self.env_step = 0
|
||||||
|
if self.resume_from_log:
|
||||||
|
self.start_epoch, self.env_step, self.gradient_step = \
|
||||||
|
self.logger.restore_data()
|
||||||
|
|
||||||
|
self.last_rew, self.last_len = 0.0, 0
|
||||||
|
self.start_time = time.time()
|
||||||
|
if self.train_collector is not None:
|
||||||
|
self.train_collector.reset_stat()
|
||||||
|
|
||||||
|
if self.train_collector.policy != self.policy:
|
||||||
|
self.test_in_train = False
|
||||||
|
elif self.test_collector is None:
|
||||||
|
self.test_in_train = False
|
||||||
|
|
||||||
|
if self.test_collector is not None:
|
||||||
|
assert self.episode_per_test is not None
|
||||||
|
self.test_collector.reset_stat()
|
||||||
|
test_result = test_episode(
|
||||||
|
self.policy, self.test_collector, self.test_fn, self.start_epoch,
|
||||||
|
self.episode_per_test, self.logger, self.env_step, self.reward_metric
|
||||||
|
)
|
||||||
|
self.best_epoch = self.start_epoch
|
||||||
|
self.best_reward, self.best_reward_std = \
|
||||||
|
test_result["rew"], test_result["rew_std"]
|
||||||
|
if self.save_fn:
|
||||||
|
self.save_fn(self.policy)
|
||||||
|
|
||||||
|
self.epoch = self.start_epoch
|
||||||
|
self.stop_fn_flag = False
|
||||||
|
self.iter_num = 0
|
||||||
|
|
||||||
|
def __iter__(self): # type: ignore
|
||||||
|
self.reset()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]:
|
||||||
|
"""Perform one epoch (both train and eval)."""
|
||||||
|
self.epoch += 1
|
||||||
|
self.iter_num += 1
|
||||||
|
|
||||||
|
if self.iter_num > 1:
|
||||||
|
|
||||||
|
# iterator exhaustion check
|
||||||
|
if self.epoch >= self.max_epoch:
|
||||||
|
if self.test_collector is None and self.save_fn:
|
||||||
|
self.save_fn(self.policy)
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
# exit flag 1, when stop_fn succeeds in train_step or test_step
|
||||||
|
if self.stop_fn_flag:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
# set policy in train mode
|
||||||
|
self.policy.train()
|
||||||
|
|
||||||
|
epoch_stat: Dict[str, Any] = dict()
|
||||||
|
# perform n step_per_epoch
|
||||||
|
with tqdm.tqdm(
|
||||||
|
total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config
|
||||||
|
) as t:
|
||||||
|
while t.n < t.total and not self.stop_fn_flag:
|
||||||
|
data: Dict[str, Any] = dict()
|
||||||
|
result: Dict[str, Any] = dict()
|
||||||
|
if self.train_collector is not None:
|
||||||
|
data, result, self.stop_fn_flag = self.train_step()
|
||||||
|
t.update(result["n/st"])
|
||||||
|
if self.stop_fn_flag:
|
||||||
|
t.set_postfix(**data)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
assert self.buffer, "No train_collector or buffer specified"
|
||||||
|
result["n/ep"] = len(self.buffer)
|
||||||
|
result["n/st"] = int(self.gradient_step)
|
||||||
|
t.update()
|
||||||
|
|
||||||
|
self.policy_update_fn(data, result)
|
||||||
|
t.set_postfix(**data)
|
||||||
|
|
||||||
|
if t.n <= t.total and not self.stop_fn_flag:
|
||||||
|
t.update()
|
||||||
|
|
||||||
|
if not self.stop_fn_flag:
|
||||||
|
self.logger.save_data(
|
||||||
|
self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn
|
||||||
|
)
|
||||||
|
# test
|
||||||
|
if self.test_collector is not None:
|
||||||
|
test_stat, self.stop_fn_flag = self.test_step()
|
||||||
|
if not self.is_run:
|
||||||
|
epoch_stat.update(test_stat)
|
||||||
|
|
||||||
|
if not self.is_run:
|
||||||
|
epoch_stat.update({k: v.get() for k, v in self.stat.items()})
|
||||||
|
epoch_stat["gradient_step"] = self.gradient_step
|
||||||
|
epoch_stat.update(
|
||||||
|
{
|
||||||
|
"env_step": self.env_step,
|
||||||
|
"rew": self.last_rew,
|
||||||
|
"len": int(self.last_len),
|
||||||
|
"n/ep": int(result["n/ep"]),
|
||||||
|
"n/st": int(result["n/st"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
info = gather_info(
|
||||||
|
self.start_time, self.train_collector, self.test_collector,
|
||||||
|
self.best_reward, self.best_reward_std
|
||||||
|
)
|
||||||
|
return self.epoch, epoch_stat, info
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def test_step(self) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Perform one testing step."""
|
||||||
|
assert self.episode_per_test is not None
|
||||||
|
assert self.test_collector is not None
|
||||||
|
stop_fn_flag = False
|
||||||
|
test_result = test_episode(
|
||||||
|
self.policy, self.test_collector, self.test_fn, self.epoch,
|
||||||
|
self.episode_per_test, self.logger, self.env_step, self.reward_metric
|
||||||
|
)
|
||||||
|
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||||
|
if self.best_epoch < 0 or self.best_reward < rew:
|
||||||
|
self.best_epoch = self.epoch
|
||||||
|
self.best_reward = float(rew)
|
||||||
|
self.best_reward_std = rew_std
|
||||||
|
if self.save_fn:
|
||||||
|
self.save_fn(self.policy)
|
||||||
|
if self.verbose:
|
||||||
|
print(
|
||||||
|
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
|
||||||
|
f" best_reward: {self.best_reward:.6f} ± "
|
||||||
|
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
|
||||||
|
)
|
||||||
|
if not self.is_run:
|
||||||
|
test_stat = {
|
||||||
|
"test_reward": rew,
|
||||||
|
"test_reward_std": rew_std,
|
||||||
|
"best_reward": self.best_reward,
|
||||||
|
"best_reward_std": self.best_reward_std,
|
||||||
|
"best_epoch": self.best_epoch
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
test_stat = {}
|
||||||
|
if self.stop_fn and self.stop_fn(self.best_reward):
|
||||||
|
stop_fn_flag = True
|
||||||
|
|
||||||
|
return test_stat, stop_fn_flag
|
||||||
|
|
||||||
|
def train_step(self) -> Tuple[Dict[str, Any], Dict[str, Any], bool]:
|
||||||
|
"""Perform one training step."""
|
||||||
|
assert self.episode_per_test is not None
|
||||||
|
assert self.train_collector is not None
|
||||||
|
stop_fn_flag = False
|
||||||
|
if self.train_fn:
|
||||||
|
self.train_fn(self.epoch, self.env_step)
|
||||||
|
result = self.train_collector.collect(
|
||||||
|
n_step=self.step_per_collect, n_episode=self.episode_per_collect
|
||||||
|
)
|
||||||
|
if result["n/ep"] > 0 and self.reward_metric:
|
||||||
|
rew = self.reward_metric(result["rews"])
|
||||||
|
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
|
||||||
|
self.env_step += int(result["n/st"])
|
||||||
|
self.logger.log_train_data(result, self.env_step)
|
||||||
|
self.last_rew = result["rew"] if result["n/ep"] > 0 else self.last_rew
|
||||||
|
self.last_len = result["len"] if result["n/ep"] > 0 else self.last_len
|
||||||
|
data = {
|
||||||
|
"env_step": str(self.env_step),
|
||||||
|
"rew": f"{self.last_rew:.2f}",
|
||||||
|
"len": str(int(self.last_len)),
|
||||||
|
"n/ep": str(int(result["n/ep"])),
|
||||||
|
"n/st": str(int(result["n/st"])),
|
||||||
|
}
|
||||||
|
if result["n/ep"] > 0:
|
||||||
|
if self.test_in_train and self.stop_fn and self.stop_fn(result["rew"]):
|
||||||
|
assert self.test_collector is not None
|
||||||
|
test_result = test_episode(
|
||||||
|
self.policy, self.test_collector, self.test_fn, self.epoch,
|
||||||
|
self.episode_per_test, self.logger, self.env_step
|
||||||
|
)
|
||||||
|
if self.stop_fn(test_result["rew"]):
|
||||||
|
stop_fn_flag = True
|
||||||
|
self.best_reward = test_result["rew"]
|
||||||
|
self.best_reward_std = test_result["rew_std"]
|
||||||
|
else:
|
||||||
|
self.policy.train()
|
||||||
|
|
||||||
|
return data, result, stop_fn_flag
|
||||||
|
|
||||||
|
def log_update_data(self, data: Dict[str, Any], losses: Dict[str, Any]) -> None:
|
||||||
|
"""Log losses to current logger."""
|
||||||
|
for k in losses.keys():
|
||||||
|
self.stat[k].add(losses[k])
|
||||||
|
losses[k] = self.stat[k].get()
|
||||||
|
data[k] = f"{losses[k]:.3f}"
|
||||||
|
self.logger.log_update_data(losses, self.gradient_step)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None:
|
||||||
|
"""Policy update function for different trainer implementation.
|
||||||
|
|
||||||
|
:param data: information in progress bar.
|
||||||
|
:param result: collector's return value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run(self) -> Dict[str, Union[float, str]]:
|
||||||
|
"""Consume iterator.
|
||||||
|
|
||||||
|
See itertools - recipes. Use functions that consume iterators at C speed
|
||||||
|
(feed the entire iterator into a zero-length deque).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.is_run = True
|
||||||
|
deque(self, maxlen=0) # feed the entire iterator into a zero-length deque
|
||||||
|
info = gather_info(
|
||||||
|
self.start_time, self.train_collector, self.test_collector,
|
||||||
|
self.best_reward, self.best_reward_std
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self.is_run = False
|
||||||
|
|
||||||
|
return info
|
@ -1,131 +1,115 @@
|
|||||||
import time
|
from typing import Any, Callable, Dict, Optional, Union
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Callable, Dict, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.trainer import gather_info, test_episode
|
from tianshou.trainer.base import BaseTrainer
|
||||||
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
from tianshou.utils import BaseLogger, LazyLogger
|
||||||
|
|
||||||
|
|
||||||
def offline_trainer(
|
class OfflineTrainer(BaseTrainer):
|
||||||
policy: BasePolicy,
|
"""Create an iterator class for offline training procedure.
|
||||||
buffer: ReplayBuffer,
|
|
||||||
test_collector: Optional[Collector],
|
|
||||||
max_epoch: int,
|
|
||||||
update_per_epoch: int,
|
|
||||||
episode_per_test: int,
|
|
||||||
batch_size: int,
|
|
||||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
|
||||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
|
||||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
|
||||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
|
||||||
resume_from_log: bool = False,
|
|
||||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
|
||||||
logger: BaseLogger = LazyLogger(),
|
|
||||||
verbose: bool = True,
|
|
||||||
) -> Dict[str, Union[float, str]]:
|
|
||||||
"""A wrapper for offline trainer procedure.
|
|
||||||
|
|
||||||
The "step" in offline trainer means a gradient step.
|
|
||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||||
:param Collector test_collector: the collector used for testing. If it's None, then
|
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
|
||||||
no testing will be performed.
|
This buffer must be populated with experiences for offline RL.
|
||||||
|
:param Collector test_collector: the collector used for testing. If it's None,
|
||||||
|
then no testing will be performed.
|
||||||
:param int max_epoch: the maximum number of epochs for training. The training
|
:param int max_epoch: the maximum number of epochs for training. The training
|
||||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
|
||||||
|
set.
|
||||||
:param int update_per_epoch: the number of policy network updates, so-called
|
:param int update_per_epoch: the number of policy network updates, so-called
|
||||||
gradient steps, per epoch.
|
gradient steps, per epoch.
|
||||||
:param episode_per_test: the number of episodes for one policy evaluation.
|
:param episode_per_test: the number of episodes for one policy evaluation.
|
||||||
:param int batch_size: the batch size of sample data, which is going to feed in
|
:param int batch_size: the batch size of sample data, which is going to feed in
|
||||||
the policy network.
|
the policy network.
|
||||||
:param function test_fn: a hook called at the beginning of testing in each epoch.
|
:param function test_fn: a hook called at the beginning of testing in each
|
||||||
It can be used to perform custom additional operations, with the signature ``f(
|
epoch.
|
||||||
num_epoch: int, step_idx: int) -> None``.
|
It can be used to perform custom additional operations, with the signature
|
||||||
:param function save_fn: a hook called when the undiscounted average mean reward in
|
``f(num_epoch: int, step_idx: int) -> None``.
|
||||||
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
|
:param function save_fn: a hook called when the undiscounted average mean
|
||||||
None``.
|
reward in evaluation phase gets better, with the signature
|
||||||
:param function save_checkpoint_fn: a function to save training process, with the
|
``f(policy: BasePolicy) -> None``.
|
||||||
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
|
:param function save_checkpoint_fn: a function to save training process,
|
||||||
save whatever you want. Because offline-RL doesn't have env_step, the env_step
|
with the signature ``f(epoch: int, env_step: int, gradient_step: int) ->
|
||||||
is always 0 here.
|
None``; you can save whatever you want. Because offline-RL doesn't have
|
||||||
:param bool resume_from_log: resume gradient_step and other metadata from existing
|
env_step, the env_step is always 0 here.
|
||||||
tensorboard log. Default to False.
|
:param bool resume_from_log: resume gradient_step and other metadata from
|
||||||
|
existing tensorboard log. Default to False.
|
||||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||||
bool``, receives the average undiscounted returns of the testing result,
|
bool``, receives the average undiscounted returns of the testing result,
|
||||||
returns a boolean which indicates whether reaching the goal.
|
returns a boolean which indicates whether reaching the goal.
|
||||||
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
|
:param function reward_metric: a function with signature ``f(rewards:
|
||||||
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
|
np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape
|
||||||
used in multi-agent RL. We need to return a single scalar for each episode's
|
(num_episode,)``, used in multi-agent RL. We need to return a single scalar
|
||||||
result to monitor training in the multi-agent RL setting. This function
|
for each episode's result to monitor training in the multi-agent RL
|
||||||
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
setting. This function specifies what is the desired metric, e.g., the
|
||||||
average reward over all agents.
|
reward of agent 1 or the average reward over all agents.
|
||||||
:param BaseLogger logger: A logger that logs statistics during updating/testing.
|
:param BaseLogger logger: A logger that logs statistics during
|
||||||
Default to a logger that doesn't log anything.
|
updating/testing. Default to a logger that doesn't log anything.
|
||||||
:param bool verbose: whether to print the information. Default to True.
|
:param bool verbose: whether to print the information. Default to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:])
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
policy: BasePolicy,
|
||||||
|
buffer: ReplayBuffer,
|
||||||
|
test_collector: Optional[Collector],
|
||||||
|
max_epoch: int,
|
||||||
|
update_per_epoch: int,
|
||||||
|
episode_per_test: int,
|
||||||
|
batch_size: int,
|
||||||
|
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||||
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||||
|
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||||
|
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||||
|
resume_from_log: bool = False,
|
||||||
|
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||||
|
logger: BaseLogger = LazyLogger(),
|
||||||
|
verbose: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
learning_type="offline",
|
||||||
|
policy=policy,
|
||||||
|
buffer=buffer,
|
||||||
|
test_collector=test_collector,
|
||||||
|
max_epoch=max_epoch,
|
||||||
|
update_per_epoch=update_per_epoch,
|
||||||
|
step_per_epoch=update_per_epoch,
|
||||||
|
episode_per_test=episode_per_test,
|
||||||
|
batch_size=batch_size,
|
||||||
|
test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn,
|
||||||
|
save_checkpoint_fn=save_checkpoint_fn,
|
||||||
|
resume_from_log=resume_from_log,
|
||||||
|
reward_metric=reward_metric,
|
||||||
|
logger=logger,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
def policy_update_fn(
|
||||||
|
self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
|
"""Perform one off-line policy update."""
|
||||||
|
assert self.buffer
|
||||||
|
self.gradient_step += 1
|
||||||
|
losses = self.policy.update(self.batch_size, self.buffer)
|
||||||
|
data.update({"gradient_step": str(self.gradient_step)})
|
||||||
|
self.log_update_data(data, losses)
|
||||||
|
|
||||||
|
|
||||||
|
def offline_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
|
||||||
|
"""Wrapper for offline_trainer run method.
|
||||||
|
|
||||||
|
It is identical to ``OfflineTrainer(...).run()``.
|
||||||
|
|
||||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||||
"""
|
"""
|
||||||
start_epoch, gradient_step = 0, 0
|
return OfflineTrainer(*args, **kwargs).run()
|
||||||
if resume_from_log:
|
|
||||||
start_epoch, _, gradient_step = logger.restore_data()
|
|
||||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
if test_collector is not None:
|
|
||||||
test_c: Collector = test_collector
|
|
||||||
test_collector.reset_stat()
|
|
||||||
test_result = test_episode(
|
|
||||||
policy, test_c, test_fn, start_epoch, episode_per_test, logger,
|
|
||||||
gradient_step, reward_metric
|
|
||||||
)
|
|
||||||
best_epoch = start_epoch
|
|
||||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
|
||||||
if save_fn:
|
|
||||||
save_fn(policy)
|
|
||||||
|
|
||||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
offline_trainer_iter = OfflineTrainer
|
||||||
policy.train()
|
|
||||||
with tqdm.trange(update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t:
|
|
||||||
for _ in t:
|
|
||||||
gradient_step += 1
|
|
||||||
losses = policy.update(batch_size, buffer)
|
|
||||||
data = {"gradient_step": str(gradient_step)}
|
|
||||||
for k in losses.keys():
|
|
||||||
stat[k].add(losses[k])
|
|
||||||
losses[k] = stat[k].get()
|
|
||||||
data[k] = f"{losses[k]:.3f}"
|
|
||||||
logger.log_update_data(losses, gradient_step)
|
|
||||||
t.set_postfix(**data)
|
|
||||||
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
|
|
||||||
# test
|
|
||||||
if test_collector is not None:
|
|
||||||
test_result = test_episode(
|
|
||||||
policy, test_c, test_fn, epoch, episode_per_test, logger,
|
|
||||||
gradient_step, reward_metric
|
|
||||||
)
|
|
||||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
|
||||||
if best_epoch < 0 or best_reward < rew:
|
|
||||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
|
||||||
if save_fn:
|
|
||||||
save_fn(policy)
|
|
||||||
if verbose:
|
|
||||||
print(
|
|
||||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
|
||||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
|
|
||||||
)
|
|
||||||
if stop_fn and stop_fn(best_reward):
|
|
||||||
break
|
|
||||||
|
|
||||||
if test_collector is None and save_fn:
|
|
||||||
save_fn(policy)
|
|
||||||
|
|
||||||
if test_collector is None:
|
|
||||||
return gather_info(start_time, None, None, 0.0, 0.0)
|
|
||||||
else:
|
|
||||||
return gather_info(
|
|
||||||
start_time, None, test_collector, best_reward, best_reward_std
|
|
||||||
)
|
|
||||||
|
@ -1,193 +1,130 @@
|
|||||||
import time
|
from typing import Any, Callable, Dict, Optional, Union
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Callable, Dict, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.trainer import gather_info, test_episode
|
from tianshou.trainer.base import BaseTrainer
|
||||||
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
from tianshou.utils import BaseLogger, LazyLogger
|
||||||
|
|
||||||
|
|
||||||
def offpolicy_trainer(
|
class OffpolicyTrainer(BaseTrainer):
|
||||||
policy: BasePolicy,
|
"""Create an iterator wrapper for off-policy training procedure.
|
||||||
train_collector: Collector,
|
|
||||||
test_collector: Optional[Collector],
|
|
||||||
max_epoch: int,
|
|
||||||
step_per_epoch: int,
|
|
||||||
step_per_collect: int,
|
|
||||||
episode_per_test: int,
|
|
||||||
batch_size: int,
|
|
||||||
update_per_step: Union[int, float] = 1,
|
|
||||||
train_fn: Optional[Callable[[int, int], None]] = None,
|
|
||||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
|
||||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
|
||||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
|
||||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
|
||||||
resume_from_log: bool = False,
|
|
||||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
|
||||||
logger: BaseLogger = LazyLogger(),
|
|
||||||
verbose: bool = True,
|
|
||||||
test_in_train: bool = True,
|
|
||||||
) -> Dict[str, Union[float, str]]:
|
|
||||||
"""A wrapper for off-policy trainer procedure.
|
|
||||||
|
|
||||||
The "step" in trainer means an environment step (a.k.a. transition).
|
|
||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||||
:param Collector train_collector: the collector used for training.
|
:param Collector train_collector: the collector used for training.
|
||||||
:param Collector test_collector: the collector used for testing. If it's None, then
|
:param Collector test_collector: the collector used for testing. If it's None,
|
||||||
no testing will be performed.
|
then no testing will be performed.
|
||||||
:param int max_epoch: the maximum number of epochs for training. The training
|
:param int max_epoch: the maximum number of epochs for training. The training
|
||||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
|
||||||
|
set.
|
||||||
:param int step_per_epoch: the number of transitions collected per epoch.
|
:param int step_per_epoch: the number of transitions collected per epoch.
|
||||||
:param int step_per_collect: the number of transitions the collector would collect
|
:param int step_per_collect: the number of transitions the collector would
|
||||||
before the network update, i.e., trainer will collect "step_per_collect"
|
collect before the network update, i.e., trainer will collect
|
||||||
transitions and do some policy network update repeatedly in each epoch.
|
"step_per_collect" transitions and do some policy network update repeatedly
|
||||||
|
in each epoch.
|
||||||
:param episode_per_test: the number of episodes for one policy evaluation.
|
:param episode_per_test: the number of episodes for one policy evaluation.
|
||||||
:param int batch_size: the batch size of sample data, which is going to feed in the
|
:param int batch_size: the batch size of sample data, which is going to feed in
|
||||||
policy network.
|
the policy network.
|
||||||
:param int/float update_per_step: the number of times the policy network would be
|
:param int/float update_per_step: the number of times the policy network would
|
||||||
updated per transition after (step_per_collect) transitions are collected,
|
be updated per transition after (step_per_collect) transitions are
|
||||||
e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will
|
collected, e.g., if update_per_step set to 0.3, and step_per_collect is 256
|
||||||
be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are
|
, policy will be updated round(256 * 0.3 = 76.8) = 77 times after 256
|
||||||
collected by the collector. Default to 1.
|
transitions are collected by the collector. Default to 1.
|
||||||
:param function train_fn: a hook called at the beginning of training in each epoch.
|
:param function train_fn: a hook called at the beginning of training in each
|
||||||
It can be used to perform custom additional operations, with the signature ``f(
|
epoch. It can be used to perform custom additional operations, with the
|
||||||
num_epoch: int, step_idx: int) -> None``.
|
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||||
:param function test_fn: a hook called at the beginning of testing in each epoch.
|
:param function test_fn: a hook called at the beginning of testing in each
|
||||||
It can be used to perform custom additional operations, with the signature ``f(
|
epoch. It can be used to perform custom additional operations, with the
|
||||||
num_epoch: int, step_idx: int) -> None``.
|
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||||
:param function save_fn: a hook called when the undiscounted average mean reward in
|
:param function save_fn: a hook called when the undiscounted average mean
|
||||||
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
|
reward in evaluation phase gets better, with the signature
|
||||||
None``.
|
``f(policy: BasePolicy) -> None``.
|
||||||
:param function save_checkpoint_fn: a function to save training process, with the
|
:param function save_checkpoint_fn: a function to save training process, with
|
||||||
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
|
the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
|
||||||
save whatever you want.
|
you can save whatever you want.
|
||||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata from
|
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
||||||
existing tensorboard log. Default to False.
|
from existing tensorboard log. Default to False.
|
||||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||||
bool``, receives the average undiscounted returns of the testing result,
|
bool``, receives the average undiscounted returns of the testing result,
|
||||||
returns a boolean which indicates whether reaching the goal.
|
returns a boolean which indicates whether reaching the goal.
|
||||||
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
|
:param function reward_metric: a function with signature
|
||||||
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
|
``f(rewards: np.ndarray with shape (num_episode, agent_num)) ->
|
||||||
used in multi-agent RL. We need to return a single scalar for each episode's
|
np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to
|
||||||
result to monitor training in the multi-agent RL setting. This function
|
return a single scalar for each episode's result to monitor training in the
|
||||||
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
multi-agent RL setting. This function specifies what is the desired metric,
|
||||||
average reward over all agents.
|
e.g., the reward of agent 1 or the average reward over all agents.
|
||||||
:param BaseLogger logger: A logger that logs statistics during
|
:param BaseLogger logger: A logger that logs statistics during
|
||||||
training/testing/updating. Default to a logger that doesn't log anything.
|
training/testing/updating. Default to a logger that doesn't log anything.
|
||||||
:param bool verbose: whether to print the information. Default to True.
|
:param bool verbose: whether to print the information. Default to True.
|
||||||
:param bool test_in_train: whether to test in the training phase. Default to True.
|
:param bool test_in_train: whether to test in the training phase.
|
||||||
|
Default to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__doc__ = BaseTrainer.gen_doc("offpolicy") + "\n".join(__doc__.split("\n")[1:])
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
policy: BasePolicy,
|
||||||
|
train_collector: Collector,
|
||||||
|
test_collector: Optional[Collector],
|
||||||
|
max_epoch: int,
|
||||||
|
step_per_epoch: int,
|
||||||
|
step_per_collect: int,
|
||||||
|
episode_per_test: int,
|
||||||
|
batch_size: int,
|
||||||
|
update_per_step: Union[int, float] = 1,
|
||||||
|
train_fn: Optional[Callable[[int, int], None]] = None,
|
||||||
|
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||||
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||||
|
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||||
|
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||||
|
resume_from_log: bool = False,
|
||||||
|
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||||
|
logger: BaseLogger = LazyLogger(),
|
||||||
|
verbose: bool = True,
|
||||||
|
test_in_train: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
learning_type="offpolicy",
|
||||||
|
policy=policy,
|
||||||
|
train_collector=train_collector,
|
||||||
|
test_collector=test_collector,
|
||||||
|
max_epoch=max_epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
episode_per_test=episode_per_test,
|
||||||
|
batch_size=batch_size,
|
||||||
|
update_per_step=update_per_step,
|
||||||
|
train_fn=train_fn,
|
||||||
|
test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn,
|
||||||
|
save_checkpoint_fn=save_checkpoint_fn,
|
||||||
|
resume_from_log=resume_from_log,
|
||||||
|
reward_metric=reward_metric,
|
||||||
|
logger=logger,
|
||||||
|
verbose=verbose,
|
||||||
|
test_in_train=test_in_train,
|
||||||
|
)
|
||||||
|
|
||||||
|
def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None:
|
||||||
|
"""Perform off-policy updates."""
|
||||||
|
assert self.train_collector is not None
|
||||||
|
for _ in range(round(self.update_per_step * result["n/st"])):
|
||||||
|
self.gradient_step += 1
|
||||||
|
losses = self.policy.update(self.batch_size, self.train_collector.buffer)
|
||||||
|
self.log_update_data(data, losses)
|
||||||
|
|
||||||
|
|
||||||
|
def offpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
|
||||||
|
"""Wrapper for OffPolicyTrainer run method.
|
||||||
|
|
||||||
|
It is identical to ``OffpolicyTrainer(...).run()``.
|
||||||
|
|
||||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||||
"""
|
"""
|
||||||
start_epoch, env_step, gradient_step = 0, 0, 0
|
return OffpolicyTrainer(*args, **kwargs).run()
|
||||||
if resume_from_log:
|
|
||||||
start_epoch, env_step, gradient_step = logger.restore_data()
|
|
||||||
last_rew, last_len = 0.0, 0
|
|
||||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
|
||||||
start_time = time.time()
|
|
||||||
train_collector.reset_stat()
|
|
||||||
test_in_train = test_in_train and (
|
|
||||||
train_collector.policy == policy and test_collector is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
if test_collector is not None:
|
|
||||||
test_c: Collector = test_collector # for mypy
|
|
||||||
test_collector.reset_stat()
|
|
||||||
test_result = test_episode(
|
|
||||||
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
|
|
||||||
reward_metric
|
|
||||||
)
|
|
||||||
best_epoch = start_epoch
|
|
||||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
|
||||||
if save_fn:
|
|
||||||
save_fn(policy)
|
|
||||||
|
|
||||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
offpolicy_trainer_iter = OffpolicyTrainer
|
||||||
# train
|
|
||||||
policy.train()
|
|
||||||
with tqdm.tqdm(
|
|
||||||
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
|
||||||
) as t:
|
|
||||||
while t.n < t.total:
|
|
||||||
if train_fn:
|
|
||||||
train_fn(epoch, env_step)
|
|
||||||
result = train_collector.collect(n_step=step_per_collect)
|
|
||||||
if result["n/ep"] > 0 and reward_metric:
|
|
||||||
rew = reward_metric(result["rews"])
|
|
||||||
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
|
|
||||||
env_step += int(result["n/st"])
|
|
||||||
t.update(result["n/st"])
|
|
||||||
logger.log_train_data(result, env_step)
|
|
||||||
last_rew = result['rew'] if result["n/ep"] > 0 else last_rew
|
|
||||||
last_len = result['len'] if result["n/ep"] > 0 else last_len
|
|
||||||
data = {
|
|
||||||
"env_step": str(env_step),
|
|
||||||
"rew": f"{last_rew:.2f}",
|
|
||||||
"len": str(int(last_len)),
|
|
||||||
"n/ep": str(int(result["n/ep"])),
|
|
||||||
"n/st": str(int(result["n/st"])),
|
|
||||||
}
|
|
||||||
if result["n/ep"] > 0:
|
|
||||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
|
||||||
test_result = test_episode(
|
|
||||||
policy, test_c, test_fn, epoch, episode_per_test, logger,
|
|
||||||
env_step
|
|
||||||
)
|
|
||||||
if stop_fn(test_result["rew"]):
|
|
||||||
if save_fn:
|
|
||||||
save_fn(policy)
|
|
||||||
logger.save_data(
|
|
||||||
epoch, env_step, gradient_step, save_checkpoint_fn
|
|
||||||
)
|
|
||||||
t.set_postfix(**data)
|
|
||||||
return gather_info(
|
|
||||||
start_time, train_collector, test_collector,
|
|
||||||
test_result["rew"], test_result["rew_std"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
policy.train()
|
|
||||||
for _ in range(round(update_per_step * result["n/st"])):
|
|
||||||
gradient_step += 1
|
|
||||||
losses = policy.update(batch_size, train_collector.buffer)
|
|
||||||
for k in losses.keys():
|
|
||||||
stat[k].add(losses[k])
|
|
||||||
losses[k] = stat[k].get()
|
|
||||||
data[k] = f"{losses[k]:.3f}"
|
|
||||||
logger.log_update_data(losses, gradient_step)
|
|
||||||
t.set_postfix(**data)
|
|
||||||
if t.n <= t.total:
|
|
||||||
t.update()
|
|
||||||
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
|
||||||
# test
|
|
||||||
if test_collector is not None:
|
|
||||||
test_result = test_episode(
|
|
||||||
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
|
|
||||||
reward_metric
|
|
||||||
)
|
|
||||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
|
||||||
if best_epoch < 0 or best_reward < rew:
|
|
||||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
|
||||||
if save_fn:
|
|
||||||
save_fn(policy)
|
|
||||||
if verbose:
|
|
||||||
print(
|
|
||||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
|
||||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
|
|
||||||
)
|
|
||||||
if stop_fn and stop_fn(best_reward):
|
|
||||||
break
|
|
||||||
|
|
||||||
if test_collector is None and save_fn:
|
|
||||||
save_fn(policy)
|
|
||||||
|
|
||||||
if test_collector is None:
|
|
||||||
return gather_info(start_time, train_collector, None, 0.0, 0.0)
|
|
||||||
else:
|
|
||||||
return gather_info(
|
|
||||||
start_time, train_collector, test_collector, best_reward, best_reward_std
|
|
||||||
)
|
|
||||||
|
@ -1,209 +1,147 @@
|
|||||||
import time
|
from typing import Any, Callable, Dict, Optional, Union
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Callable, Dict, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.trainer import gather_info, test_episode
|
from tianshou.trainer.base import BaseTrainer
|
||||||
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
from tianshou.utils import BaseLogger, LazyLogger
|
||||||
|
|
||||||
|
|
||||||
def onpolicy_trainer(
|
class OnpolicyTrainer(BaseTrainer):
|
||||||
policy: BasePolicy,
|
"""Create an iterator wrapper for on-policy training procedure.
|
||||||
train_collector: Collector,
|
|
||||||
test_collector: Optional[Collector],
|
|
||||||
max_epoch: int,
|
|
||||||
step_per_epoch: int,
|
|
||||||
repeat_per_collect: int,
|
|
||||||
episode_per_test: int,
|
|
||||||
batch_size: int,
|
|
||||||
step_per_collect: Optional[int] = None,
|
|
||||||
episode_per_collect: Optional[int] = None,
|
|
||||||
train_fn: Optional[Callable[[int, int], None]] = None,
|
|
||||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
|
||||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
|
||||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
|
||||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
|
||||||
resume_from_log: bool = False,
|
|
||||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
|
||||||
logger: BaseLogger = LazyLogger(),
|
|
||||||
verbose: bool = True,
|
|
||||||
test_in_train: bool = True,
|
|
||||||
) -> Dict[str, Union[float, str]]:
|
|
||||||
"""A wrapper for on-policy trainer procedure.
|
|
||||||
|
|
||||||
The "step" in trainer means an environment step (a.k.a. transition).
|
|
||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||||
:param Collector train_collector: the collector used for training.
|
:param Collector train_collector: the collector used for training.
|
||||||
:param Collector test_collector: the collector used for testing. If it's None, then
|
:param Collector test_collector: the collector used for testing. If it's None,
|
||||||
no testing will be performed.
|
then no testing will be performed.
|
||||||
:param int max_epoch: the maximum number of epochs for training. The training
|
:param int max_epoch: the maximum number of epochs for training. The training
|
||||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
|
||||||
|
set.
|
||||||
:param int step_per_epoch: the number of transitions collected per epoch.
|
:param int step_per_epoch: the number of transitions collected per epoch.
|
||||||
:param int repeat_per_collect: the number of repeat time for policy learning, for
|
:param int repeat_per_collect: the number of repeat time for policy learning,
|
||||||
example, set it to 2 means the policy needs to learn each given batch data
|
for example, set it to 2 means the policy needs to learn each given batch
|
||||||
twice.
|
data twice.
|
||||||
:param int episode_per_test: the number of episodes for one policy evaluation.
|
:param int episode_per_test: the number of episodes for one policy evaluation.
|
||||||
:param int batch_size: the batch size of sample data, which is going to feed in the
|
:param int batch_size: the batch size of sample data, which is going to feed in
|
||||||
policy network.
|
the policy network.
|
||||||
:param int step_per_collect: the number of transitions the collector would collect
|
:param int step_per_collect: the number of transitions the collector would
|
||||||
before the network update, i.e., trainer will collect "step_per_collect"
|
collect before the network update, i.e., trainer will collect
|
||||||
transitions and do some policy network update repeatedly in each epoch.
|
"step_per_collect" transitions and do some policy network update repeatedly
|
||||||
:param int episode_per_collect: the number of episodes the collector would collect
|
in each epoch.
|
||||||
before the network update, i.e., trainer will collect "episode_per_collect"
|
:param int episode_per_collect: the number of episodes the collector would
|
||||||
episodes and do some policy network update repeatedly in each epoch.
|
collect before the network update, i.e., trainer will collect
|
||||||
:param function train_fn: a hook called at the beginning of training in each epoch.
|
"episode_per_collect" episodes and do some policy network update repeatedly
|
||||||
It can be used to perform custom additional operations, with the signature ``f(
|
in each epoch.
|
||||||
num_epoch: int, step_idx: int) -> None``.
|
:param function train_fn: a hook called at the beginning of training in each
|
||||||
:param function test_fn: a hook called at the beginning of testing in each epoch.
|
epoch. It can be used to perform custom additional operations, with the
|
||||||
It can be used to perform custom additional operations, with the signature ``f(
|
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||||
num_epoch: int, step_idx: int) -> None``.
|
:param function test_fn: a hook called at the beginning of testing in each
|
||||||
:param function save_fn: a hook called when the undiscounted average mean reward in
|
epoch. It can be used to perform custom additional operations, with the
|
||||||
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
|
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||||
None``.
|
:param function save_fn: a hook called when the undiscounted average mean
|
||||||
:param function save_checkpoint_fn: a function to save training process, with the
|
reward in evaluation phase gets better, with the signature
|
||||||
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
|
``f(policy: BasePolicy) -> None``.
|
||||||
save whatever you want.
|
:param function save_checkpoint_fn: a function to save training process,
|
||||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata from
|
with the signature ``f(epoch: int, env_step: int, gradient_step: int)
|
||||||
existing tensorboard log. Default to False.
|
-> None``; you can save whatever you want.
|
||||||
|
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
||||||
|
from existing tensorboard log. Default to False.
|
||||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||||
bool``, receives the average undiscounted returns of the testing result,
|
bool``, receives the average undiscounted returns of the testing result,
|
||||||
returns a boolean which indicates whether reaching the goal.
|
returns a boolean which indicates whether reaching the goal.
|
||||||
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
|
:param function reward_metric: a function with signature
|
||||||
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
|
``f(rewards: np.ndarray with shape (num_episode, agent_num)) ->
|
||||||
used in multi-agent RL. We need to return a single scalar for each episode's
|
np.ndarray with shape (num_episode,)``, used in multi-agent RL.
|
||||||
result to monitor training in the multi-agent RL setting. This function
|
We need to return a single scalar for each episode's result to monitor
|
||||||
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
training in the multi-agent RL setting. This function specifies what is the
|
||||||
average reward over all agents.
|
desired metric, e.g., the reward of agent 1 or the average reward over
|
||||||
|
all agents.
|
||||||
:param BaseLogger logger: A logger that logs statistics during
|
:param BaseLogger logger: A logger that logs statistics during
|
||||||
training/testing/updating. Default to a logger that doesn't log anything.
|
training/testing/updating. Default to a logger that doesn't log anything.
|
||||||
:param bool verbose: whether to print the information. Default to True.
|
:param bool verbose: whether to print the information. Default to True.
|
||||||
:param bool test_in_train: whether to test in the training phase. Default to True.
|
:param bool test_in_train: whether to test in the training phase. Default to
|
||||||
|
True.
|
||||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Only either one of step_per_collect and episode_per_collect can be specified.
|
Only either one of step_per_collect and episode_per_collect can be specified.
|
||||||
"""
|
"""
|
||||||
start_epoch, env_step, gradient_step = 0, 0, 0
|
|
||||||
if resume_from_log:
|
|
||||||
start_epoch, env_step, gradient_step = logger.restore_data()
|
|
||||||
last_rew, last_len = 0.0, 0
|
|
||||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
|
||||||
start_time = time.time()
|
|
||||||
train_collector.reset_stat()
|
|
||||||
test_in_train = test_in_train and (
|
|
||||||
train_collector.policy == policy and test_collector is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
if test_collector is not None:
|
__doc__ = BaseTrainer.gen_doc("onpolicy") + "\n".join(__doc__.split("\n")[1:])
|
||||||
test_c: Collector = test_collector # for mypy
|
|
||||||
test_collector.reset_stat()
|
def __init__(
|
||||||
test_result = test_episode(
|
self,
|
||||||
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
|
policy: BasePolicy,
|
||||||
reward_metric
|
train_collector: Collector,
|
||||||
|
test_collector: Optional[Collector],
|
||||||
|
max_epoch: int,
|
||||||
|
step_per_epoch: int,
|
||||||
|
repeat_per_collect: int,
|
||||||
|
episode_per_test: int,
|
||||||
|
batch_size: int,
|
||||||
|
step_per_collect: Optional[int] = None,
|
||||||
|
episode_per_collect: Optional[int] = None,
|
||||||
|
train_fn: Optional[Callable[[int, int], None]] = None,
|
||||||
|
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||||
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||||
|
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||||
|
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||||
|
resume_from_log: bool = False,
|
||||||
|
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||||
|
logger: BaseLogger = LazyLogger(),
|
||||||
|
verbose: bool = True,
|
||||||
|
test_in_train: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
learning_type="onpolicy",
|
||||||
|
policy=policy,
|
||||||
|
train_collector=train_collector,
|
||||||
|
test_collector=test_collector,
|
||||||
|
max_epoch=max_epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
repeat_per_collect=repeat_per_collect,
|
||||||
|
episode_per_test=episode_per_test,
|
||||||
|
batch_size=batch_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
episode_per_collect=episode_per_collect,
|
||||||
|
train_fn=train_fn,
|
||||||
|
test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn,
|
||||||
|
save_checkpoint_fn=save_checkpoint_fn,
|
||||||
|
resume_from_log=resume_from_log,
|
||||||
|
reward_metric=reward_metric,
|
||||||
|
logger=logger,
|
||||||
|
verbose=verbose,
|
||||||
|
test_in_train=test_in_train,
|
||||||
)
|
)
|
||||||
best_epoch = start_epoch
|
|
||||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
|
||||||
if save_fn:
|
|
||||||
save_fn(policy)
|
|
||||||
|
|
||||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
def policy_update_fn(
|
||||||
# train
|
self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None
|
||||||
policy.train()
|
) -> None:
|
||||||
with tqdm.tqdm(
|
"""Perform one on-policy update."""
|
||||||
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
assert self.train_collector is not None
|
||||||
) as t:
|
losses = self.policy.update(
|
||||||
while t.n < t.total:
|
0,
|
||||||
if train_fn:
|
self.train_collector.buffer,
|
||||||
train_fn(epoch, env_step)
|
batch_size=self.batch_size,
|
||||||
result = train_collector.collect(
|
repeat=self.repeat_per_collect,
|
||||||
n_step=step_per_collect, n_episode=episode_per_collect
|
|
||||||
)
|
|
||||||
if result["n/ep"] > 0 and reward_metric:
|
|
||||||
rew = reward_metric(result["rews"])
|
|
||||||
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
|
|
||||||
env_step += int(result["n/st"])
|
|
||||||
t.update(result["n/st"])
|
|
||||||
logger.log_train_data(result, env_step)
|
|
||||||
last_rew = result['rew'] if result["n/ep"] > 0 else last_rew
|
|
||||||
last_len = result['len'] if result["n/ep"] > 0 else last_len
|
|
||||||
data = {
|
|
||||||
"env_step": str(env_step),
|
|
||||||
"rew": f"{last_rew:.2f}",
|
|
||||||
"len": str(int(last_len)),
|
|
||||||
"n/ep": str(int(result["n/ep"])),
|
|
||||||
"n/st": str(int(result["n/st"])),
|
|
||||||
}
|
|
||||||
if result["n/ep"] > 0:
|
|
||||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
|
||||||
test_result = test_episode(
|
|
||||||
policy, test_c, test_fn, epoch, episode_per_test, logger,
|
|
||||||
env_step
|
|
||||||
)
|
|
||||||
if stop_fn(test_result["rew"]):
|
|
||||||
if save_fn:
|
|
||||||
save_fn(policy)
|
|
||||||
logger.save_data(
|
|
||||||
epoch, env_step, gradient_step, save_checkpoint_fn
|
|
||||||
)
|
|
||||||
t.set_postfix(**data)
|
|
||||||
return gather_info(
|
|
||||||
start_time, train_collector, test_collector,
|
|
||||||
test_result["rew"], test_result["rew_std"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
policy.train()
|
|
||||||
losses = policy.update(
|
|
||||||
0,
|
|
||||||
train_collector.buffer,
|
|
||||||
batch_size=batch_size,
|
|
||||||
repeat=repeat_per_collect
|
|
||||||
)
|
|
||||||
train_collector.reset_buffer(keep_statistics=True)
|
|
||||||
step = max(
|
|
||||||
[1] + [len(v) for v in losses.values() if isinstance(v, list)]
|
|
||||||
)
|
|
||||||
gradient_step += step
|
|
||||||
for k in losses.keys():
|
|
||||||
stat[k].add(losses[k])
|
|
||||||
losses[k] = stat[k].get()
|
|
||||||
data[k] = f"{losses[k]:.3f}"
|
|
||||||
logger.log_update_data(losses, gradient_step)
|
|
||||||
t.set_postfix(**data)
|
|
||||||
if t.n <= t.total:
|
|
||||||
t.update()
|
|
||||||
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
|
||||||
# test
|
|
||||||
if test_collector is not None:
|
|
||||||
test_result = test_episode(
|
|
||||||
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
|
|
||||||
reward_metric
|
|
||||||
)
|
|
||||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
|
||||||
if best_epoch < 0 or best_reward < rew:
|
|
||||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
|
||||||
if save_fn:
|
|
||||||
save_fn(policy)
|
|
||||||
if verbose:
|
|
||||||
print(
|
|
||||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
|
||||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
|
|
||||||
)
|
|
||||||
if stop_fn and stop_fn(best_reward):
|
|
||||||
break
|
|
||||||
|
|
||||||
if test_collector is None and save_fn:
|
|
||||||
save_fn(policy)
|
|
||||||
|
|
||||||
if test_collector is None:
|
|
||||||
return gather_info(start_time, train_collector, None, 0.0, 0.0)
|
|
||||||
else:
|
|
||||||
return gather_info(
|
|
||||||
start_time, train_collector, test_collector, best_reward, best_reward_std
|
|
||||||
)
|
)
|
||||||
|
self.train_collector.reset_buffer(keep_statistics=True)
|
||||||
|
step = max([1] + [len(v) for v in losses.values() if isinstance(v, list)])
|
||||||
|
self.gradient_step += step
|
||||||
|
self.log_update_data(data, losses)
|
||||||
|
|
||||||
|
|
||||||
|
def onpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
|
||||||
|
"""Wrapper for OnpolicyTrainer run method.
|
||||||
|
|
||||||
|
It is identical to ``OnpolicyTrainer(...).run()``.
|
||||||
|
|
||||||
|
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||||
|
"""
|
||||||
|
return OnpolicyTrainer(*args, **kwargs).run()
|
||||||
|
|
||||||
|
|
||||||
|
onpolicy_trainer_iter = OnpolicyTrainer
|
||||||
|
Loading…
x
Reference in New Issue
Block a user