Closes: #1058 ### Api Extensions - Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063 - `Collector`s can now be closed, and their reset is more granular. #1063 - Trainers can control whether collectors should be reset prior to training. #1063 - Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 - Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 - Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 - Improved typing for `exploration_noise` and within Collector. #1063 ### Breaking Changes - Removed `.data` attribute from `Collector` and its child classes. #1063 - Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset` expicitly or pass `reset_before_collect=True` . #1063 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 --------- Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
103 lines
3.7 KiB
Python
103 lines
3.7 KiB
Python
import time
|
|
from collections.abc import Callable
|
|
from dataclasses import asdict
|
|
|
|
import numpy as np
|
|
|
|
from tianshou.data import (
|
|
Collector,
|
|
CollectStats,
|
|
InfoStats,
|
|
SequenceSummaryStats,
|
|
TimingStats,
|
|
)
|
|
from tianshou.policy import BasePolicy
|
|
from tianshou.utils import BaseLogger
|
|
|
|
|
|
def test_episode(
|
|
policy: BasePolicy,
|
|
collector: Collector,
|
|
test_fn: Callable[[int, int | None], None] | None,
|
|
epoch: int,
|
|
n_episode: int,
|
|
logger: BaseLogger | None = None,
|
|
global_step: int | None = None,
|
|
reward_metric: Callable[[np.ndarray], np.ndarray] | None = None,
|
|
) -> CollectStats:
|
|
"""A simple wrapper of testing policy in collector."""
|
|
collector.reset(reset_stats=False)
|
|
policy.eval()
|
|
if test_fn:
|
|
test_fn(epoch, global_step)
|
|
result = collector.collect(n_episode=n_episode)
|
|
if reward_metric: # TODO: move into collector
|
|
rew = reward_metric(result.returns)
|
|
result.returns = rew
|
|
result.returns_stat = SequenceSummaryStats.from_sequence(rew)
|
|
if logger and global_step is not None:
|
|
assert result.n_collected_episodes > 0
|
|
logger.log_test_data(asdict(result), global_step)
|
|
return result
|
|
|
|
|
|
def gather_info(
|
|
start_time: float,
|
|
policy_update_time: float,
|
|
gradient_step: int,
|
|
best_reward: float,
|
|
best_reward_std: float,
|
|
train_collector: Collector | None = None,
|
|
test_collector: Collector | None = None,
|
|
) -> InfoStats:
|
|
"""A simple wrapper of gathering information from collectors.
|
|
|
|
:return: A dataclass object with the following members (depending on available collectors):
|
|
|
|
* ``gradient_step`` the total number of gradient steps;
|
|
* ``best_reward`` the best reward over the test results;
|
|
* ``best_reward_std`` the standard deviation of best reward over the test results;
|
|
* ``train_step`` the total collected step of training collector;
|
|
* ``train_episode`` the total collected episode of training collector;
|
|
* ``test_step`` the total collected step of test collector;
|
|
* ``test_episode`` the total collected episode of test collector;
|
|
* ``timing`` the timing statistics, with the following members:
|
|
* ``total_time`` the total time elapsed;
|
|
* ``train_time`` the total time elapsed for learning training (collecting samples plus model update);
|
|
* ``train_time_collect`` the time for collecting transitions in the \
|
|
training collector;
|
|
* ``train_time_update`` the time for training models;
|
|
* ``test_time`` the time for testing;
|
|
* ``update_speed`` the speed of updating (env_step per second).
|
|
"""
|
|
duration = max(0.0, time.time() - start_time)
|
|
test_time = 0.0
|
|
update_speed = 0.0
|
|
train_time_collect = 0.0
|
|
if test_collector is not None:
|
|
test_time = test_collector.collect_time
|
|
|
|
if train_collector is not None:
|
|
train_time_collect = train_collector.collect_time
|
|
update_speed = train_collector.collect_step / (duration - test_time)
|
|
|
|
timing_stat = TimingStats(
|
|
total_time=duration,
|
|
train_time=duration - test_time,
|
|
train_time_collect=train_time_collect,
|
|
train_time_update=policy_update_time,
|
|
test_time=test_time,
|
|
update_speed=update_speed,
|
|
)
|
|
|
|
return InfoStats(
|
|
gradient_step=gradient_step,
|
|
best_reward=best_reward,
|
|
best_reward_std=best_reward_std,
|
|
train_step=train_collector.collect_step if train_collector is not None else 0,
|
|
train_episode=train_collector.collect_episode if train_collector is not None else 0,
|
|
test_step=test_collector.collect_step if test_collector is not None else 0,
|
|
test_episode=test_collector.collect_episode if test_collector is not None else 0,
|
|
timing=timing_stat,
|
|
)
|