High-level API: Fix number of test episodes being incorrectly scaled by number of envs (#1071)

This commit is contained in:
Dominik Jain 2024-03-07 17:57:11 +01:00 committed by GitHub
parent 6746a80f6d
commit 1714c7f2c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 4 additions and 12 deletions

View File

@ -2,7 +2,7 @@ Contributors
============ ============
We always welcome contributions to help make Tianshou better! We always welcome contributions to help make Tianshou better!
Tiashou was originally created by the `THU-ML Group <https://ml.cs.tsinghua.edu.cn>`_ at Tsinghua University. Tianshou was originally created by the `THU-ML Group <https://ml.cs.tsinghua.edu.cn>`_ at Tsinghua University.
Today, it is backed by the `appliedAI Institute for Europe <https://www.appliedai-institute.de/en/>`_, Today, it is backed by the `appliedAI Institute for Europe <https://www.appliedai-institute.de/en/>`_,
which is committed to making Tianshou the go-to resource for reinforcement learning research and development, which is committed to making Tianshou the go-to resource for reinforcement learning research and development,

View File

@ -253,3 +253,4 @@ Dominik
Tsinghua Tsinghua
Tianshou Tianshou
appliedAI appliedAI
Panchenko

View File

@ -184,7 +184,7 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
max_epoch=sampling_config.num_epochs, max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch, step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect, repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_episodes_per_test_env, episode_per_test=sampling_config.num_test_episodes,
batch_size=sampling_config.batch_size, batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect, step_per_collect=sampling_config.step_per_collect,
save_best_fn=policy_persistence.get_save_best_fn(world), save_best_fn=policy_persistence.get_save_best_fn(world),
@ -228,7 +228,7 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
max_epoch=sampling_config.num_epochs, max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch, step_per_epoch=sampling_config.step_per_epoch,
step_per_collect=sampling_config.step_per_collect, step_per_collect=sampling_config.step_per_collect,
episode_per_test=sampling_config.num_test_episodes_per_test_env, episode_per_test=sampling_config.num_test_episodes,
batch_size=sampling_config.batch_size, batch_size=sampling_config.batch_size,
save_best_fn=policy_persistence.get_save_best_fn(world), save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger, logger=world.logger,

View File

@ -1,4 +1,3 @@
import math
import multiprocessing import multiprocessing
from dataclasses import dataclass from dataclasses import dataclass
@ -9,7 +8,6 @@ from tianshou.utils.string import ToStringMixin
class SamplingConfig(ToStringMixin): class SamplingConfig(ToStringMixin):
"""Configuration of sampling, epochs, parallelization, buffers, collectors, and batching.""" """Configuration of sampling, epochs, parallelization, buffers, collectors, and batching."""
# TODO: What are the most reasonable defaults?
num_epochs: int = 100 num_epochs: int = 100
""" """
the number of epochs to run training for. An epoch is the outermost iteration level and each the number of epochs to run training for. An epoch is the outermost iteration level and each
@ -55,8 +53,6 @@ class SamplingConfig(ToStringMixin):
num_test_episodes: int = 1 num_test_episodes: int = 1
"""the total number of episodes to collect in each test step (across all test environments). """the total number of episodes to collect in each test step (across all test environments).
This should be a multiple of the number of test environments; if it is not, the effective
number of episodes collected will be the nearest multiple (rounded up).
""" """
buffer_size: int = 4096 buffer_size: int = 4096
@ -129,8 +125,3 @@ class SamplingConfig(ToStringMixin):
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.num_train_envs == -1: if self.num_train_envs == -1:
self.num_train_envs = multiprocessing.cpu_count() self.num_train_envs = multiprocessing.cpu_count()
@property
def num_test_episodes_per_test_env(self) -> int:
""":return: the number of episodes to collect per test environment in every test step"""
return math.ceil(self.num_test_episodes / self.num_test_envs)