Add Trainers as generators (#559)
The new proposed feature is to have trainers as generators. The usage pattern is: ```python trainer = OnPolicyTrainer(...) for epoch, epoch_stat, info in trainer: print(f"Epoch: {epoch}") print(epoch_stat) print(info) do_something_with_policy() query_something_about_policy() make_a_plot_with(epoch_stat) display(info) ``` - epoch int: the epoch number - epoch_stat dict: a large collection of metrics of the current epoch, including stat - info dict: the usual dict out of the non-generator version of the trainer You can even iterate on several different trainers at the same time: ```python trainer1 = OnPolicyTrainer(...) trainer2 = OnPolicyTrainer(...) for result1, result2, ... in zip(trainer1, trainer2, ...): compare_results(result1, result2, ...) ``` Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
parent
2336a7db1b
commit
10d919052b
4
.github/ISSUE_TEMPLATE.md
vendored
4
.github/ISSUE_TEMPLATE.md
vendored
@ -7,6 +7,6 @@
|
||||
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||
```python
|
||||
import tianshou, torch, numpy, sys
|
||||
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
|
||||
import tianshou, gym, torch, numpy, sys
|
||||
print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
|
||||
```
|
||||
|
4
Makefile
4
Makefile
@ -22,10 +22,8 @@ lint:
|
||||
flake8 ${LINT_PATHS} --count --show-source --statistics
|
||||
|
||||
format:
|
||||
# sort imports
|
||||
$(call check_install, isort)
|
||||
isort ${LINT_PATHS}
|
||||
# reformat using yapf
|
||||
$(call check_install, yapf)
|
||||
yapf -ir ${LINT_PATHS}
|
||||
|
||||
@ -57,6 +55,6 @@ doc-clean:
|
||||
|
||||
clean: doc-clean
|
||||
|
||||
commit-checks: format lint mypy check-docstyle spelling
|
||||
commit-checks: lint check-codestyle mypy check-docstyle spelling
|
||||
|
||||
.PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks
|
||||
|
@ -1,7 +1,49 @@
|
||||
tianshou.trainer
|
||||
================
|
||||
|
||||
.. automodule:: tianshou.trainer
|
||||
|
||||
On-policy
|
||||
---------
|
||||
|
||||
.. autoclass:: tianshou.trainer.OnpolicyTrainer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autofunction:: tianshou.trainer.onpolicy_trainer
|
||||
|
||||
.. autoclass:: tianshou.trainer.onpolicy_trainer_iter
|
||||
|
||||
|
||||
Off-policy
|
||||
----------
|
||||
|
||||
.. autoclass:: tianshou.trainer.OffpolicyTrainer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autofunction:: tianshou.trainer.offpolicy_trainer
|
||||
|
||||
.. autoclass:: tianshou.trainer.offpolicy_trainer_iter
|
||||
|
||||
|
||||
Offline
|
||||
-------
|
||||
|
||||
.. autoclass:: tianshou.trainer.OfflineTrainer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autofunction:: tianshou.trainer.offline_trainer
|
||||
|
||||
.. autoclass:: tianshou.trainer.offline_trainer_iter
|
||||
|
||||
|
||||
utils
|
||||
-----
|
||||
|
||||
.. autofunction:: tianshou.trainer.test_episode
|
||||
|
||||
.. autofunction:: tianshou.trainer.gather_info
|
||||
|
@ -24,12 +24,15 @@ fqf
|
||||
iqn
|
||||
qrdqn
|
||||
rl
|
||||
offpolicy
|
||||
onpolicy
|
||||
quantile
|
||||
quantiles
|
||||
dqn
|
||||
param
|
||||
async
|
||||
subprocess
|
||||
deque
|
||||
nn
|
||||
equ
|
||||
cql
|
||||
|
@ -380,6 +380,26 @@ Once you have a collector and a policy, you can start writing the training metho
|
||||
|
||||
Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.
|
||||
|
||||
We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic:
|
||||
::
|
||||
|
||||
trainer = OnpolicyTrainer(...)
|
||||
for epoch, epoch_stat, info in trainer:
|
||||
print(f"Epoch: {epoch}")
|
||||
print(epoch_stat)
|
||||
print(info)
|
||||
do_something_with_policy()
|
||||
query_something_about_policy()
|
||||
make_a_plot_with(epoch_stat)
|
||||
display(info)
|
||||
|
||||
# or even iterate on several trainers at the same time
|
||||
|
||||
trainer1 = OnpolicyTrainer(...)
|
||||
trainer2 = OnpolicyTrainer(...)
|
||||
for result1, result2, ... in zip(trainer1, trainer2, ...):
|
||||
compare_results(result1, result2, ...)
|
||||
|
||||
|
||||
.. _pseudocode:
|
||||
|
||||
|
@ -11,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.trainer import OnpolicyTrainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
@ -157,7 +157,7 @@ def test_ppo(args=get_args()):
|
||||
print("Fail to restore policy and optim.")
|
||||
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
trainer = OnpolicyTrainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
@ -173,10 +173,16 @@ def test_ppo(args=get_args()):
|
||||
resume_from_log=args.resume,
|
||||
save_checkpoint_fn=save_checkpoint_fn
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
for epoch, epoch_stat, info in trainer:
|
||||
print(f"Epoch: {epoch}")
|
||||
print(epoch_stat)
|
||||
print(info)
|
||||
|
||||
assert stop_fn(info["best_reward"])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
pprint.pprint(info)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
|
@ -24,7 +24,7 @@ def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--reward-threshold', type=float, default=None)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||
|
@ -11,7 +11,7 @@ from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.policy import TD3Policy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.trainer import OffpolicyTrainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
@ -135,8 +135,8 @@ def test_td3(args=get_args()):
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
# Iterator trainer
|
||||
trainer = OffpolicyTrainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
@ -148,12 +148,17 @@ def test_td3(args=get_args()):
|
||||
update_per_step=args.update_per_step,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
logger=logger
|
||||
logger=logger,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
for epoch, epoch_stat, info in trainer:
|
||||
print(f"Epoch: {epoch}")
|
||||
print(epoch_stat)
|
||||
print(info)
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
assert stop_fn(info["best_reward"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(info)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
|
@ -12,7 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import CQLPolicy
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.trainer import OfflineTrainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
@ -195,7 +195,7 @@ def test_cql(args=get_args()):
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
# trainer
|
||||
result = offline_trainer(
|
||||
trainer = OfflineTrainer(
|
||||
policy,
|
||||
buffer,
|
||||
test_collector,
|
||||
@ -207,11 +207,17 @@ def test_cql(args=get_args()):
|
||||
stop_fn=stop_fn,
|
||||
logger=logger,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
for epoch, epoch_stat, info in trainer:
|
||||
print(f"Epoch: {epoch}")
|
||||
print(epoch_stat)
|
||||
print(info)
|
||||
|
||||
assert stop_fn(info["best_reward"])
|
||||
|
||||
# Let's watch its performance!
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(info)
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
|
@ -1,16 +1,34 @@
|
||||
"""Trainer package."""
|
||||
|
||||
# isort:skip_file
|
||||
|
||||
from tianshou.trainer.utils import test_episode, gather_info
|
||||
from tianshou.trainer.onpolicy import onpolicy_trainer
|
||||
from tianshou.trainer.offpolicy import offpolicy_trainer
|
||||
from tianshou.trainer.offline import offline_trainer
|
||||
from tianshou.trainer.base import BaseTrainer
|
||||
from tianshou.trainer.offline import (
|
||||
OfflineTrainer,
|
||||
offline_trainer,
|
||||
offline_trainer_iter,
|
||||
)
|
||||
from tianshou.trainer.offpolicy import (
|
||||
OffpolicyTrainer,
|
||||
offpolicy_trainer,
|
||||
offpolicy_trainer_iter,
|
||||
)
|
||||
from tianshou.trainer.onpolicy import (
|
||||
OnpolicyTrainer,
|
||||
onpolicy_trainer,
|
||||
onpolicy_trainer_iter,
|
||||
)
|
||||
from tianshou.trainer.utils import gather_info, test_episode
|
||||
|
||||
__all__ = [
|
||||
"BaseTrainer",
|
||||
"offpolicy_trainer",
|
||||
"offpolicy_trainer_iter",
|
||||
"OffpolicyTrainer",
|
||||
"onpolicy_trainer",
|
||||
"onpolicy_trainer_iter",
|
||||
"OnpolicyTrainer",
|
||||
"offline_trainer",
|
||||
"offline_trainer_iter",
|
||||
"OfflineTrainer",
|
||||
"test_episode",
|
||||
"gather_info",
|
||||
]
|
||||
|
419
tianshou/trainer/base.py
Normal file
419
tianshou/trainer/base.py
Normal file
@ -0,0 +1,419 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer.utils import gather_info, test_episode
|
||||
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
||||
|
||||
|
||||
class BaseTrainer(ABC):
|
||||
"""An iterator base class for trainers procedure.
|
||||
|
||||
Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results
|
||||
on every epoch.
|
||||
|
||||
:param learning_type str: type of learning iterator, available choices are
|
||||
"offpolicy", "onpolicy" and "offline".
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||
:param Collector train_collector: the collector used for training.
|
||||
:param Collector test_collector: the collector used for testing. If it's None,
|
||||
then no testing will be performed.
|
||||
:param int max_epoch: the maximum number of epochs for training. The training
|
||||
process might be finished before reaching ``max_epoch`` if ``stop_fn``
|
||||
is set.
|
||||
:param int step_per_epoch: the number of transitions collected per epoch.
|
||||
:param int repeat_per_collect: the number of repeat time for policy learning,
|
||||
for example, set it to 2 means the policy needs to learn each given batch
|
||||
data twice.
|
||||
:param int episode_per_test: the number of episodes for one policy evaluation.
|
||||
:param int batch_size: the batch size of sample data, which is going to feed in
|
||||
the policy network.
|
||||
:param int step_per_collect: the number of transitions the collector would
|
||||
collect before the network update, i.e., trainer will collect
|
||||
"step_per_collect" transitions and do some policy network update repeatedly
|
||||
in each epoch.
|
||||
:param int episode_per_collect: the number of episodes the collector would
|
||||
collect before the network update, i.e., trainer will collect
|
||||
"episode_per_collect" episodes and do some policy network update repeatedly
|
||||
in each epoch.
|
||||
:param function train_fn: a hook called at the beginning of training in each
|
||||
epoch. It can be used to perform custom additional operations, with the
|
||||
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||
:param function test_fn: a hook called at the beginning of testing in each
|
||||
epoch. It can be used to perform custom additional operations, with the
|
||||
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||
:param function save_fn: a hook called when the undiscounted average mean
|
||||
reward in evaluation phase gets better, with the signature
|
||||
``f(policy: BasePolicy) -> None``.
|
||||
:param function save_checkpoint_fn: a function to save training process, with
|
||||
the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
|
||||
you can save whatever you want.
|
||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
||||
from existing tensorboard log. Default to False.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||
bool``, receives the average undiscounted returns of the testing result,
|
||||
returns a boolean which indicates whether reaching the goal.
|
||||
:param function reward_metric: a function with signature
|
||||
``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray
|
||||
with shape (num_episode,)``, used in multi-agent RL. We need to return a
|
||||
single scalar for each episode's result to monitor training in the
|
||||
multi-agent RL setting. This function specifies what is the desired metric,
|
||||
e.g., the reward of agent 1 or the average reward over all agents.
|
||||
:param BaseLogger logger: A logger that logs statistics during
|
||||
training/testing/updating. Default to a logger that doesn't log anything.
|
||||
:param bool verbose: whether to print the information. Default to True.
|
||||
:param bool test_in_train: whether to test in the training phase.
|
||||
Default to True.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def gen_doc(learning_type: str) -> str:
|
||||
"""Document string for subclass trainer."""
|
||||
step_means = f'The "step" in {learning_type} trainer means '
|
||||
if learning_type != "offline":
|
||||
step_means += "an environment step (a.k.a. transition)."
|
||||
else: # offline
|
||||
step_means += "a gradient step."
|
||||
|
||||
trainer_name = learning_type.capitalize() + "Trainer"
|
||||
|
||||
return f"""An iterator class for {learning_type} trainer procedure.
|
||||
|
||||
Returns an iterator that yields a 3-tuple (epoch, stats, info) of
|
||||
train results on every epoch.
|
||||
|
||||
{step_means}
|
||||
|
||||
Example usage:
|
||||
|
||||
::
|
||||
|
||||
trainer = {trainer_name}(...)
|
||||
for epoch, epoch_stat, info in trainer:
|
||||
print("Epoch:", epoch)
|
||||
print(epoch_stat)
|
||||
print(info)
|
||||
do_something_with_policy()
|
||||
query_something_about_policy()
|
||||
make_a_plot_with(epoch_stat)
|
||||
display(info)
|
||||
|
||||
- epoch int: the epoch number
|
||||
- epoch_stat dict: a large collection of metrics of the current epoch
|
||||
- info dict: result returned from :func:`~tianshou.trainer.gather_info`
|
||||
|
||||
You can even iterate on several trainers at the same time:
|
||||
|
||||
::
|
||||
|
||||
trainer1 = {trainer_name}(...)
|
||||
trainer2 = {trainer_name}(...)
|
||||
for result1, result2, ... in zip(trainer1, trainer2, ...):
|
||||
compare_results(result1, result2, ...)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_type: str,
|
||||
policy: BasePolicy,
|
||||
max_epoch: int,
|
||||
batch_size: int,
|
||||
train_collector: Optional[Collector] = None,
|
||||
test_collector: Optional[Collector] = None,
|
||||
buffer: Optional[ReplayBuffer] = None,
|
||||
step_per_epoch: Optional[int] = None,
|
||||
repeat_per_collect: Optional[int] = None,
|
||||
episode_per_test: Optional[int] = None,
|
||||
update_per_step: Union[int, float] = 1,
|
||||
update_per_epoch: Optional[int] = None,
|
||||
step_per_collect: Optional[int] = None,
|
||||
episode_per_collect: Optional[int] = None,
|
||||
train_fn: Optional[Callable[[int, int], None]] = None,
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
test_in_train: bool = True,
|
||||
):
|
||||
self.policy = policy
|
||||
self.buffer = buffer
|
||||
|
||||
self.train_collector = train_collector
|
||||
self.test_collector = test_collector
|
||||
|
||||
self.logger = logger
|
||||
self.start_time = time.time()
|
||||
self.stat: DefaultDict[str, MovAvg] = defaultdict(MovAvg)
|
||||
self.best_reward = 0.0
|
||||
self.best_reward_std = 0.0
|
||||
self.start_epoch = 0
|
||||
self.gradient_step = 0
|
||||
self.env_step = 0
|
||||
self.max_epoch = max_epoch
|
||||
self.step_per_epoch = step_per_epoch
|
||||
|
||||
# either on of these two
|
||||
self.step_per_collect = step_per_collect
|
||||
self.episode_per_collect = episode_per_collect
|
||||
|
||||
self.update_per_step = update_per_step
|
||||
self.repeat_per_collect = repeat_per_collect
|
||||
|
||||
self.episode_per_test = episode_per_test
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.train_fn = train_fn
|
||||
self.test_fn = test_fn
|
||||
self.stop_fn = stop_fn
|
||||
self.save_fn = save_fn
|
||||
self.save_checkpoint_fn = save_checkpoint_fn
|
||||
|
||||
self.reward_metric = reward_metric
|
||||
self.verbose = verbose
|
||||
self.test_in_train = test_in_train
|
||||
self.resume_from_log = resume_from_log
|
||||
|
||||
self.is_run = False
|
||||
self.last_rew, self.last_len = 0.0, 0
|
||||
|
||||
self.epoch = self.start_epoch
|
||||
self.best_epoch = self.start_epoch
|
||||
self.stop_fn_flag = False
|
||||
self.iter_num = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Initialize or reset the instance to yield a new iterator from zero."""
|
||||
self.is_run = False
|
||||
self.env_step = 0
|
||||
if self.resume_from_log:
|
||||
self.start_epoch, self.env_step, self.gradient_step = \
|
||||
self.logger.restore_data()
|
||||
|
||||
self.last_rew, self.last_len = 0.0, 0
|
||||
self.start_time = time.time()
|
||||
if self.train_collector is not None:
|
||||
self.train_collector.reset_stat()
|
||||
|
||||
if self.train_collector.policy != self.policy:
|
||||
self.test_in_train = False
|
||||
elif self.test_collector is None:
|
||||
self.test_in_train = False
|
||||
|
||||
if self.test_collector is not None:
|
||||
assert self.episode_per_test is not None
|
||||
self.test_collector.reset_stat()
|
||||
test_result = test_episode(
|
||||
self.policy, self.test_collector, self.test_fn, self.start_epoch,
|
||||
self.episode_per_test, self.logger, self.env_step, self.reward_metric
|
||||
)
|
||||
self.best_epoch = self.start_epoch
|
||||
self.best_reward, self.best_reward_std = \
|
||||
test_result["rew"], test_result["rew_std"]
|
||||
if self.save_fn:
|
||||
self.save_fn(self.policy)
|
||||
|
||||
self.epoch = self.start_epoch
|
||||
self.stop_fn_flag = False
|
||||
self.iter_num = 0
|
||||
|
||||
def __iter__(self): # type: ignore
|
||||
self.reset()
|
||||
return self
|
||||
|
||||
def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]:
|
||||
"""Perform one epoch (both train and eval)."""
|
||||
self.epoch += 1
|
||||
self.iter_num += 1
|
||||
|
||||
if self.iter_num > 1:
|
||||
|
||||
# iterator exhaustion check
|
||||
if self.epoch >= self.max_epoch:
|
||||
if self.test_collector is None and self.save_fn:
|
||||
self.save_fn(self.policy)
|
||||
raise StopIteration
|
||||
|
||||
# exit flag 1, when stop_fn succeeds in train_step or test_step
|
||||
if self.stop_fn_flag:
|
||||
raise StopIteration
|
||||
|
||||
# set policy in train mode
|
||||
self.policy.train()
|
||||
|
||||
epoch_stat: Dict[str, Any] = dict()
|
||||
# perform n step_per_epoch
|
||||
with tqdm.tqdm(
|
||||
total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config
|
||||
) as t:
|
||||
while t.n < t.total and not self.stop_fn_flag:
|
||||
data: Dict[str, Any] = dict()
|
||||
result: Dict[str, Any] = dict()
|
||||
if self.train_collector is not None:
|
||||
data, result, self.stop_fn_flag = self.train_step()
|
||||
t.update(result["n/st"])
|
||||
if self.stop_fn_flag:
|
||||
t.set_postfix(**data)
|
||||
break
|
||||
else:
|
||||
assert self.buffer, "No train_collector or buffer specified"
|
||||
result["n/ep"] = len(self.buffer)
|
||||
result["n/st"] = int(self.gradient_step)
|
||||
t.update()
|
||||
|
||||
self.policy_update_fn(data, result)
|
||||
t.set_postfix(**data)
|
||||
|
||||
if t.n <= t.total and not self.stop_fn_flag:
|
||||
t.update()
|
||||
|
||||
if not self.stop_fn_flag:
|
||||
self.logger.save_data(
|
||||
self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn
|
||||
)
|
||||
# test
|
||||
if self.test_collector is not None:
|
||||
test_stat, self.stop_fn_flag = self.test_step()
|
||||
if not self.is_run:
|
||||
epoch_stat.update(test_stat)
|
||||
|
||||
if not self.is_run:
|
||||
epoch_stat.update({k: v.get() for k, v in self.stat.items()})
|
||||
epoch_stat["gradient_step"] = self.gradient_step
|
||||
epoch_stat.update(
|
||||
{
|
||||
"env_step": self.env_step,
|
||||
"rew": self.last_rew,
|
||||
"len": int(self.last_len),
|
||||
"n/ep": int(result["n/ep"]),
|
||||
"n/st": int(result["n/st"]),
|
||||
}
|
||||
)
|
||||
info = gather_info(
|
||||
self.start_time, self.train_collector, self.test_collector,
|
||||
self.best_reward, self.best_reward_std
|
||||
)
|
||||
return self.epoch, epoch_stat, info
|
||||
else:
|
||||
return None
|
||||
|
||||
def test_step(self) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Perform one testing step."""
|
||||
assert self.episode_per_test is not None
|
||||
assert self.test_collector is not None
|
||||
stop_fn_flag = False
|
||||
test_result = test_episode(
|
||||
self.policy, self.test_collector, self.test_fn, self.epoch,
|
||||
self.episode_per_test, self.logger, self.env_step, self.reward_metric
|
||||
)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if self.best_epoch < 0 or self.best_reward < rew:
|
||||
self.best_epoch = self.epoch
|
||||
self.best_reward = float(rew)
|
||||
self.best_reward_std = rew_std
|
||||
if self.save_fn:
|
||||
self.save_fn(self.policy)
|
||||
if self.verbose:
|
||||
print(
|
||||
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
|
||||
f" best_reward: {self.best_reward:.6f} ± "
|
||||
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
|
||||
)
|
||||
if not self.is_run:
|
||||
test_stat = {
|
||||
"test_reward": rew,
|
||||
"test_reward_std": rew_std,
|
||||
"best_reward": self.best_reward,
|
||||
"best_reward_std": self.best_reward_std,
|
||||
"best_epoch": self.best_epoch
|
||||
}
|
||||
else:
|
||||
test_stat = {}
|
||||
if self.stop_fn and self.stop_fn(self.best_reward):
|
||||
stop_fn_flag = True
|
||||
|
||||
return test_stat, stop_fn_flag
|
||||
|
||||
def train_step(self) -> Tuple[Dict[str, Any], Dict[str, Any], bool]:
|
||||
"""Perform one training step."""
|
||||
assert self.episode_per_test is not None
|
||||
assert self.train_collector is not None
|
||||
stop_fn_flag = False
|
||||
if self.train_fn:
|
||||
self.train_fn(self.epoch, self.env_step)
|
||||
result = self.train_collector.collect(
|
||||
n_step=self.step_per_collect, n_episode=self.episode_per_collect
|
||||
)
|
||||
if result["n/ep"] > 0 and self.reward_metric:
|
||||
rew = self.reward_metric(result["rews"])
|
||||
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
|
||||
self.env_step += int(result["n/st"])
|
||||
self.logger.log_train_data(result, self.env_step)
|
||||
self.last_rew = result["rew"] if result["n/ep"] > 0 else self.last_rew
|
||||
self.last_len = result["len"] if result["n/ep"] > 0 else self.last_len
|
||||
data = {
|
||||
"env_step": str(self.env_step),
|
||||
"rew": f"{self.last_rew:.2f}",
|
||||
"len": str(int(self.last_len)),
|
||||
"n/ep": str(int(result["n/ep"])),
|
||||
"n/st": str(int(result["n/st"])),
|
||||
}
|
||||
if result["n/ep"] > 0:
|
||||
if self.test_in_train and self.stop_fn and self.stop_fn(result["rew"]):
|
||||
assert self.test_collector is not None
|
||||
test_result = test_episode(
|
||||
self.policy, self.test_collector, self.test_fn, self.epoch,
|
||||
self.episode_per_test, self.logger, self.env_step
|
||||
)
|
||||
if self.stop_fn(test_result["rew"]):
|
||||
stop_fn_flag = True
|
||||
self.best_reward = test_result["rew"]
|
||||
self.best_reward_std = test_result["rew_std"]
|
||||
else:
|
||||
self.policy.train()
|
||||
|
||||
return data, result, stop_fn_flag
|
||||
|
||||
def log_update_data(self, data: Dict[str, Any], losses: Dict[str, Any]) -> None:
|
||||
"""Log losses to current logger."""
|
||||
for k in losses.keys():
|
||||
self.stat[k].add(losses[k])
|
||||
losses[k] = self.stat[k].get()
|
||||
data[k] = f"{losses[k]:.3f}"
|
||||
self.logger.log_update_data(losses, self.gradient_step)
|
||||
|
||||
@abstractmethod
|
||||
def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None:
|
||||
"""Policy update function for different trainer implementation.
|
||||
|
||||
:param data: information in progress bar.
|
||||
:param result: collector's return value.
|
||||
"""
|
||||
|
||||
def run(self) -> Dict[str, Union[float, str]]:
|
||||
"""Consume iterator.
|
||||
|
||||
See itertools - recipes. Use functions that consume iterators at C speed
|
||||
(feed the entire iterator into a zero-length deque).
|
||||
"""
|
||||
try:
|
||||
self.is_run = True
|
||||
deque(self, maxlen=0) # feed the entire iterator into a zero-length deque
|
||||
info = gather_info(
|
||||
self.start_time, self.train_collector, self.test_collector,
|
||||
self.best_reward, self.best_reward_std
|
||||
)
|
||||
finally:
|
||||
self.is_run = False
|
||||
|
||||
return info
|
@ -1,131 +1,115 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer import gather_info, test_episode
|
||||
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
||||
from tianshou.trainer.base import BaseTrainer
|
||||
from tianshou.utils import BaseLogger, LazyLogger
|
||||
|
||||
|
||||
def offline_trainer(
|
||||
policy: BasePolicy,
|
||||
buffer: ReplayBuffer,
|
||||
test_collector: Optional[Collector],
|
||||
max_epoch: int,
|
||||
update_per_epoch: int,
|
||||
episode_per_test: int,
|
||||
batch_size: int,
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for offline trainer procedure.
|
||||
|
||||
The "step" in offline trainer means a gradient step.
|
||||
class OfflineTrainer(BaseTrainer):
|
||||
"""Create an iterator class for offline training procedure.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||
:param Collector test_collector: the collector used for testing. If it's None, then
|
||||
no testing will be performed.
|
||||
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
|
||||
This buffer must be populated with experiences for offline RL.
|
||||
:param Collector test_collector: the collector used for testing. If it's None,
|
||||
then no testing will be performed.
|
||||
:param int max_epoch: the maximum number of epochs for training. The training
|
||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
|
||||
set.
|
||||
:param int update_per_epoch: the number of policy network updates, so-called
|
||||
gradient steps, per epoch.
|
||||
:param episode_per_test: the number of episodes for one policy evaluation.
|
||||
:param int batch_size: the batch size of sample data, which is going to feed in
|
||||
the policy network.
|
||||
:param function test_fn: a hook called at the beginning of testing in each epoch.
|
||||
It can be used to perform custom additional operations, with the signature ``f(
|
||||
num_epoch: int, step_idx: int) -> None``.
|
||||
:param function save_fn: a hook called when the undiscounted average mean reward in
|
||||
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
|
||||
None``.
|
||||
:param function save_checkpoint_fn: a function to save training process, with the
|
||||
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
|
||||
save whatever you want. Because offline-RL doesn't have env_step, the env_step
|
||||
is always 0 here.
|
||||
:param bool resume_from_log: resume gradient_step and other metadata from existing
|
||||
tensorboard log. Default to False.
|
||||
:param function test_fn: a hook called at the beginning of testing in each
|
||||
epoch.
|
||||
It can be used to perform custom additional operations, with the signature
|
||||
``f(num_epoch: int, step_idx: int) -> None``.
|
||||
:param function save_fn: a hook called when the undiscounted average mean
|
||||
reward in evaluation phase gets better, with the signature
|
||||
``f(policy: BasePolicy) -> None``.
|
||||
:param function save_checkpoint_fn: a function to save training process,
|
||||
with the signature ``f(epoch: int, env_step: int, gradient_step: int) ->
|
||||
None``; you can save whatever you want. Because offline-RL doesn't have
|
||||
env_step, the env_step is always 0 here.
|
||||
:param bool resume_from_log: resume gradient_step and other metadata from
|
||||
existing tensorboard log. Default to False.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||
bool``, receives the average undiscounted returns of the testing result,
|
||||
returns a boolean which indicates whether reaching the goal.
|
||||
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
|
||||
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
|
||||
used in multi-agent RL. We need to return a single scalar for each episode's
|
||||
result to monitor training in the multi-agent RL setting. This function
|
||||
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
||||
average reward over all agents.
|
||||
:param BaseLogger logger: A logger that logs statistics during updating/testing.
|
||||
Default to a logger that doesn't log anything.
|
||||
:param function reward_metric: a function with signature ``f(rewards:
|
||||
np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape
|
||||
(num_episode,)``, used in multi-agent RL. We need to return a single scalar
|
||||
for each episode's result to monitor training in the multi-agent RL
|
||||
setting. This function specifies what is the desired metric, e.g., the
|
||||
reward of agent 1 or the average reward over all agents.
|
||||
:param BaseLogger logger: A logger that logs statistics during
|
||||
updating/testing. Default to a logger that doesn't log anything.
|
||||
:param bool verbose: whether to print the information. Default to True.
|
||||
"""
|
||||
|
||||
__doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
buffer: ReplayBuffer,
|
||||
test_collector: Optional[Collector],
|
||||
max_epoch: int,
|
||||
update_per_epoch: int,
|
||||
episode_per_test: int,
|
||||
batch_size: int,
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
learning_type="offline",
|
||||
policy=policy,
|
||||
buffer=buffer,
|
||||
test_collector=test_collector,
|
||||
max_epoch=max_epoch,
|
||||
update_per_epoch=update_per_epoch,
|
||||
step_per_epoch=update_per_epoch,
|
||||
episode_per_test=episode_per_test,
|
||||
batch_size=batch_size,
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
save_checkpoint_fn=save_checkpoint_fn,
|
||||
resume_from_log=resume_from_log,
|
||||
reward_metric=reward_metric,
|
||||
logger=logger,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
def policy_update_fn(
|
||||
self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Perform one off-line policy update."""
|
||||
assert self.buffer
|
||||
self.gradient_step += 1
|
||||
losses = self.policy.update(self.batch_size, self.buffer)
|
||||
data.update({"gradient_step": str(self.gradient_step)})
|
||||
self.log_update_data(data, losses)
|
||||
|
||||
|
||||
def offline_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
|
||||
"""Wrapper for offline_trainer run method.
|
||||
|
||||
It is identical to ``OfflineTrainer(...).run()``.
|
||||
|
||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||
"""
|
||||
start_epoch, gradient_step = 0, 0
|
||||
if resume_from_log:
|
||||
start_epoch, _, gradient_step = logger.restore_data()
|
||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||
start_time = time.time()
|
||||
return OfflineTrainer(*args, **kwargs).run()
|
||||
|
||||
if test_collector is not None:
|
||||
test_c: Collector = test_collector
|
||||
test_collector.reset_stat()
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, start_epoch, episode_per_test, logger,
|
||||
gradient_step, reward_metric
|
||||
)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
||||
policy.train()
|
||||
with tqdm.trange(update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t:
|
||||
for _ in t:
|
||||
gradient_step += 1
|
||||
losses = policy.update(batch_size, buffer)
|
||||
data = {"gradient_step": str(gradient_step)}
|
||||
for k in losses.keys():
|
||||
stat[k].add(losses[k])
|
||||
losses[k] = stat[k].get()
|
||||
data[k] = f"{losses[k]:.3f}"
|
||||
logger.log_update_data(losses, gradient_step)
|
||||
t.set_postfix(**data)
|
||||
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
|
||||
# test
|
||||
if test_collector is not None:
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, epoch, episode_per_test, logger,
|
||||
gradient_step, reward_metric
|
||||
)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if best_epoch < 0 or best_reward < rew:
|
||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
|
||||
)
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
|
||||
if test_collector is None and save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
if test_collector is None:
|
||||
return gather_info(start_time, None, None, 0.0, 0.0)
|
||||
else:
|
||||
return gather_info(
|
||||
start_time, None, test_collector, best_reward, best_reward_std
|
||||
)
|
||||
offline_trainer_iter = OfflineTrainer
|
||||
|
@ -1,193 +1,130 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer import gather_info, test_episode
|
||||
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
||||
from tianshou.trainer.base import BaseTrainer
|
||||
from tianshou.utils import BaseLogger, LazyLogger
|
||||
|
||||
|
||||
def offpolicy_trainer(
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Optional[Collector],
|
||||
max_epoch: int,
|
||||
step_per_epoch: int,
|
||||
step_per_collect: int,
|
||||
episode_per_test: int,
|
||||
batch_size: int,
|
||||
update_per_step: Union[int, float] = 1,
|
||||
train_fn: Optional[Callable[[int, int], None]] = None,
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
test_in_train: bool = True,
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for off-policy trainer procedure.
|
||||
|
||||
The "step" in trainer means an environment step (a.k.a. transition).
|
||||
class OffpolicyTrainer(BaseTrainer):
|
||||
"""Create an iterator wrapper for off-policy training procedure.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||
:param Collector train_collector: the collector used for training.
|
||||
:param Collector test_collector: the collector used for testing. If it's None, then
|
||||
no testing will be performed.
|
||||
:param Collector test_collector: the collector used for testing. If it's None,
|
||||
then no testing will be performed.
|
||||
:param int max_epoch: the maximum number of epochs for training. The training
|
||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
|
||||
set.
|
||||
:param int step_per_epoch: the number of transitions collected per epoch.
|
||||
:param int step_per_collect: the number of transitions the collector would collect
|
||||
before the network update, i.e., trainer will collect "step_per_collect"
|
||||
transitions and do some policy network update repeatedly in each epoch.
|
||||
:param int step_per_collect: the number of transitions the collector would
|
||||
collect before the network update, i.e., trainer will collect
|
||||
"step_per_collect" transitions and do some policy network update repeatedly
|
||||
in each epoch.
|
||||
:param episode_per_test: the number of episodes for one policy evaluation.
|
||||
:param int batch_size: the batch size of sample data, which is going to feed in the
|
||||
policy network.
|
||||
:param int/float update_per_step: the number of times the policy network would be
|
||||
updated per transition after (step_per_collect) transitions are collected,
|
||||
e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will
|
||||
be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are
|
||||
collected by the collector. Default to 1.
|
||||
:param function train_fn: a hook called at the beginning of training in each epoch.
|
||||
It can be used to perform custom additional operations, with the signature ``f(
|
||||
num_epoch: int, step_idx: int) -> None``.
|
||||
:param function test_fn: a hook called at the beginning of testing in each epoch.
|
||||
It can be used to perform custom additional operations, with the signature ``f(
|
||||
num_epoch: int, step_idx: int) -> None``.
|
||||
:param function save_fn: a hook called when the undiscounted average mean reward in
|
||||
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
|
||||
None``.
|
||||
:param function save_checkpoint_fn: a function to save training process, with the
|
||||
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
|
||||
save whatever you want.
|
||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata from
|
||||
existing tensorboard log. Default to False.
|
||||
:param int batch_size: the batch size of sample data, which is going to feed in
|
||||
the policy network.
|
||||
:param int/float update_per_step: the number of times the policy network would
|
||||
be updated per transition after (step_per_collect) transitions are
|
||||
collected, e.g., if update_per_step set to 0.3, and step_per_collect is 256
|
||||
, policy will be updated round(256 * 0.3 = 76.8) = 77 times after 256
|
||||
transitions are collected by the collector. Default to 1.
|
||||
:param function train_fn: a hook called at the beginning of training in each
|
||||
epoch. It can be used to perform custom additional operations, with the
|
||||
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||
:param function test_fn: a hook called at the beginning of testing in each
|
||||
epoch. It can be used to perform custom additional operations, with the
|
||||
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||
:param function save_fn: a hook called when the undiscounted average mean
|
||||
reward in evaluation phase gets better, with the signature
|
||||
``f(policy: BasePolicy) -> None``.
|
||||
:param function save_checkpoint_fn: a function to save training process, with
|
||||
the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
|
||||
you can save whatever you want.
|
||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
||||
from existing tensorboard log. Default to False.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||
bool``, receives the average undiscounted returns of the testing result,
|
||||
returns a boolean which indicates whether reaching the goal.
|
||||
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
|
||||
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
|
||||
used in multi-agent RL. We need to return a single scalar for each episode's
|
||||
result to monitor training in the multi-agent RL setting. This function
|
||||
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
||||
average reward over all agents.
|
||||
:param function reward_metric: a function with signature
|
||||
``f(rewards: np.ndarray with shape (num_episode, agent_num)) ->
|
||||
np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to
|
||||
return a single scalar for each episode's result to monitor training in the
|
||||
multi-agent RL setting. This function specifies what is the desired metric,
|
||||
e.g., the reward of agent 1 or the average reward over all agents.
|
||||
:param BaseLogger logger: A logger that logs statistics during
|
||||
training/testing/updating. Default to a logger that doesn't log anything.
|
||||
:param bool verbose: whether to print the information. Default to True.
|
||||
:param bool test_in_train: whether to test in the training phase. Default to True.
|
||||
:param bool test_in_train: whether to test in the training phase.
|
||||
Default to True.
|
||||
"""
|
||||
|
||||
__doc__ = BaseTrainer.gen_doc("offpolicy") + "\n".join(__doc__.split("\n")[1:])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Optional[Collector],
|
||||
max_epoch: int,
|
||||
step_per_epoch: int,
|
||||
step_per_collect: int,
|
||||
episode_per_test: int,
|
||||
batch_size: int,
|
||||
update_per_step: Union[int, float] = 1,
|
||||
train_fn: Optional[Callable[[int, int], None]] = None,
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
test_in_train: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
learning_type="offpolicy",
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
test_collector=test_collector,
|
||||
max_epoch=max_epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
step_per_collect=step_per_collect,
|
||||
episode_per_test=episode_per_test,
|
||||
batch_size=batch_size,
|
||||
update_per_step=update_per_step,
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
save_checkpoint_fn=save_checkpoint_fn,
|
||||
resume_from_log=resume_from_log,
|
||||
reward_metric=reward_metric,
|
||||
logger=logger,
|
||||
verbose=verbose,
|
||||
test_in_train=test_in_train,
|
||||
)
|
||||
|
||||
def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None:
|
||||
"""Perform off-policy updates."""
|
||||
assert self.train_collector is not None
|
||||
for _ in range(round(self.update_per_step * result["n/st"])):
|
||||
self.gradient_step += 1
|
||||
losses = self.policy.update(self.batch_size, self.train_collector.buffer)
|
||||
self.log_update_data(data, losses)
|
||||
|
||||
|
||||
def offpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
|
||||
"""Wrapper for OffPolicyTrainer run method.
|
||||
|
||||
It is identical to ``OffpolicyTrainer(...).run()``.
|
||||
|
||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||
"""
|
||||
start_epoch, env_step, gradient_step = 0, 0, 0
|
||||
if resume_from_log:
|
||||
start_epoch, env_step, gradient_step = logger.restore_data()
|
||||
last_rew, last_len = 0.0, 0
|
||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||
start_time = time.time()
|
||||
train_collector.reset_stat()
|
||||
test_in_train = test_in_train and (
|
||||
train_collector.policy == policy and test_collector is not None
|
||||
)
|
||||
return OffpolicyTrainer(*args, **kwargs).run()
|
||||
|
||||
if test_collector is not None:
|
||||
test_c: Collector = test_collector # for mypy
|
||||
test_collector.reset_stat()
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
|
||||
reward_metric
|
||||
)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
||||
# train
|
||||
policy.train()
|
||||
with tqdm.tqdm(
|
||||
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
||||
) as t:
|
||||
while t.n < t.total:
|
||||
if train_fn:
|
||||
train_fn(epoch, env_step)
|
||||
result = train_collector.collect(n_step=step_per_collect)
|
||||
if result["n/ep"] > 0 and reward_metric:
|
||||
rew = reward_metric(result["rews"])
|
||||
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
|
||||
env_step += int(result["n/st"])
|
||||
t.update(result["n/st"])
|
||||
logger.log_train_data(result, env_step)
|
||||
last_rew = result['rew'] if result["n/ep"] > 0 else last_rew
|
||||
last_len = result['len'] if result["n/ep"] > 0 else last_len
|
||||
data = {
|
||||
"env_step": str(env_step),
|
||||
"rew": f"{last_rew:.2f}",
|
||||
"len": str(int(last_len)),
|
||||
"n/ep": str(int(result["n/ep"])),
|
||||
"n/st": str(int(result["n/st"])),
|
||||
}
|
||||
if result["n/ep"] > 0:
|
||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, epoch, episode_per_test, logger,
|
||||
env_step
|
||||
)
|
||||
if stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
logger.save_data(
|
||||
epoch, env_step, gradient_step, save_checkpoint_fn
|
||||
)
|
||||
t.set_postfix(**data)
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector,
|
||||
test_result["rew"], test_result["rew_std"]
|
||||
)
|
||||
else:
|
||||
policy.train()
|
||||
for _ in range(round(update_per_step * result["n/st"])):
|
||||
gradient_step += 1
|
||||
losses = policy.update(batch_size, train_collector.buffer)
|
||||
for k in losses.keys():
|
||||
stat[k].add(losses[k])
|
||||
losses[k] = stat[k].get()
|
||||
data[k] = f"{losses[k]:.3f}"
|
||||
logger.log_update_data(losses, gradient_step)
|
||||
t.set_postfix(**data)
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
||||
# test
|
||||
if test_collector is not None:
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
|
||||
reward_metric
|
||||
)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if best_epoch < 0 or best_reward < rew:
|
||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
|
||||
)
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
|
||||
if test_collector is None and save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
if test_collector is None:
|
||||
return gather_info(start_time, train_collector, None, 0.0, 0.0)
|
||||
else:
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector, best_reward, best_reward_std
|
||||
)
|
||||
offpolicy_trainer_iter = OffpolicyTrainer
|
||||
|
@ -1,209 +1,147 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer import gather_info, test_episode
|
||||
from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
|
||||
from tianshou.trainer.base import BaseTrainer
|
||||
from tianshou.utils import BaseLogger, LazyLogger
|
||||
|
||||
|
||||
def onpolicy_trainer(
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Optional[Collector],
|
||||
max_epoch: int,
|
||||
step_per_epoch: int,
|
||||
repeat_per_collect: int,
|
||||
episode_per_test: int,
|
||||
batch_size: int,
|
||||
step_per_collect: Optional[int] = None,
|
||||
episode_per_collect: Optional[int] = None,
|
||||
train_fn: Optional[Callable[[int, int], None]] = None,
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
test_in_train: bool = True,
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for on-policy trainer procedure.
|
||||
|
||||
The "step" in trainer means an environment step (a.k.a. transition).
|
||||
class OnpolicyTrainer(BaseTrainer):
|
||||
"""Create an iterator wrapper for on-policy training procedure.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||
:param Collector train_collector: the collector used for training.
|
||||
:param Collector test_collector: the collector used for testing. If it's None, then
|
||||
no testing will be performed.
|
||||
:param Collector test_collector: the collector used for testing. If it's None,
|
||||
then no testing will be performed.
|
||||
:param int max_epoch: the maximum number of epochs for training. The training
|
||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
|
||||
process might be finished before reaching ``max_epoch`` if ``stop_fn`` is
|
||||
set.
|
||||
:param int step_per_epoch: the number of transitions collected per epoch.
|
||||
:param int repeat_per_collect: the number of repeat time for policy learning, for
|
||||
example, set it to 2 means the policy needs to learn each given batch data
|
||||
twice.
|
||||
:param int repeat_per_collect: the number of repeat time for policy learning,
|
||||
for example, set it to 2 means the policy needs to learn each given batch
|
||||
data twice.
|
||||
:param int episode_per_test: the number of episodes for one policy evaluation.
|
||||
:param int batch_size: the batch size of sample data, which is going to feed in the
|
||||
policy network.
|
||||
:param int step_per_collect: the number of transitions the collector would collect
|
||||
before the network update, i.e., trainer will collect "step_per_collect"
|
||||
transitions and do some policy network update repeatedly in each epoch.
|
||||
:param int episode_per_collect: the number of episodes the collector would collect
|
||||
before the network update, i.e., trainer will collect "episode_per_collect"
|
||||
episodes and do some policy network update repeatedly in each epoch.
|
||||
:param function train_fn: a hook called at the beginning of training in each epoch.
|
||||
It can be used to perform custom additional operations, with the signature ``f(
|
||||
num_epoch: int, step_idx: int) -> None``.
|
||||
:param function test_fn: a hook called at the beginning of testing in each epoch.
|
||||
It can be used to perform custom additional operations, with the signature ``f(
|
||||
num_epoch: int, step_idx: int) -> None``.
|
||||
:param function save_fn: a hook called when the undiscounted average mean reward in
|
||||
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
|
||||
None``.
|
||||
:param function save_checkpoint_fn: a function to save training process, with the
|
||||
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
|
||||
save whatever you want.
|
||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata from
|
||||
existing tensorboard log. Default to False.
|
||||
:param int batch_size: the batch size of sample data, which is going to feed in
|
||||
the policy network.
|
||||
:param int step_per_collect: the number of transitions the collector would
|
||||
collect before the network update, i.e., trainer will collect
|
||||
"step_per_collect" transitions and do some policy network update repeatedly
|
||||
in each epoch.
|
||||
:param int episode_per_collect: the number of episodes the collector would
|
||||
collect before the network update, i.e., trainer will collect
|
||||
"episode_per_collect" episodes and do some policy network update repeatedly
|
||||
in each epoch.
|
||||
:param function train_fn: a hook called at the beginning of training in each
|
||||
epoch. It can be used to perform custom additional operations, with the
|
||||
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||
:param function test_fn: a hook called at the beginning of testing in each
|
||||
epoch. It can be used to perform custom additional operations, with the
|
||||
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||
:param function save_fn: a hook called when the undiscounted average mean
|
||||
reward in evaluation phase gets better, with the signature
|
||||
``f(policy: BasePolicy) -> None``.
|
||||
:param function save_checkpoint_fn: a function to save training process,
|
||||
with the signature ``f(epoch: int, env_step: int, gradient_step: int)
|
||||
-> None``; you can save whatever you want.
|
||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
||||
from existing tensorboard log. Default to False.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||
bool``, receives the average undiscounted returns of the testing result,
|
||||
returns a boolean which indicates whether reaching the goal.
|
||||
:param function reward_metric: a function with signature ``f(rewards: np.ndarray
|
||||
with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
|
||||
used in multi-agent RL. We need to return a single scalar for each episode's
|
||||
result to monitor training in the multi-agent RL setting. This function
|
||||
specifies what is the desired metric, e.g., the reward of agent 1 or the
|
||||
average reward over all agents.
|
||||
:param function reward_metric: a function with signature
|
||||
``f(rewards: np.ndarray with shape (num_episode, agent_num)) ->
|
||||
np.ndarray with shape (num_episode,)``, used in multi-agent RL.
|
||||
We need to return a single scalar for each episode's result to monitor
|
||||
training in the multi-agent RL setting. This function specifies what is the
|
||||
desired metric, e.g., the reward of agent 1 or the average reward over
|
||||
all agents.
|
||||
:param BaseLogger logger: A logger that logs statistics during
|
||||
training/testing/updating. Default to a logger that doesn't log anything.
|
||||
:param bool verbose: whether to print the information. Default to True.
|
||||
:param bool test_in_train: whether to test in the training phase. Default to True.
|
||||
|
||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||
:param bool test_in_train: whether to test in the training phase. Default to
|
||||
True.
|
||||
|
||||
.. note::
|
||||
|
||||
Only either one of step_per_collect and episode_per_collect can be specified.
|
||||
"""
|
||||
start_epoch, env_step, gradient_step = 0, 0, 0
|
||||
if resume_from_log:
|
||||
start_epoch, env_step, gradient_step = logger.restore_data()
|
||||
last_rew, last_len = 0.0, 0
|
||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||
start_time = time.time()
|
||||
train_collector.reset_stat()
|
||||
test_in_train = test_in_train and (
|
||||
train_collector.policy == policy and test_collector is not None
|
||||
)
|
||||
|
||||
if test_collector is not None:
|
||||
test_c: Collector = test_collector # for mypy
|
||||
test_collector.reset_stat()
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step,
|
||||
reward_metric
|
||||
__doc__ = BaseTrainer.gen_doc("onpolicy") + "\n".join(__doc__.split("\n")[1:])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Optional[Collector],
|
||||
max_epoch: int,
|
||||
step_per_epoch: int,
|
||||
repeat_per_collect: int,
|
||||
episode_per_test: int,
|
||||
batch_size: int,
|
||||
step_per_collect: Optional[int] = None,
|
||||
episode_per_collect: Optional[int] = None,
|
||||
train_fn: Optional[Callable[[int, int], None]] = None,
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
test_in_train: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
learning_type="onpolicy",
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
test_collector=test_collector,
|
||||
max_epoch=max_epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
episode_per_test=episode_per_test,
|
||||
batch_size=batch_size,
|
||||
step_per_collect=step_per_collect,
|
||||
episode_per_collect=episode_per_collect,
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
save_checkpoint_fn=save_checkpoint_fn,
|
||||
resume_from_log=resume_from_log,
|
||||
reward_metric=reward_metric,
|
||||
logger=logger,
|
||||
verbose=verbose,
|
||||
test_in_train=test_in_train,
|
||||
)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
||||
# train
|
||||
policy.train()
|
||||
with tqdm.tqdm(
|
||||
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
||||
) as t:
|
||||
while t.n < t.total:
|
||||
if train_fn:
|
||||
train_fn(epoch, env_step)
|
||||
result = train_collector.collect(
|
||||
n_step=step_per_collect, n_episode=episode_per_collect
|
||||
)
|
||||
if result["n/ep"] > 0 and reward_metric:
|
||||
rew = reward_metric(result["rews"])
|
||||
result.update(rews=rew, rew=rew.mean(), rew_std=rew.std())
|
||||
env_step += int(result["n/st"])
|
||||
t.update(result["n/st"])
|
||||
logger.log_train_data(result, env_step)
|
||||
last_rew = result['rew'] if result["n/ep"] > 0 else last_rew
|
||||
last_len = result['len'] if result["n/ep"] > 0 else last_len
|
||||
data = {
|
||||
"env_step": str(env_step),
|
||||
"rew": f"{last_rew:.2f}",
|
||||
"len": str(int(last_len)),
|
||||
"n/ep": str(int(result["n/ep"])),
|
||||
"n/st": str(int(result["n/st"])),
|
||||
}
|
||||
if result["n/ep"] > 0:
|
||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, epoch, episode_per_test, logger,
|
||||
env_step
|
||||
)
|
||||
if stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
logger.save_data(
|
||||
epoch, env_step, gradient_step, save_checkpoint_fn
|
||||
)
|
||||
t.set_postfix(**data)
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector,
|
||||
test_result["rew"], test_result["rew_std"]
|
||||
)
|
||||
else:
|
||||
policy.train()
|
||||
losses = policy.update(
|
||||
0,
|
||||
train_collector.buffer,
|
||||
batch_size=batch_size,
|
||||
repeat=repeat_per_collect
|
||||
)
|
||||
train_collector.reset_buffer(keep_statistics=True)
|
||||
step = max(
|
||||
[1] + [len(v) for v in losses.values() if isinstance(v, list)]
|
||||
)
|
||||
gradient_step += step
|
||||
for k in losses.keys():
|
||||
stat[k].add(losses[k])
|
||||
losses[k] = stat[k].get()
|
||||
data[k] = f"{losses[k]:.3f}"
|
||||
logger.log_update_data(losses, gradient_step)
|
||||
t.set_postfix(**data)
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
||||
# test
|
||||
if test_collector is not None:
|
||||
test_result = test_episode(
|
||||
policy, test_c, test_fn, epoch, episode_per_test, logger, env_step,
|
||||
reward_metric
|
||||
)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if best_epoch < 0 or best_reward < rew:
|
||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
|
||||
)
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
|
||||
if test_collector is None and save_fn:
|
||||
save_fn(policy)
|
||||
|
||||
if test_collector is None:
|
||||
return gather_info(start_time, train_collector, None, 0.0, 0.0)
|
||||
else:
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector, best_reward, best_reward_std
|
||||
def policy_update_fn(
|
||||
self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Perform one on-policy update."""
|
||||
assert self.train_collector is not None
|
||||
losses = self.policy.update(
|
||||
0,
|
||||
self.train_collector.buffer,
|
||||
batch_size=self.batch_size,
|
||||
repeat=self.repeat_per_collect,
|
||||
)
|
||||
self.train_collector.reset_buffer(keep_statistics=True)
|
||||
step = max([1] + [len(v) for v in losses.values() if isinstance(v, list)])
|
||||
self.gradient_step += step
|
||||
self.log_update_data(data, losses)
|
||||
|
||||
|
||||
def onpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore
|
||||
"""Wrapper for OnpolicyTrainer run method.
|
||||
|
||||
It is identical to ``OnpolicyTrainer(...).run()``.
|
||||
|
||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||
"""
|
||||
return OnpolicyTrainer(*args, **kwargs).run()
|
||||
|
||||
|
||||
onpolicy_trainer_iter = OnpolicyTrainer
|
||||
|
Loading…
x
Reference in New Issue
Block a user