Add Trainers as generators (#559)

The new proposed feature is to have trainers as generators.
The usage pattern is:

```python
trainer = OnPolicyTrainer(...)
for epoch, epoch_stat, info in trainer:
    print(f"Epoch: {epoch}")
    print(epoch_stat)
    print(info)
    do_something_with_policy()
    query_something_about_policy()
    make_a_plot_with(epoch_stat)
    display(info)
```

- epoch int: the epoch number
- epoch_stat dict: a large collection of metrics of the current epoch, including stat
- info dict: the usual dict out of the non-generator version of the trainer

You can even iterate on several different trainers at the same time:

```python
trainer1 = OnPolicyTrainer(...)
trainer2 = OnPolicyTrainer(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
    compare_results(result1, result2, ...)
```

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
Jose Antonio Martin H 2022-03-17 17:26:14 +01:00 committed by GitHub
parent 2336a7db1b
commit 10d919052b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 862 additions and 486 deletions

View File

@ -7,6 +7,6 @@
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, torch, numpy, sys
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
import tianshou, gym, torch, numpy, sys
print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
```

View File

@ -22,10 +22,8 @@ lint:
flake8 ${LINT_PATHS} --count --show-source --statistics
format:
# sort imports
$(call check_install, isort)
isort ${LINT_PATHS}
# reformat using yapf
$(call check_install, yapf)
yapf -ir ${LINT_PATHS}
@ -57,6 +55,6 @@ doc-clean:
clean: doc-clean
commit-checks: format lint mypy check-docstyle spelling
commit-checks: lint check-codestyle mypy check-docstyle spelling
.PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks

View File

@ -1,7 +1,49 @@
tianshou.trainer
================
.. automodule:: tianshou.trainer
On-policy
---------
.. autoclass:: tianshou.trainer.OnpolicyTrainer
:members:
:undoc-members:
:show-inheritance:
.. autofunction:: tianshou.trainer.onpolicy_trainer
.. autoclass:: tianshou.trainer.onpolicy_trainer_iter
Off-policy
----------
.. autoclass:: tianshou.trainer.OffpolicyTrainer
:members:
:undoc-members:
:show-inheritance:
.. autofunction:: tianshou.trainer.offpolicy_trainer
.. autoclass:: tianshou.trainer.offpolicy_trainer_iter
Offline
-------
.. autoclass:: tianshou.trainer.OfflineTrainer
:members:
:undoc-members:
:show-inheritance:
.. autofunction:: tianshou.trainer.offline_trainer
.. autoclass:: tianshou.trainer.offline_trainer_iter
utils
-----
.. autofunction:: tianshou.trainer.test_episode
.. autofunction:: tianshou.trainer.gather_info

View File

@ -24,12 +24,15 @@ fqf
iqn
qrdqn
rl
offpolicy
onpolicy
quantile
quantiles
dqn
param
async
subprocess
deque
nn
equ
cql

View File

@ -380,6 +380,26 @@ Once you have a collector and a policy, you can start writing the training metho
Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.
We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic:
::
trainer = OnpolicyTrainer(...)
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)
do_something_with_policy()
query_something_about_policy()
make_a_plot_with(epoch_stat)
display(info)
# or even iterate on several trainers at the same time
trainer1 = OnpolicyTrainer(...)
trainer2 = OnpolicyTrainer(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
compare_results(result1, result2, ...)
.. _pseudocode:

View File

@ -11,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
@ -157,7 +157,7 @@ def test_ppo(args=get_args()):
print("Fail to restore policy and optim.")
# trainer
result = onpolicy_trainer(
trainer = OnpolicyTrainer(
policy,
train_collector,
test_collector,
@ -173,10 +173,16 @@ def test_ppo(args=get_args()):
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn
)
assert stop_fn(result['best_reward'])
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)
assert stop_fn(info["best_reward"])
if __name__ == '__main__':
pprint.pprint(result)
pprint.pprint(info)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()

View File

@ -24,7 +24,7 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--reward-threshold', type=float, default=None)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)

View File

@ -11,7 +11,7 @@ from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.exploration import GaussianNoise
from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
@ -135,8 +135,8 @@ def test_td3(args=get_args()):
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold
# trainer
result = offpolicy_trainer(
# Iterator trainer
trainer = OffpolicyTrainer(
policy,
train_collector,
test_collector,
@ -148,12 +148,17 @@ def test_td3(args=get_args()):
update_per_step=args.update_per_step,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger
logger=logger,
)
assert stop_fn(result['best_reward'])
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)
if __name__ == '__main__':
pprint.pprint(result)
assert stop_fn(info["best_reward"])
if __name__ == "__main__":
pprint.pprint(info)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()

View File

@ -12,7 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import CQLPolicy
from tianshou.trainer import offline_trainer
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
@ -195,7 +195,7 @@ def test_cql(args=get_args()):
collector.collect(n_episode=1, render=1 / 35)
# trainer
result = offline_trainer(
trainer = OfflineTrainer(
policy,
buffer,
test_collector,
@ -207,11 +207,17 @@ def test_cql(args=get_args()):
stop_fn=stop_fn,
logger=logger,
)
assert stop_fn(result['best_reward'])
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)
assert stop_fn(info["best_reward"])
# Let's watch its performance!
if __name__ == '__main__':
pprint.pprint(result)
if __name__ == "__main__":
pprint.pprint(info)
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)

View File

@ -1,16 +1,34 @@
"""Trainer package."""
# isort:skip_file
from tianshou.trainer.utils import test_episode, gather_info
from tianshou.trainer.onpolicy import onpolicy_trainer
from tianshou.trainer.offpolicy import offpolicy_trainer
from tianshou.trainer.offline import offline_trainer
from tianshou.trainer.base import BaseTrainer
from tianshou.trainer.offline import (
OfflineTrainer,
offline_trainer,
offline_trainer_iter,
)
from tianshou.trainer.offpolicy import (
OffpolicyTrainer,
offpolicy_trainer,
offpolicy_trainer_iter,
)
from tianshou.trainer.onpolicy import (
OnpolicyTrainer,
onpolicy_trainer,
onpolicy_trainer_iter,
)
from tianshou.trainer.utils import gather_info, test_episode
__all__ = [
"BaseTrainer",
"offpolicy_trainer",
"offpolicy_trainer_iter",
"OffpolicyTrainer",
"onpolicy_trainer",
"onpolicy_trainer_iter",
"OnpolicyTrainer",
"offline_trainer",
"offline_trainer_iter",
"OfflineTrainer",
"test_episode",
"gather_info",
]

419
tianshou/trainer/base.py Normal file
View File

@ -0,0 +1,419 @@
import time
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Union
import numpy as np
import tqdm
from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import BasePolicy
from tianshou.trainer.utils import gather_info, test_episode
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
class BaseTrainer(ABC):
"""An iterator base class for trainers procedure.
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results
on every epoch.
:param learning_type str: type of learning iterator, available choices are
"offpolicy", "onpolicy" and "offline".
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector train_collector: the collector used for training.
:param Collector test_collector: the collector used for testing. If it's None,
then no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn``
is set.
:param int step_per_epoch: the number of transitions collected per epoch.
:param int repeat_per_collect: the number of repeat time for policy learning,
for example, set it to 2 means the policy needs to learn each given batch
data twice.
:param int episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to feed in
the policy network.
:param int step_per_collect: the number of transitions the collector would
collect before the network update, i.e., trainer will collect
"step_per_collect" transitions and do some policy network update repeatedly
in each epoch.
:param int episode_per_collect: the number of episodes the collector would
collect before the network update, i.e., trainer will collect
"episode_per_collect" episodes and do some policy network update repeatedly
in each epoch.
:param function train_fn: a hook called at the beginning of training in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function test_fn: a hook called at the beginning of testing in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature
``f(policy: BasePolicy) -> None``.
:param function save_checkpoint_fn: a function to save training process, with
the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
you can save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata
from existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal.
:param function reward_metric: a function with signature
``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray
with shape (num_episode,)``, used in multi-agent RL. We need to return a
single scalar for each episode's result to monitor training in the
multi-agent RL setting. This function specifies what is the desired metric,
e.g., the reward of agent 1 or the average reward over all agents.
:param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool test_in_train: whether to test in the training phase.
Default to True.
"""
@staticmethod
def gen_doc(learning_type: str) -> str:
"""Document string for subclass trainer."""
step_means = f'The "step" in {learning_type} trainer means '
if learning_type != "offline":
step_means += "an environment step (a.k.a. transition)."
else: # offline
step_means += "a gradient step."
trainer_name = learning_type.capitalize() + "Trainer"
return f"""An iterator class for {learning_type} trainer procedure.
Returns an iterator that yields a 3-tuple (epoch, stats, info) of
train results on every epoch.
{step_means}
Example usage:
::
trainer = {trainer_name}(...)
for epoch, epoch_stat, info in trainer:
print("Epoch:", epoch)
print(epoch_stat)
print(info)
do_something_with_policy()
query_something_about_policy()
make_a_plot_with(epoch_stat)
display(info)
- epoch int: the epoch number
- epoch_stat dict: a large collection of metrics of the current epoch
- info dict: result returned from :func:`~tianshou.trainer.gather_info`
You can even iterate on several trainers at the same time:
::
trainer1 = {trainer_name}(...)
trainer2 = {trainer_name}(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
compare_results(result1, result2, ...)
"""
def __init__(
self,
learning_type: str,
policy: BasePolicy,
max_epoch: int,
batch_size: int,
train_collector: Optional[Collector] = None,
test_collector: Optional[Collector] = None,
buffer: Optional[ReplayBuffer] = None,
step_per_epoch: Optional[int] = None,
repeat_per_collect: Optional[int] = None,
episode_per_test: Optional[int] = None,
update_per_step: Union[int, float] = 1,
update_per_epoch: Optional[int] = None,
step_per_collect: Optional[int] = None,
episode_per_collect: Optional[int] = None,
train_fn: Optional[Callable[[int, int], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
test_in_train: bool = True,
):
self.policy = policy
self.buffer = buffer
self.train_collector = train_collector
self.test_collector = test_collector
self.logger = logger
self.start_time = time.time()
self.stat: DefaultDict[str, MovAvg] = defaultdict(MovAvg)
self.best_reward = 0.0
self.best_reward_std = 0.0
self.start_epoch = 0
self.gradient_step = 0
self.env_step = 0
self.max_epoch = max_epoch
self.step_per_epoch = step_per_epoch
# either on of these two
self.step_per_collect = step_per_collect
self.episode_per_collect = episode_per_collect
self.update_per_step = update_per_step
self.repeat_per_collect = repeat_per_collect
self.episode_per_test = episode_per_test
self.batch_size = batch_size
self.train_fn = train_fn
self.test_fn = test_fn
self.stop_fn = stop_fn
self.save_fn = save_fn
self.save_checkpoint_fn = save_checkpoint_fn
self.reward_metric = reward_metric
self.verbose = verbose
self.test_in_train = test_in_train
self.resume_from_log = resume_from_log
self.is_run = False
self.last_rew, self.last_len = 0.0, 0
self.epoch = self.start_epoch
self.best_epoch = self.start_epoch
self.stop_fn_flag = False
self.iter_num = 0
def reset(self) -> None:
"""Initialize or reset the instance to yield a new iterator from zero."""
self.is_run = False
self.env_step = 0
if self.resume_from_log:
self.start_epoch, self.env_step, self.gradient_step = \
self.logger.restore_data()
self.last_rew, self.last_len = 0.0, 0
self.start_time = time.time()
if self.train_collector is not None:
self.train_collector.reset_stat()
if self.train_collector.policy != self.policy:
self.test_in_train = False
elif self.test_collector is None:
self.test_in_train = False
if self.test_collector is not None:
assert self.episode_per_test is not None
self.test_collector.reset_stat()
test_result = test_episode(
self.policy, self.test_collector, self.test_fn, self.start_epoch,
self.episode_per_test, self.logger, self.env_step, self.reward_metric
)
self.best_epoch = self.start_epoch
self.best_reward, self.best_reward_std = \
test_result["rew"], test_result["rew_std"]
if self.save_fn:
self.save_fn(self.policy)
self.epoch = self.start_epoch
self.stop_fn_flag = False
self.iter_num = 0
def __iter__(self): # type: ignore
self.reset()
return self
def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]:
"""Perform one epoch (both train and eval)."""
self.epoch += 1
self.iter_num += 1
if self.iter_num > 1:
# iterator exhaustion check
if self.epoch >= self.max_epoch:
if self.test_collector is None and self.save_fn:
self.save_fn(self.policy)
raise StopIteration
# exit flag 1, when stop_fn succeeds in train_step or test_step
if self.stop_fn_flag:
raise StopIteration
# set policy in train mode
self.policy.train()
epoch_stat: Dict[str, Any] = dict()
# perform n step_per_epoch
with tqdm.tqdm(
total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config
) as t:
while t.n < t.total and not self.stop_fn_flag:
data: Dict[str, Any] = dict()
result: Dict[str, Any] = dict()
if self.train_collector is not None:
data, result, self.stop_fn_flag = self.train_step()
t.update(result["n/st"])
if self.stop_fn_flag:
t.set_postfix(**data)
break
else:
assert self.buffer, "No train_collector or buffer specified"
result["n/ep"] = len(self.buffer)
result["n/st"] = int(self.gradient_step)
t.update()
self.policy_update_fn(data, result)
t.set_postfix(**data)
if t.n <= t.total and not self.stop_fn_flag:
t.update()
if not self.stop_fn_flag:
self.logger.save_data(
self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn
)
# test
if self.test_collector is not None:
test_stat, self.stop_fn_flag = self.test_step()
if not self.is_run:
epoch_stat.update(test_stat)
if not self.is_run:
epoch_stat.update({k: v.get() for k, v in self.stat.items()})
epoch_stat["gradient_step"] = self.gradient_step
epoch_stat.update(
{
"env_step": self.env_step,
"rew": self.last_rew,
"len": int(self.last_len),
"n/ep": int(result["n/ep"]),
"n/st": int(result["n/st"]),
}
)
info = gather_info(
self.start_time, self.train_collector, self.test_collector,
self.best_reward, self.best_reward_std
)
return self.epoch, epoch_stat, info
else:
return None
def test_step(self) -> Tuple[Dict[str, Any], bool]:
"""Perform one testing step."""
assert self.episode_per_test is not None
assert self.test_collector is not None
stop_fn_flag = False
test_result = test_episode(
self.policy, self.test_collector, self.test_fn, self.epoch,
self.episode_per_test, self.logger, self.env_step, self.reward_metric
)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if self.best_epoch < 0 or self.best_reward < rew:
self.best_epoch = self.epoch
self.best_reward = float(rew)
self.best_reward_std = rew_std
if self.save_fn:
self.save_fn(self.policy)
if self.verbose:
print(
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
)
if not self.is_run:
test_stat = {
"test_reward": rew,
"test_reward_std": rew_std,
"best_reward": self.best_reward,
"best_reward_std": self.best_reward_std,
"best_epoch": self.best_epoch
}
else:
test_stat = {}
if self.stop_fn and self.stop_fn(self.best_reward):
stop_fn_flag = True
return test_stat, stop_fn_flag
def train_step(self) -> Tuple[Dict[str, Any], Dict[str, Any], bool]:
"""Perform one training step."""
assert self.episode_per_test is not None
assert self.train_collector is not None
stop_fn_flag = False
if self.train_fn:
self.train_fn(self.epoch, self.env_step)
result = self.train_collector.collect(
n_step=self.step_per_collect, n_episode=self.episode_per_collect
)
if result["n/ep"] > 0 and self.reward_metric:
rew = self.reward_metric(result["rews"])
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
self.env_step += int(result["n/st"])
self.logger.log_train_data(result, self.env_step)
self.last_rew = result["rew"] if result["n/ep"] > 0 else self.last_rew
self.last_len = result["len"] if result["n/ep"] > 0 else self.last_len
data = {
"env_step": str(self.env_step),
"rew": f"{self.last_rew:.2f}",
"len": str(int(self.last_len)),
"n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])),
}
if result["n/ep"] > 0:
if self.test_in_train and self.stop_fn and self.stop_fn(result["rew"]):
assert self.test_collector is not None
test_result = test_episode(
self.policy, self.test_collector, self.test_fn, self.epoch,
self.episode_per_test, self.logger, self.env_step
)
if self.stop_fn(test_result["rew"]):
stop_fn_flag = True
self.best_reward = test_result["rew"]
self.best_reward_std = test_result["rew_std"]
else:
self.policy.train()
return data, result, stop_fn_flag
def log_update_data(self, data: Dict[str, Any], losses: Dict[str, Any]) -> None:
"""Log losses to current logger."""
for k in losses.keys():
self.stat[k].add(losses[k])
losses[k] = self.stat[k].get()
data[k] = f"{losses[k]:.3f}"
self.logger.log_update_data(losses, self.gradient_step)
@abstractmethod
def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None:
"""Policy update function for different trainer implementation.
:param data: information in progress bar.
:param result: collector's return value.
"""
def run(self) -> Dict[str, Union[float, str]]:
"""Consume iterator.
See itertools - recipes. Use functions that consume iterators at C speed
(feed the entire iterator into a zero-length deque).
"""
try:
self.is_run = True
deque(self, maxlen=0) # feed the entire iterator into a zero-length deque
info = gather_info(
self.start_time, self.train_collector, self.test_collector,
self.best_reward, self.best_reward_std
)
finally:
self.is_run = False
return info

View File

@ -1,131 +1,115 @@
import time
from collections import defaultdict
from typing import Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union
import numpy as np
import tqdm
from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import BasePolicy
from tianshou.trainer import gather_info, test_episode
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
from tianshou.trainer.base import BaseTrainer
from tianshou.utils import BaseLogger, LazyLogger
def offline_trainer(
policy: BasePolicy,
buffer: ReplayBuffer,
test_collector: Optional[Collector],
max_epoch: int,
update_per_epoch: int,
episode_per_test: int,
batch_size: int,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
) -> Dict[str, Union[float, str]]:
"""A wrapper for offline trainer procedure.
The "step" in offline trainer means a gradient step.
class OfflineTrainer(BaseTrainer):
"""Create an iterator class for offline training procedure.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector test_collector: the collector used for testing. If it's None, then
no testing will be performed.
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
This buffer must be populated with experiences for offline RL.
:param Collector test_collector: the collector used for testing. If it's None,
then no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
set.
:param int update_per_epoch: the number of policy network updates, so-called
gradient steps, per epoch.
:param episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to feed in
the policy network.
:param function test_fn: a hook called at the beginning of testing in each epoch.
It can be used to perform custom additional operations, with the signature ``f(
num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean reward in
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
None``.
:param function save_checkpoint_fn: a function to save training process, with the
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
save whatever you want. Because offline-RL doesn't have env_step, the env_step
is always 0 here.
:param bool resume_from_log: resume gradient_step and other metadata from existing
tensorboard log. Default to False.
:param function test_fn: a hook called at the beginning of testing in each
epoch.
It can be used to perform custom additional operations, with the signature
``f(num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature
``f(policy: BasePolicy) -> None``.
:param function save_checkpoint_fn: a function to save training process,
with the signature ``f(epoch: int, env_step: int, gradient_step: int) ->
None``; you can save whatever you want. Because offline-RL doesn't have
env_step, the env_step is always 0 here.
:param bool resume_from_log: resume gradient_step and other metadata from
existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal.
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
used in multi-agent RL. We need to return a single scalar for each episode's
result to monitor training in the multi-agent RL setting. This function
specifies what is the desired metric, e.g., the reward of agent 1 or the
average reward over all agents.
:param BaseLogger logger: A logger that logs statistics during updating/testing.
Default to a logger that doesn't log anything.
:param function reward_metric: a function with signature ``f(rewards:
np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape
(num_episode,)``, used in multi-agent RL. We need to return a single scalar
for each episode's result to monitor training in the multi-agent RL
setting. This function specifies what is the desired metric, e.g., the
reward of agent 1 or the average reward over all agents.
:param BaseLogger logger: A logger that logs statistics during
updating/testing. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
"""
__doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:])
def __init__(
self,
policy: BasePolicy,
buffer: ReplayBuffer,
test_collector: Optional[Collector],
max_epoch: int,
update_per_epoch: int,
episode_per_test: int,
batch_size: int,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
):
super().__init__(
learning_type="offline",
policy=policy,
buffer=buffer,
test_collector=test_collector,
max_epoch=max_epoch,
update_per_epoch=update_per_epoch,
step_per_epoch=update_per_epoch,
episode_per_test=episode_per_test,
batch_size=batch_size,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
save_checkpoint_fn=save_checkpoint_fn,
resume_from_log=resume_from_log,
reward_metric=reward_metric,
logger=logger,
verbose=verbose,
)
def policy_update_fn(
self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None
) -> None:
"""Perform one off-line policy update."""
assert self.buffer
self.gradient_step += 1
losses = self.policy.update(self.batch_size, self.buffer)
data.update({"gradient_step": str(self.gradient_step)})
self.log_update_data(data, losses)
def offline_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
"""Wrapper for offline_trainer run method.
It is identical to ``OfflineTrainer(...).run()``.
:return: See :func:`~tianshou.trainer.gather_info`.
"""
start_epoch, gradient_step = 0, 0
if resume_from_log:
start_epoch, _, gradient_step = logger.restore_data()
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
return OfflineTrainer(*args, **kwargs).run()
if test_collector is not None:
test_c: Collector = test_collector
test_collector.reset_stat()
test_result = test_episode(
policy, test_c, test_fn, start_epoch, episode_per_test, logger,
gradient_step, reward_metric
)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if save_fn:
save_fn(policy)
for epoch in range(1 + start_epoch, 1 + max_epoch):
policy.train()
with tqdm.trange(update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t:
for _ in t:
gradient_step += 1
losses = policy.update(batch_size, buffer)
data = {"gradient_step": str(gradient_step)}
for k in losses.keys():
stat[k].add(losses[k])
losses[k] = stat[k].get()
data[k] = f"{losses[k]:.3f}"
logger.log_update_data(losses, gradient_step)
t.set_postfix(**data)
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
# test
if test_collector is not None:
test_result = test_episode(
policy, test_c, test_fn, epoch, episode_per_test, logger,
gradient_step, reward_metric
)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
)
if stop_fn and stop_fn(best_reward):
break
if test_collector is None and save_fn:
save_fn(policy)
if test_collector is None:
return gather_info(start_time, None, None, 0.0, 0.0)
else:
return gather_info(
start_time, None, test_collector, best_reward, best_reward_std
)
offline_trainer_iter = OfflineTrainer

View File

@ -1,193 +1,130 @@
import time
from collections import defaultdict
from typing import Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union
import numpy as np
import tqdm
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from tianshou.trainer import gather_info, test_episode
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
from tianshou.trainer.base import BaseTrainer
from tianshou.utils import BaseLogger, LazyLogger
def offpolicy_trainer(
policy: BasePolicy,
train_collector: Collector,
test_collector: Optional[Collector],
max_epoch: int,
step_per_epoch: int,
step_per_collect: int,
episode_per_test: int,
batch_size: int,
update_per_step: Union[int, float] = 1,
train_fn: Optional[Callable[[int, int], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
"""A wrapper for off-policy trainer procedure.
The "step" in trainer means an environment step (a.k.a. transition).
class OffpolicyTrainer(BaseTrainer):
"""Create an iterator wrapper for off-policy training procedure.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector train_collector: the collector used for training.
:param Collector test_collector: the collector used for testing. If it's None, then
no testing will be performed.
:param Collector test_collector: the collector used for testing. If it's None,
then no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
set.
:param int step_per_epoch: the number of transitions collected per epoch.
:param int step_per_collect: the number of transitions the collector would collect
before the network update, i.e., trainer will collect "step_per_collect"
transitions and do some policy network update repeatedly in each epoch.
:param int step_per_collect: the number of transitions the collector would
collect before the network update, i.e., trainer will collect
"step_per_collect" transitions and do some policy network update repeatedly
in each epoch.
:param episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to feed in the
policy network.
:param int/float update_per_step: the number of times the policy network would be
updated per transition after (step_per_collect) transitions are collected,
e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will
be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are
collected by the collector. Default to 1.
:param function train_fn: a hook called at the beginning of training in each epoch.
It can be used to perform custom additional operations, with the signature ``f(
num_epoch: int, step_idx: int) -> None``.
:param function test_fn: a hook called at the beginning of testing in each epoch.
It can be used to perform custom additional operations, with the signature ``f(
num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean reward in
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
None``.
:param function save_checkpoint_fn: a function to save training process, with the
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata from
existing tensorboard log. Default to False.
:param int batch_size: the batch size of sample data, which is going to feed in
the policy network.
:param int/float update_per_step: the number of times the policy network would
be updated per transition after (step_per_collect) transitions are
collected, e.g., if update_per_step set to 0.3, and step_per_collect is 256
, policy will be updated round(256 * 0.3 = 76.8) = 77 times after 256
transitions are collected by the collector. Default to 1.
:param function train_fn: a hook called at the beginning of training in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function test_fn: a hook called at the beginning of testing in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature
``f(policy: BasePolicy) -> None``.
:param function save_checkpoint_fn: a function to save training process, with
the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
you can save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata
from existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal.
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
used in multi-agent RL. We need to return a single scalar for each episode's
result to monitor training in the multi-agent RL setting. This function
specifies what is the desired metric, e.g., the reward of agent 1 or the
average reward over all agents.
:param function reward_metric: a function with signature
``f(rewards: np.ndarray with shape (num_episode, agent_num)) ->
np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to
return a single scalar for each episode's result to monitor training in the
multi-agent RL setting. This function specifies what is the desired metric,
e.g., the reward of agent 1 or the average reward over all agents.
:param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool test_in_train: whether to test in the training phase. Default to True.
:param bool test_in_train: whether to test in the training phase.
Default to True.
"""
__doc__ = BaseTrainer.gen_doc("offpolicy") + "\n".join(__doc__.split("\n")[1:])
def __init__(
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Optional[Collector],
max_epoch: int,
step_per_epoch: int,
step_per_collect: int,
episode_per_test: int,
batch_size: int,
update_per_step: Union[int, float] = 1,
train_fn: Optional[Callable[[int, int], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
test_in_train: bool = True,
):
super().__init__(
learning_type="offpolicy",
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=max_epoch,
step_per_epoch=step_per_epoch,
step_per_collect=step_per_collect,
episode_per_test=episode_per_test,
batch_size=batch_size,
update_per_step=update_per_step,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
save_checkpoint_fn=save_checkpoint_fn,
resume_from_log=resume_from_log,
reward_metric=reward_metric,
logger=logger,
verbose=verbose,
test_in_train=test_in_train,
)
def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None:
"""Perform off-policy updates."""
assert self.train_collector is not None
for _ in range(round(self.update_per_step * result["n/st"])):
self.gradient_step += 1
losses = self.policy.update(self.batch_size, self.train_collector.buffer)
self.log_update_data(data, losses)
def offpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
"""Wrapper for OffPolicyTrainer run method.
It is identical to ``OffpolicyTrainer(...).run()``.
:return: See :func:`~tianshou.trainer.gather_info`.
"""
start_epoch, env_step, gradient_step = 0, 0, 0
if resume_from_log:
start_epoch, env_step, gradient_step = logger.restore_data()
last_rew, last_len = 0.0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
train_collector.reset_stat()
test_in_train = test_in_train and (
train_collector.policy == policy and test_collector is not None
)
return OffpolicyTrainer(*args, **kwargs).run()
if test_collector is not None:
test_c: Collector = test_collector # for mypy
test_collector.reset_stat()
test_result = test_episode(
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
reward_metric
)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if save_fn:
save_fn(policy)
for epoch in range(1 + start_epoch, 1 + max_epoch):
# train
policy.train()
with tqdm.tqdm(
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
) as t:
while t.n < t.total:
if train_fn:
train_fn(epoch, env_step)
result = train_collector.collect(n_step=step_per_collect)
if result["n/ep"] > 0 and reward_metric:
rew = reward_metric(result["rews"])
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
env_step += int(result["n/st"])
t.update(result["n/st"])
logger.log_train_data(result, env_step)
last_rew = result['rew'] if result["n/ep"] > 0 else last_rew
last_len = result['len'] if result["n/ep"] > 0 else last_len
data = {
"env_step": str(env_step),
"rew": f"{last_rew:.2f}",
"len": str(int(last_len)),
"n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])),
}
if result["n/ep"] > 0:
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
policy, test_c, test_fn, epoch, episode_per_test, logger,
env_step
)
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
logger.save_data(
epoch, env_step, gradient_step, save_checkpoint_fn
)
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
test_result["rew"], test_result["rew_std"]
)
else:
policy.train()
for _ in range(round(update_per_step * result["n/st"])):
gradient_step += 1
losses = policy.update(batch_size, train_collector.buffer)
for k in losses.keys():
stat[k].add(losses[k])
losses[k] = stat[k].get()
data[k] = f"{losses[k]:.3f}"
logger.log_update_data(losses, gradient_step)
t.set_postfix(**data)
if t.n <= t.total:
t.update()
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
# test
if test_collector is not None:
test_result = test_episode(
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
reward_metric
)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
)
if stop_fn and stop_fn(best_reward):
break
if test_collector is None and save_fn:
save_fn(policy)
if test_collector is None:
return gather_info(start_time, train_collector, None, 0.0, 0.0)
else:
return gather_info(
start_time, train_collector, test_collector, best_reward, best_reward_std
)
offpolicy_trainer_iter = OffpolicyTrainer

View File

@ -1,209 +1,147 @@
import time
from collections import defaultdict
from typing import Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union
import numpy as np
import tqdm
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from tianshou.trainer import gather_info, test_episode
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
from tianshou.trainer.base import BaseTrainer
from tianshou.utils import BaseLogger, LazyLogger
def onpolicy_trainer(
policy: BasePolicy,
train_collector: Collector,
test_collector: Optional[Collector],
max_epoch: int,
step_per_epoch: int,
repeat_per_collect: int,
episode_per_test: int,
batch_size: int,
step_per_collect: Optional[int] = None,
episode_per_collect: Optional[int] = None,
train_fn: Optional[Callable[[int, int], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
"""A wrapper for on-policy trainer procedure.
The "step" in trainer means an environment step (a.k.a. transition).
class OnpolicyTrainer(BaseTrainer):
"""Create an iterator wrapper for on-policy training procedure.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param Collector train_collector: the collector used for training.
:param Collector test_collector: the collector used for testing. If it's None, then
no testing will be performed.
:param Collector test_collector: the collector used for testing. If it's None,
then no testing will be performed.
:param int max_epoch: the maximum number of epochs for training. The training
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
set.
:param int step_per_epoch: the number of transitions collected per epoch.
:param int repeat_per_collect: the number of repeat time for policy learning, for
example, set it to 2 means the policy needs to learn each given batch data
twice.
:param int repeat_per_collect: the number of repeat time for policy learning,
for example, set it to 2 means the policy needs to learn each given batch
data twice.
:param int episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to feed in the
policy network.
:param int step_per_collect: the number of transitions the collector would collect
before the network update, i.e., trainer will collect "step_per_collect"
transitions and do some policy network update repeatedly in each epoch.
:param int episode_per_collect: the number of episodes the collector would collect
before the network update, i.e., trainer will collect "episode_per_collect"
episodes and do some policy network update repeatedly in each epoch.
:param function train_fn: a hook called at the beginning of training in each epoch.
It can be used to perform custom additional operations, with the signature ``f(
num_epoch: int, step_idx: int) -> None``.
:param function test_fn: a hook called at the beginning of testing in each epoch.
It can be used to perform custom additional operations, with the signature ``f(
num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean reward in
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
None``.
:param function save_checkpoint_fn: a function to save training process, with the
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata from
existing tensorboard log. Default to False.
:param int batch_size: the batch size of sample data, which is going to feed in
the policy network.
:param int step_per_collect: the number of transitions the collector would
collect before the network update, i.e., trainer will collect
"step_per_collect" transitions and do some policy network update repeatedly
in each epoch.
:param int episode_per_collect: the number of episodes the collector would
collect before the network update, i.e., trainer will collect
"episode_per_collect" episodes and do some policy network update repeatedly
in each epoch.
:param function train_fn: a hook called at the beginning of training in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function test_fn: a hook called at the beginning of testing in each
epoch. It can be used to perform custom additional operations, with the
signature ``f(num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature
``f(policy: BasePolicy) -> None``.
:param function save_checkpoint_fn: a function to save training process,
with the signature ``f(epoch: int, env_step: int, gradient_step: int)
-> None``; you can save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata
from existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal.
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
used in multi-agent RL. We need to return a single scalar for each episode's
result to monitor training in the multi-agent RL setting. This function
specifies what is the desired metric, e.g., the reward of agent 1 or the
average reward over all agents.
:param function reward_metric: a function with signature
``f(rewards: np.ndarray with shape (num_episode, agent_num)) ->
np.ndarray with shape (num_episode,)``, used in multi-agent RL.
We need to return a single scalar for each episode's result to monitor
training in the multi-agent RL setting. This function specifies what is the
desired metric, e.g., the reward of agent 1 or the average reward over
all agents.
:param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool test_in_train: whether to test in the training phase. Default to True.
:return: See :func:`~tianshou.trainer.gather_info`.
:param bool test_in_train: whether to test in the training phase. Default to
True.
.. note::
Only either one of step_per_collect and episode_per_collect can be specified.
"""
start_epoch, env_step, gradient_step = 0, 0, 0
if resume_from_log:
start_epoch, env_step, gradient_step = logger.restore_data()
last_rew, last_len = 0.0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
train_collector.reset_stat()
test_in_train = test_in_train and (
train_collector.policy == policy and test_collector is not None
)
if test_collector is not None:
test_c: Collector = test_collector # for mypy
test_collector.reset_stat()
test_result = test_episode(
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
reward_metric
__doc__ = BaseTrainer.gen_doc("onpolicy") + "\n".join(__doc__.split("\n")[1:])
def __init__(
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Optional[Collector],
max_epoch: int,
step_per_epoch: int,
repeat_per_collect: int,
episode_per_test: int,
batch_size: int,
step_per_collect: Optional[int] = None,
episode_per_collect: Optional[int] = None,
train_fn: Optional[Callable[[int, int], None]] = None,
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
test_in_train: bool = True,
):
super().__init__(
learning_type="onpolicy",
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=max_epoch,
step_per_epoch=step_per_epoch,
repeat_per_collect=repeat_per_collect,
episode_per_test=episode_per_test,
batch_size=batch_size,
step_per_collect=step_per_collect,
episode_per_collect=episode_per_collect,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
save_checkpoint_fn=save_checkpoint_fn,
resume_from_log=resume_from_log,
reward_metric=reward_metric,
logger=logger,
verbose=verbose,
test_in_train=test_in_train,
)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
if save_fn:
save_fn(policy)
for epoch in range(1 + start_epoch, 1 + max_epoch):
# train
policy.train()
with tqdm.tqdm(
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
) as t:
while t.n < t.total:
if train_fn:
train_fn(epoch, env_step)
result = train_collector.collect(
n_step=step_per_collect, n_episode=episode_per_collect
)
if result["n/ep"] > 0 and reward_metric:
rew = reward_metric(result["rews"])
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
env_step += int(result["n/st"])
t.update(result["n/st"])
logger.log_train_data(result, env_step)
last_rew = result['rew'] if result["n/ep"] > 0 else last_rew
last_len = result['len'] if result["n/ep"] > 0 else last_len
data = {
"env_step": str(env_step),
"rew": f"{last_rew:.2f}",
"len": str(int(last_len)),
"n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])),
}
if result["n/ep"] > 0:
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
policy, test_c, test_fn, epoch, episode_per_test, logger,
env_step
)
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
logger.save_data(
epoch, env_step, gradient_step, save_checkpoint_fn
)
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
test_result["rew"], test_result["rew_std"]
)
else:
policy.train()
losses = policy.update(
0,
train_collector.buffer,
batch_size=batch_size,
repeat=repeat_per_collect
)
train_collector.reset_buffer(keep_statistics=True)
step = max(
[1] + [len(v) for v in losses.values() if isinstance(v, list)]
)
gradient_step += step
for k in losses.keys():
stat[k].add(losses[k])
losses[k] = stat[k].get()
data[k] = f"{losses[k]:.3f}"
logger.log_update_data(losses, gradient_step)
t.set_postfix(**data)
if t.n <= t.total:
t.update()
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
# test
if test_collector is not None:
test_result = test_episode(
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
reward_metric
)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
)
if stop_fn and stop_fn(best_reward):
break
if test_collector is None and save_fn:
save_fn(policy)
if test_collector is None:
return gather_info(start_time, train_collector, None, 0.0, 0.0)
else:
return gather_info(
start_time, train_collector, test_collector, best_reward, best_reward_std
def policy_update_fn(
self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None
) -> None:
"""Perform one on-policy update."""
assert self.train_collector is not None
losses = self.policy.update(
0,
self.train_collector.buffer,
batch_size=self.batch_size,
repeat=self.repeat_per_collect,
)
self.train_collector.reset_buffer(keep_statistics=True)
step = max([1] + [len(v) for v in losses.values() if isinstance(v, list)])
self.gradient_step += step
self.log_update_data(data, losses)
def onpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
"""Wrapper for OnpolicyTrainer run method.
It is identical to ``OnpolicyTrainer(...).run()``.
:return: See :func:`~tianshou.trainer.gather_info`.
"""
return OnpolicyTrainer(*args, **kwargs).run()
onpolicy_trainer_iter = OnpolicyTrainer