n+e 94bfb32cc1
optimize training procedure and improve code coverage (#189)
1. add policy.eval() in all test scripts' "watch performance"
2. remove dict return support for collector preprocess_fn
3. add `__contains__` and `pop` in batch: `key in batch`, `batch.pop(key, deft)`
4. exact n_episode for a list of n_episode limitation and save fake data in cache_buffer when self.buffer is None (#184)
5. fix tensorboard logging: h-axis stands for env step instead of gradient step; add test results into tensorboard
6. add test_returns (both GAE and nstep)
7. change the type-checking order in batch.py and converter.py in order to meet the most often case first
8. fix shape inconsistency for torch.Tensor in replay buffer
9. remove `**kwargs` in ReplayBuffer
10. remove default value in batch.split() and add merge_last argument (#185)
11. improve nstep efficiency
12. add max_batchsize in onpolicy algorithms
13. potential bugfix for subproc.wait
14. fix RecurrentActorProb
15. improve the code-coverage (from 90% to 95%) and remove the dead code
16. fix some incorrect type annotation

The above improvement also increases the training FPS: on my computer, the previous version is only ~1800 FPS and after that, it can reach ~2050 (faster than v0.2.4.post1).
2020-08-27 12:15:18 +08:00

75 lines
2.9 KiB
Python

import time
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, List, Union, Callable, Optional
from tianshou.data import Collector
from tianshou.policy import BasePolicy
def test_episode(
policy: BasePolicy,
collector: Collector,
test_fn: Optional[Callable[[int], None]],
epoch: int,
n_episode: Union[int, List[int]],
writer: SummaryWriter = None,
global_step: int = None) -> Dict[str, float]:
"""A simple wrapper of testing policy in collector."""
collector.reset_env()
collector.reset_buffer()
policy.eval()
if test_fn:
test_fn(epoch)
if collector.get_env_num() > 1 and np.isscalar(n_episode):
n = collector.get_env_num()
n_ = np.zeros(n) + n_episode // n
n_[:n_episode % n] += 1
n_episode = list(n_)
result = collector.collect(n_episode=n_episode)
if writer is not None and global_step is not None:
for k in result.keys():
writer.add_scalar('test/' + k, result[k], global_step=global_step)
return result
def gather_info(start_time: float,
train_c: Collector,
test_c: Collector,
best_reward: float
) -> Dict[str, Union[float, str]]:
"""A simple wrapper of gathering information from collectors.
:return: A dictionary with the following keys:
* ``train_step`` the total collected step of training collector;
* ``train_episode`` the total collected episode of training collector;
* ``train_time/collector`` the time for collecting frames in the \
training collector;
* ``train_time/model`` the time for training models;
* ``train_speed`` the speed of training (frames per second);
* ``test_step`` the total collected step of test collector;
* ``test_episode`` the total collected episode of test collector;
* ``test_time`` the time for testing;
* ``test_speed`` the speed of testing (frames per second);
* ``best_reward`` the best reward over the test results;
* ``duration`` the total elapsed time.
"""
duration = time.time() - start_time
model_time = duration - train_c.collect_time - test_c.collect_time
train_speed = train_c.collect_step / (duration - test_c.collect_time)
test_speed = test_c.collect_step / test_c.collect_time
return {
'train_step': train_c.collect_step,
'train_episode': train_c.collect_episode,
'train_time/collector': f'{train_c.collect_time:.2f}s',
'train_time/model': f'{model_time:.2f}s',
'train_speed': f'{train_speed:.2f} step/s',
'test_step': test_c.collect_step,
'test_episode': test_c.collect_episode,
'test_time': f'{test_c.collect_time:.2f}s',
'test_speed': f'{test_speed:.2f} step/s',
'best_reward': best_reward,
'duration': f'{duration:.2f}s',
}