2020-03-20 19:52:29 +08:00
|
|
|
import time
|
2020-03-27 09:04:29 +08:00
|
|
|
import numpy as np
|
2020-05-12 11:31:47 +08:00
|
|
|
from typing import Dict, List, Union, Callable
|
2020-03-20 19:52:29 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
from tianshou.data import Collector
|
|
|
|
from tianshou.policy import BasePolicy
|
2020-03-20 19:52:29 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
|
|
|
|
def test_episode(
|
|
|
|
policy: BasePolicy,
|
|
|
|
collector: Collector,
|
|
|
|
test_fn: Callable[[int], None],
|
|
|
|
epoch: int,
|
|
|
|
n_episode: Union[int, List[int]]) -> Dict[str, float]:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""A simple wrapper of testing policy in collector."""
|
2020-03-20 19:52:29 +08:00
|
|
|
collector.reset_env()
|
|
|
|
collector.reset_buffer()
|
|
|
|
policy.eval()
|
|
|
|
if test_fn:
|
|
|
|
test_fn(epoch)
|
2020-03-27 09:04:29 +08:00
|
|
|
if collector.get_env_num() > 1 and np.isscalar(n_episode):
|
|
|
|
n = collector.get_env_num()
|
|
|
|
n_ = np.zeros(n) + n_episode // n
|
|
|
|
n_[:n_episode % n] += 1
|
|
|
|
n_episode = list(n_)
|
2020-03-26 17:32:51 +08:00
|
|
|
return collector.collect(n_episode=n_episode)
|
2020-03-20 19:52:29 +08:00
|
|
|
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def gather_info(start_time: float,
|
|
|
|
train_c: Collector,
|
|
|
|
test_c: Collector,
|
|
|
|
best_reward: float
|
|
|
|
) -> Dict[str, Union[float, str]]:
|
2020-04-05 18:34:45 +08:00
|
|
|
"""A simple wrapper of gathering information from collectors.
|
2020-04-03 21:28:12 +08:00
|
|
|
|
2020-04-05 18:34:45 +08:00
|
|
|
:return: A dictionary with the following keys:
|
2020-04-03 21:28:12 +08:00
|
|
|
|
2020-04-05 18:34:45 +08:00
|
|
|
* ``train_step`` the total collected step of training collector;
|
|
|
|
* ``train_episode`` the total collected episode of training collector;
|
|
|
|
* ``train_time/collector`` the time for collecting frames in the \
|
2020-04-03 21:28:12 +08:00
|
|
|
training collector;
|
2020-04-05 18:34:45 +08:00
|
|
|
* ``train_time/model`` the time for training models;
|
|
|
|
* ``train_speed`` the speed of training (frames per second);
|
|
|
|
* ``test_step`` the total collected step of test collector;
|
|
|
|
* ``test_episode`` the total collected episode of test collector;
|
|
|
|
* ``test_time`` the time for testing;
|
|
|
|
* ``test_speed`` the speed of testing (frames per second);
|
|
|
|
* ``best_reward`` the best reward over the test results;
|
|
|
|
* ``duration`` the total elapsed time.
|
2020-04-03 21:28:12 +08:00
|
|
|
"""
|
2020-03-20 19:52:29 +08:00
|
|
|
duration = time.time() - start_time
|
|
|
|
model_time = duration - train_c.collect_time - test_c.collect_time
|
|
|
|
train_speed = train_c.collect_step / (duration - test_c.collect_time)
|
|
|
|
test_speed = test_c.collect_step / test_c.collect_time
|
|
|
|
return {
|
|
|
|
'train_step': train_c.collect_step,
|
|
|
|
'train_episode': train_c.collect_episode,
|
|
|
|
'train_time/collector': f'{train_c.collect_time:.2f}s',
|
|
|
|
'train_time/model': f'{model_time:.2f}s',
|
|
|
|
'train_speed': f'{train_speed:.2f} step/s',
|
|
|
|
'test_step': test_c.collect_step,
|
|
|
|
'test_episode': test_c.collect_episode,
|
|
|
|
'test_time': f'{test_c.collect_time:.2f}s',
|
|
|
|
'test_speed': f'{test_speed:.2f} step/s',
|
|
|
|
'best_reward': best_reward,
|
|
|
|
'duration': f'{duration:.2f}s',
|
|
|
|
}
|