Add VizDoom PPO example and results (#533)
* update vizdoom ppo example * update README with results
This commit is contained in:
parent
23fbc3b712
commit
97df511a13
@ -53,10 +53,6 @@ python3 replay.py maps/D4_battle2.cfg results/c51/d4.lmp
|
||||
|
||||
See [maps/README.md](maps/README.md)
|
||||
|
||||
## Algorithms
|
||||
|
||||
The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example.
|
||||
|
||||
## Reward
|
||||
|
||||
1. living reward is bad
|
||||
@ -64,3 +60,28 @@ The setting is exactly the same as Atari. You can definitely try more algorithms
|
||||
3. negative reward for health and ammo2 is really helpful for d3/d4
|
||||
4. only with positive reward for health is really helpful for d1
|
||||
5. remove MOVE_BACKWARD may converge faster but the final performance may be lower
|
||||
|
||||
## Algorithms
|
||||
|
||||
The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example.
|
||||
|
||||
### C51 (single run)
|
||||
|
||||
| task | best reward | reward curve | parameters |
|
||||
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
||||
| D2_navigation | 747.52 |  | `python3 vizdoom_c51.py --task "D2_navigation"` |
|
||||
| D3_battle | 1855.29 |  | `python3 vizdoom_c51.py --task "D3_battle"` |
|
||||
|
||||
### PPO (single run)
|
||||
|
||||
| task | best reward | reward curve | parameters |
|
||||
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
||||
| D2_navigation | 770.75 |  | `python3 vizdoom_ppo.py --task "D2_navigation"` |
|
||||
| D3_battle | 320.59 |  | `python3 vizdoom_ppo.py --task "D3_battle"` |
|
||||
|
||||
### PPO with ICM (single run)
|
||||
|
||||
| task | best reward | reward curve | parameters |
|
||||
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
||||
| D2_navigation | 844.99 |  | `python3 vizdoom_ppo.py --task "D2_navigation" --icm-lr-scale 10` |
|
||||
| D3_battle | 547.08 |  | `python3 vizdoom_ppo.py --task "D3_battle" --icm-lr-scale 10` |
|
||||
|
BIN
examples/vizdoom/results/c51/D2_navigation_rew.png
Normal file
BIN
examples/vizdoom/results/c51/D2_navigation_rew.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 209 KiB |
BIN
examples/vizdoom/results/c51/D3_battle_rew.png
Normal file
BIN
examples/vizdoom/results/c51/D3_battle_rew.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 156 KiB |
BIN
examples/vizdoom/results/ppo/D2_navigation_rew.png
Normal file
BIN
examples/vizdoom/results/ppo/D2_navigation_rew.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 159 KiB |
BIN
examples/vizdoom/results/ppo/D3_battle_rew.png
Normal file
BIN
examples/vizdoom/results/ppo/D3_battle_rew.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 164 KiB |
BIN
examples/vizdoom/results/ppo_icm/D2_navigation_rew.png
Normal file
BIN
examples/vizdoom/results/ppo_icm/D2_navigation_rew.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 157 KiB |
BIN
examples/vizdoom/results/ppo_icm/D3_battle_rew.png
Normal file
BIN
examples/vizdoom/results/ppo_icm/D3_battle_rew.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 159 KiB |
@ -6,11 +6,12 @@ import numpy as np
|
||||
import torch
|
||||
from env import Env
|
||||
from network import DQN
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.policy import A2CPolicy, ICMPolicy
|
||||
from tianshou.policy import ICMPolicy, PPOPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
@ -21,18 +22,28 @@ def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='D2_navigation')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=2000000)
|
||||
parser.add_argument('--lr', type=float, default=0.0001)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=0.00002)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--epoch', type=int, default=300)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||
parser.add_argument('--episode-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--update-per-step', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
|
||||
parser.add_argument('--step-per-collect', type=int, default=1000)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=4)
|
||||
parser.add_argument('--batch-size', type=int, default=256)
|
||||
parser.add_argument('--hidden-size', type=int, default=512)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--rew-norm', type=int, default=False)
|
||||
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||
parser.add_argument('--ent-coef', type=float, default=0.01)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--lr-decay', type=int, default=True)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--dual-clip', type=float, default=None)
|
||||
parser.add_argument('--value-clip', type=int, default=0)
|
||||
parser.add_argument('--norm-adv', type=int, default=1)
|
||||
parser.add_argument('--recompute-adv', type=int, default=0)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
@ -75,7 +86,7 @@ def get_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_a2c(args=get_args()):
|
||||
def test_ppo(args=get_args()):
|
||||
args.cfg_path = f"maps/{args.task}.cfg"
|
||||
args.wad_path = f"maps/{args.task}.wad"
|
||||
args.res = (args.skip_num, 84, 84)
|
||||
@ -105,33 +116,65 @@ def test_a2c(args=get_args()):
|
||||
test_envs.seed(args.seed)
|
||||
# define model
|
||||
net = DQN(
|
||||
*args.state_shape, args.action_shape, device=args.device, features_only=True
|
||||
*args.state_shape,
|
||||
args.action_shape,
|
||||
device=args.device,
|
||||
features_only=True,
|
||||
output_dim=args.hidden_size
|
||||
)
|
||||
actor = Actor(
|
||||
net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
|
||||
)
|
||||
critic = Critic(net, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
|
||||
critic = Critic(net, device=args.device)
|
||||
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
|
||||
|
||||
lr_scheduler = None
|
||||
if args.lr_decay:
|
||||
# decay learning rate to 0 linearly
|
||||
max_update_num = np.ceil(
|
||||
args.step_per_epoch / args.step_per_collect
|
||||
) * args.epoch
|
||||
|
||||
lr_scheduler = LambdaLR(
|
||||
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
|
||||
)
|
||||
|
||||
# define policy
|
||||
dist = torch.distributions.Categorical
|
||||
policy = A2CPolicy(actor, critic, optim, dist).to(args.device)
|
||||
def dist(p):
|
||||
return torch.distributions.Categorical(logits=p)
|
||||
|
||||
policy = PPOPolicy(
|
||||
actor,
|
||||
critic,
|
||||
optim,
|
||||
dist,
|
||||
discount_factor=args.gamma,
|
||||
gae_lambda=args.gae_lambda,
|
||||
max_grad_norm=args.max_grad_norm,
|
||||
vf_coef=args.vf_coef,
|
||||
ent_coef=args.ent_coef,
|
||||
reward_normalization=args.rew_norm,
|
||||
action_scaling=False,
|
||||
lr_scheduler=lr_scheduler,
|
||||
action_space=env.action_space,
|
||||
eps_clip=args.eps_clip,
|
||||
value_clip=args.value_clip,
|
||||
dual_clip=args.dual_clip,
|
||||
advantage_normalization=args.norm_adv,
|
||||
recompute_advantage=args.recompute_adv
|
||||
).to(args.device)
|
||||
if args.icm_lr_scale > 0:
|
||||
feature_net = DQN(
|
||||
*args.state_shape,
|
||||
args.action_shape,
|
||||
device=args.device,
|
||||
features_only=True
|
||||
features_only=True,
|
||||
output_dim=args.hidden_size
|
||||
)
|
||||
action_dim = np.prod(args.action_shape)
|
||||
feature_dim = feature_net.output_dim
|
||||
icm_net = IntrinsicCuriosityModule(
|
||||
feature_net.net,
|
||||
feature_dim,
|
||||
action_dim,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device
|
||||
feature_net.net, feature_dim, action_dim, device=args.device
|
||||
)
|
||||
icm_optim = torch.optim.adam(icm_net.parameters(), lr=args.lr)
|
||||
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
|
||||
policy = ICMPolicy(
|
||||
policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
|
||||
args.icm_forward_loss_weight
|
||||
@ -153,7 +196,8 @@ def test_a2c(args=get_args()):
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'a2c')
|
||||
log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo'
|
||||
log_path = os.path.join(args.logdir, args.task, log_name)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
@ -162,10 +206,15 @@ def test_a2c(args=get_args()):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return False
|
||||
if env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
|
||||
# watch agent's performance
|
||||
def watch():
|
||||
# watch agent's performance
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
@ -210,7 +259,7 @@ def test_a2c(args=get_args()):
|
||||
args.repeat_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
episode_per_collect=args.episode_per_collect,
|
||||
step_per_collect=args.step_per_collect,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
logger=logger,
|
||||
@ -222,4 +271,4 @@ def test_a2c(args=get_args()):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_a2c(get_args())
|
||||
test_ppo(get_args())
|
Loading…
x
Reference in New Issue
Block a user