diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index bca77e0..02094bc 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -1,3 +1,4 @@ +import multiprocessing from dataclasses import dataclass from tianshou.utils.string import ToStringMixin @@ -11,8 +12,9 @@ class SamplingConfig(ToStringMixin): num_epochs: int = 100 step_per_epoch: int = 30000 batch_size: int = 64 - num_train_envs: int = 64 - num_test_envs: int = 10 + num_train_envs: int = -1 + """the number of training environments to use. If set to -1, use number of CPUs/threads.""" + num_test_envs: int = 1 buffer_size: int = 4096 step_per_collect: int = 2048 repeat_per_collect: int | None = 10 @@ -27,3 +29,7 @@ class SamplingConfig(ToStringMixin): replay_buffer_ignore_obs_next: bool = False replay_buffer_save_only_last_obs: bool = False replay_buffer_stack_num: int = 1 + + def __post_init__(self) -> None: + if self.num_train_envs == -1: + self.num_train_envs = multiprocessing.cpu_count()