Add high-level discrete example (CartPole) for README
This commit is contained in:
parent
1e5ebc2a2d
commit
8d6df2b276
55
examples/discrete/discrete_dqn_hl.py
Normal file
55
examples/discrete/discrete_dqn_hl.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
|
from tianshou.highlevel.env import (
|
||||||
|
EnvFactoryGymnasium,
|
||||||
|
VectorEnvType,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
|
||||||
|
from tianshou.highlevel.params.policy_params import DQNParams
|
||||||
|
from tianshou.highlevel.trainer import (
|
||||||
|
EpochStopCallbackRewardThreshold,
|
||||||
|
EpochTestCallbackDQNSetEps,
|
||||||
|
EpochTrainCallbackDQNSetEps,
|
||||||
|
)
|
||||||
|
from tianshou.utils.logging import run_main
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
experiment = (
|
||||||
|
DQNExperimentBuilder(
|
||||||
|
EnvFactoryGymnasium(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY),
|
||||||
|
ExperimentConfig(
|
||||||
|
persistence_enabled=False,
|
||||||
|
watch=True,
|
||||||
|
watch_render=1 / 35,
|
||||||
|
watch_num_episodes=100,
|
||||||
|
),
|
||||||
|
SamplingConfig(
|
||||||
|
num_epochs=10,
|
||||||
|
step_per_epoch=10000,
|
||||||
|
batch_size=64,
|
||||||
|
num_train_envs=10,
|
||||||
|
num_test_envs=100,
|
||||||
|
buffer_size=20000,
|
||||||
|
step_per_collect=10,
|
||||||
|
update_per_step=1 / 10,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_dqn_params(
|
||||||
|
DQNParams(
|
||||||
|
lr=1e-3,
|
||||||
|
discount_factor=0.9,
|
||||||
|
estimation_step=3,
|
||||||
|
target_update_freq=320,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_model_factory_default(hidden_sizes=(64, 64))
|
||||||
|
.with_epoch_train_callback(EpochTrainCallbackDQNSetEps(0.3))
|
||||||
|
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(0.0))
|
||||||
|
.with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195))
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
experiment.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_main(main)
|
Loading…
x
Reference in New Issue
Block a user