High-level API: Fix number of test episodes being incorrectly scaled by number of envs (#1071)
This commit is contained in:
parent
6746a80f6d
commit
1714c7f2c7
@ -2,7 +2,7 @@ Contributors
|
|||||||
============
|
============
|
||||||
|
|
||||||
We always welcome contributions to help make Tianshou better!
|
We always welcome contributions to help make Tianshou better!
|
||||||
Tiashou was originally created by the `THU-ML Group <https://ml.cs.tsinghua.edu.cn>`_ at Tsinghua University.
|
Tianshou was originally created by the `THU-ML Group <https://ml.cs.tsinghua.edu.cn>`_ at Tsinghua University.
|
||||||
|
|
||||||
Today, it is backed by the `appliedAI Institute for Europe <https://www.appliedai-institute.de/en/>`_,
|
Today, it is backed by the `appliedAI Institute for Europe <https://www.appliedai-institute.de/en/>`_,
|
||||||
which is committed to making Tianshou the go-to resource for reinforcement learning research and development,
|
which is committed to making Tianshou the go-to resource for reinforcement learning research and development,
|
||||||
|
@ -253,3 +253,4 @@ Dominik
|
|||||||
Tsinghua
|
Tsinghua
|
||||||
Tianshou
|
Tianshou
|
||||||
appliedAI
|
appliedAI
|
||||||
|
Panchenko
|
||||||
|
@ -184,7 +184,7 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
|
|||||||
max_epoch=sampling_config.num_epochs,
|
max_epoch=sampling_config.num_epochs,
|
||||||
step_per_epoch=sampling_config.step_per_epoch,
|
step_per_epoch=sampling_config.step_per_epoch,
|
||||||
repeat_per_collect=sampling_config.repeat_per_collect,
|
repeat_per_collect=sampling_config.repeat_per_collect,
|
||||||
episode_per_test=sampling_config.num_test_episodes_per_test_env,
|
episode_per_test=sampling_config.num_test_episodes,
|
||||||
batch_size=sampling_config.batch_size,
|
batch_size=sampling_config.batch_size,
|
||||||
step_per_collect=sampling_config.step_per_collect,
|
step_per_collect=sampling_config.step_per_collect,
|
||||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||||
@ -228,7 +228,7 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
|
|||||||
max_epoch=sampling_config.num_epochs,
|
max_epoch=sampling_config.num_epochs,
|
||||||
step_per_epoch=sampling_config.step_per_epoch,
|
step_per_epoch=sampling_config.step_per_epoch,
|
||||||
step_per_collect=sampling_config.step_per_collect,
|
step_per_collect=sampling_config.step_per_collect,
|
||||||
episode_per_test=sampling_config.num_test_episodes_per_test_env,
|
episode_per_test=sampling_config.num_test_episodes,
|
||||||
batch_size=sampling_config.batch_size,
|
batch_size=sampling_config.batch_size,
|
||||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||||
logger=world.logger,
|
logger=world.logger,
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import math
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@ -9,7 +8,6 @@ from tianshou.utils.string import ToStringMixin
|
|||||||
class SamplingConfig(ToStringMixin):
|
class SamplingConfig(ToStringMixin):
|
||||||
"""Configuration of sampling, epochs, parallelization, buffers, collectors, and batching."""
|
"""Configuration of sampling, epochs, parallelization, buffers, collectors, and batching."""
|
||||||
|
|
||||||
# TODO: What are the most reasonable defaults?
|
|
||||||
num_epochs: int = 100
|
num_epochs: int = 100
|
||||||
"""
|
"""
|
||||||
the number of epochs to run training for. An epoch is the outermost iteration level and each
|
the number of epochs to run training for. An epoch is the outermost iteration level and each
|
||||||
@ -55,8 +53,6 @@ class SamplingConfig(ToStringMixin):
|
|||||||
|
|
||||||
num_test_episodes: int = 1
|
num_test_episodes: int = 1
|
||||||
"""the total number of episodes to collect in each test step (across all test environments).
|
"""the total number of episodes to collect in each test step (across all test environments).
|
||||||
This should be a multiple of the number of test environments; if it is not, the effective
|
|
||||||
number of episodes collected will be the nearest multiple (rounded up).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
buffer_size: int = 4096
|
buffer_size: int = 4096
|
||||||
@ -129,8 +125,3 @@ class SamplingConfig(ToStringMixin):
|
|||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.num_train_envs == -1:
|
if self.num_train_envs == -1:
|
||||||
self.num_train_envs = multiprocessing.cpu_count()
|
self.num_train_envs = multiprocessing.cpu_count()
|
||||||
|
|
||||||
@property
|
|
||||||
def num_test_episodes_per_test_env(self) -> int:
|
|
||||||
""":return: the number of episodes to collect per test environment in every test step"""
|
|
||||||
return math.ceil(self.num_test_episodes / self.num_test_envs)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user