High-level API: Fix number of test episodes being incorrectly scaled by number of envs (#1071)
This commit is contained in:
parent
6746a80f6d
commit
1714c7f2c7
@ -2,7 +2,7 @@ Contributors
|
||||
============
|
||||
|
||||
We always welcome contributions to help make Tianshou better!
|
||||
Tiashou was originally created by the `THU-ML Group <https://ml.cs.tsinghua.edu.cn>`_ at Tsinghua University.
|
||||
Tianshou was originally created by the `THU-ML Group <https://ml.cs.tsinghua.edu.cn>`_ at Tsinghua University.
|
||||
|
||||
Today, it is backed by the `appliedAI Institute for Europe <https://www.appliedai-institute.de/en/>`_,
|
||||
which is committed to making Tianshou the go-to resource for reinforcement learning research and development,
|
||||
|
@ -253,3 +253,4 @@ Dominik
|
||||
Tsinghua
|
||||
Tianshou
|
||||
appliedAI
|
||||
Panchenko
|
||||
|
@ -184,7 +184,7 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
|
||||
max_epoch=sampling_config.num_epochs,
|
||||
step_per_epoch=sampling_config.step_per_epoch,
|
||||
repeat_per_collect=sampling_config.repeat_per_collect,
|
||||
episode_per_test=sampling_config.num_test_episodes_per_test_env,
|
||||
episode_per_test=sampling_config.num_test_episodes,
|
||||
batch_size=sampling_config.batch_size,
|
||||
step_per_collect=sampling_config.step_per_collect,
|
||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||
@ -228,7 +228,7 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
|
||||
max_epoch=sampling_config.num_epochs,
|
||||
step_per_epoch=sampling_config.step_per_epoch,
|
||||
step_per_collect=sampling_config.step_per_collect,
|
||||
episode_per_test=sampling_config.num_test_episodes_per_test_env,
|
||||
episode_per_test=sampling_config.num_test_episodes,
|
||||
batch_size=sampling_config.batch_size,
|
||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||
logger=world.logger,
|
||||
|
@ -1,4 +1,3 @@
|
||||
import math
|
||||
import multiprocessing
|
||||
from dataclasses import dataclass
|
||||
|
||||
@ -9,7 +8,6 @@ from tianshou.utils.string import ToStringMixin
|
||||
class SamplingConfig(ToStringMixin):
|
||||
"""Configuration of sampling, epochs, parallelization, buffers, collectors, and batching."""
|
||||
|
||||
# TODO: What are the most reasonable defaults?
|
||||
num_epochs: int = 100
|
||||
"""
|
||||
the number of epochs to run training for. An epoch is the outermost iteration level and each
|
||||
@ -55,8 +53,6 @@ class SamplingConfig(ToStringMixin):
|
||||
|
||||
num_test_episodes: int = 1
|
||||
"""the total number of episodes to collect in each test step (across all test environments).
|
||||
This should be a multiple of the number of test environments; if it is not, the effective
|
||||
number of episodes collected will be the nearest multiple (rounded up).
|
||||
"""
|
||||
|
||||
buffer_size: int = 4096
|
||||
@ -129,8 +125,3 @@ class SamplingConfig(ToStringMixin):
|
||||
def __post_init__(self) -> None:
|
||||
if self.num_train_envs == -1:
|
||||
self.num_train_envs = multiprocessing.cpu_count()
|
||||
|
||||
@property
|
||||
def num_test_episodes_per_test_env(self) -> int:
|
||||
""":return: the number of episodes to collect per test environment in every test step"""
|
||||
return math.ceil(self.num_test_episodes / self.num_test_envs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user