2022-02-15 17:56:45 +03:00
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
import gym
|
|
|
|
import numpy as np
|
2022-02-25 07:40:33 +08:00
|
|
|
import pettingzoo.butterfly.pistonball_v6 as pistonball_v6
|
2022-02-15 17:56:45 +03:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.distributions import Independent, Normal
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
|
|
from tianshou.data import Collector, VectorReplayBuffer
|
|
|
|
from tianshou.env import DummyVectorEnv
|
|
|
|
from tianshou.env.pettingzoo_env import PettingZooEnv
|
|
|
|
from tianshou.policy import BasePolicy, MultiAgentPolicyManager, PPOPolicy
|
|
|
|
from tianshou.trainer import onpolicy_trainer
|
|
|
|
from tianshou.utils import TensorboardLogger
|
|
|
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
|
|
|
|
|
|
|
|
|
|
|
class DQN(nn.Module):
|
|
|
|
"""Reference: Human-level control through deep reinforcement learning.
|
|
|
|
|
|
|
|
For advanced usage (how to customize the network), please refer to
|
|
|
|
:ref:`build_the_network`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
c: int,
|
|
|
|
h: int,
|
|
|
|
w: int,
|
|
|
|
device: Union[str, int, torch.device] = "cpu",
|
|
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.device = device
|
|
|
|
self.c = c
|
|
|
|
self.h = h
|
|
|
|
self.w = w
|
|
|
|
self.net = nn.Sequential(
|
|
|
|
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
|
|
|
|
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
|
|
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
|
|
|
|
nn.Flatten()
|
|
|
|
)
|
|
|
|
with torch.no_grad():
|
|
|
|
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
x: Union[np.ndarray, torch.Tensor],
|
|
|
|
state: Optional[Any] = None,
|
|
|
|
info: Dict[str, Any] = {},
|
|
|
|
) -> Tuple[torch.Tensor, Any]:
|
|
|
|
r"""Mapping: x -> Q(x, \*)."""
|
|
|
|
x = torch.as_tensor(x, device=self.device, dtype=torch.float32)
|
|
|
|
return self.net(x.reshape(-1, self.c, self.w, self.h)), state
|
|
|
|
|
|
|
|
|
|
|
|
def get_parser() -> argparse.ArgumentParser:
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('--seed', type=int, default=1626)
|
|
|
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
|
|
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
|
|
|
parser.add_argument('--buffer-size', type=int, default=2000)
|
|
|
|
parser.add_argument('--lr', type=float, default=1e-3)
|
|
|
|
parser.add_argument(
|
|
|
|
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
'--n-pistons',
|
|
|
|
type=int,
|
|
|
|
default=3,
|
|
|
|
help='Number of pistons(agents) in the env'
|
|
|
|
)
|
|
|
|
parser.add_argument('--n-step', type=int, default=100)
|
|
|
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
|
|
|
parser.add_argument('--epoch', type=int, default=5)
|
|
|
|
parser.add_argument('--step-per-epoch', type=int, default=500)
|
|
|
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
|
|
|
parser.add_argument('--episode-per-collect', type=int, default=16)
|
|
|
|
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
|
|
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
2022-02-25 07:40:33 +08:00
|
|
|
parser.add_argument('--batch-size', type=int, default=32)
|
2022-02-15 17:56:45 +03:00
|
|
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
2022-02-25 07:40:33 +08:00
|
|
|
parser.add_argument('--training-num', type=int, default=10)
|
|
|
|
parser.add_argument('--test-num', type=int, default=10)
|
2022-02-15 17:56:45 +03:00
|
|
|
parser.add_argument('--logdir', type=str, default='log')
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
'--watch',
|
|
|
|
default=False,
|
|
|
|
action='store_true',
|
|
|
|
help='no training, '
|
|
|
|
'watch the play of pre-trained models'
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
)
|
|
|
|
# ppo special
|
|
|
|
parser.add_argument('--vf-coef', type=float, default=0.25)
|
|
|
|
parser.add_argument('--ent-coef', type=float, default=0.0)
|
|
|
|
parser.add_argument('--eps-clip', type=float, default=0.2)
|
|
|
|
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
|
|
|
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
|
|
|
parser.add_argument('--rew-norm', type=int, default=1)
|
|
|
|
parser.add_argument('--dual-clip', type=float, default=None)
|
|
|
|
parser.add_argument('--value-clip', type=int, default=1)
|
|
|
|
parser.add_argument('--norm-adv', type=int, default=1)
|
|
|
|
parser.add_argument('--recompute-adv', type=int, default=0)
|
|
|
|
parser.add_argument('--resume', action="store_true")
|
|
|
|
parser.add_argument("--save-interval", type=int, default=4)
|
|
|
|
parser.add_argument('--render', type=float, default=0.0)
|
|
|
|
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
|
|
|
def get_args() -> argparse.Namespace:
|
|
|
|
parser = get_parser()
|
|
|
|
return parser.parse_known_args()[0]
|
|
|
|
|
|
|
|
|
|
|
|
def get_env(args: argparse.Namespace = get_args()):
|
2022-02-25 07:40:33 +08:00
|
|
|
return PettingZooEnv(pistonball_v6.env(continuous=True, n_pistons=args.n_pistons))
|
2022-02-15 17:56:45 +03:00
|
|
|
|
|
|
|
|
|
|
|
def get_agents(
|
|
|
|
args: argparse.Namespace = get_args(),
|
|
|
|
agents: Optional[List[BasePolicy]] = None,
|
|
|
|
optims: Optional[List[torch.optim.Optimizer]] = None,
|
|
|
|
) -> Tuple[BasePolicy, List[torch.optim.Optimizer], List]:
|
|
|
|
env = get_env()
|
|
|
|
observation_space = env.observation_space['observation'] if isinstance(
|
|
|
|
env.observation_space, gym.spaces.Dict
|
|
|
|
) else env.observation_space
|
|
|
|
args.state_shape = observation_space.shape or observation_space.n
|
|
|
|
args.action_shape = env.action_space.shape or env.action_space.n
|
|
|
|
args.max_action = env.action_space.high[0]
|
|
|
|
|
|
|
|
if agents is None:
|
|
|
|
agents = []
|
|
|
|
optims = []
|
|
|
|
for _ in range(args.n_pistons):
|
|
|
|
# model
|
|
|
|
net = DQN(
|
|
|
|
observation_space.shape[2],
|
|
|
|
observation_space.shape[1],
|
|
|
|
observation_space.shape[0],
|
|
|
|
device=args.device
|
|
|
|
).to(args.device)
|
|
|
|
|
|
|
|
actor = ActorProb(
|
|
|
|
net, args.action_shape, max_action=args.max_action, device=args.device
|
|
|
|
).to(args.device)
|
|
|
|
net2 = DQN(
|
|
|
|
observation_space.shape[2],
|
|
|
|
observation_space.shape[1],
|
|
|
|
observation_space.shape[0],
|
|
|
|
device=args.device
|
|
|
|
).to(args.device)
|
|
|
|
critic = Critic(net2, device=args.device).to(args.device)
|
|
|
|
for m in set(actor.modules()).union(critic.modules()):
|
|
|
|
if isinstance(m, torch.nn.Linear):
|
|
|
|
torch.nn.init.orthogonal_(m.weight)
|
|
|
|
torch.nn.init.zeros_(m.bias)
|
|
|
|
optim = torch.optim.Adam(
|
|
|
|
set(actor.parameters()).union(critic.parameters()), lr=args.lr
|
|
|
|
)
|
|
|
|
|
|
|
|
def dist(*logits):
|
|
|
|
return Independent(Normal(*logits), 1)
|
|
|
|
|
|
|
|
agent = PPOPolicy(
|
|
|
|
actor,
|
|
|
|
critic,
|
|
|
|
optim,
|
|
|
|
dist,
|
|
|
|
discount_factor=args.gamma,
|
|
|
|
max_grad_norm=args.max_grad_norm,
|
|
|
|
eps_clip=args.eps_clip,
|
|
|
|
vf_coef=args.vf_coef,
|
|
|
|
ent_coef=args.ent_coef,
|
|
|
|
reward_normalization=args.rew_norm,
|
|
|
|
advantage_normalization=args.norm_adv,
|
|
|
|
recompute_advantage=args.recompute_adv,
|
|
|
|
# dual_clip=args.dual_clip,
|
|
|
|
# dual clip cause monotonically increasing log_std :)
|
|
|
|
value_clip=args.value_clip,
|
|
|
|
gae_lambda=args.gae_lambda,
|
|
|
|
action_space=env.action_space
|
|
|
|
)
|
|
|
|
|
|
|
|
agents.append(agent)
|
|
|
|
optims.append(optim)
|
|
|
|
|
|
|
|
policy = MultiAgentPolicyManager(
|
|
|
|
agents, env, action_scaling=True, action_bound_method='clip'
|
|
|
|
)
|
|
|
|
return policy, optims, env.agents
|
|
|
|
|
|
|
|
|
|
|
|
def train_agent(
|
|
|
|
args: argparse.Namespace = get_args(),
|
|
|
|
agents: Optional[List[BasePolicy]] = None,
|
|
|
|
optims: Optional[List[torch.optim.Optimizer]] = None,
|
|
|
|
) -> Tuple[dict, BasePolicy]:
|
|
|
|
train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)])
|
|
|
|
test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)])
|
|
|
|
# seed
|
|
|
|
np.random.seed(args.seed)
|
|
|
|
torch.manual_seed(args.seed)
|
|
|
|
train_envs.seed(args.seed)
|
|
|
|
test_envs.seed(args.seed)
|
|
|
|
|
|
|
|
policy, optim, agents = get_agents(args, agents=agents, optims=optims)
|
|
|
|
|
|
|
|
# collector
|
|
|
|
train_collector = Collector(
|
|
|
|
policy,
|
|
|
|
train_envs,
|
|
|
|
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
|
|
|
exploration_noise=False # True
|
|
|
|
)
|
|
|
|
test_collector = Collector(policy, test_envs)
|
|
|
|
# train_collector.collect(n_step=args.batch_size * args.training_num)
|
|
|
|
# log
|
|
|
|
log_path = os.path.join(args.logdir, 'pistonball', 'dqn')
|
|
|
|
writer = SummaryWriter(log_path)
|
|
|
|
writer.add_text("args", str(args))
|
|
|
|
logger = TensorboardLogger(writer)
|
|
|
|
|
2022-03-21 16:29:27 -04:00
|
|
|
def save_best_fn(policy):
|
2022-02-15 17:56:45 +03:00
|
|
|
pass
|
|
|
|
|
|
|
|
def stop_fn(mean_rewards):
|
|
|
|
return False
|
|
|
|
|
|
|
|
def train_fn(epoch, env_step):
|
|
|
|
[agent.set_eps(args.eps_train) for agent in policy.policies.values()]
|
|
|
|
|
|
|
|
def test_fn(epoch, env_step):
|
|
|
|
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
|
|
|
|
|
|
|
def reward_metric(rews):
|
|
|
|
return rews[:, 0]
|
|
|
|
|
|
|
|
# trainer
|
|
|
|
result = onpolicy_trainer(
|
|
|
|
policy,
|
|
|
|
train_collector,
|
|
|
|
test_collector,
|
|
|
|
args.epoch,
|
|
|
|
args.step_per_epoch,
|
|
|
|
args.repeat_per_collect,
|
|
|
|
args.test_num,
|
|
|
|
args.batch_size,
|
|
|
|
episode_per_collect=args.episode_per_collect,
|
|
|
|
stop_fn=stop_fn,
|
2022-03-21 16:29:27 -04:00
|
|
|
save_best_fn=save_best_fn,
|
2022-02-15 17:56:45 +03:00
|
|
|
logger=logger,
|
|
|
|
resume_from_log=args.resume
|
|
|
|
)
|
|
|
|
|
|
|
|
return result, policy
|
|
|
|
|
|
|
|
|
|
|
|
def watch(
|
|
|
|
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
|
|
|
) -> None:
|
2022-04-16 11:17:53 -04:00
|
|
|
env = DummyVectorEnv([get_env])
|
2022-02-15 17:56:45 +03:00
|
|
|
policy.eval()
|
|
|
|
collector = Collector(policy, env)
|
|
|
|
result = collector.collect(n_episode=1, render=args.render)
|
|
|
|
rews, lens = result["rews"], result["lens"]
|
|
|
|
print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}")
|