High-level API: Fix number of test episodes being incorrectly scaled by number of envs (#1071)

This commit is contained in:
Dominik Jain 2024-03-07 17:57:11 +01:00 committed by GitHub
parent 6746a80f6d
commit 1714c7f2c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 4 additions and 12 deletions

View File

@ -2,7 +2,7 @@ Contributors
============
We always welcome contributions to help make Tianshou better!
Tiashou was originally created by the `THU-ML Group <https://ml.cs.tsinghua.edu.cn>`_ at Tsinghua University.
Tianshou was originally created by the `THU-ML Group <https://ml.cs.tsinghua.edu.cn>`_ at Tsinghua University.
Today, it is backed by the `appliedAI Institute for Europe <https://www.appliedai-institute.de/en/>`_,
which is committed to making Tianshou the go-to resource for reinforcement learning research and development,

View File

@ -253,3 +253,4 @@ Dominik
Tsinghua
Tianshou
appliedAI
Panchenko

View File

@ -184,7 +184,7 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_episodes_per_test_env,
episode_per_test=sampling_config.num_test_episodes,
batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect,
save_best_fn=policy_persistence.get_save_best_fn(world),
@ -228,7 +228,7 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
step_per_collect=sampling_config.step_per_collect,
episode_per_test=sampling_config.num_test_episodes_per_test_env,
episode_per_test=sampling_config.num_test_episodes,
batch_size=sampling_config.batch_size,
save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger,

View File

@ -1,4 +1,3 @@
import math
import multiprocessing
from dataclasses import dataclass
@ -9,7 +8,6 @@ from tianshou.utils.string import ToStringMixin
class SamplingConfig(ToStringMixin):
"""Configuration of sampling, epochs, parallelization, buffers, collectors, and batching."""
# TODO: What are the most reasonable defaults?
num_epochs: int = 100
"""
the number of epochs to run training for. An epoch is the outermost iteration level and each
@ -55,8 +53,6 @@ class SamplingConfig(ToStringMixin):
num_test_episodes: int = 1
"""the total number of episodes to collect in each test step (across all test environments).
This should be a multiple of the number of test environments; if it is not, the effective
number of episodes collected will be the nearest multiple (rounded up).
"""
buffer_size: int = 4096
@ -129,8 +125,3 @@ class SamplingConfig(ToStringMixin):
def __post_init__(self) -> None:
if self.num_train_envs == -1:
self.num_train_envs = multiprocessing.cpu_count()
@property
def num_test_episodes_per_test_env(self) -> int:
""":return: the number of episodes to collect per test environment in every test step"""
return math.ceil(self.num_test_episodes / self.num_test_envs)