Add Atari SAC examples (#657)
- Add Atari (discrete) SAC examples; - Fix a bug in Discrete SAC evaluation; default to deterministic mode.
@ -121,3 +121,17 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
| MsPacmanNoFrameskip-v4 | 1930 |  | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` |
|
||||
| SeaquestNoFrameskip-v4 | 904 |  | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 2.5e-5` |
|
||||
| SpaceInvadersNoFrameskip-v4 | 843 |  | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||
|
||||
# SAC (single run)
|
||||
|
||||
One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
|
||||
| task | best reward | reward curve | parameters |
|
||||
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.1 |  | `python3 atari_sac.py --task "PongNoFrameskip-v4"` |
|
||||
| BreakoutNoFrameskip-v4 | 211.2 |  | `python3 atari_sac.py --task "BreakoutNoFrameskip-v4" --n-step 1 --actor-lr 1e-4 --critic-lr 1e-4` |
|
||||
| EnduroNoFrameskip-v4 | 1290.7 |  | `python3 atari_sac.py --task "EnduroNoFrameskip-v4"` |
|
||||
| QbertNoFrameskip-v4 | 13157.5 |  | `python3 atari_sac.py --task "QbertNoFrameskip-v4"` |
|
||||
| MsPacmanNoFrameskip-v4 | 3836 |  | `python3 atari_sac.py --task "MsPacmanNoFrameskip-v4"` |
|
||||
| SeaquestNoFrameskip-v4 | 1772 |  | `python3 atari_sac.py --task "SeaquestNoFrameskip-v4"` |
|
||||
| SpaceInvadersNoFrameskip-v4 | 649 |  | `python3 atari_sac.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||
|
@ -162,7 +162,7 @@ def test_ppo(args=get_args()):
|
||||
feature_net.net,
|
||||
feature_dim,
|
||||
action_dim,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
hidden_sizes=[args.hidden_size],
|
||||
device=args.device,
|
||||
)
|
||||
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
|
||||
|
269
examples/atari/atari_sac.py
Normal file
@ -0,0 +1,269 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from atari_network import DQN
|
||||
from atari_wrapper import make_atari_env
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.policy import DiscreteSACPolicy, ICMPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
||||
parser.add_argument("--seed", type=int, default=4213)
|
||||
parser.add_argument("--scale-obs", type=int, default=0)
|
||||
parser.add_argument("--buffer-size", type=int, default=100000)
|
||||
parser.add_argument("--actor-lr", type=float, default=1e-5)
|
||||
parser.add_argument("--critic-lr", type=float, default=1e-5)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--n-step", type=int, default=3)
|
||||
parser.add_argument("--tau", type=float, default=0.005)
|
||||
parser.add_argument("--alpha", type=float, default=0.05)
|
||||
parser.add_argument("--auto-alpha", action="store_true", default=False)
|
||||
parser.add_argument("--alpha-lr", type=float, default=3e-4)
|
||||
parser.add_argument("--epoch", type=int, default=100)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=100000)
|
||||
parser.add_argument("--step-per-collect", type=int, default=10)
|
||||
parser.add_argument("--update-per-step", type=float, default=0.1)
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
parser.add_argument("--hidden-size", type=int, default=512)
|
||||
parser.add_argument("--training-num", type=int, default=10)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument("--rew-norm", type=int, default=False)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
parser.add_argument("--frames-stack", type=int, default=4)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--resume-id", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--logger",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=["tensorboard", "wandb"],
|
||||
)
|
||||
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
|
||||
parser.add_argument(
|
||||
"--watch",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="watch the play of pre-trained policy only"
|
||||
)
|
||||
parser.add_argument("--save-buffer-name", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--icm-lr-scale",
|
||||
type=float,
|
||||
default=0.,
|
||||
help="use intrinsic curiosity module with this lr scale"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--icm-reward-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="scaling factor for intrinsic curiosity reward"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--icm-forward-loss-weight",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="weight for the forward model loss in ICM"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_discrete_sac(args=get_args()):
|
||||
env, train_envs, test_envs = make_atari_env(
|
||||
args.task,
|
||||
args.seed,
|
||||
args.training_num,
|
||||
args.test_num,
|
||||
scale=args.scale_obs,
|
||||
frame_stack=args.frames_stack,
|
||||
)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# should be N_FRAMES x H x W
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
# define model
|
||||
net = DQN(
|
||||
*args.state_shape,
|
||||
args.action_shape,
|
||||
device=args.device,
|
||||
features_only=True,
|
||||
output_dim=args.hidden_size
|
||||
)
|
||||
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
critic1 = Critic(net, last_size=args.action_shape, device=args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
critic2 = Critic(net, last_size=args.action_shape, device=args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
# define policy
|
||||
if args.auto_alpha:
|
||||
target_entropy = 0.98 * np.log(np.prod(args.action_shape))
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||
args.alpha = (target_entropy, log_alpha, alpha_optim)
|
||||
|
||||
policy = DiscreteSACPolicy(
|
||||
actor,
|
||||
actor_optim,
|
||||
critic1,
|
||||
critic1_optim,
|
||||
critic2,
|
||||
critic2_optim,
|
||||
args.tau,
|
||||
args.gamma,
|
||||
args.alpha,
|
||||
estimation_step=args.n_step,
|
||||
reward_normalization=args.rew_norm,
|
||||
).to(args.device)
|
||||
if args.icm_lr_scale > 0:
|
||||
feature_net = DQN(
|
||||
*args.state_shape, args.action_shape, args.device, features_only=True
|
||||
)
|
||||
action_dim = np.prod(args.action_shape)
|
||||
feature_dim = feature_net.output_dim
|
||||
icm_net = IntrinsicCuriosityModule(
|
||||
feature_net.net,
|
||||
feature_dim,
|
||||
action_dim,
|
||||
hidden_sizes=[args.hidden_size],
|
||||
device=args.device,
|
||||
)
|
||||
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.actor_lr)
|
||||
policy = ICMPolicy(
|
||||
policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
|
||||
args.icm_forward_loss_weight
|
||||
).to(args.device)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||
# when you have enough RAM
|
||||
buffer = VectorReplayBuffer(
|
||||
args.buffer_size,
|
||||
buffer_num=len(train_envs),
|
||||
ignore_obs_next=True,
|
||||
save_only_last_obs=True,
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
# log
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
args.algo_name = "discrete_sac_icm" if args.icm_lr_scale > 0 else "discrete_sac"
|
||||
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
|
||||
log_path = os.path.join(args.logdir, log_name)
|
||||
|
||||
# logger
|
||||
if args.logger == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
name=log_name.replace(os.path.sep, "__"),
|
||||
run_id=args.resume_id,
|
||||
config=args,
|
||||
project=args.wandb_project,
|
||||
)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
if args.logger == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else: # wandb
|
||||
logger.load(writer)
|
||||
|
||||
def save_best_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif "Pong" in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
torch.save({"model": policy.state_dict()}, ckpt_path)
|
||||
return ckpt_path
|
||||
|
||||
# watch agent's performance
|
||||
def watch():
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
print(f"Generate buffer with size {args.buffer_size}")
|
||||
buffer = VectorReplayBuffer(
|
||||
args.buffer_size,
|
||||
buffer_num=len(test_envs),
|
||||
ignore_obs_next=True,
|
||||
save_only_last_obs=True,
|
||||
stack_num=args.frames_stack,
|
||||
)
|
||||
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(
|
||||
n_episode=args.test_num, render=args.render
|
||||
)
|
||||
rew = result["rews"].mean()
|
||||
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.step_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
stop_fn=stop_fn,
|
||||
save_best_fn=save_best_fn,
|
||||
logger=logger,
|
||||
update_per_step=args.update_per_step,
|
||||
test_in_train=False,
|
||||
resume_from_log=args.resume_id is not None,
|
||||
save_checkpoint_fn=save_checkpoint_fn,
|
||||
)
|
||||
|
||||
pprint.pprint(result)
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_discrete_sac(get_args())
|
BIN
examples/atari/results/discrete_sac/Breakout_rew.png
Normal file
After Width: | Height: | Size: 228 KiB |
BIN
examples/atari/results/discrete_sac/Enduro_rew.png
Normal file
After Width: | Height: | Size: 194 KiB |
BIN
examples/atari/results/discrete_sac/MsPacman_rew.png
Normal file
After Width: | Height: | Size: 198 KiB |
BIN
examples/atari/results/discrete_sac/Pong_rew.png
Normal file
After Width: | Height: | Size: 134 KiB |
BIN
examples/atari/results/discrete_sac/Qbert_rew.png
Normal file
After Width: | Height: | Size: 261 KiB |
BIN
examples/atari/results/discrete_sac/Seaquest_rew.png
Normal file
After Width: | Height: | Size: 206 KiB |
BIN
examples/atari/results/discrete_sac/SpaceInvaders_rew.png
Normal file
After Width: | Height: | Size: 229 KiB |
@ -53,7 +53,7 @@ def test_discrete_sac(args=get_args()):
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 180} # lower the goal
|
||||
default_reward_threshold = {"CartPole-v0": 170} # lower the goal
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task, env.spec.reward_threshold
|
||||
)
|
||||
|
@ -80,7 +80,10 @@ class DiscreteSACPolicy(SACPolicy):
|
||||
obs = batch[input]
|
||||
logits, hidden = self.actor(obs, state=state, info=batch.info)
|
||||
dist = Categorical(logits=logits)
|
||||
act = dist.sample()
|
||||
if self._deterministic_eval and not self.training:
|
||||
act = logits.argmax(axis=-1)
|
||||
else:
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=hidden, dist=dist)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
|