This is the PR for QR-DQN algorithm: https://arxiv.org/abs/1710.10044 1. add QR-DQN policy in tianshou/policy/modelfree/qrdqn.py. 2. add QR-DQN net in examples/atari/atari_network.py. 3. add QR-DQN atari example in examples/atari/atari_qrdqn.py. 4. add QR-DQN statement in tianshou/policy/init.py. 5. add QR-DQN unit test in test/discrete/test_qrdqn.py. 6. add QR-DQN atari results in examples/atari/results/qrdqn/. 7. add compute_q_value in DQNPolicy and C51Policy for simplify forward function. 8. move `with torch.no_grad():` from `_target_q` to BasePolicy By running "python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch-size 64", get best_result': '19.8 ± 0.40', in epoch 8.
37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
from tianshou.policy.base import BasePolicy
|
|
from tianshou.policy.random import RandomPolicy
|
|
from tianshou.policy.imitation.base import ImitationPolicy
|
|
from tianshou.policy.modelfree.dqn import DQNPolicy
|
|
from tianshou.policy.modelfree.c51 import C51Policy
|
|
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
|
|
from tianshou.policy.modelfree.pg import PGPolicy
|
|
from tianshou.policy.modelfree.a2c import A2CPolicy
|
|
from tianshou.policy.modelfree.ddpg import DDPGPolicy
|
|
from tianshou.policy.modelfree.ppo import PPOPolicy
|
|
from tianshou.policy.modelfree.td3 import TD3Policy
|
|
from tianshou.policy.modelfree.sac import SACPolicy
|
|
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
|
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
|
from tianshou.policy.modelbase.psrl import PSRLPolicy
|
|
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
|
|
|
|
|
|
__all__ = [
|
|
"BasePolicy",
|
|
"RandomPolicy",
|
|
"ImitationPolicy",
|
|
"DQNPolicy",
|
|
"C51Policy",
|
|
"QRDQNPolicy",
|
|
"PGPolicy",
|
|
"A2CPolicy",
|
|
"DDPGPolicy",
|
|
"PPOPolicy",
|
|
"TD3Policy",
|
|
"SACPolicy",
|
|
"DiscreteSACPolicy",
|
|
"DiscreteBCQPolicy",
|
|
"PSRLPolicy",
|
|
"MultiAgentPolicyManager",
|
|
]
|