Add VizDoom PPO example and results (#533)

* update vizdoom ppo example

* update README with results
This commit is contained in:
Yi Su 2022-02-24 17:33:34 -08:00 committed by GitHub
parent 23fbc3b712
commit 97df511a13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 102 additions and 32 deletions

View File

@ -53,10 +53,6 @@ python3 replay.py maps/D4_battle2.cfg results/c51/d4.lmp
See [maps/README.md](maps/README.md) See [maps/README.md](maps/README.md)
## Algorithms
The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example.
## Reward ## Reward
1. living reward is bad 1. living reward is bad
@ -64,3 +60,28 @@ The setting is exactly the same as Atari. You can definitely try more algorithms
3. negative reward for health and ammo2 is really helpful for d3/d4 3. negative reward for health and ammo2 is really helpful for d3/d4
4. only with positive reward for health is really helpful for d1 4. only with positive reward for health is really helpful for d1
5. remove MOVE_BACKWARD may converge faster but the final performance may be lower 5. remove MOVE_BACKWARD may converge faster but the final performance may be lower
## Algorithms
The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example.
### C51 (single run)
| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| D2_navigation | 747.52 | ![](results/c51/D2_navigation_rew.png) | `python3 vizdoom_c51.py --task "D2_navigation"` |
| D3_battle | 1855.29 | ![](results/c51/D3_battle_rew.png) | `python3 vizdoom_c51.py --task "D3_battle"` |
### PPO (single run)
| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| D2_navigation | 770.75 | ![](results/ppo/D2_navigation_rew.png) | `python3 vizdoom_ppo.py --task "D2_navigation"` |
| D3_battle | 320.59 | ![](results/ppo/D3_battle_rew.png) | `python3 vizdoom_ppo.py --task "D3_battle"` |
### PPO with ICM (single run)
| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| D2_navigation | 844.99 | ![](results/ppo_icm/D2_navigation_rew.png) | `python3 vizdoom_ppo.py --task "D2_navigation" --icm-lr-scale 10` |
| D3_battle | 547.08 | ![](results/ppo_icm/D3_battle_rew.png) | `python3 vizdoom_ppo.py --task "D3_battle" --icm-lr-scale 10` |

Binary file not shown.

After

Width:  |  Height:  |  Size: 209 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 159 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 164 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 157 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 159 KiB

View File

@ -6,11 +6,12 @@ import numpy as np
import torch import torch
from env import Env from env import Env
from network import DQN from network import DQN
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv from tianshou.env import ShmemVectorEnv
from tianshou.policy import A2CPolicy, ICMPolicy from tianshou.policy import ICMPolicy, PPOPolicy
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.common import ActorCritic
@ -21,18 +22,28 @@ def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='D2_navigation') parser.add_argument('--task', type=str, default='D2_navigation')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=2000000) parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument('--lr', type=float, default=0.00002)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=300) parser.add_argument('--epoch', type=int, default=300)
parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--episode-per-collect', type=int, default=10) parser.add_argument('--step-per-collect', type=int, default=1000)
parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--repeat-per-collect', type=int, default=4)
parser.add_argument('--update-per-step', type=int, default=1) parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-size', type=int, default=512)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--rew-norm', type=int, default=False)
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.01)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=int, default=0)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
parser.add_argument( parser.add_argument(
@ -75,7 +86,7 @@ def get_args():
return parser.parse_args() return parser.parse_args()
def test_a2c(args=get_args()): def test_ppo(args=get_args()):
args.cfg_path = f"maps/{args.task}.cfg" args.cfg_path = f"maps/{args.task}.cfg"
args.wad_path = f"maps/{args.task}.wad" args.wad_path = f"maps/{args.task}.wad"
args.res = (args.skip_num, 84, 84) args.res = (args.skip_num, 84, 84)
@ -105,33 +116,65 @@ def test_a2c(args=get_args()):
test_envs.seed(args.seed) test_envs.seed(args.seed)
# define model # define model
net = DQN( net = DQN(
*args.state_shape, args.action_shape, device=args.device, features_only=True *args.state_shape,
args.action_shape,
device=args.device,
features_only=True,
output_dim=args.hidden_size
) )
actor = Actor( actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device critic = Critic(net, device=args.device)
)
critic = Critic(net, hidden_sizes=args.hidden_sizes, device=args.device)
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
args.step_per_epoch / args.step_per_collect
) * args.epoch
lr_scheduler = LambdaLR(
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
)
# define policy # define policy
dist = torch.distributions.Categorical def dist(p):
policy = A2CPolicy(actor, critic, optim, dist).to(args.device) return torch.distributions.Categorical(logits=p)
policy = PPOPolicy(
actor,
critic,
optim,
dist,
discount_factor=args.gamma,
gae_lambda=args.gae_lambda,
max_grad_norm=args.max_grad_norm,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
reward_normalization=args.rew_norm,
action_scaling=False,
lr_scheduler=lr_scheduler,
action_space=env.action_space,
eps_clip=args.eps_clip,
value_clip=args.value_clip,
dual_clip=args.dual_clip,
advantage_normalization=args.norm_adv,
recompute_advantage=args.recompute_adv
).to(args.device)
if args.icm_lr_scale > 0: if args.icm_lr_scale > 0:
feature_net = DQN( feature_net = DQN(
*args.state_shape, *args.state_shape,
args.action_shape, args.action_shape,
device=args.device, device=args.device,
features_only=True features_only=True,
output_dim=args.hidden_size
) )
action_dim = np.prod(args.action_shape) action_dim = np.prod(args.action_shape)
feature_dim = feature_net.output_dim feature_dim = feature_net.output_dim
icm_net = IntrinsicCuriosityModule( icm_net = IntrinsicCuriosityModule(
feature_net.net, feature_net.net, feature_dim, action_dim, device=args.device
feature_dim,
action_dim,
hidden_sizes=args.hidden_sizes,
device=args.device
) )
icm_optim = torch.optim.adam(icm_net.parameters(), lr=args.lr) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
policy = ICMPolicy( policy = ICMPolicy(
policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale, policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
args.icm_forward_loss_weight args.icm_forward_loss_weight
@ -153,7 +196,8 @@ def test_a2c(args=get_args()):
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# log # log
log_path = os.path.join(args.logdir, args.task, 'a2c') log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo'
log_path = os.path.join(args.logdir, args.task, log_name)
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer.add_text("args", str(args)) writer.add_text("args", str(args))
logger = TensorboardLogger(writer) logger = TensorboardLogger(writer)
@ -162,10 +206,15 @@ def test_a2c(args=get_args()):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
return False if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False
# watch agent's performance
def watch(): def watch():
# watch agent's performance
print("Setup test envs ...") print("Setup test envs ...")
policy.eval() policy.eval()
test_envs.seed(args.seed) test_envs.seed(args.seed)
@ -210,7 +259,7 @@ def test_a2c(args=get_args()):
args.repeat_per_collect, args.repeat_per_collect,
args.test_num, args.test_num,
args.batch_size, args.batch_size,
episode_per_collect=args.episode_per_collect, step_per_collect=args.step_per_collect,
stop_fn=stop_fn, stop_fn=stop_fn,
save_fn=save_fn, save_fn=save_fn,
logger=logger, logger=logger,
@ -222,4 +271,4 @@ def test_a2c(args=get_args()):
if __name__ == '__main__': if __name__ == '__main__':
test_a2c(get_args()) test_ppo(get_args())