2024-01-08 18:12:43 +01:00
|
|
|
from tianshou.highlevel.config import SamplingConfig
|
|
|
|
from tianshou.highlevel.env import (
|
2024-01-16 12:22:07 +01:00
|
|
|
EnvFactoryRegistered,
|
2024-01-08 18:12:43 +01:00
|
|
|
VectorEnvType,
|
|
|
|
)
|
|
|
|
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
|
|
|
|
from tianshou.highlevel.params.policy_params import DQNParams
|
|
|
|
from tianshou.highlevel.trainer import (
|
|
|
|
EpochStopCallbackRewardThreshold,
|
|
|
|
EpochTestCallbackDQNSetEps,
|
|
|
|
EpochTrainCallbackDQNSetEps,
|
|
|
|
)
|
|
|
|
from tianshou.utils.logging import run_main
|
|
|
|
|
|
|
|
|
2024-02-06 14:24:30 +01:00
|
|
|
def main() -> None:
|
2024-01-08 18:12:43 +01:00
|
|
|
experiment = (
|
|
|
|
DQNExperimentBuilder(
|
2024-04-21 01:25:33 +02:00
|
|
|
EnvFactoryRegistered(
|
|
|
|
task="CartPole-v1",
|
|
|
|
venv_type=VectorEnvType.DUMMY,
|
|
|
|
train_seed=0,
|
|
|
|
test_seed=10,
|
|
|
|
),
|
2024-01-08 18:12:43 +01:00
|
|
|
ExperimentConfig(
|
|
|
|
persistence_enabled=False,
|
|
|
|
watch=True,
|
|
|
|
watch_render=1 / 35,
|
|
|
|
watch_num_episodes=100,
|
|
|
|
),
|
|
|
|
SamplingConfig(
|
|
|
|
num_epochs=10,
|
|
|
|
step_per_epoch=10000,
|
|
|
|
batch_size=64,
|
|
|
|
num_train_envs=10,
|
|
|
|
num_test_envs=100,
|
|
|
|
buffer_size=20000,
|
|
|
|
step_per_collect=10,
|
|
|
|
update_per_step=1 / 10,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
.with_dqn_params(
|
|
|
|
DQNParams(
|
|
|
|
lr=1e-3,
|
|
|
|
discount_factor=0.9,
|
|
|
|
estimation_step=3,
|
|
|
|
target_update_freq=320,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
.with_model_factory_default(hidden_sizes=(64, 64))
|
|
|
|
.with_epoch_train_callback(EpochTrainCallbackDQNSetEps(0.3))
|
|
|
|
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(0.0))
|
|
|
|
.with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195))
|
|
|
|
.build()
|
|
|
|
)
|
|
|
|
experiment.run()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
run_main(main)
|