Dominik Jain 1cba589bd4 Add DQN support in high-level API
* Allow to specify trainer callbacks (train_fn, test_fn, stop_fn)
  in high-level API, adding the necessary abstractions and pass-on
  mechanisms
* Add example atari_dqn_hl
2023-10-18 20:44:16 +02:00

28 lines
933 B
Python

from dataclasses import dataclass
@dataclass
class RLSamplingConfig:
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
# TODO: What are reasonable defaults?
num_epochs: int = 100
step_per_epoch: int = 30000
batch_size: int = 64
num_train_envs: int = 64
num_test_envs: int = 10
buffer_size: int = 4096
step_per_collect: int = 2048
repeat_per_collect: int | None = 10
update_per_step: float = 1.0
"""
Only used in off-policy algorithms.
How many gradient steps to perform per step in the environment (i.e., per sample added to the buffer).
"""
start_timesteps: int = 0
start_timesteps_random: bool = False
# TODO can we set the parameters below intelligently? Perhaps based on env. representation?
replay_buffer_ignore_obs_next: bool = False
replay_buffer_save_only_last_obs: bool = False
replay_buffer_stack_num: int = 1