Add Trainers as generators (#559)

The new proposed feature is to have trainers as generators.
The usage pattern is:

```python
trainer = OnPolicyTrainer(...)
for epoch, epoch_stat, info in trainer:
    print(f"Epoch: {epoch}")
    print(epoch_stat)
    print(info)
    do_something_with_policy()
    query_something_about_policy()
    make_a_plot_with(epoch_stat)
    display(info)
```

- epoch int: the epoch number
- epoch_stat dict: a large collection of metrics of the current epoch, including stat
- info dict: the usual dict out of the non-generator version of the trainer

You can even iterate on several different trainers at the same time:

```python
trainer1 = OnPolicyTrainer(...)
trainer2 = OnPolicyTrainer(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
    compare_results(result1, result2, ...)
```

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
Jose Antonio Martin H 2022-03-17 17:26:14 +01:00 committed by GitHub
parent 2336a7db1b
commit 10d919052b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 862 additions and 486 deletions

View File

@ -7,6 +7,6 @@
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates - [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable: - [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python ```python
import tianshou, torch, numpy, sys import tianshou, gym, torch, numpy, sys
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform) print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
``` ```

View File

@ -22,10 +22,8 @@ lint:
flake8 ${LINT_PATHS} --count --show-source --statistics flake8 ${LINT_PATHS} --count --show-source --statistics
format: format:
# sort imports
$(call check_install, isort) $(call check_install, isort)
isort ${LINT_PATHS} isort ${LINT_PATHS}
# reformat using yapf
$(call check_install, yapf) $(call check_install, yapf)
yapf -ir ${LINT_PATHS} yapf -ir ${LINT_PATHS}
@ -57,6 +55,6 @@ doc-clean:
clean: doc-clean clean: doc-clean
commit-checks: format lint mypy check-docstyle spelling commit-checks: lint check-codestyle mypy check-docstyle spelling
.PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks .PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks

View File

@ -1,7 +1,49 @@
tianshou.trainer tianshou.trainer
================ ================
.. automodule:: tianshou.trainer
On-policy
---------
.. autoclass:: tianshou.trainer.OnpolicyTrainer
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
.. autofunction:: tianshou.trainer.onpolicy_trainer
.. autoclass:: tianshou.trainer.onpolicy_trainer_iter
Off-policy
----------
.. autoclass:: tianshou.trainer.OffpolicyTrainer
:members:
:undoc-members:
:show-inheritance:
.. autofunction:: tianshou.trainer.offpolicy_trainer
.. autoclass:: tianshou.trainer.offpolicy_trainer_iter
Offline
-------
.. autoclass:: tianshou.trainer.OfflineTrainer
:members:
:undoc-members:
:show-inheritance:
.. autofunction:: tianshou.trainer.offline_trainer
.. autoclass:: tianshou.trainer.offline_trainer_iter
utils
-----
.. autofunction:: tianshou.trainer.test_episode
.. autofunction:: tianshou.trainer.gather_info

View File

@ -24,12 +24,15 @@ fqf
iqn iqn
qrdqn qrdqn
rl rl
offpolicy
onpolicy
quantile quantile
quantiles quantiles
dqn dqn
param param
async async
subprocess subprocess
deque
nn nn
equ equ
cql cql

View File

@ -380,6 +380,26 @@ Once you have a collector and a policy, you can start writing the training metho
Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage. Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.
We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic:
::
trainer = OnpolicyTrainer(...)
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)
do_something_with_policy()
query_something_about_policy()
make_a_plot_with(epoch_stat)
display(info)
# or even iterate on several trainers at the same time
trainer1 = OnpolicyTrainer(...)
trainer2 = OnpolicyTrainer(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
compare_results(result1, result2, ...)
.. _pseudocode: .. _pseudocode:

View File

@ -11,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.policy import PPOPolicy from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import OnpolicyTrainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.net.continuous import ActorProb, Critic
@ -157,7 +157,7 @@ def test_ppo(args=get_args()):
print("Fail to restore policy and optim.") print("Fail to restore policy and optim.")
# trainer # trainer
result = onpolicy_trainer( trainer = OnpolicyTrainer(
policy, policy,
train_collector, train_collector,
test_collector, test_collector,
@ -173,10 +173,16 @@ def test_ppo(args=get_args()):
resume_from_log=args.resume, resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn save_checkpoint_fn=save_checkpoint_fn
) )
assert stop_fn(result['best_reward'])
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)
assert stop_fn(info["best_reward"])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(info)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval() policy.eval()

View File

@ -24,7 +24,7 @@ def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--reward-threshold', type=float, default=None) parser.add_argument('--reward-threshold', type=float, default=None)
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=1e-3) parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3)

View File

@ -11,7 +11,7 @@ from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.exploration import GaussianNoise from tianshou.exploration import GaussianNoise
from tianshou.policy import TD3Policy from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic from tianshou.utils.net.continuous import Actor, Critic
@ -135,8 +135,8 @@ def test_td3(args=get_args()):
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold return mean_rewards >= args.reward_threshold
# trainer # Iterator trainer
result = offpolicy_trainer( trainer = OffpolicyTrainer(
policy, policy,
train_collector, train_collector,
test_collector, test_collector,
@ -148,12 +148,17 @@ def test_td3(args=get_args()):
update_per_step=args.update_per_step, update_per_step=args.update_per_step,
stop_fn=stop_fn, stop_fn=stop_fn,
save_fn=save_fn, save_fn=save_fn,
logger=logger logger=logger,
) )
assert stop_fn(result['best_reward']) for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)
if __name__ == '__main__': assert stop_fn(info["best_reward"])
pprint.pprint(result)
if __name__ == "__main__":
pprint.pprint(info)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval() policy.eval()

View File

@ -12,7 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.policy import CQLPolicy from tianshou.policy import CQLPolicy
from tianshou.trainer import offline_trainer from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.net.continuous import ActorProb, Critic
@ -195,7 +195,7 @@ def test_cql(args=get_args()):
collector.collect(n_episode=1, render=1 / 35) collector.collect(n_episode=1, render=1 / 35)
# trainer # trainer
result = offline_trainer( trainer = OfflineTrainer(
policy, policy,
buffer, buffer,
test_collector, test_collector,
@ -207,11 +207,17 @@ def test_cql(args=get_args()):
stop_fn=stop_fn, stop_fn=stop_fn,
logger=logger, logger=logger,
) )
assert stop_fn(result['best_reward'])
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)
assert stop_fn(info["best_reward"])
# Let's watch its performance! # Let's watch its performance!
if __name__ == '__main__': if __name__ == "__main__":
pprint.pprint(result) pprint.pprint(info)
env = gym.make(args.task) env = gym.make(args.task)
policy.eval() policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)

View File

@ -1,16 +1,34 @@
"""Trainer package.""" """Trainer package."""
# isort:skip_file from tianshou.trainer.base import BaseTrainer
from tianshou.trainer.offline import (
from tianshou.trainer.utils import test_episode, gather_info OfflineTrainer,
from tianshou.trainer.onpolicy import onpolicy_trainer offline_trainer,
from tianshou.trainer.offpolicy import offpolicy_trainer offline_trainer_iter,
from tianshou.trainer.offline import offline_trainer )
from tianshou.trainer.offpolicy import (
OffpolicyTrainer,
offpolicy_trainer,
offpolicy_trainer_iter,
)
from tianshou.trainer.onpolicy import (
OnpolicyTrainer,
onpolicy_trainer,
onpolicy_trainer_iter,
)
from tianshou.trainer.utils import gather_info, test_episode
__all__ = [ __all__ = [
"BaseTrainer",
"offpolicy_trainer", "offpolicy_trainer",
"offpolicy_trainer_iter",
"OffpolicyTrainer",
"onpolicy_trainer", "onpolicy_trainer",
"onpolicy_trainer_iter",
"OnpolicyTrainer",
"offline_trainer", "offline_trainer",
"offline_trainer_iter",
"OfflineTrainer",
"test_episode", "test_episode",
"gather_info", "gather_info",
] ]

419
tianshou/trainer/base.py Normal file
View File

@ -0,0 +1,419 @@
import time
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Union
import numpy as np
import tqdm
from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import BasePolicy
from tianshou.trainer.utils import gather_info, test_episode
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
class BaseTrainer(ABC):
"""An iterator base class for trainers procedure.
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results
on every epoch.
:param learning_type str: type of learning iterator, available choices are
"offpolicy", "onpolicy" and "offline".
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector train_collector: the collector used for training.
:param Collector test_collector: the collector used for testing. If it's None,
then no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn``
is set.
:param int step_per_epoch: the number of transitions collected per epoch.
:param int repeat_per_collect: the number of repeat time for policy learning,
for example, set it to 2 means the policy needs to learn each given batch
data twice.
:param int episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to feed in
the policy network.
:param int step_per_collect: the number of transitions the collector would
collect before the network update, i.e., trainer will collect
"step_per_collect" transitions and do some policy network update repeatedly
in each epoch.
:param int episode_per_collect: the number of episodes the collector would
collect before the network update, i.e., trainer will collect
"episode_per_collect" episodes and do some policy network update repeatedly
in each epoch.
:param function train_fn: a hook called at the beginning of training in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function test_fn: a hook called at the beginning of testing in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature
``f(policy: BasePolicy) -> None``.
:param function save_checkpoint_fn: a function to save training process, with
the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
you can save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata
from existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal.
:param function reward_metric: a function with signature
``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray
with shape (num_episode,)``, used in multi-agent RL. We need to return a
single scalar for each episode's result to monitor training in the
multi-agent RL setting. This function specifies what is the desired metric,
e.g., the reward of agent 1 or the average reward over all agents.
:param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool test_in_train: whether to test in the training phase.
Default to True.
"""
@staticmethod
def gen_doc(learning_type: str) -> str:
"""Document string for subclass trainer."""
step_means = f'The "step" in {learning_type} trainer means '
if learning_type != "offline":
step_means += "an environment step (a.k.a. transition)."
else: # offline
step_means += "a gradient step."
trainer_name = learning_type.capitalize() + "Trainer"
return f"""An iterator class for {learning_type} trainer procedure.
Returns an iterator that yields a 3-tuple (epoch, stats, info) of
train results on every epoch.
{step_means}
Example usage:
::
trainer = {trainer_name}(...)
for epoch, epoch_stat, info in trainer:
print("Epoch:", epoch)
print(epoch_stat)
print(info)
do_something_with_policy()
query_something_about_policy()
make_a_plot_with(epoch_stat)
display(info)
- epoch int: the epoch number
- epoch_stat dict: a large collection of metrics of the current epoch
- info dict: result returned from :func:`~tianshou.trainer.gather_info`
You can even iterate on several trainers at the same time:
::
trainer1 = {trainer_name}(...)
trainer2 = {trainer_name}(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
compare_results(result1, result2, ...)
"""
def __init__(
self,
learning_type: str,
policy: BasePolicy,
max_epoch: int,
batch_size: int,
train_collector: Optional[Collector] = None,
test_collector: Optional[Collector] = None,
buffer: Optional[ReplayBuffer] = None,
step_per_epoch: Optional[int] = None,
repeat_per_collect: Optional[int] = None,
episode_per_test: Optional[int] = None,
update_per_step: Union[int, float] = 1,
update_per_epoch: Optional[int] = None,
step_per_collect: Optional[int] = None,
episode_per_collect: Optional[int] = None,
train_fn: Optional[Callable[[int, int], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
test_in_train: bool = True,
):
self.policy = policy
self.buffer = buffer
self.train_collector = train_collector
self.test_collector = test_collector
self.logger = logger
self.start_time = time.time()
self.stat: DefaultDict[str, MovAvg] = defaultdict(MovAvg)
self.best_reward = 0.0
self.best_reward_std = 0.0
self.start_epoch = 0
self.gradient_step = 0
self.env_step = 0
self.max_epoch = max_epoch
self.step_per_epoch = step_per_epoch
# either on of these two
self.step_per_collect = step_per_collect
self.episode_per_collect = episode_per_collect
self.update_per_step = update_per_step
self.repeat_per_collect = repeat_per_collect
self.episode_per_test = episode_per_test
self.batch_size = batch_size
self.train_fn = train_fn
self.test_fn = test_fn
self.stop_fn = stop_fn
self.save_fn = save_fn
self.save_checkpoint_fn = save_checkpoint_fn
self.reward_metric = reward_metric
self.verbose = verbose
self.test_in_train = test_in_train
self.resume_from_log = resume_from_log
self.is_run = False
self.last_rew, self.last_len = 0.0, 0
self.epoch = self.start_epoch
self.best_epoch = self.start_epoch
self.stop_fn_flag = False
self.iter_num = 0
def reset(self) -> None:
"""Initialize or reset the instance to yield a new iterator from zero."""
self.is_run = False
self.env_step = 0
if self.resume_from_log:
self.start_epoch, self.env_step, self.gradient_step = \
self.logger.restore_data()
self.last_rew, self.last_len = 0.0, 0
self.start_time = time.time()
if self.train_collector is not None:
self.train_collector.reset_stat()
if self.train_collector.policy != self.policy:
self.test_in_train = False
elif self.test_collector is None:
self.test_in_train = False
if self.test_collector is not None:
assert self.episode_per_test is not None
self.test_collector.reset_stat()
test_result = test_episode(
self.policy, self.test_collector, self.test_fn, self.start_epoch,
self.episode_per_test, self.logger, self.env_step, self.reward_metric
)
self.best_epoch = self.start_epoch
self.best_reward, self.best_reward_std = \
test_result["rew"], test_result["rew_std"]
if self.save_fn:
self.save_fn(self.policy)
self.epoch = self.start_epoch
self.stop_fn_flag = False
self.iter_num = 0
def __iter__(self): # type: ignore
self.reset()
return self
def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]:
"""Perform one epoch (both train and eval)."""
self.epoch += 1
self.iter_num += 1
if self.iter_num > 1:
# iterator exhaustion check
if self.epoch >= self.max_epoch:
if self.test_collector is None and self.save_fn:
self.save_fn(self.policy)
raise StopIteration
# exit flag 1, when stop_fn succeeds in train_step or test_step
if self.stop_fn_flag:
raise StopIteration
# set policy in train mode
self.policy.train()
epoch_stat: Dict[str, Any] = dict()
# perform n step_per_epoch
with tqdm.tqdm(
total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config
) as t:
while t.n < t.total and not self.stop_fn_flag:
data: Dict[str, Any] = dict()
result: Dict[str, Any] = dict()
if self.train_collector is not None:
data, result, self.stop_fn_flag = self.train_step()
t.update(result["n/st"])
if self.stop_fn_flag:
t.set_postfix(**data)
break
else:
assert self.buffer, "No train_collector or buffer specified"
result["n/ep"] = len(self.buffer)
result["n/st"] = int(self.gradient_step)
t.update()
self.policy_update_fn(data, result)
t.set_postfix(**data)
if t.n <= t.total and not self.stop_fn_flag:
t.update()
if not self.stop_fn_flag:
self.logger.save_data(
self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn
)
# test
if self.test_collector is not None:
test_stat, self.stop_fn_flag = self.test_step()
if not self.is_run:
epoch_stat.update(test_stat)
if not self.is_run:
epoch_stat.update({k: v.get() for k, v in self.stat.items()})
epoch_stat["gradient_step"] = self.gradient_step
epoch_stat.update(
{
"env_step": self.env_step,
"rew": self.last_rew,
"len": int(self.last_len),
"n/ep": int(result["n/ep"]),
"n/st": int(result["n/st"]),
}
)
info = gather_info(
self.start_time, self.train_collector, self.test_collector,
self.best_reward, self.best_reward_std
)
return self.epoch, epoch_stat, info
else:
return None
def test_step(self) -> Tuple[Dict[str, Any], bool]:
"""Perform one testing step."""
assert self.episode_per_test is not None
assert self.test_collector is not None
stop_fn_flag = False
test_result = test_episode(
self.policy, self.test_collector, self.test_fn, self.epoch,
self.episode_per_test, self.logger, self.env_step, self.reward_metric
)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if self.best_epoch < 0 or self.best_reward < rew:
self.best_epoch = self.epoch
self.best_reward = float(rew)
self.best_reward_std = rew_std
if self.save_fn:
self.save_fn(self.policy)
if self.verbose:
print(
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
)
if not self.is_run:
test_stat = {
"test_reward": rew,
"test_reward_std": rew_std,
"best_reward": self.best_reward,
"best_reward_std": self.best_reward_std,
"best_epoch": self.best_epoch
}
else:
test_stat = {}
if self.stop_fn and self.stop_fn(self.best_reward):
stop_fn_flag = True
return test_stat, stop_fn_flag
def train_step(self) -> Tuple[Dict[str, Any], Dict[str, Any], bool]:
"""Perform one training step."""
assert self.episode_per_test is not None
assert self.train_collector is not None
stop_fn_flag = False
if self.train_fn:
self.train_fn(self.epoch, self.env_step)
result = self.train_collector.collect(
n_step=self.step_per_collect, n_episode=self.episode_per_collect
)
if result["n/ep"] > 0 and self.reward_metric:
rew = self.reward_metric(result["rews"])
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
self.env_step += int(result["n/st"])
self.logger.log_train_data(result, self.env_step)
self.last_rew = result["rew"] if result["n/ep"] > 0 else self.last_rew
self.last_len = result["len"] if result["n/ep"] > 0 else self.last_len
data = {
"env_step": str(self.env_step),
"rew": f"{self.last_rew:.2f}",
"len": str(int(self.last_len)),
"n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])),
}
if result["n/ep"] > 0:
if self.test_in_train and self.stop_fn and self.stop_fn(result["rew"]):
assert self.test_collector is not None
test_result = test_episode(
self.policy, self.test_collector, self.test_fn, self.epoch,
self.episode_per_test, self.logger, self.env_step
)
if self.stop_fn(test_result["rew"]):
stop_fn_flag = True
self.best_reward = test_result["rew"]
self.best_reward_std = test_result["rew_std"]
else:
self.policy.train()
return data, result, stop_fn_flag
def log_update_data(self, data: Dict[str, Any], losses: Dict[str, Any]) -> None:
"""Log losses to current logger."""
for k in losses.keys():
self.stat[k].add(losses[k])
losses[k] = self.stat[k].get()
data[k] = f"{losses[k]:.3f}"
self.logger.log_update_data(losses, self.gradient_step)
@abstractmethod
def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None:
"""Policy update function for different trainer implementation.
:param data: information in progress bar.
:param result: collector's return value.
"""
def run(self) -> Dict[str, Union[float, str]]:
"""Consume iterator.
See itertools - recipes. Use functions that consume iterators at C speed
(feed the entire iterator into a zero-length deque).
"""
try:
self.is_run = True
deque(self, maxlen=0) # feed the entire iterator into a zero-length deque
info = gather_info(
self.start_time, self.train_collector, self.test_collector,
self.best_reward, self.best_reward_std
)
finally:
self.is_run = False
return info

View File

@ -1,131 +1,115 @@
import time from typing import Any, Callable, Dict, Optional, Union
from collections import defaultdict
from typing import Callable, Dict, Optional, Union
import numpy as np import numpy as np
import tqdm
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.trainer import gather_info, test_episode from tianshou.trainer.base import BaseTrainer
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config from tianshou.utils import BaseLogger, LazyLogger
def offline_trainer( class OfflineTrainer(BaseTrainer):
policy: BasePolicy, """Create an iterator class for offline training procedure.
buffer: ReplayBuffer,
test_collector: Optional[Collector],
max_epoch: int,
update_per_epoch: int,
episode_per_test: int,
batch_size: int,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
) -> Dict[str, Union[float, str]]:
"""A wrapper for offline trainer procedure.
The "step" in offline trainer means a gradient step.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector test_collector: the collector used for testing. If it's None, then :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
no testing will be performed. This buffer must be populated with experiences for offline RL.
:param Collector test_collector: the collector used for testing. If it's None,
then no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training :param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
set.
:param int update_per_epoch: the number of policy network updates, so-called :param int update_per_epoch: the number of policy network updates, so-called
gradient steps, per epoch. gradient steps, per epoch.
:param episode_per_test: the number of episodes for one policy evaluation. :param episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to feed in :param int batch_size: the batch size of sample data, which is going to feed in
the policy network. the policy network.
:param function test_fn: a hook called at the beginning of testing in each epoch. :param function test_fn: a hook called at the beginning of testing in each
It can be used to perform custom additional operations, with the signature ``f( epoch.
num_epoch: int, step_idx: int) -> None``. It can be used to perform custom additional operations, with the signature
:param function save_fn: a hook called when the undiscounted average mean reward in ``f(num_epoch: int, step_idx: int) -> None``.
evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> :param function save_fn: a hook called when the undiscounted average mean
None``. reward in evaluation phase gets better, with the signature
:param function save_checkpoint_fn: a function to save training process, with the ``f(policy: BasePolicy) -> None``.
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can :param function save_checkpoint_fn: a function to save training process,
save whatever you want. Because offline-RL doesn't have env_step, the env_step with the signature ``f(epoch: int, env_step: int, gradient_step: int) ->
is always 0 here. None``; you can save whatever you want. Because offline-RL doesn't have
:param bool resume_from_log: resume gradient_step and other metadata from existing env_step, the env_step is always 0 here.
tensorboard log. Default to False. :param bool resume_from_log: resume gradient_step and other metadata from
existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) -> :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result, bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal. returns a boolean which indicates whether reaching the goal.
:param function reward_metric: a function with signature ``f(rewards: np.ndarray :param function reward_metric: a function with signature ``f(rewards:
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape
used in multi-agent RL. We need to return a single scalar for each episode's (num_episode,)``, used in multi-agent RL. We need to return a single scalar
result to monitor training in the multi-agent RL setting. This function for each episode's result to monitor training in the multi-agent RL
specifies what is the desired metric, e.g., the reward of agent 1 or the setting. This function specifies what is the desired metric, e.g., the
average reward over all agents. reward of agent 1 or the average reward over all agents.
:param BaseLogger logger: A logger that logs statistics during updating/testing. :param BaseLogger logger: A logger that logs statistics during
Default to a logger that doesn't log anything. updating/testing. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True. :param bool verbose: whether to print the information. Default to True.
"""
__doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:])
def __init__(
self,
policy: BasePolicy,
buffer: ReplayBuffer,
test_collector: Optional[Collector],
max_epoch: int,
update_per_epoch: int,
episode_per_test: int,
batch_size: int,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
):
super().__init__(
learning_type="offline",
policy=policy,
buffer=buffer,
test_collector=test_collector,
max_epoch=max_epoch,
update_per_epoch=update_per_epoch,
step_per_epoch=update_per_epoch,
episode_per_test=episode_per_test,
batch_size=batch_size,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
save_checkpoint_fn=save_checkpoint_fn,
resume_from_log=resume_from_log,
reward_metric=reward_metric,
logger=logger,
verbose=verbose,
)
def policy_update_fn(
self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None
) -> None:
"""Perform one off-line policy update."""
assert self.buffer
self.gradient_step += 1
losses = self.policy.update(self.batch_size, self.buffer)
data.update({"gradient_step": str(self.gradient_step)})
self.log_update_data(data, losses)
def offline_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
"""Wrapper for offline_trainer run method.
It is identical to ``OfflineTrainer(...).run()``.
:return: See :func:`~tianshou.trainer.gather_info`. :return: See :func:`~tianshou.trainer.gather_info`.
""" """
start_epoch, gradient_step = 0, 0 return OfflineTrainer(*args, **kwargs).run()
if resume_from_log:
start_epoch, _, gradient_step = logger.restore_data()
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
if test_collector is not None:
test_c: Collector = test_collector
test_collector.reset_stat()
test_result = test_episode(
policy, test_c, test_fn, start_epoch, episode_per_test, logger,
gradient_step, reward_metric
)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if save_fn:
save_fn(policy)
for epoch in range(1 + start_epoch, 1 + max_epoch): offline_trainer_iter = OfflineTrainer
policy.train()
with tqdm.trange(update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t:
for _ in t:
gradient_step += 1
losses = policy.update(batch_size, buffer)
data = {"gradient_step": str(gradient_step)}
for k in losses.keys():
stat[k].add(losses[k])
losses[k] = stat[k].get()
data[k] = f"{losses[k]:.3f}"
logger.log_update_data(losses, gradient_step)
t.set_postfix(**data)
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
# test
if test_collector is not None:
test_result = test_episode(
policy, test_c, test_fn, epoch, episode_per_test, logger,
gradient_step, reward_metric
)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
)
if stop_fn and stop_fn(best_reward):
break
if test_collector is None and save_fn:
save_fn(policy)
if test_collector is None:
return gather_info(start_time, None, None, 0.0, 0.0)
else:
return gather_info(
start_time, None, test_collector, best_reward, best_reward_std
)

View File

@ -1,193 +1,130 @@
import time from typing import Any, Callable, Dict, Optional, Union
from collections import defaultdict
from typing import Callable, Dict, Optional, Union
import numpy as np import numpy as np
import tqdm
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.trainer import gather_info, test_episode from tianshou.trainer.base import BaseTrainer
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config from tianshou.utils import BaseLogger, LazyLogger
def offpolicy_trainer( class OffpolicyTrainer(BaseTrainer):
policy: BasePolicy, """Create an iterator wrapper for off-policy training procedure.
train_collector: Collector,
test_collector: Optional[Collector],
max_epoch: int,
step_per_epoch: int,
step_per_collect: int,
episode_per_test: int,
batch_size: int,
update_per_step: Union[int, float] = 1,
train_fn: Optional[Callable[[int, int], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
"""A wrapper for off-policy trainer procedure.
The "step" in trainer means an environment step (a.k.a. transition).
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector train_collector: the collector used for training. :param Collector train_collector: the collector used for training.
:param Collector test_collector: the collector used for testing. If it's None, then :param Collector test_collector: the collector used for testing. If it's None,
no testing will be performed. then no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training :param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
set.
:param int step_per_epoch: the number of transitions collected per epoch. :param int step_per_epoch: the number of transitions collected per epoch.
:param int step_per_collect: the number of transitions the collector would collect :param int step_per_collect: the number of transitions the collector would
before the network update, i.e., trainer will collect "step_per_collect" collect before the network update, i.e., trainer will collect
transitions and do some policy network update repeatedly in each epoch. "step_per_collect" transitions and do some policy network update repeatedly
in each epoch.
:param episode_per_test: the number of episodes for one policy evaluation. :param episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to feed in the :param int batch_size: the batch size of sample data, which is going to feed in
policy network. the policy network.
:param int/float update_per_step: the number of times the policy network would be :param int/float update_per_step: the number of times the policy network would
updated per transition after (step_per_collect) transitions are collected, be updated per transition after (step_per_collect) transitions are
e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will collected, e.g., if update_per_step set to 0.3, and step_per_collect is 256
be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are , policy will be updated round(256 * 0.3 = 76.8) = 77 times after 256
collected by the collector. Default to 1. transitions are collected by the collector. Default to 1.
:param function train_fn: a hook called at the beginning of training in each epoch. :param function train_fn: a hook called at the beginning of training in each
It can be used to perform custom additional operations, with the signature ``f( epoch. It can be used to perform custom additional operations, with the
num_epoch: int, step_idx: int) -> None``. signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function test_fn: a hook called at the beginning of testing in each epoch. :param function test_fn: a hook called at the beginning of testing in each
It can be used to perform custom additional operations, with the signature ``f( epoch. It can be used to perform custom additional operations, with the
num_epoch: int, step_idx: int) -> None``. signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean reward in :param function save_fn: a hook called when the undiscounted average mean
evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> reward in evaluation phase gets better, with the signature
None``. ``f(policy: BasePolicy) -> None``.
:param function save_checkpoint_fn: a function to save training process, with the :param function save_checkpoint_fn: a function to save training process, with
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
save whatever you want. you can save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata from :param bool resume_from_log: resume env_step/gradient_step and other metadata
existing tensorboard log. Default to False. from existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) -> :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result, bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal. returns a boolean which indicates whether reaching the goal.
:param function reward_metric: a function with signature ``f(rewards: np.ndarray :param function reward_metric: a function with signature
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, ``f(rewards: np.ndarray with shape (num_episode, agent_num)) ->
used in multi-agent RL. We need to return a single scalar for each episode's np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to
result to monitor training in the multi-agent RL setting. This function return a single scalar for each episode's result to monitor training in the
specifies what is the desired metric, e.g., the reward of agent 1 or the multi-agent RL setting. This function specifies what is the desired metric,
average reward over all agents. e.g., the reward of agent 1 or the average reward over all agents.
:param BaseLogger logger: A logger that logs statistics during :param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything. training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True. :param bool verbose: whether to print the information. Default to True.
:param bool test_in_train: whether to test in the training phase. Default to True. :param bool test_in_train: whether to test in the training phase.
Default to True.
"""
__doc__ = BaseTrainer.gen_doc("offpolicy") + "\n".join(__doc__.split("\n")[1:])
def __init__(
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Optional[Collector],
max_epoch: int,
step_per_epoch: int,
step_per_collect: int,
episode_per_test: int,
batch_size: int,
update_per_step: Union[int, float] = 1,
train_fn: Optional[Callable[[int, int], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
test_in_train: bool = True,
):
super().__init__(
learning_type="offpolicy",
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=max_epoch,
step_per_epoch=step_per_epoch,
step_per_collect=step_per_collect,
episode_per_test=episode_per_test,
batch_size=batch_size,
update_per_step=update_per_step,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
save_checkpoint_fn=save_checkpoint_fn,
resume_from_log=resume_from_log,
reward_metric=reward_metric,
logger=logger,
verbose=verbose,
test_in_train=test_in_train,
)
def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None:
"""Perform off-policy updates."""
assert self.train_collector is not None
for _ in range(round(self.update_per_step * result["n/st"])):
self.gradient_step += 1
losses = self.policy.update(self.batch_size, self.train_collector.buffer)
self.log_update_data(data, losses)
def offpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
"""Wrapper for OffPolicyTrainer run method.
It is identical to ``OffpolicyTrainer(...).run()``.
:return: See :func:`~tianshou.trainer.gather_info`. :return: See :func:`~tianshou.trainer.gather_info`.
""" """
start_epoch, env_step, gradient_step = 0, 0, 0 return OffpolicyTrainer(*args, **kwargs).run()
if resume_from_log:
start_epoch, env_step, gradient_step = logger.restore_data()
last_rew, last_len = 0.0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
train_collector.reset_stat()
test_in_train = test_in_train and (
train_collector.policy == policy and test_collector is not None
)
if test_collector is not None:
test_c: Collector = test_collector # for mypy
test_collector.reset_stat()
test_result = test_episode(
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
reward_metric
)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if save_fn:
save_fn(policy)
for epoch in range(1 + start_epoch, 1 + max_epoch): offpolicy_trainer_iter = OffpolicyTrainer
# train
policy.train()
with tqdm.tqdm(
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
) as t:
while t.n < t.total:
if train_fn:
train_fn(epoch, env_step)
result = train_collector.collect(n_step=step_per_collect)
if result["n/ep"] > 0 and reward_metric:
rew = reward_metric(result["rews"])
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
env_step += int(result["n/st"])
t.update(result["n/st"])
logger.log_train_data(result, env_step)
last_rew = result['rew'] if result["n/ep"] > 0 else last_rew
last_len = result['len'] if result["n/ep"] > 0 else last_len
data = {
"env_step": str(env_step),
"rew": f"{last_rew:.2f}",
"len": str(int(last_len)),
"n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])),
}
if result["n/ep"] > 0:
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
policy, test_c, test_fn, epoch, episode_per_test, logger,
env_step
)
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
logger.save_data(
epoch, env_step, gradient_step, save_checkpoint_fn
)
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
test_result["rew"], test_result["rew_std"]
)
else:
policy.train()
for _ in range(round(update_per_step * result["n/st"])):
gradient_step += 1
losses = policy.update(batch_size, train_collector.buffer)
for k in losses.keys():
stat[k].add(losses[k])
losses[k] = stat[k].get()
data[k] = f"{losses[k]:.3f}"
logger.log_update_data(losses, gradient_step)
t.set_postfix(**data)
if t.n <= t.total:
t.update()
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
# test
if test_collector is not None:
test_result = test_episode(
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
reward_metric
)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
)
if stop_fn and stop_fn(best_reward):
break
if test_collector is None and save_fn:
save_fn(policy)
if test_collector is None:
return gather_info(start_time, train_collector, None, 0.0, 0.0)
else:
return gather_info(
start_time, train_collector, test_collector, best_reward, best_reward_std
)

View File

@ -1,209 +1,147 @@
import time from typing import Any, Callable, Dict, Optional, Union
from collections import defaultdict
from typing import Callable, Dict, Optional, Union
import numpy as np import numpy as np
import tqdm
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.trainer import gather_info, test_episode from tianshou.trainer.base import BaseTrainer
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config from tianshou.utils import BaseLogger, LazyLogger
def onpolicy_trainer( class OnpolicyTrainer(BaseTrainer):
policy: BasePolicy, """Create an iterator wrapper for on-policy training procedure.
train_collector: Collector,
test_collector: Optional[Collector],
max_epoch: int,
step_per_epoch: int,
repeat_per_collect: int,
episode_per_test: int,
batch_size: int,
step_per_collect: Optional[int] = None,
episode_per_collect: Optional[int] = None,
train_fn: Optional[Callable[[int, int], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
"""A wrapper for on-policy trainer procedure.
The "step" in trainer means an environment step (a.k.a. transition).
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector train_collector: the collector used for training. :param Collector train_collector: the collector used for training.
:param Collector test_collector: the collector used for testing. If it's None, then :param Collector test_collector: the collector used for testing. If it's None,
no testing will be performed. then no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training :param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
set.
:param int step_per_epoch: the number of transitions collected per epoch. :param int step_per_epoch: the number of transitions collected per epoch.
:param int repeat_per_collect: the number of repeat time for policy learning, for :param int repeat_per_collect: the number of repeat time for policy learning,
example, set it to 2 means the policy needs to learn each given batch data for example, set it to 2 means the policy needs to learn each given batch
twice. data twice.
:param int episode_per_test: the number of episodes for one policy evaluation. :param int episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to feed in the :param int batch_size: the batch size of sample data, which is going to feed in
policy network. the policy network.
:param int step_per_collect: the number of transitions the collector would collect :param int step_per_collect: the number of transitions the collector would
before the network update, i.e., trainer will collect "step_per_collect" collect before the network update, i.e., trainer will collect
transitions and do some policy network update repeatedly in each epoch. "step_per_collect" transitions and do some policy network update repeatedly
:param int episode_per_collect: the number of episodes the collector would collect in each epoch.
before the network update, i.e., trainer will collect "episode_per_collect" :param int episode_per_collect: the number of episodes the collector would
episodes and do some policy network update repeatedly in each epoch. collect before the network update, i.e., trainer will collect
:param function train_fn: a hook called at the beginning of training in each epoch. "episode_per_collect" episodes and do some policy network update repeatedly
It can be used to perform custom additional operations, with the signature ``f( in each epoch.
num_epoch: int, step_idx: int) -> None``. :param function train_fn: a hook called at the beginning of training in each
:param function test_fn: a hook called at the beginning of testing in each epoch. epoch. It can be used to perform custom additional operations, with the
It can be used to perform custom additional operations, with the signature ``f( signature ``f(num_epoch: int, step_idx: int) -> None``.
num_epoch: int, step_idx: int) -> None``. :param function test_fn: a hook called at the beginning of testing in each
:param function save_fn: a hook called when the undiscounted average mean reward in epoch. It can be used to perform custom additional operations, with the
evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> signature ``f(num_epoch: int, step_idx: int) -> None``.
None``. :param function save_fn: a hook called when the undiscounted average mean
:param function save_checkpoint_fn: a function to save training process, with the reward in evaluation phase gets better, with the signature
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can ``f(policy: BasePolicy) -> None``.
save whatever you want. :param function save_checkpoint_fn: a function to save training process,
:param bool resume_from_log: resume env_step/gradient_step and other metadata from with the signature ``f(epoch: int, env_step: int, gradient_step: int)
existing tensorboard log. Default to False. -> None``; you can save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata
from existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) -> :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result, bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal. returns a boolean which indicates whether reaching the goal.
:param function reward_metric: a function with signature ``f(rewards: np.ndarray :param function reward_metric: a function with signature
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, ``f(rewards: np.ndarray with shape (num_episode, agent_num)) ->
used in multi-agent RL. We need to return a single scalar for each episode's np.ndarray with shape (num_episode,)``, used in multi-agent RL.
result to monitor training in the multi-agent RL setting. This function We need to return a single scalar for each episode's result to monitor
specifies what is the desired metric, e.g., the reward of agent 1 or the training in the multi-agent RL setting. This function specifies what is the
average reward over all agents. desired metric, e.g., the reward of agent 1 or the average reward over
all agents.
:param BaseLogger logger: A logger that logs statistics during :param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything. training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True. :param bool verbose: whether to print the information. Default to True.
:param bool test_in_train: whether to test in the training phase. Default to True. :param bool test_in_train: whether to test in the training phase. Default to
True.
:return: See :func:`~tianshou.trainer.gather_info`.
.. note:: .. note::
Only either one of step_per_collect and episode_per_collect can be specified. Only either one of step_per_collect and episode_per_collect can be specified.
""" """
start_epoch, env_step, gradient_step = 0, 0, 0
if resume_from_log:
start_epoch, env_step, gradient_step = logger.restore_data()
last_rew, last_len = 0.0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
train_collector.reset_stat()
test_in_train = test_in_train and (
train_collector.policy == policy and test_collector is not None
)
if test_collector is not None: __doc__ = BaseTrainer.gen_doc("onpolicy") + "\n".join(__doc__.split("\n")[1:])
test_c: Collector = test_collector # for mypy
test_collector.reset_stat() def __init__(
test_result = test_episode( self,
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, policy: BasePolicy,
reward_metric train_collector: Collector,
test_collector: Optional[Collector],
max_epoch: int,
step_per_epoch: int,
repeat_per_collect: int,
episode_per_test: int,
batch_size: int,
step_per_collect: Optional[int] = None,
episode_per_collect: Optional[int] = None,
train_fn: Optional[Callable[[int, int], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
test_in_train: bool = True,
):
super().__init__(
learning_type="onpolicy",
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=max_epoch,
step_per_epoch=step_per_epoch,
repeat_per_collect=repeat_per_collect,
episode_per_test=episode_per_test,
batch_size=batch_size,
step_per_collect=step_per_collect,
episode_per_collect=episode_per_collect,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
save_checkpoint_fn=save_checkpoint_fn,
resume_from_log=resume_from_log,
reward_metric=reward_metric,
logger=logger,
verbose=verbose,
test_in_train=test_in_train,
) )
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if save_fn:
save_fn(policy)
for epoch in range(1 + start_epoch, 1 + max_epoch): def policy_update_fn(
# train self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None
policy.train() ) -> None:
with tqdm.tqdm( """Perform one on-policy update."""
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config assert self.train_collector is not None
) as t: losses = self.policy.update(
while t.n < t.total: 0,
if train_fn: self.train_collector.buffer,
train_fn(epoch, env_step) batch_size=self.batch_size,
result = train_collector.collect( repeat=self.repeat_per_collect,
n_step=step_per_collect, n_episode=episode_per_collect
)
if result["n/ep"] > 0 and reward_metric:
rew = reward_metric(result["rews"])
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
env_step += int(result["n/st"])
t.update(result["n/st"])
logger.log_train_data(result, env_step)
last_rew = result['rew'] if result["n/ep"] > 0 else last_rew
last_len = result['len'] if result["n/ep"] > 0 else last_len
data = {
"env_step": str(env_step),
"rew": f"{last_rew:.2f}",
"len": str(int(last_len)),
"n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])),
}
if result["n/ep"] > 0:
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
policy, test_c, test_fn, epoch, episode_per_test, logger,
env_step
)
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
logger.save_data(
epoch, env_step, gradient_step, save_checkpoint_fn
)
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
test_result["rew"], test_result["rew_std"]
)
else:
policy.train()
losses = policy.update(
0,
train_collector.buffer,
batch_size=batch_size,
repeat=repeat_per_collect
)
train_collector.reset_buffer(keep_statistics=True)
step = max(
[1] + [len(v) for v in losses.values() if isinstance(v, list)]
)
gradient_step += step
for k in losses.keys():
stat[k].add(losses[k])
losses[k] = stat[k].get()
data[k] = f"{losses[k]:.3f}"
logger.log_update_data(losses, gradient_step)
t.set_postfix(**data)
if t.n <= t.total:
t.update()
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
# test
if test_collector is not None:
test_result = test_episode(
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
reward_metric
)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
)
if stop_fn and stop_fn(best_reward):
break
if test_collector is None and save_fn:
save_fn(policy)
if test_collector is None:
return gather_info(start_time, train_collector, None, 0.0, 0.0)
else:
return gather_info(
start_time, train_collector, test_collector, best_reward, best_reward_std
) )
self.train_collector.reset_buffer(keep_statistics=True)
step = max([1] + [len(v) for v in losses.values() if isinstance(v, list)])
self.gradient_step += step
self.log_update_data(data, losses)
def onpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
"""Wrapper for OnpolicyTrainer run method.
It is identical to ``OnpolicyTrainer(...).run()``.
:return: See :func:`~tianshou.trainer.gather_info`.
"""
return OnpolicyTrainer(*args, **kwargs).run()
onpolicy_trainer_iter = OnpolicyTrainer