From 8d6df2b27644f6c8a4922262b9d78166d5bf614d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 8 Jan 2024 18:12:43 +0100 Subject: [PATCH] Add high-level discrete example (CartPole) for README --- examples/discrete/discrete_dqn_hl.py | 55 ++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 examples/discrete/discrete_dqn_hl.py diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py new file mode 100644 index 0000000..c6092f2 --- /dev/null +++ b/examples/discrete/discrete_dqn_hl.py @@ -0,0 +1,55 @@ +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.env import ( + EnvFactoryGymnasium, + VectorEnvType, +) +from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig +from tianshou.highlevel.params.policy_params import DQNParams +from tianshou.highlevel.trainer import ( + EpochStopCallbackRewardThreshold, + EpochTestCallbackDQNSetEps, + EpochTrainCallbackDQNSetEps, +) +from tianshou.utils.logging import run_main + + +def main(): + experiment = ( + DQNExperimentBuilder( + EnvFactoryGymnasium(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY), + ExperimentConfig( + persistence_enabled=False, + watch=True, + watch_render=1 / 35, + watch_num_episodes=100, + ), + SamplingConfig( + num_epochs=10, + step_per_epoch=10000, + batch_size=64, + num_train_envs=10, + num_test_envs=100, + buffer_size=20000, + step_per_collect=10, + update_per_step=1 / 10, + ), + ) + .with_dqn_params( + DQNParams( + lr=1e-3, + discount_factor=0.9, + estimation_step=3, + target_update_freq=320, + ), + ) + .with_model_factory_default(hidden_sizes=(64, 64)) + .with_epoch_train_callback(EpochTrainCallbackDQNSetEps(0.3)) + .with_epoch_test_callback(EpochTestCallbackDQNSetEps(0.0)) + .with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195)) + .build() + ) + experiment.run() + + +if __name__ == "__main__": + run_main(main)