2020-03-20 19:52:29 +08:00
|
|
|
import time
|
2021-09-03 05:05:04 +08:00
|
|
|
from typing import Any, Callable, Dict, Optional, Union
|
|
|
|
|
2020-03-27 09:04:29 +08:00
|
|
|
import numpy as np
|
2020-03-20 19:52:29 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
from tianshou.data import Collector
|
|
|
|
from tianshou.policy import BasePolicy
|
2021-02-24 14:48:42 +08:00
|
|
|
from tianshou.utils import BaseLogger
|
2020-03-20 19:52:29 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
|
|
|
|
def test_episode(
|
2020-09-12 15:39:01 +08:00
|
|
|
policy: BasePolicy,
|
|
|
|
collector: Collector,
|
2020-09-26 16:35:37 +08:00
|
|
|
test_fn: Optional[Callable[[int, Optional[int]], None]],
|
2020-09-12 15:39:01 +08:00
|
|
|
epoch: int,
|
2021-02-19 10:33:49 +08:00
|
|
|
n_episode: int,
|
2021-02-24 14:48:42 +08:00
|
|
|
logger: Optional[BaseLogger] = None,
|
2020-09-12 15:39:01 +08:00
|
|
|
global_step: Optional[int] = None,
|
2021-02-19 10:33:49 +08:00
|
|
|
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
|
|
|
) -> Dict[str, Any]:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""A simple wrapper of testing policy in collector."""
|
2020-03-20 19:52:29 +08:00
|
|
|
collector.reset_env()
|
|
|
|
collector.reset_buffer()
|
|
|
|
policy.eval()
|
|
|
|
if test_fn:
|
2020-09-26 16:35:37 +08:00
|
|
|
test_fn(epoch, global_step)
|
2020-08-27 12:15:18 +08:00
|
|
|
result = collector.collect(n_episode=n_episode)
|
2021-02-19 10:33:49 +08:00
|
|
|
if reward_metric:
|
|
|
|
result["rews"] = reward_metric(result["rews"])
|
2021-02-24 14:48:42 +08:00
|
|
|
if logger and global_step is not None:
|
|
|
|
logger.log_test_data(result, global_step)
|
2020-08-27 12:15:18 +08:00
|
|
|
return result
|
2020-03-20 19:52:29 +08:00
|
|
|
|
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def gather_info(
|
|
|
|
start_time: float,
|
2021-01-20 02:13:04 -08:00
|
|
|
train_c: Optional[Collector],
|
2020-09-12 15:39:01 +08:00
|
|
|
test_c: Collector,
|
|
|
|
best_reward: float,
|
2020-10-04 21:55:43 +08:00
|
|
|
best_reward_std: float,
|
2020-09-12 15:39:01 +08:00
|
|
|
) -> Dict[str, Union[float, str]]:
|
2020-04-05 18:34:45 +08:00
|
|
|
"""A simple wrapper of gathering information from collectors.
|
2020-04-03 21:28:12 +08:00
|
|
|
|
2020-04-05 18:34:45 +08:00
|
|
|
:return: A dictionary with the following keys:
|
2020-04-03 21:28:12 +08:00
|
|
|
|
2020-04-05 18:34:45 +08:00
|
|
|
* ``train_step`` the total collected step of training collector;
|
|
|
|
* ``train_episode`` the total collected episode of training collector;
|
2021-02-21 13:06:02 +08:00
|
|
|
* ``train_time/collector`` the time for collecting transitions in the \
|
2020-04-03 21:28:12 +08:00
|
|
|
training collector;
|
2020-04-05 18:34:45 +08:00
|
|
|
* ``train_time/model`` the time for training models;
|
2021-02-21 13:06:02 +08:00
|
|
|
* ``train_speed`` the speed of training (env_step per second);
|
2020-04-05 18:34:45 +08:00
|
|
|
* ``test_step`` the total collected step of test collector;
|
|
|
|
* ``test_episode`` the total collected episode of test collector;
|
|
|
|
* ``test_time`` the time for testing;
|
2021-02-21 13:06:02 +08:00
|
|
|
* ``test_speed`` the speed of testing (env_step per second);
|
2020-04-05 18:34:45 +08:00
|
|
|
* ``best_reward`` the best reward over the test results;
|
|
|
|
* ``duration`` the total elapsed time.
|
2020-04-03 21:28:12 +08:00
|
|
|
"""
|
2020-03-20 19:52:29 +08:00
|
|
|
duration = time.time() - start_time
|
2021-01-20 02:13:04 -08:00
|
|
|
model_time = duration - test_c.collect_time
|
2020-03-20 19:52:29 +08:00
|
|
|
test_speed = test_c.collect_step / test_c.collect_time
|
2021-01-20 02:13:04 -08:00
|
|
|
result: Dict[str, Union[float, str]] = {
|
2020-09-12 15:39:01 +08:00
|
|
|
"test_step": test_c.collect_step,
|
|
|
|
"test_episode": test_c.collect_episode,
|
|
|
|
"test_time": f"{test_c.collect_time:.2f}s",
|
|
|
|
"test_speed": f"{test_speed:.2f} step/s",
|
|
|
|
"best_reward": best_reward,
|
2020-10-04 21:55:43 +08:00
|
|
|
"best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}",
|
2020-09-12 15:39:01 +08:00
|
|
|
"duration": f"{duration:.2f}s",
|
2021-01-20 02:13:04 -08:00
|
|
|
"train_time/model": f"{model_time:.2f}s",
|
2020-03-20 19:52:29 +08:00
|
|
|
}
|
2021-01-20 02:13:04 -08:00
|
|
|
if train_c is not None:
|
|
|
|
model_time -= train_c.collect_time
|
|
|
|
train_speed = train_c.collect_step / (duration - test_c.collect_time)
|
2021-09-03 05:05:04 +08:00
|
|
|
result.update(
|
|
|
|
{
|
|
|
|
"train_step": train_c.collect_step,
|
|
|
|
"train_episode": train_c.collect_episode,
|
|
|
|
"train_time/collector": f"{train_c.collect_time:.2f}s",
|
|
|
|
"train_time/model": f"{model_time:.2f}s",
|
|
|
|
"train_speed": f"{train_speed:.2f} step/s",
|
|
|
|
}
|
|
|
|
)
|
2021-01-20 02:13:04 -08:00
|
|
|
return result
|