From d684dae6cde09486864778adf5ca0e3f50892707 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 26 Oct 2023 11:12:42 +0200 Subject: [PATCH] Change default number of environments (train=#CPUs, test=1) --- tianshou/highlevel/config.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index bca77e0..02094bc 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -1,3 +1,4 @@ +import multiprocessing from dataclasses import dataclass from tianshou.utils.string import ToStringMixin @@ -11,8 +12,9 @@ class SamplingConfig(ToStringMixin): num_epochs: int = 100 step_per_epoch: int = 30000 batch_size: int = 64 - num_train_envs: int = 64 - num_test_envs: int = 10 + num_train_envs: int = -1 + """the number of training environments to use. If set to -1, use number of CPUs/threads.""" + num_test_envs: int = 1 buffer_size: int = 4096 step_per_collect: int = 2048 repeat_per_collect: int | None = 10 @@ -27,3 +29,7 @@ class SamplingConfig(ToStringMixin): replay_buffer_ignore_obs_next: bool = False replay_buffer_save_only_last_obs: bool = False replay_buffer_stack_num: int = 1 + + def __post_init__(self) -> None: + if self.num_train_envs == -1: + self.num_train_envs = multiprocessing.cpu_count()